In [3]:
import torch
from train_cnn import *

from loaddata import *
import pandas as pd

DATA_DIR = '../data'
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
BATCHSIZE = 1000

PATH_MNISTC = '../data/MNIST_C/'
CORRUPTION_TYPES = ['identity', 
         'shot_noise', 'impulse_noise','glass_blur','motion_blur',
         'shear', 'scale',  'rotate',  'brightness',  'translate',
         'stripe', 'fog','spatter','dotted_line', 'zigzag',
         'canny_edges']


N_MINI_PER_CORRUPTION = 1000
    
@torch.no_grad()
def evaluate_cnn_on_mnistc_mini(corruption, cnn, max_batch_num=None):
    # get corruption batch information
    corruption_id = int(CORRUPTION_TYPES.index(corruption))
    num_batch_required = int(N_MINI_PER_CORRUPTION/BATCHSIZE) # if batchsize 100; 10 batches are requried
    
    # load dataloader and iterator
    dataloader = fetch_dataloader('mnist_c_mini', DATA_DIR, DEVICE, BATCHSIZE, train=False)    
    diter = iter(dataloader)
    
    # save output
    x_all, y_all, pred_all, acc_all, class_prob_all = [],[],[], [],[]
    cnn.eval() 

    # get input and gt
    for i in range(corruption_id*num_batch_required): #id =0, 0 iteration; id=1, 10 iteration
        x, y = next(diter)
    

    for i in range(0, num_batch_required):
        x, y = next(diter)
        gtx = None
        
        if max_batch_num:
            if i == max_batch_num:
                break
        
        data, target = x.to(DEVICE),  y.to(DEVICE)
        target = target.argmax(dim=1, keepdim=True)
        output = cnn(data)
        #                 test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        acc = pred.eq(target.view_as(pred))
        
        x_all.append(data)
        y_all.append(target.flatten())
        pred_all.append(pred.flatten())
        acc_all.append(acc.flatten().float())
        class_prob_all.append(output)
    

    # concat and add to outputs dictionary
    x_all = torch.cat(x_all, dim=0)
    y_all = torch.cat(y_all, dim=0)
    pred_all = torch.cat(pred_all, dim=0)
    acc_all = torch.cat(acc_all, dim=0)
    class_prob_all = torch.cat(class_prob_all, dim=0)

    return x_all, y_all, class_prob_all, pred_all, acc_all


@torch.no_grad()
def evaluate_cnn(task):

    # load dataloader and iterator
    dataloader = fetch_dataloader(task, DATA_DIR, DEVICE, BATCHSIZE, train=False)    
    diter = iter(dataloader)
    
    # save output
    x_all, y_all, pred_all, acc_all, class_prob_all = [],[],[], [],[]
    cnn.eval() 


    for x, y in diter:
        gtx = None
        
        data, target = x.to(DEVICE),  y.to(DEVICE)
        target = target.argmax(dim=1, keepdim=True)
        output = cnn(data)
        #                 test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        acc = pred.eq(target.view_as(pred))
        
        x_all.append(data)
        y_all.append(target.flatten())
        pred_all.append(pred.flatten())
        acc_all.append(acc.flatten().float())
        class_prob_all.append(output)
    

    # concat and add to outputs dictionary
    x_all = torch.cat(x_all, dim=0)
    y_all = torch.cat(y_all, dim=0)
    pred_all = torch.cat(pred_all, dim=0)
    acc_all = torch.cat(acc_all, dim=0)
    class_prob_all = torch.cat(class_prob_all, dim=0)

    return x_all, y_all, class_prob_all, pred_all, acc_all



@torch.no_grad()
def evaluate_cnn_on_mnistc_original(corruption, cnn):
    path_images = os.path.join(PATH_MNISTC, corruption, 'test_images.npy')
    path_labels = os.path.join(PATH_MNISTC, corruption, 'test_labels.npy')

    # convert to torch
    images = np.load(path_images)
    labels = np.load(path_labels)
    transform_tohot = T.Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
    images_tensorized = torch.stack([T.ToTensor()(im) for im in images])
    labels_tensorized = torch.stack([transform_tohot(label) for label in labels])
    # print(images_tensorized.shape) #torch.Size([10000, 1, 28, 28])
    # print(labels_tensorized.shape) #torch.Size([10000, 10])

    # create dataloader
    kwargs = {'num_workers': 1, 'pin_memory': True} if DEVICE == 'cuda' else {}
    dataset = TensorDataset(images_tensorized, labels_tensorized)
    dataloader = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=False, drop_last=False, **kwargs)

    # save output
    x_all, y_all, pred_all, acc_all, class_prob_all = [],[],[], [],[]
    cnn.eval() 

    # get input and gt

    for data in dataloader:
        x, y = data
        gtx = None
        
        data, target = x.to(DEVICE),  y.to(DEVICE)
        target = target.argmax(dim=1, keepdim=True)
        output = cnn(data)
        #                 test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        acc = pred.eq(target.view_as(pred))
        
        x_all.append(data)
        y_all.append(target.flatten())
        pred_all.append(pred.flatten())
        acc_all.append(acc.flatten().float())
        class_prob_all.append(output)
    

    # concat and add to outputs dictionary
    x_all = torch.cat(x_all, dim=0)
    y_all = torch.cat(y_all, dim=0)
    pred_all = torch.cat(pred_all, dim=0)
    acc_all = torch.cat(acc_all, dim=0)
    class_prob_all = torch.cat(class_prob_all, dim=0)

    return x_all, y_all, class_prob_all, pred_all, acc_all

# Model evaluation on MNIST-C

In [4]:
# task ='mnist_c_mini'
task ='mnist_c_original'
modellist = [
# './models/cnn/run1_epoch50.pt',
# './models/cnn/run2_epoch50.pt',
# './models/cnn/run3_epoch50.pt',
# './models/cnn/run4_epoch50.pt',
# './models/cnn/run5_epoch50.pt',
# './results/mnist/cnn_shift/epoch50_99.50666666666666.pt'
# './results/mnist/cnn_shift_epoch1000/earlystopatepoch38_0.9950666666666667.pt'
# './results/mnist/cnn_shift_epoch1000/earlystopatepoch36_0.9967166666666667.pt'

# clean
# './results/mnist/Apr29_1330_cnn_clean_run1/best_epoch15_0.9963.pt',
# './results/mnist/Apr29_1336_cnn_clean_run2/best_epoch14_0.9969.pt',
# './results/mnist/Apr29_1342_cnn_clean_run3/best_epoch16_0.9968.pt',
# './results/mnist/Apr29_1348_cnn_clean_run4/best_epoch14_0.9960.pt',
# './results/mnist/Apr29_1355_cnn_clean_run5/best_epoch13_0.9969.pt',
    
# shift
# './results/mnist/Apr29_1404_cnn_shift_run1/best_epoch14_0.9946.pt',
# './results/mnist/Apr29_1407_cnn_shift_run2/best_epoch14_0.9950.pt',
# './results/mnist/Apr29_1410_cnn_shift_run3/best_epoch18_0.9951.pt',
# './results/mnist/Apr29_1414_cnn_shift_run4/best_epoch15_0.9946.pt',
# './results/mnist/Apr29_1417_cnn_shift_run5/best_epoch15_0.9944.pt'
    
# AUG
'./results/mnist/May25_0239_cnn_aug_run1/best_epoch22_0.9708.pt',
'./results/mnist/May25_0251_cnn_aug_run2/best_epoch33_0.9718.pt',
'./results/mnist/May25_0306_cnn_aug_run3/best_epoch34_0.9711.pt',
'./results/mnist/May25_0321_cnn_aug_run4/best_epoch17_0.9709.pt',
'./results/mnist/May25_0331_cnn_aug_run5/best_epoch23_0.9720.pt'
]

df = pd.DataFrame()
df['corruption'] = CORRUPTION_TYPES

for i, load_model_path in enumerate(modellist):
    print(f'start analysis on {load_model_path}')
    cnn = Net().to(DEVICE)
    cnn.load_state_dict(torch.load(load_model_path))
    cnn.eval()
    accs = []
    for corruption in CORRUPTION_TYPES:

        if task == 'mnist_c_mini':
            data_cnn, target_cnn, logsoft_cnn, pred_cnn, acc_cnn \
            = evaluate_cnn_on_mnistc_mini(corruption, cnn, max_batch_num=None)
#             print(f'==> corruption type: {corruption}, this batch acc: {acc_cnn.float().mean().item()}')
        elif task =='mnist_c_original':
#             print("original is used")
            data_cnn, target_cnn, logsoft_cnn, pred_cnn, acc_cnn \
            =  evaluate_cnn_on_mnistc_original(corruption, cnn)
#             print(f'==> corruption type: {corruption}, this batch acc: {acc_cnn.float().mean().item()}')

        accs.append(100*acc_cnn.float().mean().item())

    df[load_model_path]=accs

df.index = np.arange(1, len(df)+1)
df.loc['AVERAGE'] = df.mean()

df

start analysis on ./results/mnist/May25_0239_cnn_aug_run1/best_epoch22_0.9708.pt
start analysis on ./results/mnist/May25_0251_cnn_aug_run2/best_epoch33_0.9718.pt
start analysis on ./results/mnist/May25_0306_cnn_aug_run3/best_epoch34_0.9711.pt
start analysis on ./results/mnist/May25_0321_cnn_aug_run4/best_epoch17_0.9709.pt
start analysis on ./results/mnist/May25_0331_cnn_aug_run5/best_epoch23_0.9720.pt


  df.loc['AVERAGE'] = df.mean()


Unnamed: 0,corruption,./results/mnist/May25_0239_cnn_aug_run1/best_epoch22_0.9708.pt,./results/mnist/May25_0251_cnn_aug_run2/best_epoch33_0.9718.pt,./results/mnist/May25_0306_cnn_aug_run3/best_epoch34_0.9711.pt,./results/mnist/May25_0321_cnn_aug_run4/best_epoch17_0.9709.pt,./results/mnist/May25_0331_cnn_aug_run5/best_epoch23_0.9720.pt
1,identity,98.899996,98.929995,98.92,98.909998,98.839998
2,shot_noise,94.269997,91.670001,92.209995,92.479998,93.269998
3,impulse_noise,48.33,54.170001,53.95,44.56,56.580001
4,glass_blur,79.749995,64.15,60.499996,70.879996,65.509999
5,motion_blur,85.929996,92.899996,88.409996,93.57,88.859999
6,shear,98.069996,98.209995,98.109996,98.079997,98.229998
7,scale,98.549998,98.579997,98.509997,98.679996,98.6
8,rotate,97.969997,98.189998,98.189998,98.13,98.089999
9,brightness,24.049999,18.509999,16.97,34.630001,15.42
10,translate,88.76,88.009995,88.76,88.529998,88.849998


In [5]:
# save to csv
path_df = 'model-results-cnn-clean-aug.csv'
if os.path.isfile(path_df):
    print(f'test done! file {path_df} already exists, df is not saved')
else: 
    df.to_csv(path_df, index=False)
    print(f'test done! df is saved to csv as {path_df}')

test done! df is saved to csv as model-results-cnn-clean-aug.csv


# model evaluation on shape-dataset

In [2]:
tasklist = ['mnist', 'mnist_occlusion', 'mnist_flipped', 'mnist_random']

modellist = [
# './models/cnn/run1_epoch50.pt',
# './models/cnn/run2_epoch50.pt',
# './models/cnn/run3_epoch50.pt',
# './models/cnn/run4_epoch50.pt',
# './models/cnn/run5_epoch50.pt',
# './results/mnist/cnn_shift/epoch50_99.50666666666666.pt'
# './results/mnist/cnn_shift_epoch1000/earlystopatepoch38_0.9950666666666667.pt'
# './results/mnist/cnn_shift_epoch1000/earlystopatepoch36_0.9967166666666667.pt'

# clean
'./results/mnist/Apr29_1330_cnn_clean_run1/best_epoch15_0.9963.pt',
'./results/mnist/Apr29_1336_cnn_clean_run2/best_epoch14_0.9969.pt',
'./results/mnist/Apr29_1342_cnn_clean_run3/best_epoch16_0.9968.pt',
'./results/mnist/Apr29_1348_cnn_clean_run4/best_epoch14_0.9960.pt',
'./results/mnist/Apr29_1355_cnn_clean_run5/best_epoch13_0.9969.pt',
    
# shift
# './results/mnist/Apr29_1404_cnn_shift_run1/best_epoch14_0.9946.pt',
# './results/mnist/Apr29_1407_cnn_shift_run2/best_epoch14_0.9950.pt',
# './results/mnist/Apr29_1410_cnn_shift_run3/best_epoch18_0.9951.pt',
# './results/mnist/Apr29_1414_cnn_shift_run4/best_epoch15_0.9946.pt',
# './results/mnist/Apr29_1417_cnn_shift_run5/best_epoch15_0.9944.pt'
]

df = pd.DataFrame()
df['task'] = tasklist

for i, load_model_path in enumerate(modellist):
    print(f'start analysis on {load_model_path}')
    cnn = Net().to(DEVICE)
    cnn.load_state_dict(torch.load(load_model_path))
    cnn.eval()
    accs = []
    for task in tasklist:
        data_cnn, target_cnn, logsoft_cnn, pred_cnn, acc_cnn = evaluate_cnn(task)
        accs.append(100*acc_cnn.float().mean().item())

    df[load_model_path]=accs

df

start analysis on ./results/mnist/Apr29_1330_cnn_clean_run1/best_epoch15_0.9963.pt
original mnist dataset
start analysis on ./results/mnist/Apr29_1336_cnn_clean_run2/best_epoch14_0.9969.pt
original mnist dataset
start analysis on ./results/mnist/Apr29_1342_cnn_clean_run3/best_epoch16_0.9968.pt
original mnist dataset
start analysis on ./results/mnist/Apr29_1348_cnn_clean_run4/best_epoch14_0.9960.pt
original mnist dataset
start analysis on ./results/mnist/Apr29_1355_cnn_clean_run5/best_epoch13_0.9969.pt
original mnist dataset


Unnamed: 0,task,./results/mnist/Apr29_1330_cnn_clean_run1/best_epoch15_0.9963.pt,./results/mnist/Apr29_1336_cnn_clean_run2/best_epoch14_0.9969.pt,./results/mnist/Apr29_1342_cnn_clean_run3/best_epoch16_0.9968.pt,./results/mnist/Apr29_1348_cnn_clean_run4/best_epoch14_0.9960.pt,./results/mnist/Apr29_1355_cnn_clean_run5/best_epoch13_0.9969.pt
0,mnist,99.059999,99.079996,99.149996,99.159998,99.1
1,mnist_occlusion,88.849998,88.559997,88.629997,88.909996,87.909997
2,mnist_flipped,59.670001,62.580001,61.559999,62.470001,60.939997
3,mnist_random,34.529999,34.419999,34.41,34.509999,33.949998


# Model evaluation on each batch

In [40]:
task ='mnist_c_mini'
test_dataloader = fetch_dataloader(task, DATA_DIR, DEVICE, BATCHSIZE, train=False)

load_model_path = './results/mnist/cnn3/epoch50.pt'

model = Net().to(DEVICE)
model.load_state_dict(torch.load(load_model_path))
model.eval()


batchnum=3*int(1000/BATCHSIZE)-1

accs_mini = []
pred_mini =[]
target_mini =[]

diter = iter(test_dataloader)
for i in range(batchnum):
    x, y = next(diter)

with torch.no_grad():            
    data, target = x.to(DEVICE),  y.to(DEVICE)
    target = target.argmax(dim=1, keepdim=True)
    output = model(data)
#                 test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
    pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
    correct = pred.eq(target.view_as(pred)).sum().item()
    acc = correct /BATCHSIZE
accs_mini.append(acc)
pred_mini.append(pred)
target_mini.append(target)

print(np.mean(accs_mini))

predictions = torch.cat(pred_mini, dim=0).flatten()
targets =  torch.cat(target_mini, dim=0).flatten()

## print predictions and acc
for i in range(len(predictions)):
    if predictions[i] != targets[i]:
        sign= "***"
    else: sign=None
    print(f'trial: {i}, prediction: {predictions[i]}, {sign}')

0.91
