In [None]:
# necessary imports
import torch
import pickle
from openood.evaluation_api import Evaluator
from openood.networks import ResNet18_32x32
from openood.networks.lenet import LeNet # just a wrapper around the ResNet

from collections import OrderedDict

In [None]:
# 파일 경로
file_path = ".pth.tar"

# 모델 정의
net = ResNet18_32x32(num_classes=10)

# 저장된 모델 로드
checkpoint = torch.load(file_path, map_location=torch.device("cuda"))  # CPU에서도 로드 가능하도록 설정

# 'module.' 키 제거 (DataParallel에서 저장된 경우)
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    new_key = k.replace("module.", "")  # 'module.' 접두사 제거
    new_state_dict[new_key] = v

# 모델 가중치 로드
net.load_state_dict(new_state_dict)

# 모델을 GPU로 이동
net.cuda()
net.eval()

In [None]:
OOD_list = ['msp', 'odin','ebo','gradnorm','react','mls','klm','vim','knn','dice','rankfeat','ash','she','mds','rmds','gram','mds_ensemble','temp_scaling','openmax']


res = []
for method in OOD_list:
    print(f"-------------------{method}----------------------")
    postprocessor_name = method
    evaluator = Evaluator(
        net,
        id_name='cifar10',                     # the target ID dataset
        data_root='./data',                    # change if necessary
        config_root=None,                      # see notes above
        preprocessor=None,                     # default preprocessing for the target ID dataset
        postprocessor_name=postprocessor_name, # the postprocessor to use
        postprocessor=None,                    # if you want to use your own postprocessor
        batch_size=200,                        # for certain methods the results can be slightly affected by batch size
        shuffle=False,
        num_workers=2)                         # could use more num_workers outside colab
    
    metrics = evaluator.eval_ood(fsood=False)
    near,far = metrics['AUROC']['nearood'],metrics['AUROC']['farood']
    accuracy = metrics['ACC']['nearood']

    res_each = [method,near,far]
    res.append(res_each)

res.append(accuracy)



# 지정된 파일 경로
file_path = '.pkl'

# res 리스트를 pickle 파일로 저장
with open(file_path, 'wb') as f:
    pickle.dump(res, f)

print(f"Results have been saved to {file_path}")