In [1]:
import torch
from train_cnn import *

from loaddata import *
import pandas as pd

DATA_DIR = '../data'
DEVICE = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
BATCHSIZE = 1000

PATH_MNISTC = '../data/MNIST_C/'
CORRUPTION_TYPES = ['identity', 
         'shot_noise', 'impulse_noise','glass_blur','motion_blur',
         'shear', 'scale',  'rotate',  'brightness',  'translate',
         'stripe', 'fog','spatter','dotted_line', 'zigzag',
         'canny_edges']


N_MINI_PER_CORRUPTION = 1000
    
@torch.no_grad()
def evaluate_cnn_on_mnistc_mini(corruption, cnn, max_batch_num=None):
    # get corruption batch information
    corruption_id = int(CORRUPTION_TYPES.index(corruption))
    num_batch_required = int(N_MINI_PER_CORRUPTION/BATCHSIZE) # if batchsize 100; 10 batches are requried
    
    # load dataloader and iterator
    dataloader = fetch_dataloader('mnist_c_mini', DATA_DIR, DEVICE, BATCHSIZE, train=False)    
    diter = iter(dataloader)
    
    # save output
    x_all, y_all, pred_all, acc_all, class_prob_all = [],[],[], [],[]
    cnn.eval() 

    # get input and gt
    for i in range(corruption_id*num_batch_required): #id =0, 0 iteration; id=1, 10 iteration
        x, y = next(diter)
    

    for i in range(0, num_batch_required):
        x, y = next(diter)
        gtx = None
        
        if max_batch_num:
            if i == max_batch_num:
                break
        
        data, target = x.to(DEVICE),  y.to(DEVICE)
        target = target.argmax(dim=1, keepdim=True)
        output = cnn(data)
        #                 test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        acc = pred.eq(target.view_as(pred))
        
        x_all.append(data)
        y_all.append(target.flatten())
        pred_all.append(pred.flatten())
        acc_all.append(acc.flatten().float())
        class_prob_all.append(output)
    

    # concat and add to outputs dictionary
    x_all = torch.cat(x_all, dim=0)
    y_all = torch.cat(y_all, dim=0)
    pred_all = torch.cat(pred_all, dim=0)
    acc_all = torch.cat(acc_all, dim=0)
    class_prob_all = torch.cat(class_prob_all, dim=0)

    return x_all, y_all, class_prob_all, pred_all, acc_all



@torch.no_grad()
def evaluate_cnn_on_mnistc_original(corruption, cnn):
    path_images = os.path.join(PATH_MNISTC, corruption, 'test_images.npy')
    path_labels = os.path.join(PATH_MNISTC, corruption, 'test_labels.npy')

    # convert to torch
    images = np.load(path_images)
    labels = np.load(path_labels)
    transform_tohot = T.Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
    images_tensorized = torch.stack([T.ToTensor()(im) for im in images])
    labels_tensorized = torch.stack([transform_tohot(label) for label in labels])
    # print(images_tensorized.shape) #torch.Size([10000, 1, 28, 28])
    # print(labels_tensorized.shape) #torch.Size([10000, 10])

    # create dataloader
    kwargs = {'num_workers': 1, 'pin_memory': True} if DEVICE == 'cuda' else {}
    dataset = TensorDataset(images_tensorized, labels_tensorized)
    dataloader = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=False, drop_last=False, **kwargs)

    # save output
    x_all, y_all, pred_all, acc_all, class_prob_all = [],[],[], [],[]
    cnn.eval() 

    # get input and gt

    for data in dataloader:
        x, y = data
        gtx = None
        
        data, target = x.to(DEVICE),  y.to(DEVICE)
        target = target.argmax(dim=1, keepdim=True)
        output = cnn(data)
        #                 test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        acc = pred.eq(target.view_as(pred))
        
        x_all.append(data)
        y_all.append(target.flatten())
        pred_all.append(pred.flatten())
        acc_all.append(acc.flatten().float())
        class_prob_all.append(output)
    

    # concat and add to outputs dictionary
    x_all = torch.cat(x_all, dim=0)
    y_all = torch.cat(y_all, dim=0)
    pred_all = torch.cat(pred_all, dim=0)
    acc_all = torch.cat(acc_all, dim=0)
    class_prob_all = torch.cat(class_prob_all, dim=0)

    return x_all, y_all, class_prob_all, pred_all, acc_all

# Model evaluation on MNIST-C

In [6]:
# task ='mnist_c_mini'
task ='mnist_c_original'

modellist = [
# './models/cnn/run1_epoch50.pt',
# './models/cnn/run2_epoch50.pt',
# './models/cnn/run3_epoch50.pt',
# './models/cnn/run4_epoch50.pt',
# './models/cnn/run5_epoch50.pt',
# './results/mnist/cnn_shift/epoch50_99.50666666666666.pt'
# './results/mnist/cnn_shift_epoch1000/earlystopatepoch38_0.9950666666666667.pt'
# './results/mnist/cnn_shift_epoch1000/earlystopatepoch36_0.9967166666666667.pt'

# clean
'./results/mnist/Apr29_1330_cnn_clean_run1/best_epoch15_0.9963.pt',
'./results/mnist/Apr29_1336_cnn_clean_run2/best_epoch14_0.9969.pt',
'./results/mnist/Apr29_1342_cnn_clean_run3/best_epoch16_0.9968.pt',
'./results/mnist/Apr29_1348_cnn_clean_run4/best_epoch14_0.9960.pt',
'./results/mnist/Apr29_1355_cnn_clean_run5/best_epoch13_0.9969.pt',
    
# shift
# './results/mnist/Apr29_1404_cnn_shift_run1/best_epoch14_0.9946.pt',
# './results/mnist/Apr29_1407_cnn_shift_run2/best_epoch14_0.9950.pt',
# './results/mnist/Apr29_1410_cnn_shift_run3/best_epoch18_0.9951.pt',
# './results/mnist/Apr29_1414_cnn_shift_run4/best_epoch15_0.9946.pt',
# './results/mnist/Apr29_1417_cnn_shift_run5/best_epoch15_0.9944.pt'
]

df = pd.DataFrame()
df['corruption'] = CORRUPTION_TYPES

for i, load_model_path in enumerate(modellist):
    print(f'start analysis on {load_model_path}')
    cnn = Net().to(DEVICE)
    cnn.load_state_dict(torch.load(load_model_path))
    cnn.eval()
    accs = []
    for corruption in CORRUPTION_TYPES:

        if task == 'mnist_c_mini':
            data_cnn, target_cnn, logsoft_cnn, pred_cnn, acc_cnn \
            = evaluate_cnn_on_mnistc_mini(corruption, cnn, max_batch_num=None)
#             print(f'==> corruption type: {corruption}, this batch acc: {acc_cnn.float().mean().item()}')
        elif task =='mnist_c_original':
#             print("original is used")
            data_cnn, target_cnn, logsoft_cnn, pred_cnn, acc_cnn \
            =  evaluate_cnn_on_mnistc_original(corruption, cnn)
#             print(f'==> corruption type: {corruption}, this batch acc: {acc_cnn.float().mean().item()}')

        accs.append(100*acc_cnn.float().mean().item())

    df[load_model_path]=accs

df.index = np.arange(1, len(df)+1)
df.loc['AVERAGE'] = df.mean()

df

start analysis on ./results/mnist/Apr29_1330_cnn_clean_run1/best_epoch15_0.9963.pt
start analysis on ./results/mnist/Apr29_1336_cnn_clean_run2/best_epoch14_0.9969.pt
start analysis on ./results/mnist/Apr29_1342_cnn_clean_run3/best_epoch16_0.9968.pt
start analysis on ./results/mnist/Apr29_1348_cnn_clean_run4/best_epoch14_0.9960.pt
start analysis on ./results/mnist/Apr29_1355_cnn_clean_run5/best_epoch13_0.9969.pt


  df.loc['AVERAGE'] = df.mean()


Unnamed: 0,corruption,./results/mnist/Apr29_1330_cnn_clean_run1/best_epoch15_0.9963.pt,./results/mnist/Apr29_1336_cnn_clean_run2/best_epoch14_0.9969.pt,./results/mnist/Apr29_1342_cnn_clean_run3/best_epoch16_0.9968.pt,./results/mnist/Apr29_1348_cnn_clean_run4/best_epoch14_0.9960.pt,./results/mnist/Apr29_1355_cnn_clean_run5/best_epoch13_0.9969.pt
1,identity,99.059999,99.079996,99.149996,99.159998,99.1
2,shot_noise,97.369999,97.099996,97.469997,97.329998,97.319996
3,impulse_noise,90.509999,91.81,92.85,91.259998,87.949997
4,glass_blur,90.139997,91.159999,90.129995,89.029998,90.41
5,motion_blur,90.389997,92.409998,94.409996,92.379999,91.189998
6,shear,97.039998,97.419995,97.169995,97.249997,97.389996
7,scale,93.299997,94.059998,94.119996,93.649995,93.779999
8,rotate,90.749997,91.889995,91.179997,91.329998,91.359997
9,brightness,83.419997,80.680001,87.689996,85.519999,80.329996
10,translate,46.3,48.109999,47.459999,46.899998,47.409999


In [10]:
# save to csv
path_df = 'model-results-cnn-clean.csv'
if os.path.isfile(path_df):
    print(f'test done! file {path_df} already exists, df is not saved')
else: 
    df.to_csv(path_df, index=False)
    print(f'test done! df is saved to csv as {path_df}')

test done! df is saved to csv as model-results-cnn.csv


# Model evaluation on each batch

In [40]:
task ='mnist_c_mini'
test_dataloader = fetch_dataloader(task, DATA_DIR, DEVICE, BATCHSIZE, train=False)

load_model_path = './results/mnist/cnn3/epoch50.pt'

model = Net().to(DEVICE)
model.load_state_dict(torch.load(load_model_path))
model.eval()


batchnum=3*int(1000/BATCHSIZE)-1

accs_mini = []
pred_mini =[]
target_mini =[]

diter = iter(test_dataloader)
for i in range(batchnum):
    x, y = next(diter)

with torch.no_grad():            
    data, target = x.to(DEVICE),  y.to(DEVICE)
    target = target.argmax(dim=1, keepdim=True)
    output = model(data)
#                 test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
    pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
    correct = pred.eq(target.view_as(pred)).sum().item()
    acc = correct /BATCHSIZE
accs_mini.append(acc)
pred_mini.append(pred)
target_mini.append(target)

print(np.mean(accs_mini))

predictions = torch.cat(pred_mini, dim=0).flatten()
targets =  torch.cat(target_mini, dim=0).flatten()

## print predictions and acc
for i in range(len(predictions)):
    if predictions[i] != targets[i]:
        sign= "***"
    else: sign=None
    print(f'trial: {i}, prediction: {predictions[i]}, {sign}')

0.91
