-
Notifications
You must be signed in to change notification settings - Fork 4
/
cdfsl_test.py
146 lines (98 loc) · 4.57 KB
/
cdfsl_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import time
import os
import glob
import random
import sys
from utils.io_utils import set_seed, parse_args
params = parse_args('test')
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim
import torch.optim.lr_scheduler as lr_scheduler
set_seed(params.seed)
import config.configs as configs
import models.backbone as backbone
from data.datamgr_2loss import SimpleDataManager, SetDataManager
from methods.protonet_2loss import ProtoNet
from utils.io_utils import model_dict, get_resume_file, get_best_file, get_assigned_file
import json
from models.model_resnet import *
from utils.utils import RunningAverage, Logger, wandb_restore_models
from tqdm import tqdm
import wandb
from data.cdfsl import Chest_few_shot
from data.cdfsl import CropDisease_few_shot
from data.cdfsl import EuroSAT_few_shot
from data.cdfsl import ISIC_few_shot
import csv
out_file = open("other/cdfsl_results.txt", "a")
log_file = open("other/cdfsl_results_logs.txt", "a")
timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime())
datamanagers = {"ISIC": ISIC_few_shot.SetDataManager, "EuroSAT": EuroSAT_few_shot.SetDataManager, \
"Chest": Chest_few_shot.SetDataManager}
dataloaders = {}
for dset in datamanagers.keys():
dataloaders[dset] = {}
datamgr = datamanagers[dset](224, n_query = 16, n_eposide = 600, n_way = 5, n_support = 5)
dataloaders[dset]["224"] = datamgr.get_data_loader(aug=False)
datamgr = datamanagers[dset](84, n_query = 16, n_eposide = 600, n_way = 5, n_support = 5)
dataloaders[dset]["84"] = datamgr.get_data_loader(aug=False)
with open('other/runs.csv') as csv_file:
csv_reader = csv.reader(csv_file, delimiter=',')
line_count = 0
for row in csv_reader:
id = row[0]
print(id)
wandb.init(project="Table-2", entity="meta-learners", id=id, resume=True) # NOTE: Change when project="CDFSL"
dir = wandb.config["checkpoint_dir"]
dir = dir[dir.index("results"):]
if len(id) == 0 or len(dir) == 0:
continue
image_size = wandb.config["image_size"]
model_type = wandb.config["model"]
params = wandb.config
model = ProtoNet( model_dict[model_type], n_way=5, n_support=5, use_bn=(not params["no_bn"]), pretrain=params["pretrain"], tracking=params["tracking"],)
try:
for file in ["best_model.tar", "last_model.tar"]:
full_path = os.path.join(dir, file)
pth = wandb.restore(full_path)
print("Restored %s" % (pth.name))
tmp = torch.load(pth.name)
state = tmp['state']
state_keys = list(state.keys())
for i, key in enumerate(state_keys):
if "feature." in key:
newkey = key.replace("feature.","") # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
state[newkey] = state.pop(key)
else:
state.pop(key)
model.feature.load_state_dict(state)
model = model.cuda()
model.feature = model.feature.cuda()
model.feature.eval()
model.eval()
for dset in datamanagers.keys():
print(dset, end=": ")
acc_mean, acc_std = model.test_loop( dataloaders[dset][str(image_size)], proto_only=True)
acc_str_c = '%4.2f%% +- %4.2f%%' %(acc_mean, 1.96* acc_std/np.sqrt(600))
wandb.log({"cdfsl/%s_%s" % (dset, "best" if file=="best_model.tar" else "last") : acc_str_c})
exp_setting = 'Time: %s, W&B ID: %s, Dataset: %s' %(timestamp, id, dset)
acc_str = 'Test Acc: %s' %(acc_str_c)
out_file.write( '%s %s\n' %(exp_setting,acc_str) )
print("Removed %s" % (pth.name))
os.remove(pth.name)
wandb.finish()
except ValueError as ve:
print(ve)
log_file.write("ValueError for %s: %s" % (id, ve))
except RuntimeError as re:
print(re)
log_file.write("RuntimeError for %s: %s" % (id, re))
except:
print("Unexpected error:", sys.exc_info()[0])
log_file.write("Unexpected for %s: %s" % (id, sys.exc_info()[0]))
wandb.finish()