In [8]:
# set current directory (where this repo is located)
import os
PROJECT_ROOT = '/home/young/workspace/reconstruction/recon-mnistc'
os.chdir(PROJECT_ROOT)
print('current directory:', os.getcwd())

current directory: /home/young/workspace/reconstruction/recon-mnistc


In [20]:
import torch
from train_cnn import *

from loaddata import *
import pandas as pd
from evaluation import topkacc



DATA_DIR = './data'
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
BATCHSIZE = 100

    
@torch.no_grad()
def evaluate_cnn(model, task, num_targets):
    model.eval() 

    # load dataloader and iterator
    dataloader = fetch_dataloader(task, DATA_DIR, DEVICE, BATCHSIZE, train=False)    
    diter = iter(dataloader)
    
    # run and save output
    x_all, y_all, pred_all, obj_accs_all, img_accs_all= [],[],[],[],[]
    for x, y in diter:

        data, target = x.to(DEVICE),  y.to(DEVICE)
        output = model(data)
        output = torch.sigmoid(output)     #<--- since you use BCEWithLogitsLoss
    
        obj_accs = topkacc(output, target, topk=num_targets)
        img_accs = obj_accs == 1

        # import pdb; pdb.set_trace()
        x_all.append(data)
        y_all.append(target)
        pred_all.append(output)
        obj_accs_all.append(obj_accs)
        img_accs_all.append(img_accs)   


    # concat and add to outputs dictionary
    x_all = torch.cat(x_all, dim=0)
    y_all = torch.cat(y_all, dim=0)
    pred_all = torch.cat(pred_all, dim=0)
    obj_accs_all = torch.cat(obj_accs_all, dim=0)
    img_accs_all = torch.cat(img_accs_all, dim=0)
    

    return x_all, y_all, pred_all, obj_accs_all, img_accs_all


# Model evaluation on MNIST-C


In [21]:
task ='mnist_multi_high'
cnn_type = 'resnet' #'2conv' or 'resnet' 
num_classes = 10
num_targets = 2

if cnn_type == '2conv':
    modellist = [
    ]
elif cnn_type == 'resnet':
    modellist = [
        './results/mnist/Feb22_1645_resnet_mnist_multi_high/best_epoch32_0.9584.pt',

    ]


df = pd.DataFrame()
for i, load_model_path in enumerate(modellist):

    print(f'start analysis on {load_model_path}')
    
    if cnn_type=='2conv':
        cnn = Net().to(DEVICE)
    elif cnn_type=='resnet':
        cnn = ResNet(in_channels=1, resblock= ResBlock, outputs=10).to(DEVICE)
    else:
        raise NotImplementedError 
    
    # load the trained weights
    cnn.load_state_dict(torch.load(load_model_path))

    x_all, y_all, pred_all, obj_accs_all, img_accs_all = evaluate_cnn(cnn, task, num_targets)

    obj_acc = obj_accs_all.mean().item()
    img_acc = img_accs_all.mean().item()

    df[load_model_path]= [obj_acc, img_acc]

df.index = ['obj_acc', 'img_acc']
# df.index = np.arange(1, len(df)+1)
# df.loc['AVERAGE'] = df.mean()

df

start analysis on ./results/mnist/Feb22_1645_resnet_mnist_multi_high/best_epoch32_0.9584.pt
mnist_highoverlap_4pix_nodup_1fold_36width_2obj_train.pt mnist_highoverlap_4pix_nodup_1fold_36width_2obj_test.pt


RuntimeError: "topk_out_cuda" not implemented for 'Bool'

In [3]:
# save to csv
path_df = 'model-results-cnn-clean-resnet4.csv'
if os.path.isfile(path_df):
    print(f'test done! file {path_df} already exists, df is not saved')
else: 
    df.to_csv(path_df, index=False)
    print(f'test done! df is saved to csv as {path_df}')

test done! df is saved to csv as model-results-cnn-clean-resnet4.csv


# model evaluation on shape-dataset

In [2]:
tasklist = ['mnist', 'mnist_occlusion', 'mnist_flipped', 'mnist_random']

modellist = [

]

df = pd.DataFrame()
df['task'] = tasklist

for i, load_model_path in enumerate(modellist):
    print(f'start analysis on {load_model_path}')
    cnn = Net().to(DEVICE)
    cnn.load_state_dict(torch.load(load_model_path))
    cnn.eval()
    accs = []
    for task in tasklist:
        data_cnn, target_cnn, logsoft_cnn, pred_cnn, acc_cnn = evaluate_cnn(task)
        accs.append(100*acc_cnn.float().mean().item())

    df[load_model_path]=accs

df

start analysis on ./results/mnist/Apr29_1330_cnn_clean_run1/best_epoch15_0.9963.pt
original mnist dataset
start analysis on ./results/mnist/Apr29_1336_cnn_clean_run2/best_epoch14_0.9969.pt
original mnist dataset
start analysis on ./results/mnist/Apr29_1342_cnn_clean_run3/best_epoch16_0.9968.pt
original mnist dataset
start analysis on ./results/mnist/Apr29_1348_cnn_clean_run4/best_epoch14_0.9960.pt
original mnist dataset
start analysis on ./results/mnist/Apr29_1355_cnn_clean_run5/best_epoch13_0.9969.pt
original mnist dataset


Unnamed: 0,task,./results/mnist/Apr29_1330_cnn_clean_run1/best_epoch15_0.9963.pt,./results/mnist/Apr29_1336_cnn_clean_run2/best_epoch14_0.9969.pt,./results/mnist/Apr29_1342_cnn_clean_run3/best_epoch16_0.9968.pt,./results/mnist/Apr29_1348_cnn_clean_run4/best_epoch14_0.9960.pt,./results/mnist/Apr29_1355_cnn_clean_run5/best_epoch13_0.9969.pt
0,mnist,99.059999,99.079996,99.149996,99.159998,99.1
1,mnist_occlusion,88.849998,88.559997,88.629997,88.909996,87.909997
2,mnist_flipped,59.670001,62.580001,61.559999,62.470001,60.939997
3,mnist_random,34.529999,34.419999,34.41,34.509999,33.949998


# Model evaluation on each batch

In [40]:
task ='mnist_c_mini'
test_dataloader = fetch_dataloader(task, DATA_DIR, DEVICE, BATCHSIZE, train=False)

load_model_path = './results/mnist/cnn3/epoch50.pt'

model = Net().to(DEVICE)
model.load_state_dict(torch.load(load_model_path))
model.eval()


batchnum=3*int(1000/BATCHSIZE)-1

accs_mini = []
pred_mini =[]
target_mini =[]

diter = iter(test_dataloader)
for i in range(batchnum):
    x, y = next(diter)

with torch.no_grad():            
    data, target = x.to(DEVICE),  y.to(DEVICE)
    target = target.argmax(dim=1, keepdim=True)
    output = model(data)
#                 test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
    pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
    correct = pred.eq(target.view_as(pred)).sum().item()
    acc = correct /BATCHSIZE
accs_mini.append(acc)
pred_mini.append(pred)
target_mini.append(target)

print(np.mean(accs_mini))

predictions = torch.cat(pred_mini, dim=0).flatten()
targets =  torch.cat(target_mini, dim=0).flatten()

## print predictions and acc
for i in range(len(predictions)):
    if predictions[i] != targets[i]:
        sign= "***"
    else: sign=None
    print(f'trial: {i}, prediction: {predictions[i]}, {sign}')

0.91
