## Visualize and compare models on ID/OOD Perforamnce

In this notebook, you can compare the differences between models and training and architecture paradigms. Given that the results for a model run such as a checkpoint and counts file are saved, different model configurations can be compared through graphs and visualizations of score distributions. Below, we provide code to compare a learned temperature model scored with AbeT with a normal model scored with Standardized Max Logit, Max Logit, Entropy, and Max Softmax Probability.

In [None]:
%%capture
!pip install h5py Cython

In [None]:
from config import config_evaluation_setup
from src.imageaugmentations import Compose, Normalize, ToTensor
from evaluation import eval_pixels
from viz_utils import twod_to_threed, get_crops, plot_curve, plot_curve_comparisons, plot_score_graphs
from src.helper import counts_array_to_data_list
from src.model_utils import inference
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import json

In [None]:
##### ROOT ARGS
args = {
    'TRAINSET': None,
    'VALSET': 'LostAndFound',
    # 'VALSET': 'RoadAnomaly',
    'split': 'test',
    'MODEL': None,
    'val_epoch': None,
    'pareto_alpha': None,
    'pixel_eval': True,
    'segment_eval': False,
    'temperature_model': "none",
    'checkpoint': "REQUIRE_CHECKPONT.PTH",
    'score_function': 'entropy',
    'ood_finetune': 'FALSE',
    'name': None,
}

comparisons = []


#########################
# TEMPERATURE MODEL
#########################
these_args = args.copy()
these_args.update({
    "temperature_model": "learned",
    "checkpoint": " /your/path/to/Abet/weights/full_learned_temp_best.pth",
    "score_function": "abet",
    "name": "AbeT (Ours)"
})
comparisons.append(these_args)


###################################
# Comparison models and scoring methods: cityscapes_best.pth from https://github.com/NVIDIA/semantic-segmentation/tree/sdcnet
###################################
these_args = args.copy()
these_args.update({
    "checkpoint": " /your/path/to/Abet/weights/cityscapes_best.pth",
    "score_function": "sml",
    "temperature_model": "none",
    "name": "Standardized\nMax Logits"
})
comparisons.append(these_args)

these_args = args.copy()
these_args.update({
    "checkpoint": " /your/path/to/Abet/weights/cityscapes_best.pth",
    "score_function": "max_logit",
    "temperature_model": "none",
    "name": "Max Logit"
})
comparisons.append(these_args)

these_args = args.copy()
these_args.update({
    "checkpoint": " /your/path/to/Abet/weights/cityscapes_best.pth",
    "score_function": "entropy",
    "temperature_model": "none",
    "name": "Entropy"
})
comparisons.append(these_args)

these_args = args.copy()
these_args.update({
    "checkpoint": " /your/path/to/Abet/weights/cityscapes_best.pth",
    "score_function": "msp",
    "temperature_model": "none",
    "name": "MSP"
})
comparisons.append(these_args)




print(comparisons)


In [None]:
class EvaluationRun:
    def __init__(self, name, config, args, datloader, run_name_str) -> None:
        self.name = name
        self.config = config
        self.args = args
        self.run_name_str = run_name_str
        self.datloader = datloader
        
        self.eval_pixels = eval_pixels(
            self.config.params,
            self.config.roots,
            self.config.dataset,
            self.args,
            run_name_str=self.run_name_str,
        ) # params, roots, dataset, args, run_name_str
        self.inference = inference(self.config.params, self.config.roots, self.datloader, self.datloader.num_eval_classes, self.run_name_str)
        self.root_path = self.eval_pixels.save_path_data
        self.scores_dict = json.load(open(os.path.join(self.root_path, "data.json"), "r"))        
        if "threshold" not in self.scores_dict:
            ind = (np.abs(np.array(self.scores_dict["roc_tpr"]) - 0.95)).argmin()
            self.scores_dict["threshold"] = self.scores_dict["roc_thresholds"][ind]
        self.threshold = self.scores_dict["threshold"]
            

In [None]:
comparison_eval_runs = dict()

for comparison_args in comparisons:
    print("=" * 50)
    config = config_evaluation_setup(comparison_args)
    config.params.temperature_model = comparison_args["temperature_model"]
    config.params.checkpoint = comparison_args["checkpoint"]
    run_name_str = f"valset_{comparison_args['VALSET']}_" + \
                    f"split_{comparison_args['split']}_" + \
                    f"{comparison_args['temperature_model']}_temperature_model_" + \
                    f"{comparison_args['score_function']}_" + \
                    f"OODFT_{comparison_args['ood_finetune']}" + \
                    f"_ckpt_{config.params.checkpoint.split('/')[-1][:-4]}"
                    # (f"_ckpt_{config.params.checkpoint.split('/')[-1][:-4]}" if "learned" in comparison_args['temperature_model'] else "")
    print(run_name_str)
                

    transform = Compose([ToTensor(), Normalize(config.dataset.mean, config.dataset.std)])
    datloader = config.dataset(
        root=config.roots.eval_dataset_root, transform=transform, split=comparison_args["split"]
    )

    # ASSERT SCORING FUNCTION WORKS WITH MODEL TYPE
    if comparison_args["score_function"] in ["energy", "godin"]:
        assert "learned" == comparison_args["temperature_model"]
    eval_run = EvaluationRun(comparison_args['name'], config, comparison_args, datloader, run_name_str)
    comparison_eval_runs[eval_run.name] = eval_run
    

To reproduce our figure in the AbeT paper, evaluate the results on indices 80 and 1117 in LostAndFound and 3 and 23 in RoadAnomaly.

In [None]:
from collections import defaultdict

eval_inds = np.arange(4)
print(eval_inds)

im_paths = []
name_to_scores = defaultdict(list)
name_to_labels = defaultdict(list)

for i, (name, eval_run) in enumerate(comparison_eval_runs.items()):
    for eval_ind in eval_inds:
        outputs, gt_train, gt_label, im_path = eval_run.inference.probs_gt_load(eval_ind)
        if np.sum(gt_train == datloader.train_id_out) < 50:
            continue
        if isinstance(outputs, tuple):
            logits, numerators, temperatures = outputs
        else:
            logits = outputs
            nummerators, temperatures = None, None
        scores = eval_run.eval_pixels.score_fn(logits, numerators, temperatures, datloader.num_eval_classes)
        name_to_scores[name].append(scores)
        name_to_labels[name].append(gt_train)
        
        if i == 0:
            im_paths.append(im_path)

In [None]:
def get_crops(image, train_labels, scores, ood_ind=2, border=50):
    IMW, IMH = train_labels.shape
    xx, yy = np.where(train_labels == ood_ind)
    minx = max(0, np.min(xx) -border)
    maxx = min(np.max(xx) + border, IMW)
    miny = max(0, np.min(yy) - border)
    maxy = min(np.max(yy) + border, IMH)
    
    # square crops
    # if maxx - minx > maxy - miny:
    #     half_dist = (maxx - minx) / 2.0
    #     middle_y = (maxy + miny) / 2.0
    #     maxy = min(int(middle_y + half_dist), IMH)
    #     miny = max(int(middle_y - half_dist), 0)
    # else:
    #     half_dist = (maxy - miny) / 2.0
    #     middle_x = (maxx + minx) / 2.0
    #     maxx = min(int(middle_x + half_dist), IMW)
    #     minx = max(int(middle_x - half_dist), 0)
    
    # need rectangular crops 1.777777 for LAF to match RA
    x_width = maxx - minx
    y_new_width = x_width / 2.0 * 1.7777777
    y_center = (maxy + miny) / 2.0
    maxy = int(y_center + y_new_width)
    miny = int(y_center - y_new_width)
    return image[minx: maxx, miny:maxy, :], train_labels[minx:maxx, miny:maxy], scores[minx:maxx, miny:maxy]

def get_rgb_norm_scores(scores, train_labels, in_id, out_id, thresh, bottom_clip=0, top_clip=1):
    rgb_scores = twod_to_threed(np.zeros_like(scores).astype(float))
    scores = np.clip(scores - thresh, 0, np.max(scores))
    norm_scores = (scores - np.min(scores)) / (np.max(scores) - np.min(scores))
    rgb_scores[:, :, 0] = norm_scores
    # rgb_scores = twod_to_threed(norm_scores) # for white masks
    rgb_scores[(train_labels != in_id) &(train_labels != out_id)] = [0,0,0]  # set VOID to black
    
    return rgb_scores

def get_rgb_labels(train_labels, in_id, out_id):
    rgb_labels = twod_to_threed(np.zeros_like(train_labels, dtype=np.uint8))
    rgb_labels[train_labels == out_id] =[255,0,0] 
    # rgb_labels[train_labels == out_id] = [255, 255, 255] # white masks
    return rgb_labels
    

In [None]:
from sklearn.metrics import average_precision_score, precision_recall_curve
from src.calc import calc_precision_recall, calc_sensitivity_specificity, get_tpr95_ind
from src.helper import counts_array_to_data_list

def run_eval(scores, labels, thresh, datloader):
    data = {"in": np.zeros(100), "out": np.zeros(100)}
    bins = np.linspace(0, 1, 101)
    
    in_scores = scores[labels == datloader.train_id_in]
    out_scores = scores[labels == datloader.train_id_out]

    in_mean, in_std = np.mean(in_scores), np.std(in_scores)
    data["in"] += np.histogram(in_scores, bins=bins, density=False)[0]

    if len(out_scores) > 0:
        out_mean, out_std = np.mean(out_scores), np.std(out_scores)
        data["out"] += np.histogram(out_scores, bins=bins, density=False)[0]
        roc_fpr, roc_tpr, roc_thresholds, auroc = calc_sensitivity_specificity(data, balance=True)
        roc_stats = dict(fprs=roc_fpr, tprs=roc_tpr, thresholds=roc_thresholds, auroc=auroc)
        pr_precision, pr_recall, pr_thresholds, auprc = calc_precision_recall(data)
        ind, tpr95, fpr95 = get_tpr95_ind(roc_fpr, roc_tpr)
        pr_stats = dict(precisions=pr_precision, recalls=pr_recall, thresholds=pr_thresholds, auprc=auprc)
    else:
        roc_stats, pr_stats = None, None, None
    return roc_stats, pr_stats, fpr95, data

In [None]:
def counts_data_to_lists(data):
    ratio_in = np.sum(data["in"]) / (np.sum(data["in"]) + np.sum(data["out"]))
    ratio_out = 1 - ratio_in
    x1 = counts_array_to_data_list(np.array(data["in"]), 1e7 * ratio_in)
    x2 = counts_array_to_data_list(np.array(data["out"]), 1e7 * ratio_out)
    return x1, x2


def plot_counts_histogram(data, ax, num_bins, thresh):
    bins = np.linspace(0, 100, num_bins + 1)
    x1, x2 = counts_data_to_lists(data)
    ax.hist(
        [x1, x2],
        bins,
        label=["ID", "OOD"],
        weights=[np.ones(len(x1)) / len(x1), np.ones(len(x2)) / len(x2)],
    )
    if thresh is not None:
        ax.axvline(thresh * 100, color="black")
    ax.legend()

In [None]:
use_crop = "F" in comparisons[0]['VALSET']
show_hist = False # show histogram of scores
show_captions = False # print the scores for each image and model
f, ax = plt.subplots(len(eval_inds) + show_hist, len(name_to_scores) + 1 + show_hist, figsize=(12, 8))

# cols
for j, name in enumerate(name_to_scores.keys()):
    scores_list = name_to_scores[name]
    labels_list = name_to_labels[name] 
         
    # rows
    for i in range(len(scores_list)):
        print(i)
        scores = scores_list[i]
        labels = labels_list[i]
        im_path = im_paths[i]
        im = np.array(Image.open(im_path).convert("RGB")) # [0,255]
        thresh = comparison_eval_runs[name].threshold
        rgb_norm_scores = get_rgb_norm_scores(scores, labels, in_id=datloader.train_id_in, out_id=datloader.train_id_out, thresh=thresh, bottom_clip=0, top_clip=np.max(scores)) 
        rgb_labels = get_rgb_labels(labels, in_id=datloader.train_id_in, out_id=datloader.train_id_out)
        # full_roc_stats, full_pr_stats, full_fpr95, full_data = run_eval(scores, labels, thresh, datloader)
        
        if use_crop:
            cropped_im, cropped_labels, cropped_scores = get_crops(im, labels, scores, ood_ind=datloader.train_id_out, border=40)
            cropped_rgb_norm_scores = get_rgb_norm_scores(cropped_scores, cropped_labels,in_id=datloader.train_id_in, out_id=datloader.train_id_out, thresh=thresh, bottom_clip=0, top_clip=np.max(scores)) 
            cropped_rgb_labels = get_rgb_labels(cropped_labels, in_id=datloader.train_id_in, out_id=datloader.train_id_out)
        
        ax[i, j + 1+ show_hist].imshow(cropped_rgb_norm_scores if use_crop else rgb_norm_scores)
        # if show_captions:
            # cropped_caption = f"\ncropped AUROC {cropped_roc_stats['auroc']:.4f}\ncropped AUPRC {cropped_pr_stats['auprc']:.4f}\nfpr95 {cropped_fpr95:.4f}"
            # full_caption =  f"AUROC {full_roc_stats['auroc']:.4f}\nAUPRC {full_pr_stats['auprc']:.4f}\nfpr95 {full_fpr95:.4f}"
            # ax[i, j+1 + show_hist].set_xlabel(full_caption)
        
        if j == 0: 
            if not use_crop:
                cropped_im = im
                cropped_labels = labels
            masked_im = cropped_im.copy() / 255.
            ood_delta = np.where(cropped_labels == datloader.train_id_out, 0.2, 0.0)

            masked_im[:, :, 0] += ood_delta
            masked_im[:, :, 1] -= ood_delta
            masked_im[:, :, 2] -= ood_delta
            
            masked_im[(cropped_labels != datloader.train_id_in) &(cropped_labels != datloader.train_id_out)] = [0,0,0]  
            masked_im = np.clip(masked_im, 0, 1)
            ax[i, 0].imshow(masked_im)
            if show_hist:
                cropped_roc_stats, cropped_pr_stats, cropped_fpr95, cropped_data = run_eval(cropped_scores, cropped_labels, thresh, datloader)
                plot_counts_histogram(cropped_data, ax[i, 1], 100, thresh)
            
       
for i, row_ax in enumerate(ax):
    for j, row_col_ax in enumerate(row_ax):
        if j == 1 and show_hist:
            continue
        row_col_ax.tick_params(axis='both', left=False, bottom=False, labelleft=False, labelbottom=False)
        for loc, spine in row_col_ax.spines.items():
            spine.set_color("white")

plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
plt.margins(0,0)
ax[-1, 0].set_xlabel(f"Image with \nOOD Label", fontweight="bold", fontsize="x-large")
if show_hist:
    ax[-1, 1].set_xlabel(f"counts hist")
for j, (name, eval_run) in enumerate(comparison_eval_runs.items()):
    ax[-1, j + 1 + show_hist].set_xlabel(f"{eval_run.name}", fontweight="bold", fontsize="x-large")

all_names_str = "_".join(name_to_scores.keys()).replace(" ", "_").replace('\n', '')
dataset_str = f"{comparisons[0]['VALSET']}_{comparisons[0]['split']}"
save_dir = f"/your/path/to/Abet/io/viz_plots/"
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, f"{all_names_str}_{dataset_str}_matrix_viz.png")
plt.savefig(save_path, bbox_inches="tight")
print(f"saved to {save_path}")
plt.show()

    
