In [2]:
# load required libraries & modules
%load_ext autoreload
%autoreload 2

import os
from tqdm.notebook import tqdm
import pprint
import time
import warnings
# warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
import torch

from utils import *
from loaddata import *
from visualization import *
from rrcapsnet_original import *

torch.set_grad_enabled(False)
torch.set_printoptions(sci_mode=False)

DATA_DIR = '../data'
DEVICE = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
# DEVICE = torch.device('cpu')

BATCHSIZE = 1000

PATH_MNISTC = '../data/MNIST_C/'
CORRUPTION_TYPES = ['identity', 
         'shot_noise', 'impulse_noise','glass_blur','motion_blur',
         'shear', 'scale',  'rotate',  'brightness',  'translate',
         'stripe', 'fog','spatter','dotted_line', 'zigzag',
         'canny_edges']



ACC_TYPE = "hypothesis"

#################
# model load
################
def load_model(args):
    # load model
    model = RRCapsNet(args).to(args.device) 
    model.load_state_dict(torch.load(args.load_model_path))
    return model

def load_args(load_model_path, args_to_update, verbose=False):
    params_filename = os.path.dirname(load_model_path) + '/params.txt'
    assert os.path.isfile(params_filename), "No param flie exists"
    args = parse_params_wremove(params_filename, removelist = ['device']) 
    args = update_args(args, args_to_update)
    args.load_model_path = load_model_path
    if verbose:
        pprint.pprint(args.__dict__, sort_dicts=False)
    return args

############
# testing 
############
def test_model(task, model, args, verbose=False):
    # set task and print setting
    if verbose:
        pprint.pprint(args.__dict__, sort_dicts=False)
    
    # get test results
    model.eval()
    test_dataloader = fetch_dataloader(task, DATA_DIR, DEVICE, BATCHSIZE, train=False)
    test_loss, test_loss_class, test_loss_recon, test_acc = test(model, test_dataloader, args, acc_type=ACC_TYPE)
    if verbose:
        print("==> test_loss=%.5f, test_loss_class=%.5f, test_loss_recon=%.5f, test_acc=%.4f"
              % (test_loss, test_loss_class, test_loss_recon, test_acc))
    return test_loss, test_loss_class, test_loss_recon, test_acc

def test_model_mnistc(path_mnistc, corruptionlist, model, verbose=False):
    # set task and print setting
    if verbose:
        pprint.pprint(args.__dict__, sort_dicts=False)
    
    # get average test results over corruptionlist   
    losses, classlosses, reconlosses, accs = [], [], [], []
    for corruption in corruptionlist:
        test_loss, test_loss_class, test_loss_recon, test_acc = test_model_on_each_corruption(path_mnistc, corruption, model, verbose)

        losses.append(test_loss)
        classlosses.append(test_loss_class)
        reconlosses.append(test_loss_recon)
        accs.append(test_acc)
    
    avgtest_loss = sum(losses)/len(corruptionlist)
    avgtest_loss_class = sum(classlosses)/len(corruptionlist)
    avgtest_loss_recon = sum(reconlosses)/len(corruptionlist)
    avgtest_acc = sum(accs)/len(corruptionlist)
    
    if verbose:
        print("==> average test_loss=%.5f, test_loss_class=%.5f, test_loss_recon=%.5f, test_acc=%.4f"
              % (avgtest_loss, avgtest_loss_class, avgtest_loss_recon, avgtest_acc))
        
    return avgtest_loss, avgtest_loss_class, avgtest_loss_recon, avgtest_acc
    
    
def test_model_on_each_corruption(path_mnistc, corruption, model, verbose=False):
    path_images = os.path.join(path_mnistc, corruption, 'test_images.npy')
    path_labels = os.path.join(path_mnistc, corruption, 'test_labels.npy')

    # convert to torch
    images = np.load(path_images)
    labels = np.load(path_labels)
    transform_tohot = T.Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
    images_tensorized = torch.stack([T.ToTensor()(im) for im in images])
    labels_tensorized = torch.stack([transform_tohot(label) for label in labels])
    # print(images_tensorized.shape) #torch.Size([10000, 1, 28, 28])
    # print(labels_tensorized.shape) #torch.Size([10000, 10])

    # create dataloader
    kwargs = {'num_workers': 1, 'pin_memory': True} if DEVICE == 'cuda' else {}
    dataset = TensorDataset(images_tensorized, labels_tensorized)
    dataloader = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=False, drop_last=False, **kwargs)

    # test on the dataloder
    model.eval()
    test_loss, test_loss_class, test_loss_recon, test_acc = test(model, dataloader, args, acc_type=ACC_TYPE)

    if verbose:
        print("==> individual test_loss=%.5f, test_loss_class=%.5f, test_loss_recon=%.5f, test_acc=%.4f"
              % (test_loss, test_loss_class, test_loss_recon, test_acc))
    return test_loss, test_loss_class, test_loss_recon, test_acc


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# On original MNIST-C

In [3]:
########################################################
# best model comparison on entire corruptions
########################################################
verbose=False
print_args =False
args_to_update = {'device':DEVICE, 'batch_size':BATCHSIZE, 
                 'time_steps': 1, 'routings': 1, 'mask_threshold': 0.1}


modelpathlist = [
# './models/rrcapsnet/run1_epoch50_acc0.9917.pt',
# './models/rrcapsnet/run2_epoch50_acc0.9915.pt',
# './models/rrcapsnet/run3_epoch50_acc0.9907.pt',
# './models/rrcapsnet/run4_epoch50_acc0.9905.pt',
# './models/rrcapsnet/run5_epoch50_acc0.9907.pt',

# './results/mnist/Apr28_1352_adam_clr_128/archive_model_epoch50_acc0.9987.pt'
# './results/mnist/Apr28_1408_adam_clr_512/archive_model_epoch50_acc0.9953.pt' # increase bn worse
# './results/mnist/Apr28_1438_adam_clr_128_wde5/archive_model_epoch50_acc0.9986.pt'
# './results/mnist/Apr28_1458_adam_clr_128_wde4/archive_model_epoch50_acc0.9963.pt' #incerase weigthday worse
# './results/mnist/Apr28_1550_adamw_clr_128/archive_model_epoch50_acc0.9988.pt'
# './results/mnist/Apr28_1626_rmsprop_clr_128/archive_model_epoch50_acc0.9936.pt'
# './results/mnist/Apr28_1641_adam_exp_128/archive_model_epoch50_acc0.9979.pt'
# './results/mnist/Apr28_1719_adam_exp_128_lre3/archive_model_epoch50_acc0.9982.pt'
# './results/mnist/Apr28_1745_adam_exp_128_lre3_lamrecon5/archive_model_epoch50_acc0.9986.pt'
# './results/mnist/Apr28_1815_capsnetencoder_cycle/archive_model_epoch220_acc0.9961.pt'

# './results/mnist/Apr28_2001_shift_adamclr/archive_model_epoch50_acc0.9860.pt'

# Clean-CLR
# './results/mnist/Apr29_0212_clean_clr_run1/best_epoch174_acc0.9993.pt',
# './results/mnist/Apr29_0321_clean_clr_run2/best_epoch163_acc0.9993.pt',
# './results/mnist/Apr29_0427_clean_clr_run3/best_epoch186_acc0.9994.pt',
# './results/mnist/Apr29_0540_clean_clr_run4/best_epoch212_acc0.9995.pt',
# './results/mnist/Apr29_0703_clean_clr_run5/earlystop_245_acc0.9995.pt',

    
# Clean-WD 0.0005, 0.98
# './results/mnist/Apr29_1353_clean_wd2_run1/best_epoch85_acc0.9983.pt',
# './results/mnist/Apr29_1435_clean_wd2_run2/best_epoch95_acc0.9986.pt',
# './results/mnist/Apr29_1704_clean_wd2_run3/best_epoch105_acc0.9988.pt',
# './results/mnist/Apr29_1749_clean_wd2_run4/best_epoch92_acc0.9983.pt',
# './results/mnist/Apr29_1829_clean_wd2_run5/best_epoch115_acc0.9987.pt',
    

    
# Shift-WD
# './results/mnist/Apr30_0304_shift_run1/best_epoch103_acc0.9966.pt',
# './results/mnist/Apr30_0400_shift_run2/best_epoch78_acc0.9977.pt',
# './results/mnist/Apr30_0425_shift_run3/best_epoch85_acc0.9973.pt',
# './results/mnist/Apr30_0450_shift_run4/best_epoch80_acc0.9958.pt',
# './results/mnist/Apr30_0513_shift_run5/best_epoch64_acc0.9978.pt'

    
# Clean-Aug-WD
# './results/mnist/May25_0235_clean_aug_run1/best_epoch84_acc0.9662.pt',
# './results/mnist/May25_0323_clean_aug_run2/best_epoch83_acc0.9714.pt',
# './results/mnist/May25_0411_clean_aug_run3/best_epoch99_acc0.9713.pt',
# './results/mnist/May25_0506_clean_aug_run4/best_epoch81_acc0.9664.pt',
# './results/mnist/May25_0554_clean_aug_run5/best_epoch75_acc0.9713.pt'

    
    
# Clean-WD
# './results/mnist/Apr29_0213_clean_wd_run1/best_epoch75_acc0.9981.pt',
# './results/mnist/Apr29_0247_clean_wd_run2/best_epoch81_acc0.9987.pt',
# './results/mnist/Apr29_0322_clean_wd_run3/best_epoch88_acc0.9987.pt',
# './results/mnist/Apr29_0400_clean_wd_run4/best_epoch62_acc0.9980.pt',
# './results/mnist/Apr29_0429_clean_wd_run5/best_epoch71_acc0.9987.pt',
    
# Recon-blur-WD
# './results/mnist/Apr30_0144_recon_run1/best_epoch81_acc0.9983.pt',
# './results/mnist/Apr30_0208_recon_run2/best_epoch73_acc0.9986.pt',
# './results/mnist/Apr30_0230_recon_run3/best_epoch77_acc0.9986.pt',
# './results/mnist/Apr30_0253_recon_run4/best_epoch63_acc0.9981.pt',
# './results/mnist/Apr30_0321_recon_run5/best_epoch77_acc0.9987.pt',
    
# resnet-blur version
# './results/mnist/May29_0416_blur_res_run1/best_epoch21_acc1.0000.pt',
# './results/mnist/May29_0427_blur_res_run2/best_epoch30_acc1.0000.pt',
# './results/mnist/May29_0440_blur_res_run3/best_epoch37_acc1.0000.pt',
# './results/mnist/May29_0455_blur_res_run4/best_epoch36_acc1.0000.pt',
# './results/mnist/May29_0510_blur_res_run5/best_epoch29_acc1.0000.pt'

# resnet 4 layers -blur
# './results/mnist/May29_2045_blur_res4_run1/best_epoch32_acc1.0000.pt',
# './results/mnist/May29_2056_blur_res4_run2/best_epoch32_acc1.0000.pt',
# './results/mnist/May29_2109_blur_res4_run3/best_epoch32_acc1.0000.pt',
# './results/mnist/May29_2121_blur_res4_run4/best_epoch30_acc1.0000.pt',
# './results/mnist/May29_2132_blur_res4_run5/best_epoch30_acc1.0000.pt'

# resnet 4 layers - lsf
# './results/mnist/Aug14_0508_lsf_res4_run1/best_epoch30_acc1.0000.pt',
# './results/mnist/Aug14_0524_lsf_res4_run2/best_epoch29_acc1.0000.pt',
# './results/mnist/Aug14_0540_lsf_res4_run3/best_epoch26_acc1.0000.pt',
# './results/mnist/Aug14_0556_lsf_res4_run4/best_epoch29_acc1.0000.pt',
# './results/mnist/Aug14_0612_lsf_res4_run5/best_epoch28_acc1.0000.pt'

# resnet 4 layers - hsf
# './results/mnist/Aug14_0508_hsf_res4_run1/best_epoch32_acc1.0000.pt',
# './results/mnist/Aug14_0525_hsf_res4_run2/best_epoch25_acc1.0000.pt',
# './results/mnist/Aug14_0540_hsf_res4_run3/best_epoch27_acc1.0000.pt',
# './results/mnist/Aug14_0556_hsf_res4_run4/best_epoch30_acc1.0000.pt',
# './results/mnist/Aug14_0612_hsf_res4_run5/best_epoch28_acc1.0000.pt'
    
# cnn 2 layers - lsf
'./results/mnist/Aug14_0805_lsf_conv_run5/best_epoch88_acc0.9988.pt',
'./results/mnist/Aug14_0740_lsf_conv_run4/best_epoch86_acc0.9982.pt',
'./results/mnist/Aug14_0721_lsf_conv_run3/best_epoch60_acc0.9987.pt',
'./results/mnist/Aug14_0658_lsf_conv_run2/best_epoch78_acc0.9987.pt',
'./results/mnist/Aug14_0635_lsf_conv_run1/best_epoch74_acc0.9981.pt'
]



df = pd.DataFrame()
df['corruption'] = CORRUPTION_TYPES

for load_model_path in modelpathlist:    
#     modelname = '-'.join( os.path.dirname(load_model_path).split('_')[-3:-1]) #'recon-step3'
    modelname = load_model_path.split('/')[-1] # filename
    print(f'test starts on {modelname}')
    
    # load args and model  
    args = load_args(load_model_path, args_to_update, print_args)
    model = load_model(args)
    
    acclist = []
    for corruption in CORRUPTION_TYPES:
        test_loss, test_loss_class, test_loss_recon, test_acc = test_model_on_each_corruption(PATH_MNISTC, corruption, model, verbose)
        acclist.append(test_acc*100)
    df[load_model_path] = acclist


df.index = np.arange(1, len(df)+1)
df.loc['AVERAGE'] = df.mean()
df

test starts on best_epoch88_acc0.9988.pt

TASK: mnist_recon_low (# targets: 1, # classes: 10, # background: 0)
TIMESTEPS #: 1
ENCODER: two-conv-layer w/ None projection
...resulting primary caps #: 1152, dim: 8
ROUTINGS # 1
Object #: 10, BG Capsule #: 0
DECODER: fcn, w/ None projection
...recon only one object capsule: True

test starts on best_epoch86_acc0.9982.pt

TASK: mnist_recon_low (# targets: 1, # classes: 10, # background: 0)
TIMESTEPS #: 1
ENCODER: two-conv-layer w/ None projection
...resulting primary caps #: 1152, dim: 8
ROUTINGS # 1
Object #: 10, BG Capsule #: 0
DECODER: fcn, w/ None projection
...recon only one object capsule: True




KeyboardInterrupt



In [3]:
# save to csv
# path_df = 'model-results-rrcapsnet-clean-G4L3.csv'
# path_df = 'model-results-rrcapsnet-shift-G4L3.csv'
# path_df = 'model-results-rrcapsnet-recon-G1L1.csv'
path_df = 'model-results-rrcapsnet-clean-aug-G4L3.csv'
path_df = 'model-results-rrcapsnet-blur-resnet4-G1L1.csv'
path_df = 'model-results-rrcapsnet-lowpass-resnet4-G4L1.csv'
path_df = 'model-results-rrcapsnet-highpass-resnet4-G1L1.csv'
path_df = 'model-results-rrcapsnet-lowpass-conv-G1L1.csv'

overwrite = True
if os.path.isfile(path_df) and not overwrite:
    print(f'test done! file {path_df} already exists, df is not saved')
else: 
    df.to_csv(path_df, index=False)
    print(f'test done! df is saved to csv as {path_df}')

test done! df is saved to csv as model-results-rrcapsnet-lowpass-conv-G1L1.csv


# On shape purturbed dataset

In [13]:
########################################################
# best model comparison on entire corruptions
########################################################
tasklist = ['mnist', 'mnist_occlusion', 'mnist_flipped', 'mnist_random']
verbose=False
print_args =False
args_to_update = {'device':DEVICE, 'batch_size':BATCHSIZE, 
                 'time_steps': 1, 'routings': 3, 'mask_threshold': 0.1}


modelpathlist = [
# './models/rrcapsnet/run1_epoch50_acc0.9917.pt',
# './models/rrcapsnet/run2_epoch50_acc0.9915.pt',
# './models/rrcapsnet/run3_epoch50_acc0.9907.pt',
# './models/rrcapsnet/run4_epoch50_acc0.9905.pt',
# './models/rrcapsnet/run5_epoch50_acc0.9907.pt',

# './results/mnist/Apr28_1352_adam_clr_128/archive_model_epoch50_acc0.9987.pt'
# './results/mnist/Apr28_1408_adam_clr_512/archive_model_epoch50_acc0.9953.pt' # increase bn worse
# './results/mnist/Apr28_1438_adam_clr_128_wde5/archive_model_epoch50_acc0.9986.pt'
# './results/mnist/Apr28_1458_adam_clr_128_wde4/archive_model_epoch50_acc0.9963.pt' #incerase weigthday worse
# './results/mnist/Apr28_1550_adamw_clr_128/archive_model_epoch50_acc0.9988.pt'
# './results/mnist/Apr28_1626_rmsprop_clr_128/archive_model_epoch50_acc0.9936.pt'
# './results/mnist/Apr28_1641_adam_exp_128/archive_model_epoch50_acc0.9979.pt'
# './results/mnist/Apr28_1719_adam_exp_128_lre3/archive_model_epoch50_acc0.9982.pt'
# './results/mnist/Apr28_1745_adam_exp_128_lre3_lamrecon5/archive_model_epoch50_acc0.9986.pt'
# './results/mnist/Apr28_1815_capsnetencoder_cycle/archive_model_epoch220_acc0.9961.pt'

# './results/mnist/Apr28_2001_shift_adamclr/archive_model_epoch50_acc0.9860.pt'

# Clean-CLR
# './results/mnist/Apr29_0212_clean_clr_run1/best_epoch174_acc0.9993.pt',
# './results/mnist/Apr29_0321_clean_clr_run2/best_epoch163_acc0.9993.pt',
# './results/mnist/Apr29_0427_clean_clr_run3/best_epoch186_acc0.9994.pt',
# './results/mnist/Apr29_0540_clean_clr_run4/best_epoch212_acc0.9995.pt',
# './results/mnist/Apr29_0703_clean_clr_run5/earlystop_245_acc0.9995.pt',
    
# Clean-WD
# './results/mnist/Apr29_0213_clean_wd_run1/best_epoch75_acc0.9981.pt',
# './results/mnist/Apr29_0247_clean_wd_run2/best_epoch81_acc0.9987.pt',
# './results/mnist/Apr29_0322_clean_wd_run3/best_epoch88_acc0.9987.pt',
# './results/mnist/Apr29_0400_clean_wd_run4/best_epoch62_acc0.9980.pt',
# './results/mnist/Apr29_0429_clean_wd_run5/best_epoch71_acc0.9987.pt',
    
# Clean-WD 0.0005, 0.98
# './results/mnist/Apr29_1353_clean_wd2_run1/best_epoch85_acc0.9983.pt',
# './results/mnist/Apr29_1435_clean_wd2_run2/best_epoch95_acc0.9986.pt',
# './results/mnist/Apr29_1704_clean_wd2_run3/best_epoch105_acc0.9988.pt',
# './results/mnist/Apr29_1749_clean_wd2_run4/best_epoch92_acc0.9983.pt',
# './results/mnist/Apr29_1829_clean_wd2_run5/best_epoch115_acc0.9987.pt',
    
# Recon-WD
# './results/mnist/Apr30_0144_recon_run1/best_epoch81_acc0.9983.pt',
# './results/mnist/Apr30_0208_recon_run2/best_epoch73_acc0.9986.pt',
# './results/mnist/Apr30_0230_recon_run3/best_epoch77_acc0.9986.pt',
# './results/mnist/Apr30_0253_recon_run4/best_epoch63_acc0.9981.pt',
# './results/mnist/Apr30_0321_recon_run5/best_epoch77_acc0.9987.pt',
    
# Shift-WD
# './results/mnist/Apr30_0304_shift_run1/best_epoch103_acc0.9966.pt',
# './results/mnist/Apr30_0400_shift_run2/best_epoch78_acc0.9977.pt',
# './results/mnist/Apr30_0425_shift_run3/best_epoch85_acc0.9973.pt',
# './results/mnist/Apr30_0450_shift_run4/best_epoch80_acc0.9958.pt',
# './results/mnist/Apr30_0513_shift_run5/best_epoch64_acc0.9978.pt'

]



df = pd.DataFrame()
df['task'] = tasklist

for load_model_path in modelpathlist:    
#     modelname = '-'.join( os.path.dirname(load_model_path).split('_')[-3:-1]) #'recon-step3'
    modelname = load_model_path.split('/')[-1] # filename
    print(f'test starts on {modelname}')
    
    # load args and model  
    args = load_args(load_model_path, args_to_update, print_args)
    model = load_model(args)
    
    acclist=[]
    for task in tasklist:
        _, _, _,acc = test_model(task, model, args)
        acclist.append(acc*100)
        
    df[load_model_path] = acclist

df

test starts on best_epoch81_acc0.9983.pt

TASK: mnist_recon (# targets: 1, # classes: 10, # background: 0)
TIMESTEPS #: 1
ENCODER: two-conv-layer w/ None projection
...resulting primary caps #: 1152, dim: 8
ROUTINGS # 3
Object #: 10, BG Capsule #: 0
DECODER: fcn, w/ None projection
...recon only one object capsule: True

original mnist dataset
test starts on best_epoch73_acc0.9986.pt

TASK: mnist_recon (# targets: 1, # classes: 10, # background: 0)
TIMESTEPS #: 1
ENCODER: two-conv-layer w/ None projection
...resulting primary caps #: 1152, dim: 8
ROUTINGS # 3
Object #: 10, BG Capsule #: 0
DECODER: fcn, w/ None projection
...recon only one object capsule: True

original mnist dataset
test starts on best_epoch77_acc0.9986.pt

TASK: mnist_recon (# targets: 1, # classes: 10, # background: 0)
TIMESTEPS #: 1
ENCODER: two-conv-layer w/ None projection
...resulting primary caps #: 1152, dim: 8
ROUTINGS # 3
Object #: 10, BG Capsule #: 0
DECODER: fcn, w/ None projection
...recon only one object 

Unnamed: 0,task,./results/mnist/Apr30_0144_recon_run1/best_epoch81_acc0.9983.pt,./results/mnist/Apr30_0208_recon_run2/best_epoch73_acc0.9986.pt,./results/mnist/Apr30_0230_recon_run3/best_epoch77_acc0.9986.pt,./results/mnist/Apr30_0253_recon_run4/best_epoch63_acc0.9981.pt,./results/mnist/Apr30_0321_recon_run5/best_epoch77_acc0.9987.pt
0,mnist,99.030004,99.150004,99.060004,99.150005,99.110004
1,mnist_occlusion,91.280004,91.820006,91.260004,90.870005,91.760005
2,mnist_flipped,70.540003,70.250003,69.910004,70.620003,71.050004
3,mnist_random,36.190002,36.950002,35.900002,36.960001,36.730002


# for comparing models under the same experiment folder
obtain best model from each experiment and compare overall accuracy
- all pretrained models to be compared should be saved under the same folder 
- output is df file and will be saved to csv

In [3]:
##################
# set up info for model testing
##################
task='mnist_recon' #train on mnist_recon, test on mnistc_mini
task='mnist_c' #test on 15 benchmark corruption + 1 identity 

# get best model files under experiment path
path_experiment = './results/mnist/experiment-recon-norecon'
bestfiles, expnames = get_bestmodel_paths(path_experiment)
expname_format = ['use_recon','n_step', 'seed'] # what's writte n after timestamp
# expname_format = ['use_recon','n_step', 'inputmatch', 'seed']

# arguments to update
args_to_update = {'device':DEVICE, 'batch_size':BATCHSIZE, 
                 'time_steps': 4, 'routings': 3, 'routing_type': 'pd-recon', #'original' 
                 'min_coup': 0.5, 'min_rscore': 0.5, 'mask_threshold': 0.1}
print_args =False


##################
# main - model testing
##################

# create dataframe
df = pd.DataFrame()
df['model_path'] = bestfiles
df['exp_name'] = expnames
df[expname_format] = df['exp_name'].str.split('_',expand=True)
df['exp_name'] = df['exp_name'].str.split('_').str[:-1].str.join('_') #expname without seed

for i, row in df.iterrows():
    print(f'test begin on {i+1}th model')

    # load model
    load_model_path = row['model_path']   
    args = load_args(load_model_path, args_to_update, print_args)
    model = load_model(args)

    # test model
    if task == 'mnist_c':
        df.loc[i,'test_loss'], df.loc[i,'test_loss_class'], df.loc[i,'test_loss_recon'], df.loc[i,'test_acc'] = test_model_mnistc(PATH_MNISTC, CORRUPTION_TYPES, model)
    else:
        df.loc[i,'test_loss'], df.loc[i,'test_loss_class'], df.loc[i,'test_loss_recon'], df.loc[i,'test_acc'] = test_model(task, model, args)
    #     test_loss, test_loss_class, test_loss_recon, test_acc = test_model(model,args)
    
print('========== tests are done =============')

# save df
path_savefile = task + '-'+ os.path.basename(path_experiment) + '.csv'
if os.path.isfile(path_savefile):
    print(f'csv file already exists: {path_savefile}')
else: 
    df.to_csv(path_savefile)
    print(f'csv file saved: {path_savefile}')