In [1]:
# load required libraries & modules
%load_ext autoreload
%autoreload 2

import os
from tqdm.notebook import tqdm
import pprint
import time
import warnings
# warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
import torch

from utils import *
from loaddata import *
from visualization import *
from rrcapsnet_original import *

torch.set_grad_enabled(False)
torch.set_printoptions(sci_mode=False)

DATA_DIR = '../data'
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# DEVICE = torch.device('cpu')

BATCHSIZE = 1000

PATH_MNISTC = '../data/MNIST_C/'
CORRUPTION_TYPES = ['identity', 
         'shot_noise', 'impulse_noise','glass_blur','motion_blur',
         'shear', 'scale',  'rotate',  'brightness',  'translate',
         'stripe', 'fog','spatter','dotted_line', 'zigzag',
         'canny_edges']

N_MINI_PER_CORRUPTION = 1000

ACC_TYPE = "entropy"

# general helper funtions for model testing
def load_model(args):
    # load model
    model = RRCapsNet(args).to(args.device) 
    model.load_state_dict(torch.load(args.load_model_path))
    return model

def load_args(load_model_path, args_to_update, verbose=False):
    params_filename = os.path.dirname(load_model_path) + '/params.txt'
    assert os.path.isfile(params_filename), "No param flie exists"
    args = parse_params_wremove(params_filename, removelist = ['device']) 
    args = update_args(args, args_to_update)
    args.load_model_path = load_model_path
    if verbose:
        pprint.pprint(args.__dict__, sort_dicts=False)
    return args


###########################
# evaluate on mnist-c original version
############################
@torch.no_grad()
def evaluate_model_on_mnistc_original(corruption, model, verbose=False, save_hooks=False,  max_batch_num=None):
    path_images = os.path.join(PATH_MNISTC, corruption, 'test_images.npy')
    path_labels = os.path.join(PATH_MNISTC, corruption, 'test_labels.npy')

    # convert to torch
    images = np.load(path_images)
    labels = np.load(path_labels)
    transform_tohot = T.Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
    images_tensorized = torch.stack([T.ToTensor()(im) for im in images])
    labels_tensorized = torch.stack([transform_tohot(label) for label in labels])
    # print(images_tensorized.shape) #torch.Size([10000, 1, 28, 28])
    # print(labels_tensorized.shape) #torch.Size([10000, 10])

    # create dataloader
    kwargs = {'num_workers': 1, 'pin_memory': True} if DEVICE == 'cuda' else {}
    dataset = TensorDataset(images_tensorized, labels_tensorized)
    dataloader = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=False, drop_last=False, **kwargs)

    if save_hooks:
        def get_attention_outputs():
            def hook(model, input, output):
                x_mask_step.append(output[0].detach())
                x_input_step.append(output[1].detach())
            return hook

        def get_capsule_outputs():
            def hook(model, input, output):
                objcaps_step.append(output[0].detach())
                coups_step.append(torch.stack(output[1]['coups'], dim=1))
                betas_step.append(torch.stack(output[1]['betas'], dim=1)) 
                if 'rscores' in output[1].keys():
                    rscores_step.append(torch.stack(output[1]['rscores'], dim=1))
                if 'recon_coups' in output[1].keys():
                    recon_coups_step.append(torch.stack(output[1]['recon_coups'], dim=1))
                if 'outcaps_len' in output[1].keys():
                    outcaps_len_step.append(torch.stack(output[1]['outcaps_len'], dim=1))
                if 'outcaps_len_before' in output[1].keys():
                    outcaps_len_before_step.append(torch.stack(output[1]['outcaps_len_before'], dim=1))
            return hook
        
        outputs = {}

        x_input_step_all = []; x_mask_step_all = []; objcaps_step_all = []

        coups_step_all = []; betas_step_all= []; rscores_step_all=[]; recon_coups_step_all=[] 
        outcaps_len_step_all=[]; outcaps_len_before_step_all=[]

    x_all, y_all, gtx_all, loss_all, acc_all, objcaps_len_step_all, x_recon_step_all = [],[],[],[],[],[],[]
    
    model.eval()      
    
    # get input and gt
    i=0
    for data in dataloader:
        x, y = data
        gtx = None
        
        if max_batch_num:
            if i == max_batch_num:
                break
                
#         if i == max_batch_num:
#             x, y = data
#             gtx = None

        # for hooks over other model output
        x_input_step = []; x_mask_step = []; objcaps_step = []

        if save_hooks:

            # for hooks over dynamic routing
            coups_step = []; betas_step= []; rscores_step=[]; recon_coups_step=[] 
            outcaps_len_step=[]; outcaps_len_before_step=[]

            hook1 = model.input_window.register_forward_hook(get_attention_outputs())
            hook2 = model.capsule_routing.register_forward_hook(get_capsule_outputs())

        # evaluate and append results 
        losses, acc, objcaps_len_step, x_recon_step = evaluate(model, x, y, args, acc_type=ACC_TYPE, gtx=gtx)

        if verbose:
            print("==> On this sigle test batch: test_loss=%.5f, test_loss_class=%.5f, test_loss_recon=%.5f, test_acc=%.4f"
                  % (losses[0], losses[1], losses[2], acc))   

        # main input and output append
        x_all.append(x)
        y_all.append(y)
        if gtx:
            gtx_all.append(gtx)
        #         loss_all.append(losses[0])
        acc_all.append(acc)
        objcaps_len_step_all.append(objcaps_len_step)
        x_recon_step_all.append(x_recon_step)

        if save_hooks:

            # hook variables append
            x_input_step_all.append(torch.stack(x_input_step, dim=1))
            x_mask_step_all.append(torch.stack(x_mask_step, dim=1))
            objcaps_step_all.append(torch.stack(objcaps_step, dim=1))

            coups_step_all.append(torch.stack(coups_step, dim=1))
            betas_step_all.append(torch.stack(betas_step, dim=1))
            if rscores_step:
                rscores_step_all.append(torch.stack(rscores_step, dim=1))
            if recon_coups_step:
                recon_coups_step_all.append(torch.stack(recon_coups_step, dim=1))
            if outcaps_len_step:
                outcaps_len_step_all.append(torch.stack(outcaps_len_step, dim=1))
            if outcaps_len_before_step:
                outcaps_len_before_step_all.append(torch.stack(outcaps_len_before_step, dim=1))

            hook1.remove()
            hook2.remove()        
        
        
        i+=1
        

    
        
    # concat and add to outputs dictionary
    x_all = torch.cat(x_all, dim=0)
    y_all = torch.cat(y_all, dim=0)
    if gtx:
        gtx_all = torch.cat(gtx_all, dim=0)
    else:
        gtx_all = gtx
    acc_all = torch.cat(acc_all, dim=0)
    objcaps_len_step_all = torch.cat(objcaps_len_step_all, dim=0)
    x_recon_step_all = torch.cat(x_recon_step_all, dim=0)
    
    if save_hooks:
        outputs['x_input']= torch.cat(x_input_step_all, dim=0)
        outputs['x_mask']= torch.cat(x_mask_step_all, dim=0)
        outputs['objcaps']= torch.cat(objcaps_step_all, dim=0)

        outputs['coups'] = torch.cat(coups_step_all, dim=0)
        outputs['betas'] = torch.cat(betas_step_all, dim=0)
        if rscores_step_all:
            outputs['rscores'] = torch.cat(rscores_step_all, dim=0)
        if recon_coups_step_all:
            outputs['recon_coups'] = torch.cat(recon_coups_step_all, dim=0)
        if outcaps_len_step_all:
            outputs['outcaps_len'] = torch.cat(outcaps_len_step_all, dim=0)
        if outcaps_len_before_step_all:
            outputs['outcaps_len_before'] = torch.cat(outcaps_len_before_step_all, dim=0)
            
        return x_all, gtx_all, y_all, acc_all, objcaps_len_step_all, x_recon_step_all, outputs  

    else:
        return x_all, gtx_all, y_all, acc_all, objcaps_len_step_all, x_recon_step_all
    
    

@torch.no_grad()
def evaluate_cnn_on_mnistc_original(corruption, cnn, max_batch_num=None):
    path_images = os.path.join(PATH_MNISTC, corruption, 'test_images.npy')
    path_labels = os.path.join(PATH_MNISTC, corruption, 'test_labels.npy')

    # convert to torch
    images = np.load(path_images)
    labels = np.load(path_labels)
    transform_tohot = T.Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
    images_tensorized = torch.stack([T.ToTensor()(im) for im in images])
    labels_tensorized = torch.stack([transform_tohot(label) for label in labels])
    # print(images_tensorized.shape) #torch.Size([10000, 1, 28, 28])
    # print(labels_tensorized.shape) #torch.Size([10000, 10])

    # create dataloader
    kwargs = {'num_workers': 1, 'pin_memory': True} if DEVICE == 'cuda' else {}
    dataset = TensorDataset(images_tensorized, labels_tensorized)
    dataloader = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=False, drop_last=False, **kwargs)

    # save output
    x_all, y_all, pred_all, acc_all, class_prob_all = [],[],[], [],[]
    cnn.eval() 

    # get input and gt
    i=0
    for data in dataloader:
        x, y = data
        gtx = None
        
        if max_batch_num:
            if i == max_batch_num:
                break

#         if i == max_batch_num:
#             x, y = data
#             gtx = None
                    
        data, target = x.to(DEVICE),  y.to(DEVICE)
        target = target.argmax(dim=1, keepdim=True)
        output = cnn(data)
        #                 test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        acc = pred.eq(target.view_as(pred))

        x_all.append(data)
        y_all.append(target.flatten())
        pred_all.append(pred.flatten())
        acc_all.append(acc.flatten().float())
        class_prob_all.append(output)
        
        i+=1


    # concat and add to outputs dictionary
    x_all = torch.cat(x_all, dim=0)
    y_all = torch.cat(y_all, dim=0)
    pred_all = torch.cat(pred_all, dim=0)
    acc_all = torch.cat(acc_all, dim=0)
    class_prob_all = torch.cat(class_prob_all, dim=0)

    return x_all, y_all, class_prob_all, pred_all, acc_all

In [2]:
import matplotlib
import collections
import json
import random

def save_imgarr(imgarr, filename='test.png', scale=8):
    h, w, _ = imgarr.shape
    fig, axes = plt.subplots(figsize=(h*scale, w*scale))
    fig.subplots_adjust(top=1.0, bottom=0, right=1.0, left=0, hspace=0, wspace=0) 
    axes.imshow(imgarr, cmap='gray_r')
    axes.axis('off')
    plt.savefig(filename, dpi=1, format='png') 
    plt.close(fig)

# experiment 2: analyze errors and compare with cnn

In [14]:
#############################
# test a single model, and visualize outputs
#############################
task='mnist_c_original'

train=False #train or test dataset
print_args=False


# load our model
args_to_update = {'device':DEVICE, 'batch_size':BATCHSIZE, 
                 'time_steps': 5, 'routings': 3,'mask_threshold': 0.1}

load_model_path = './models/rrcapsnet/rrcapsnet_best.pt'
# load_model_path = './results/mnist/Aug14_0508_lsf_res4_run1/best_epoch30_acc1.0000.pt'

args = load_args(load_model_path, args_to_update, print_args)
model = load_model(args)

# load cnn
from train_cnn import *
path_cnn = './models/cnn/cnn_best.pt'
cnn = Net().to(DEVICE)
cnn.load_state_dict(torch.load(path_cnn))
cnn.eval()

path_save = './stimuli/stimuli-exp2-step5/'



# obtain model predictions
# CORRUPTION_INTEREST = ['identity', 'glass_blur','motion_blur', 'impulse_noise','shot_noise',
#         'fog','dotted_line','spatter', 'zigzag']
CORRUPTION_INTEREST = ['motion_blur']

for corruption in CORRUPTION_INTEREST :

    if task =='mnist_c_original':
        print("original is used")
        x, gtx, y_hot, acc_model, objcaps_len_step, x_recon_step = \
        evaluate_model_on_mnistc_original(corruption, model, verbose=False, save_hooks=False)
        print(f'==> corruption type: {corruption}, this batch acc: {acc_model.mean().item()}')
    else:
        raise NotImplementedError

    # get model prediction
    objcaps_len_step_narrow = objcaps_len_step.narrow(dim=2,start=0, length=args.num_classes)
    # pred_model = objcaps_len_step_narrow.max(dim=-1)[1][:,-1] #torch.Size([1000, 3])


    if ACC_TYPE=='hypothesis':
        if args.time_steps==1:
            y_pred = objcaps_len_step_narrow[:,-1]
            accs = topkacc(y_pred, y_true, topk=1)
        else:
            acc_model_check, pred_model, nstep  = compute_hypothesis_based_acc(objcaps_len_step_narrow, y_hot, only_acc=False)

    elif ACC_TYPE == 'entropy':    
        if args.time_steps==1:
            y_pred = objcaps_len_step_narrow[:,-1]
            accs = topkacc(y_pred, y_true, topk=1)
        else: 
            acc_model_check, pred_model, nstep, no_stop_condition, entropy_model =compute_entropy_based_acc(objcaps_len_step_narrow, y_hot, threshold=0.6, use_cumulative = False, only_acc= False)

    assert round(acc_model.mean().item(), 4) == round(acc_model_check.float().mean().item(), 4)


    ##################
    # get cnn prediction
    ##################
    if task =='mnist_c_original':
        print("original is used")
        data_cnn, target_cnn, logsoft_cnn, pred_cnn, acc_cnn \
        =  evaluate_cnn_on_mnistc_original(corruption, cnn)
        print(f'==> corruption type: {corruption}, this batch acc: {acc_cnn.float().mean().item()}')
        
        
    #######################
    # get trials id where both model fails & disagree
    #######################
    assert (target_cnn.cpu() == y_hot.max(dim=1)[1].cpu()).all()
    bool_bothincorrect = ~(acc_model.bool())* ~(acc_cnn.bool())
    bool_diffanswer = (pred_model!= pred_cnn)
    bool_onlycnncorrect = ~(acc_model.bool())* (acc_cnn.bool())
    # idx_bothincorrect = torch.nonzero(bool_bothincorrect)
    trialid_interest = torch.nonzero(bool_bothincorrect*bool_diffanswer).flatten().tolist()
    # trialid_interest = torch.nonzero(bool_onlycnncorrect*bool_diffanswer).flatten().tolist()
    print(len(trialid_interest))
    
    #####################
    # save image (in reverted grayscale) and trialinfo
    #####################
    d = collections.defaultdict(dict)
    MAX_TRIAL = 20

    trialid_to_visualize = random.sample(trialid_interest, min(len(trialid_interest), MAX_TRIAL))

    for trialid in trialid_to_visualize:
        cnn_pred = pred_cnn[trialid].cpu().item()
        our_pred = pred_model[trialid].cpu().item()
        gt =target_cnn[trialid].cpu().item()

        # save image x8 original pixel size
        imgarray = x[trialid].numpy()
        filename = f'{corruption}_t{trialid}_g{gt}_c{cnn_pred}_o{our_pred}.png'
        save_imgarr(np.transpose(imgarray,(1,2,0)), filename=path_save+filename)

        # save trial info to dictionary
        d[filename]['id'] = trialid
        d[filename]['imgarray'] = imgarray.tolist()
        d[filename]['cnn_softmax'] = torch.exp(logsoft_cnn[trialid]).cpu().numpy().tolist()
        d[filename]['our_objlen'] = objcaps_len_step_narrow[trialid, -1].cpu().numpy().tolist()
        d[filename]['cnn_pred'] = cnn_pred
        d[filename]['our_pred'] = our_pred
        d[filename]['gt'] =  gt


    # save dictionary
    with open(path_save+ f'{corruption}.json', 'w') as fp:
        json.dump(d, fp)
    print('trialinfo saved to disk!')    



TASK: mnist_recon_low (# targets: 1, # classes: 10, # background: 0)
TIMESTEPS #: 5
ENCODER: resnet w/ None projection
...resulting primary caps #: 288, dim: 8
ROUTINGS # 3
Object #: 10, BG Capsule #: 0
DECODER: fcn, w/ None projection
...recon only one object capsule: True
...use recon mask for attention: True
...with mask type bool, threshold 0.1, apply_method match

original is used
==> corruption type: motion_blur, this batch acc: 0.9667999744415283
original is used
==> corruption type: motion_blur, this batch acc: 0.9440999627113342
55
trialinfo saved to disk!


In [17]:
# how many trials we have for each corruption
# CORRUPTION_INTEREST = ['identity', 'glass_blur','motion_blur', 'impulse_noise','shot_noise',
#         'fog','dotted_line','spatter', 'zigzag']

CORRUPTION_INTEREST = ['identity', 'glass_blur','motion_blur', 'impulse_noise','shot_noise', 'fog', 'zigzag']

for corruption in CORRUPTION_INTEREST :
    with open(f'./stimuli/stimuli-exp2-step5/{corruption}.json') as json_file:
        jfile = json.load(json_file)
        
    print(corruption, len(jfile))

# the following condition won't be included because not enough trials 
# dotted_line 12
# spatter 15

identity 5
glass_blur 20
motion_blur 20
impulse_noise 20
shot_noise 20
fog 20
zigzag 20


# Create experiment files

In [1]:
import numpy as np
import pandas as pd
import os
import random

In [2]:
############################
# create prac trials
############################
# colnames = ['trialtype', 'corruption', 'imgID', 'gt', 'cnn', 'our', 'imgpath', 'maskpath',
#            'q1', 'q2', 'q1_text', 'q2_text', 'q1_from', 'q2_from']
colnames = ['trialtype', 'corruption', 'imgID', 'gt', 'cnn', 'our', 'imgpath', 
           'q1', 'q2', 'q1_text', 'q2_text', 'q1_from', 'q2_from']
df = pd.DataFrame(columns= colnames )

###########
# create trial info from filenames
###########
path_exp = 'stimuli/stimuli-exp2-step5/'
# masklist = ['stimuli/stimuli-mask/'+f for f in os.listdir('stimuli/stimuli-mask/') if not f.startswith('.')]
# masklist = [f for f in os.listdir('stimuli-mask/') if not f.startswith('.')]

i=1
for f in os.listdir(path_exp):
    if f.endswith('png'):
        corruption = '_'.join(f.split('_')[:-4])
        if corruption == 'identity': #use as practice trials
            trialtype = 'prac'
        else:
            trialtype = 'exp'        
        imgID = int(f.split('_')[-4][1:])
        gt = int(f.split('_')[-3][1:])
        cnn = int(f.split('_')[-2][1])
        our = int(f.split('_')[-1][1])
        imgpath = 'stimuli/' + f #path_exp + f
#         imgpath = f
#         maskpath= random.sample(masklist,1)[0]
        
        # counterbalance the question rating order
        if i%2==0: #even case-> rating1 is cnn's answer and rating 2 is ours
            q1 = cnn
            q2 = our
            q1_text = f'How likely is this digit to be {q1}?'
            q2_text =f'How likely is this digit to be {our}?'
            q1_from = 'cnn'
            q2_from = 'our'
        else: # odd case
            q1 = our
            q2 = cnn
            q1_text = f'How likely is this digit to be {our}?'
            q2_text =f'How likely is this digit to be {cnn}?'            
            q1_from = 'our'
            q2_from = 'cnn'            

#         df.loc[i] = [trialtype, corruption, imgID, gt, cnn, our, imgpath, maskpath,
#                        q1, q2, q1_text, q2_text, q1_from, q2_from]   
        df.loc[i] = [trialtype, corruption, imgID, gt, cnn, our, imgpath,
                       q1, q2, q1_text, q2_text, q1_from, q2_from]   
        i+=1
        


In [3]:
df

Unnamed: 0,trialtype,corruption,imgID,gt,cnn,our,imgpath,q1,q2,q1_text,q2_text,q1_from,q2_from
1,exp,impulse_noise,2732,6,5,8,stimuli/impulse_noise_t2732_g6_c5_o8.png,8,5,How likely is this digit to be 8?,How likely is this digit to be 5?,our,cnn
2,exp,zigzag,4676,1,2,4,stimuli/zigzag_t4676_g1_c2_o4.png,2,4,How likely is this digit to be 2?,How likely is this digit to be 4?,cnn,our
3,exp,shot_noise,6632,9,5,8,stimuli/shot_noise_t6632_g9_c5_o8.png,8,5,How likely is this digit to be 8?,How likely is this digit to be 5?,our,cnn
4,exp,fog,6785,2,8,4,stimuli/fog_t6785_g2_c8_o4.png,8,4,How likely is this digit to be 8?,How likely is this digit to be 4?,cnn,our
5,exp,glass_blur,1247,9,7,5,stimuli/glass_blur_t1247_g9_c7_o5.png,5,7,How likely is this digit to be 5?,How likely is this digit to be 7?,our,cnn
...,...,...,...,...,...,...,...,...,...,...,...,...,...
121,exp,impulse_noise,3926,9,8,3,stimuli/impulse_noise_t3926_g9_c8_o3.png,3,8,How likely is this digit to be 3?,How likely is this digit to be 8?,our,cnn
122,exp,fog,4176,2,4,8,stimuli/fog_t4176_g2_c4_o8.png,4,8,How likely is this digit to be 4?,How likely is this digit to be 8?,cnn,our
123,exp,zigzag,96,1,4,7,stimuli/zigzag_t96_g1_c4_o7.png,7,4,How likely is this digit to be 7?,How likely is this digit to be 4?,our,cnn
124,exp,motion_blur,6577,7,1,8,stimuli/motion_blur_t6577_g7_c1_o8.png,1,8,How likely is this digit to be 1?,How likely is this digit to be 8?,cnn,our


In [4]:
df = df.sort_values(by=['trialtype', 'corruption']).reset_index(drop=True)
df_prac = df[df['trialtype']=='prac'].copy()
df_exp = df[df['trialtype']=='exp'].copy()

print('df_prac length ', len(df_prac), ' df_exp length ', len(df_exp))
print(df_exp.corruption.unique())

df_prac length  5  df_exp length  120
['fog' 'glass_blur' 'impulse_noise' 'motion_blur' 'shot_noise' 'zigzag']


In [5]:
###############
# separate into n unique sets; 2sets * 60 images (10*6 corruptions)
###############
N_SET = 4
df_exp['cumcount'] = df_exp.groupby(['corruption']).cumcount()+1
df_exp['setnum'] = df_exp['cumcount'].apply(lambda x: int(x%N_SET +1))
df_exp = df_exp.drop(columns =['cumcount'])


for i in range(1, N_SET+1):
    expset = df_exp[df_exp.setnum==i] 
    expset = expset.sample(frac=1).reset_index(drop=True) #shuffle
    combined = df_prac.merge(expset, how='outer') # merge with prac
    combined = combined.reset_index(drop=True)
    combined['setnum'] = i # since prac parts has no set numbers

    combined.to_csv(f'exp2_source{i}.csv', index=False)
    
print('source csvs are saved to disk')

source csvs are saved to disk


In [7]:
# sanity check; load all expsource and confirm all unique images are covered
cats = []
for i in range(1, N_SET+1):
    test = pd.read_csv(f'exp2_source{i}.csv')
    print(len(test)) 
    test = test[test.trialtype=='exp']
    cats.append(test)
    
c =pd.concat(cats)
print(len(c.imgpath.unique()))
print(c.corruption.unique())
len(c.corruption.unique())

35
35
35
35
120
['zigzag' 'fog' 'glass_blur' 'shot_noise' 'motion_blur' 'impulse_noise']


6