<a href="https://colab.research.google.com/github/HAVIGILI/OOD-Co-Act/blob/main/Notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Open this in colab. Restart kernel after numpy installation and the after _distutils_hack installation.

%pip uninstall -y accelerate

%pip install numpy==1.24.3
%pip install --quiet gdown

# Install OpenMIM tool
%pip install -U openmim

# Install a compatible PyTorch (for example, 2.0.0+cu118)
%pip install torch==2.0.0+cu118 torchvision==0.15.0+cu118 -f https://download.pytorch.org/whl/cu118/torch_stable.html

%pip install torchmetrics

# Then install mmcv-full pre-built for that combination:
!mim install "mmcv-full==1.7.2" -f https://download.openmmlab.com/mmcv/dist/cu118/torch2.0.0/index.html

In [None]:
import sys

!git clone https://github.com/HAVIGILI/OOD-Co-Act.git
sys.path.append("/content/OOD-Co-Act")

!git clone -b 0.x --depth 1 https://github.com/open-mmlab/mmsegmentation.git
%cd mmsegmentation
!pip install -e .

In [None]:
import os
import torch
from mmseg.apis import init_segmentor, inference_segmentor
import mmcv
from IPython.display import Image, display
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from typing import List, Tuple
from calibrator import Calibrator
from anomalydetector import AnomalyDetector

In [None]:
# Fishyscapes importet below. Keep this commented out

# RoadAnomaly21:
# !unzip -q /content/dataset_RoadAnomalyTrack.zip -d RoadAnomaly
# !unzip -q /content/gtFine_trainvaltest.zip -d gtFine_trainvaltest
# !wget -q http://wwwlehre.dhbw-stuttgart.de/~sgehrig/lostAndFoundDataset/leftImg8bit.zip
# !unzip -q leftImg8bit.zip -d /content/dataset_root/
# !wget -q http://robotics.ethz.ch/~asl-datasets/Fishyscapes/fishyscapes_lostandfound.zip
# !unzip -q fishyscapes_lostandfound.zip -d /content/dataset_root/

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# To get fishyscapes ground truth and pictures.

LEFTZIP_ID = "1BihdEUYOCRKpLfo8VvajTbuUtV3RyFtH"
FISHYZIP_ID = "1aR_qSjuykKWuKvM1dayecfb37cNYfZC3"

# download both
!gdown https://drive.google.com/uc?id=$LEFTZIP_ID -O leftImg8bit.zip
!gdown https://drive.google.com/uc?id=$FISHYZIP_ID -O fishyscapes.zip

# unzip into the exact paths your code expects
!unzip -q leftImg8bit.zip -d /content/drive/MyDrive/
!unzip -q fishyscapes.zip -d /content/drive/MyDrive/

images_root      = "/content/drive/MyDrive/leftImg8bit"
annotations_root = "/content/drive/MyDrive/fishyscapes_lostandfound"

In [None]:
# Matching images with ground truth for fishyscapes.

import os

matched_annotations = []
matched_images = []
annotated_images = [f[5:].replace("_labels.png", "_leftImg8bit.png") for f in os.listdir(annotations_root)]
ood_annotations_paths = [os.path.join(annotations_root, f) for f in os.listdir(annotations_root)]

# Walk through the image directory
for dir_path, dir_names, filenames in os.walk(images_root):
        for filename in filenames:
            if filename in annotated_images:
                index = annotated_images.index(filename)
                annotation_path = ood_annotations_paths[index]
                image_path = os.path.join(dir_path, filename)
                matched_annotations.append(annotation_path)
                matched_images.append(image_path)
print("Matched pairs:")
for annotation, image in zip(matched_annotations, matched_images):
    print("annotation", annotation, "matched with", image)


In [None]:
# Initialize the segmentation model (DeepLabV3+ pretrained on Cityscapes)
config_file = 'configs/deeplabv3plus/deeplabv3plus_r101-d8_512x1024_40k_cityscapes.py'
checkpoint_file = ('https://download.openmmlab.com/mmsegmentation/v0.5/'
                   'deeplabv3plus/deeplabv3plus_r101-d8_512x1024_40k_cityscapes/'
                   'deeplabv3plus_r101-d8_512x1024_40k_cityscapes_20200605_094614-3769eecf.pth')
model = init_segmentor(config_file, checkpoint_file, device='cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:

# --------------------------------------------------------------------
# 1.  Collect training image / GT paths
# --------------------------------------------------------------------
def cityscapes_paths(img_root: str, gt_root: str) -> Tuple[List[str], List[str]]:
    images, gts = [], []
    for city in sorted(os.listdir(img_root)):
        d_img, d_gt = os.path.join(img_root, city), os.path.join(gt_root, city)
        if not os.path.isdir(d_img):
            continue
        for fn in sorted(os.listdir(d_img)):
            if not fn.endswith('_leftImg8bit.png'):
                continue
            images.append(os.path.join(d_img, fn))
            label_fn = fn.replace('_leftImg8bit.png', '_gtFine_labelIds.png')
            gts.append(os.path.join(d_gt, label_fn))
    return images, gts


img_root = '/content/drive/MyDrive/leftImg8bit_trainvaltest/leftImg8bit/train'
gt_root  = '/content/drive/MyDrive/gtFine_trainvaltest/gtFine/train'
images, gts = cityscapes_paths(img_root, gt_root)
print(f'Found {len(images)} images')

# --------------------------------------------------------------------
# 2.  Bring up the MMSeg model  (one line)
# --------------------------------------------------------------------
CFG  = 'configs/deeplabv3plus/deeplabv3plus_r101-d8_512x1024_40k_cityscapes.py'
CKPT = ('https://download.openmmlab.com/mmsegmentation/v0.5/'
        'deeplabv3plus/deeplabv3plus_r101-d8_512x1024_40k_cityscapes/'
        'deeplabv3plus_r101-d8_512x1024_40k_cityscapes_20200605_094614-3769eecf.pth')

DEV = 'cuda:0' if torch.cuda.is_available() else 'cpu'
mmseg_model = init_segmentor(CFG, CKPT, device=DEV)   # <-- raw MMSeg model

# --------------------------------------------------------------------
# 3.  Calibrate (or reload previous statistics)
# --------------------------------------------------------------------
SAVE = '/content/drive/MyDrive/calibrator_state_ground_truth.pt'

if os.path.exists(SAVE):
    # one-liner restore: builds a *fresh* model internally
    calib = Calibrator.load(SAVE, mmseg_cfg=CFG, mmseg_ckpt=CKPT, device=DEV)
else:
    calib = Calibrator(mmseg_model)    # hand the model in
    calib.run(images, gts)             # long pass over the dataset
    calib.save(SAVE)


In [None]:
# ╔════════════════════════════════════════╗
# ║  Grid-search for LINe + Co-activation  ║
# ╚════════════════════════════════════════╝

import itertools
import gc

start, end = 0, 100     # Max 100
matched_images_short      = matched_images[start:end]
matched_annotations_short = matched_annotations[start:end]
images_to_be_plotted      = []  # show figures only for the first image if 0, can use 3:7 to plot image 4-8 for example. Empty for no figures

# Hyper-parameter lists (edit to sweep more values)
activation_clippings            = [100]
activation_prunings             = [0]
weight_prunings                 = [0]
temperatures_line               = [1]

inverse_convert_to_ones         = [False, True]
binary_or_not                   = [False, True]
wt_thresholds                   = [0, 0]
wt_u_thresholds                 = [1, 0.1]
temperatures_co                 = [0.3, 1]

baseline_and_line_blur_ksizes   = [(3, 3)]
baseline_and_line_sigmas        = [0]

co_blur_sizes                   = [23]
co_sigmas                       = [5]
max_pools                       = [11]

clips                           = [(0, 1000000), (250000, 10000000), (0, 0)]
co_weighteds                    = [0]
weight_modes                    = ['softmax']
use_c_u_ratios                  = [False, True]

plot_pixel_acts   = True
plot_pair_ratios  = True
id_dot, ood_dot   = (850, 1900, 'lime'), (485, 1360, 'red') # Place these manually by choosing coordinates on a ood object and an id object.

state = calib.state_dict()


grid = {
    "temp_co":                temperatures_co,
    "activation_clipping":    activation_clippings,
    "activation_pruning":     activation_prunings,
    "weight_pruning":         weight_prunings,
    "wt_pair":                list(zip(wt_thresholds, wt_u_thresholds)),
    "co_blur_pair":           list(zip(co_blur_sizes, co_sigmas)),
    "clip_val":               clips,
    "temp_line":              temperatures_line,
    "co_weighted":            co_weighteds,
    "inv2ones":               inverse_convert_to_ones,
    "bin_or_not":             binary_or_not,
    "baseline_pair":          list(zip(baseline_and_line_blur_ksizes, baseline_and_line_sigmas)),
    "max_pool":               max_pools,
    "weight_mode":            weight_modes,
    "c_u_ratio":              use_c_u_ratios,
}

param_names, param_values = zip(*grid.items())

# 4. Tracking best metrics
best_ap    =  -1
best_fpr   =   1
best_auc   =  -1
best_tag_ap  = best_tag_fpr = best_tag_auc = None

# 5. Iterate over the Cartesian product of all hyper-params
for combo in itertools.product(*param_values):
    P = dict(zip(param_names, combo))

    # unpack zipped pairs
    wt_thr,  wt_thr_u      = P["wt_pair"]
    co_blur_ksize, co_sigma = P["co_blur_pair"]
    baseline_blur, baseline_sigma = P["baseline_pair"]

    # create & configure detector
    det = AnomalyDetector(matched_images_short,
                          matched_annotations_short,
                          model,
                          state)
    det.plot_pixel_activations = plot_pixel_acts
    det.plot_pair_ratios       = plot_pair_ratios
    det.images_to_be_plotted   = images_to_be_plotted
    det.id_dot, det.ood_dot    = id_dot, ood_dot

    # set thresholds & params
    det.activation_clipping     = P["activation_clipping"]
    det.activation_pruning      = P["activation_pruning"]
    det.weight_pruning          = P["weight_pruning"]
    det.temperature_co          = P["temp_co"]
    det.temperature_line        = P["temp_line"]
    det.inverse_convert_to_ones = P["inv2ones"]
    det.make_feat_binary        = P["bin_or_not"]
    det.wt_threshold            = wt_thr
    det.wt_u_threshold          = wt_thr_u
    det.co_blur_ksize           = co_blur_ksize
    det.co_sigma                = co_sigma
    det.blur_ksize              = baseline_blur
    det.sigma                   = baseline_sigma
    det.max_pool                = P["max_pool"]
    det.clips                   = P["clip_val"]
    det.co_weighted             = P["co_weighted"]
    det.weight_mode             = P["weight_mode"]
    det.c_u_ratio               = P["c_u_ratio"]

    # run inference
    det.ood_inference()
    IDs, ood_scores = det.get_ood_score_lists()

    # evaluate selected methods
    methods = {
        "ood_scores_coact_bin":       "Co-act count",
        "ood_scores_coact_wt":        "Co-act magnitude",
        "ood_scores_coact_any":       "Co-act any",
        "ood_scores_nr_activation":   "#activations",
    }

    # build a reusable run tag
    base_tag = " | ".join(f"{k}={P[k]}" for k in param_names)

    scores_line = ood_scores["ood_scores_line"]
    for method, title in methods.items():
        # scores = -ood_scores[method] / scores_line / scores_line / scores_line
        scores = (scores - scores.min()) / (scores.ptp() + 1e-9)

        run_tag = f"method | ={method}{base_tag}"
        print(run_tag)

        auc, fpr, ap = det.calculate_metrics(IDs, scores,
                                             plot_hist=False,
                                             title=title)

        if ap  > best_ap:  best_ap,  best_tag_ap  = ap,  run_tag
        if fpr < best_fpr: best_fpr, best_tag_fpr = fpr, run_tag
        if auc  > best_auc: best_auc, best_tag_auc = auc, run_tag

    # cleanup GPU memory
    del det
    gc.collect()
    torch.cuda.empty_cache()

# 6. Print summary of best combinations
print("\n══════ BEST RESULTS ══════")
print(f"• Best AP    = {best_ap:.4f}  → {best_tag_ap}")
print(f"• Best FPR95 = {best_fpr:.4f}  → {best_tag_fpr}")
print(f"• Best AUROC = {best_auc:.4f}  → {best_tag_auc}")


In [None]:
# Set true and run if you need to free memory

delete = True
if delete:
    del det
    import gc
    gc.collect()
    import torch
    torch.cuda.empty_cache()