# Script to evaluate the finetuned models (`finetune.py`) on OOD benchmarks

In [None]:
cd ~/plp-official-tmlr2024/

In [None]:
import os
from pathlib import Path
import torch
import torch.nn as nn
import json
from model_builders import load_backbone,get_embed_dim
from loaders.datasets import get_ood, get_num_classes
from finetune import get_args_parser,build_transform,get_head
from eval import *
from eval.utils import *
import glob 
from eval_args import *

# Provide you paths to the checkpoint of the `finetune.py` experiments
# If you want to evauate just a single experiments pass the pth path to the `paths` as a list
base_path=Path("base_exp_path") # to eval all finetune experiments together
paths = list(base_path.rglob("/*/checkpoint_teacher*.pth"))
results = []

In [None]:
for path in paths:
        print(path)
        with open(path.parent / 'hp.json', 'r') as f:
            hp = json.load(f)
        args = get_args_parser().parse_args("")
        args.arch = hp['arch']
        args.dataset = hp['dataset']
        args.out_dist = get_ood(args.dataset)[0]
        args.head=False
        args.precomputed = False
        args.train_backbone = True
        args.batch_size = 64
        
        exp_params = {
            'dataset': args.dataset,
            'out_dist': args.out_dist,
            'arch': args.arch,
            'freeze_percent': hp['freeze_blocks']
        }
        
        args.nb_classes = get_num_classes(args.dataset)

        model_ckpt = torch.load(path)
        model, _ = load_backbone(args.arch)
        transform = build_transform(False, args)
        embed_dim = get_embed_dim(args=None, model=model)
        head = get_head(embed_dim, args.nb_classes, args.init_scale)
        model = nn.Sequential(model, head) 
        msg = model.load_state_dict(model_ckpt["model_ema"], strict=True) # "model" or "model_ema"

        args_eval = get_eval_args_parser().parse_args("")
        args_eval.eval_ood_knn = True
        args_eval.eval_ood_maha = True
        args_eval.eval_ood_norm = True
        args_eval.eval_ood_logits = True
        args_eval.dataset, args_eval.out_dist = args.dataset, args.out_dist
        
        dict1 = vars(args)
        merged_dict = {**dict1}
        merged_namespace = argparse.Namespace(**merged_dict)
        args = merged_namespace
        model.cuda()
        epoch = model_ckpt["epoch"]

        extractor = FeatureExtractionPipeline(args, cache_backbone=False, model=model, transform=transform)
        train_features, test_features, train_labels, test_labels_indist = extractor.get_train_logits(return_feats=True)
        
        ood_dl = extractor.get_dataloader(args.out_dist, train=False)
        test_features_ood, ood_labels = extractor.get_logits(ood_dl)
        
        res_dict_ood_ckpt = eval_ood(args_eval, epoch, test_features, test_features_ood, train_features, train_labels)
        res = flatten_result(res_dict_ood_ckpt)
        res.update(exp_params)
        results.append(res)
        #model.cpu()
        train_features, test_features, test_features_ood = None, None, None
        torch.cuda.empty_cache()
df = pd.DataFrame(results)
df.to_csv(base_path / f"results.csv")

In [None]:
df