In [8]:
# set current directory (where this repo is located)
import os
PROJECT_ROOT = '/home/young/workspace/reconstruction/recon-mnistc'
os.chdir(PROJECT_ROOT)
print('current directory:', os.getcwd())

current directory: /home/young/workspace/reconstruction/recon-mnistc


In [24]:
import torch
from train_cnn import *

from loaddata import *
import pandas as pd
from evaluation import topkacc



DATA_DIR = './data'
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
BATCHSIZE = 100

    
@torch.no_grad()
def evaluate_cnn(model, task, num_targets):
    model.eval() 

    # load dataloader and iterator
    dataloader = fetch_dataloader(task, DATA_DIR, DEVICE, BATCHSIZE, train=False)    
    diter = iter(dataloader)
    
    # run and save output
    x_all, y_all, pred_all, obj_accs_all, img_accs_all= [],[],[],[],[]
    for x, y in diter:

        data, target = x.to(DEVICE),  y.to(DEVICE)
        output = model(data)
        output = torch.sigmoid(output)     #<--- since you use BCEWithLogitsLoss
    
        obj_accs = topkacc(output, target, topk=num_targets)
        img_accs = (obj_accs == 1).float()

        # import pdb; pdb.set_trace()
        x_all.append(data)
        y_all.append(target)
        pred_all.append(output)
        obj_accs_all.append(obj_accs)
        img_accs_all.append(img_accs)   


    # concat and add to outputs dictionary
    x_all = torch.cat(x_all, dim=0)
    y_all = torch.cat(y_all, dim=0)
    pred_all = torch.cat(pred_all, dim=0)
    obj_accs_all = torch.cat(obj_accs_all, dim=0)
    img_accs_all = torch.cat(img_accs_all, dim=0)
    

    return x_all, y_all, pred_all, obj_accs_all, img_accs_all


# Model evaluation


In [31]:
task = 'mnist_multi' #'mnist_multi_high'
cnn_type = 'resnet' #'2conv' or 'resnet' 
num_classes = 10
num_targets = 2

if cnn_type == '2conv':
    modellist = [
        './results/mnist/Feb20_2331_cnn_mnist_shift_single_36/best_epoch24_0.9936.pt',
    ]
elif cnn_type == 'resnet':
    modellist = [
        './results/mnist/Feb23_0417_resnet_mnist_shift_single_36/best_epoch9_1.0000.pt', #single shift
        # './results/mnist/Feb21_0709_resnet_mnist_multi/best_epoch24_0.9788.pt', # multi 
        # './results/mnist/Feb22_1645_resnet_mnist_multi_high/best_epoch32_0.9584.pt', # multi highoverlap

    ]


df = pd.DataFrame()
for i, load_model_path in enumerate(modellist):

    print(f'start analysis on {load_model_path}')
    
    if cnn_type=='2conv':
        cnn = Net(feature_size_after_conv=16384, num_classes= num_classes).to(DEVICE)
    elif cnn_type=='resnet':
        cnn = ResNet(in_channels=1, resblock= ResBlock, outputs=10).to(DEVICE)
    else:
        raise NotImplementedError 
    
    # load the trained weights
    cnn.load_state_dict(torch.load(load_model_path))

    x_all, y_all, pred_all, obj_accs_all, img_accs_all = evaluate_cnn(cnn, task, num_targets)

    obj_acc = obj_accs_all.mean().item()
    img_acc = img_accs_all.mean().item()

    df[load_model_path]= [obj_acc, img_acc]

df.index = ['obj_acc', 'img_acc']
# df.index = np.arange(1, len(df)+1)
# df.loc['AVERAGE'] = df.mean()

df

start analysis on ./results/mnist/Feb23_0417_resnet_mnist_shift_single_36/best_epoch9_1.0000.pt
mnist_overlap_4pix_nodup_1fold_36width_2obj_train.pt mnist_overlap_4pix_nodup_1fold_36width_2obj_test.pt


Unnamed: 0,./results/mnist/Feb23_0417_resnet_mnist_shift_single_36/best_epoch9_1.0000.pt
obj_acc,0.5547
img_acc,0.2279
