# Prerequisites

In [1]:
import sys

sys.path.append("..")

In [2]:
import argparse
import json
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from IPython.display import display
from torchvision.models import resnet18, resnet50
from tqdm.notebook import tqdm

from dataset import (
    get_dloader,
    normalize_hw,
    normalize_hw_mask,
    normalize_inv_hw,
    normalize_inv_hw_mask,
)
from perlin import get_rgb_fractal_noise
from util import eval_step, get_obj_score, get_performance, get_saliency

In [3]:
data_dir = "../data"
class_legend = ("Siberian Husky", "Grey Wolf")
model_types = {"r18": "ResNet 18", "r50": "Resnet 50"}

device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {device}")


def get_model(model_type, device="cpu", seed=191510):
    torch.manual_seed(seed)
    if model_type == "r18":
        model = resnet18(weights=None)
    elif model_type == "r50":
        model = resnet50(weights=None)
    model.fc = torch.nn.Linear(
        in_features=model.fc.in_features, out_features=len(class_legend), bias=True
    )
    model.to(device)
    return model

Using device: cuda


# Evaluation

In [4]:
exp_dir = "../models/hw-final"
split = "test"

norms = {"a": normalize_hw, "b": normalize_hw_mask, "c": normalize_hw_mask}
dloader = get_dloader(
    split, batch_size=1, data_dir="../data", noise=True, num_workers=0
)
dloader_spurious = get_dloader(
    split, batch_size=1, data_dir="../data", noise=True, spurious=True, num_workers=0
)

for exp in os.listdir(exp_dir):
    exp_runs = f"{exp_dir}/{exp}"
    for run in os.listdir(exp_runs):
        run = f"{exp_runs}/{run}"
        with open(f"{run}/config.json", "r") as f:
            config = json.load(f)
        spurious = config.get("spurious", False)
        model_type = config.get("model_type")
        for m in "abc":
            m_dir = f"{run}/{m}"
            cpts = os.listdir(m_dir)
            if len(cpts) == 0:
                continue
            
            dst = f"{run}/eval_{m}.csv"
            dst_spurious = f"{run}/eval_spurious_{m}.csv"
            
            last_cpt = None
            max_epoch = -1
            for cpt in cpts:
                epoch = int(cpt.split("_")[1].split(".")[0][1:])
                if epoch > max_epoch:
                    last_cpt = cpt
                    max_epoch = epoch

            model = get_model(model_type, device=device)
            checkpoint = torch.load(f"{m_dir}/{last_cpt}")
            to_delete = []
            for k in checkpoint["model_state_dict"]:
                if "feature_extractor" in k:
                    to_delete.append(k)
            for k in to_delete:
                del checkpoint["model_state_dict"][k]
            model.load_state_dict(checkpoint["model_state_dict"])
            model.eval()
            del checkpoint
            
            norm = norms[m]
            dloaders = [dloader]
            dsts = [dst]
            if spurious:
                dloaders.append(dloader_spurious)
                dsts.append(dst_spurious)
            for dloader_, dst_ in zip(dloaders, dsts):
                if os.path.exists(dst_):
                    continue
                preds = {"a": [], "b": [], "c": []}
                obj_scores = {"a": [], "b": [], "c": []}
                gt_labels = []
                for imgs, labels, masks, noise in tqdm(dloader_):
                    inputs = {
                        "a": norm(imgs),
                        "b": norm(imgs) * masks,
                        "c": norm(imgs * masks + noise * (~masks)),
                    }
    
                    for k, imgs_in in inputs.items():
                        if k != "a" and spurious:
                            continue
                        slc, _, indices, _ = get_saliency(
                            model, imgs_in, device=device
                        )
                        obj_scores[k].append(get_obj_score(slc, masks))
                        preds[k].append(indices.item())
                    gt_labels.append(labels.item())
                gt_labels = np.array(gt_labels)
    
                results = {}
                for k, p in preds.items():
                    if len(p) == 0:
                        continue
                    p = np.array(p)
                    acc = (p == gt_labels).mean()
                    obj_score = np.mean(obj_scores[k])
                    results[k] = {"Accuracy": acc, "Object Score": obj_score}
                results = pd.DataFrame.from_dict(results, orient="index")
                results.to_csv(dst_)

  0%|          | 0/629 [00:00<?, ?it/s]

  0%|          | 0/629 [00:00<?, ?it/s]

  0%|          | 0/629 [00:00<?, ?it/s]