In [4]:
# load required libraries & modules
%load_ext autoreload
%autoreload 2

import os
from tqdm.notebook import tqdm
import pprint
import time
import warnings
# warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
import torch

from utils import *
from loaddata import *
from visualization import *
from rrcapsnet_original import *

torch.set_grad_enabled(False)
torch.set_printoptions(sci_mode=False)

DATA_DIR = '../data'
DEVICE = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
# DEVICE = torch.device('cpu')

BATCHSIZE = 1000

PATH_MNISTC = '../data/MNIST_C/'
CORRUPTION_TYPES = ['identity', 
         'shot_noise', 'impulse_noise','glass_blur','motion_blur',
         'shear', 'scale',  'rotate',  'brightness',  'translate',
         'stripe', 'fog','spatter','dotted_line', 'zigzag',
         'canny_edges']

N_MINI_PER_CORRUPTION = 1000

ACC_TYPE = "entropy"

# general helper funtions for model testing
def load_model(args):
    # load model
    model = RRCapsNet(args).to(args.device) 
    model.load_state_dict(torch.load(args.load_model_path))
    return model

def load_args(load_model_path, args_to_update, verbose=False):
    params_filename = os.path.dirname(load_model_path) + '/params.txt'
    assert os.path.isfile(params_filename), "No param flie exists"
    args = parse_params_wremove(params_filename, removelist = ['device']) 
    args = update_args(args, args_to_update)
    args.load_model_path = load_model_path
    if verbose:
        pprint.pprint(args.__dict__, sort_dicts=False)
    return args


###########################
# evaluate on mnist-c original version
############################
@torch.no_grad()
def evaluate_model_on_mnistc_original(corruption, model, verbose=False, save_hooks=False,  max_batch_num=None):
    path_images = os.path.join(PATH_MNISTC, corruption, 'test_images.npy')
    path_labels = os.path.join(PATH_MNISTC, corruption, 'test_labels.npy')

    # convert to torch
    images = np.load(path_images)
    labels = np.load(path_labels)
    transform_tohot = T.Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
    images_tensorized = torch.stack([T.ToTensor()(im) for im in images])
    labels_tensorized = torch.stack([transform_tohot(label) for label in labels])
    # print(images_tensorized.shape) #torch.Size([10000, 1, 28, 28])
    # print(labels_tensorized.shape) #torch.Size([10000, 10])

    # create dataloader
    kwargs = {'num_workers': 1, 'pin_memory': True} if DEVICE == 'cuda' else {}
    dataset = TensorDataset(images_tensorized, labels_tensorized)
    dataloader = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=False, drop_last=False, **kwargs)

    if save_hooks:
        def get_attention_outputs():
            def hook(model, input, output):
                x_mask_step.append(output[0].detach())
                x_input_step.append(output[1].detach())
            return hook

        def get_capsule_outputs():
            def hook(model, input, output):
                objcaps_step.append(output[0].detach())
                coups_step.append(torch.stack(output[1]['coups'], dim=1))
                betas_step.append(torch.stack(output[1]['betas'], dim=1)) 
                if 'rscores' in output[1].keys():
                    rscores_step.append(torch.stack(output[1]['rscores'], dim=1))
                if 'recon_coups' in output[1].keys():
                    recon_coups_step.append(torch.stack(output[1]['recon_coups'], dim=1))
                if 'outcaps_len' in output[1].keys():
                    outcaps_len_step.append(torch.stack(output[1]['outcaps_len'], dim=1))
                if 'outcaps_len_before' in output[1].keys():
                    outcaps_len_before_step.append(torch.stack(output[1]['outcaps_len_before'], dim=1))
            return hook
        
        outputs = {}

        x_input_step_all = []; x_mask_step_all = []; objcaps_step_all = []

        coups_step_all = []; betas_step_all= []; rscores_step_all=[]; recon_coups_step_all=[] 
        outcaps_len_step_all=[]; outcaps_len_before_step_all=[]

    x_all, y_all, gtx_all, loss_all, acc_all, objcaps_len_step_all, x_recon_step_all = [],[],[],[],[],[],[]
    
    model.eval()      
    
    # get input and gt
    i=0
    for data in dataloader:
        x, y = data
        gtx = None
        
        if max_batch_num:
            if i == max_batch_num:
                break
                
#         if i == max_batch_num:
#             x, y = data
#             gtx = None

        # for hooks over other model output
        x_input_step = []; x_mask_step = []; objcaps_step = []

        if save_hooks:

            # for hooks over dynamic routing
            coups_step = []; betas_step= []; rscores_step=[]; recon_coups_step=[] 
            outcaps_len_step=[]; outcaps_len_before_step=[]

            hook1 = model.input_window.register_forward_hook(get_attention_outputs())
            hook2 = model.capsule_routing.register_forward_hook(get_capsule_outputs())

        # evaluate and append results 
        losses, acc, objcaps_len_step, x_recon_step = evaluate(model, x, y, args, acc_type=ACC_TYPE, gtx=gtx)

        if verbose:
            print("==> On this sigle test batch: test_loss=%.5f, test_loss_class=%.5f, test_loss_recon=%.5f, test_acc=%.4f"
                  % (losses[0], losses[1], losses[2], acc))   

        # main input and output append
        x_all.append(x)
        y_all.append(y)
        if gtx:
            gtx_all.append(gtx)
        #         loss_all.append(losses[0])
        acc_all.append(acc)
        objcaps_len_step_all.append(objcaps_len_step)
        x_recon_step_all.append(x_recon_step)

        if save_hooks:

            # hook variables append
            x_input_step_all.append(torch.stack(x_input_step, dim=1))
            x_mask_step_all.append(torch.stack(x_mask_step, dim=1))
            objcaps_step_all.append(torch.stack(objcaps_step, dim=1))

            coups_step_all.append(torch.stack(coups_step, dim=1))
            betas_step_all.append(torch.stack(betas_step, dim=1))
            if rscores_step:
                rscores_step_all.append(torch.stack(rscores_step, dim=1))
            if recon_coups_step:
                recon_coups_step_all.append(torch.stack(recon_coups_step, dim=1))
            if outcaps_len_step:
                outcaps_len_step_all.append(torch.stack(outcaps_len_step, dim=1))
            if outcaps_len_before_step:
                outcaps_len_before_step_all.append(torch.stack(outcaps_len_before_step, dim=1))

            hook1.remove()
            hook2.remove()        
        
        
        i+=1
        

    
        
    # concat and add to outputs dictionary
    x_all = torch.cat(x_all, dim=0)
    y_all = torch.cat(y_all, dim=0)
    if gtx:
        gtx_all = torch.cat(gtx_all, dim=0)
    else:
        gtx_all = gtx
    acc_all = torch.cat(acc_all, dim=0)
    objcaps_len_step_all = torch.cat(objcaps_len_step_all, dim=0)
    x_recon_step_all = torch.cat(x_recon_step_all, dim=0)
    
    if save_hooks:
        outputs['x_input']= torch.cat(x_input_step_all, dim=0)
        outputs['x_mask']= torch.cat(x_mask_step_all, dim=0)
        outputs['objcaps']= torch.cat(objcaps_step_all, dim=0)

        outputs['coups'] = torch.cat(coups_step_all, dim=0)
        outputs['betas'] = torch.cat(betas_step_all, dim=0)
        if rscores_step_all:
            outputs['rscores'] = torch.cat(rscores_step_all, dim=0)
        if recon_coups_step_all:
            outputs['recon_coups'] = torch.cat(recon_coups_step_all, dim=0)
        if outcaps_len_step_all:
            outputs['outcaps_len'] = torch.cat(outcaps_len_step_all, dim=0)
        if outcaps_len_before_step_all:
            outputs['outcaps_len_before'] = torch.cat(outcaps_len_before_step_all, dim=0)
            
        return x_all, gtx_all, y_all, acc_all, objcaps_len_step_all, x_recon_step_all, outputs  

    else:
        return x_all, gtx_all, y_all, acc_all, objcaps_len_step_all, x_recon_step_all
    
    

@torch.no_grad()
def evaluate_cnn_on_mnistc_original(corruption, cnn, max_batch_num=None):
    path_images = os.path.join(PATH_MNISTC, corruption, 'test_images.npy')
    path_labels = os.path.join(PATH_MNISTC, corruption, 'test_labels.npy')

    # convert to torch
    images = np.load(path_images)
    labels = np.load(path_labels)
    transform_tohot = T.Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
    images_tensorized = torch.stack([T.ToTensor()(im) for im in images])
    labels_tensorized = torch.stack([transform_tohot(label) for label in labels])
    # print(images_tensorized.shape) #torch.Size([10000, 1, 28, 28])
    # print(labels_tensorized.shape) #torch.Size([10000, 10])

    # create dataloader
    kwargs = {'num_workers': 1, 'pin_memory': True} if DEVICE == 'cuda' else {}
    dataset = TensorDataset(images_tensorized, labels_tensorized)
    dataloader = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=False, drop_last=False, **kwargs)

    # save output
    x_all, y_all, pred_all, acc_all, class_prob_all = [],[],[], [],[]
    cnn.eval() 

    # get input and gt
    i=0
    for data in dataloader:
        x, y = data
        gtx = None
        
        if max_batch_num:
            if i == max_batch_num:
                break

#         if i == max_batch_num:
#             x, y = data
#             gtx = None
                    
        data, target = x.to(DEVICE),  y.to(DEVICE)
        target = target.argmax(dim=1, keepdim=True)
        output = cnn(data)
        #                 test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        acc = pred.eq(target.view_as(pred))

        x_all.append(data)
        y_all.append(target.flatten())
        pred_all.append(pred.flatten())
        acc_all.append(acc.flatten().float())
        class_prob_all.append(output)
        
        i+=1


    # concat and add to outputs dictionary
    x_all = torch.cat(x_all, dim=0)
    y_all = torch.cat(y_all, dim=0)
    pred_all = torch.cat(pred_all, dim=0)
    acc_all = torch.cat(acc_all, dim=0)
    class_prob_all = torch.cat(class_prob_all, dim=0)

    return x_all, y_all, class_prob_all, pred_all, acc_all

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import matplotlib
import collections
import json
import random

def save_imgarr(imgarr, filename='test.png', scale=8):
    h, w, _ = imgarr.shape
    fig, axes = plt.subplots(figsize=(h*scale, w*scale))
    fig.subplots_adjust(top=1.0, bottom=0, right=1.0, left=0, hspace=0, wspace=0) 
    axes.imshow(imgarr, cmap='gray_r')
    axes.axis('off')
    plt.savefig(filename, dpi=1, format='png') 
    plt.close(fig)

# experiment 1: timestep vs RT

In [None]:
import gc
#############################
# test a single model, and visualize outputs
#############################
task='mnist_c_original'
train=False #train or test dataset
print_args=False


# load model
args_to_update = {'device':DEVICE, 'batch_size':BATCHSIZE, 
                 'time_steps': 5, 'routings': 3,'mask_threshold': 0.1}

# load_model_path = './models/rrcapsnet/rrcapsnet_best.pt'
load_model_path = './results/mnist/Aug14_0508_lsf_res4_run1/best_epoch30_acc1.0000.pt'


args = load_args(load_model_path, args_to_update, print_args)
model = load_model(args)

path_save = './stimuli/stimuli-exp1-step5/'
#  obtain model prediction

d_triallist = {}

for corruption_index in [13]:

    corruption =CORRUPTION_TYPES[corruption_index-1]
    
    d_triallist[corruption] = {}
    
    if task =='mnist_c_original':
        print("original is used")
        x, gtx, y_hot, acc_model, objcaps_len_step, x_recon_step = \
        evaluate_model_on_mnistc_original(corruption, model, verbose=False, save_hooks=False)
        print(f'==> corruption type: {corruption}, this batch acc: {acc_model.mean().item()}')
    else:
        raise NotImplementedError

    # get model prediction
    objcaps_len_step_narrow = objcaps_len_step.narrow(dim=2,start=0, length=args.num_classes)
    # pred_model = objcaps_len_step_narrow.max(dim=-1)[1][:,-1] #torch.Size([1000, 3])

    if ACC_TYPE=='hypothesis':
        if args.time_steps==1:
            y_pred = objcaps_len_step_narrow[:,-1]
            accs = topkacc(y_pred, y_true, topk=1)
        else:
            acc_model_check, pred_model, nstep  = compute_hypothesis_based_acc(objcaps_len_step_narrow, y_hot, only_acc=False)

    elif ACC_TYPE == 'entropy':    
        if args.time_steps==1:
            y_pred = objcaps_len_step_narrow[:,-1]
            accs = topkacc(y_pred, y_true, topk=1)
        else: 
            acc_model_check, pred_model, nstep, no_stop_condition, entropy_model =compute_entropy_based_acc(objcaps_len_step_narrow, y_hot, threshold=0.6, use_cumulative = False, only_acc= False)

    assert round(acc_model.mean().item(), 4) == round(acc_model_check.float().mean().item(), 4)

    # get trial id that requires 2 steps and 3 steps among correct trials
    nstep_masked = acc_model.cpu().numpy()*nstep.cpu().numpy()
    id_incorrect = list(np.where(nstep_masked==0)[0])
    id_correct = list(np.where(nstep_masked!=0)[0])    
    print('correct: ', len(id_correct) )
    print('incorrect: ', len(id_incorrect) )

    id_step1 = list(np.where(nstep_masked==1)[0])    
    id_step2 = list(np.where(nstep_masked==2)[0])
    id_step3 = list(np.where(nstep_masked==3)[0])
    id_step4 = list(np.where(nstep_masked==4)[0])
    id_step5 = list(np.where(nstep_masked==5)[0])
    
    print('step1: ', len(id_step1))    
    print('step2: ', len(id_step2))
    print('step3: ', len(id_step3))
    print('step4: ', len(id_step4))
    print('step5: ', len(id_step5))

    # save tiralist to dictionary
    d_triallist[corruption]['step1'] = id_step1
    d_triallist[corruption]['step2'] = id_step2
    d_triallist[corruption]['step3'] = id_step3
    d_triallist[corruption]['step4'] = id_step4
    d_triallist[corruption]['step5'] = id_step4    
    d_triallist[corruption]['correct'] = id_correct
    d_triallist[corruption]['incorrect'] = id_incorrect

    # get 20 randomly sampled image from each and save
    MAX_TRIAL = 20
    id_step1_sampled = random.sample(id_step1, min(len(id_step1),MAX_TRIAL))
    id_step2_sampled = random.sample(id_step2, min(len(id_step1),MAX_TRIAL))

                                
    for step in ['step1', 'step2']: 
        if step=='step1':
            trialid_to_visualize = id_step1_sampled 
        elif step=='step2':
            trialid_to_visualize = id_step2_sampled 
#         elif step=='step3':
#             trialid_to_visualize = [ti for ti in id_step3 if ti in id_step34_sampled]
        
        for trialid in trialid_to_visualize:
            our_pred = pred_model[trialid].cpu().item()
            gt =y_hot.max(dim=1)[1][trialid].cpu().item()

            # save image x8 original pixel size
            imgarray = x[trialid].numpy()
            filename = f'{corruption}_{step}_t{trialid}_g{gt}_o{our_pred}.png'
            save_imgarr(np.transpose(imgarray,(1,2,0)), filename= path_save+filename)
            
    print('all images are saved')
    gc.collect()
    


In [14]:
## change filenames
# import os
# path = './stimuli/stimuli-exp1-step4/'

# for f in os.listdir(path):
#     if not f.startswith('.'):
#         fsplit = f.split('_')
#         stepsize = fsplit[-1].split('.')[0]
#         corruption =  fsplit[:-4]
#         newf = '_'.join(corruption + [stepsize] + fsplit[-4:-1] ) + '.png'
#         os.rename(path+f, path+newf)

# experiment 2: compare with cnn

In [11]:
#############################
# test a single model, and visualize outputs
#############################
task='mnist_c_original'

train=False #train or test dataset
print_args=False


# load our model
args_to_update = {'device':DEVICE, 'batch_size':BATCHSIZE, 
                 'time_steps': 5, 'routings': 3,'mask_threshold': 0.1}

load_model_path = './models/rrcapsnet/rrcapsnet_best.pt'
load_model_path = './results/mnist/Aug14_0508_lsf_res4_run1/best_epoch30_acc1.0000.pt'

args = load_args(load_model_path, args_to_update, print_args)
model = load_model(args)

# load cnn
from train_cnn import *
path_cnn = './models/cnn/cnn_best.pt'
cnn = Net().to(DEVICE)
cnn.load_state_dict(torch.load(path_cnn))
cnn.eval()

path_save = './stimuli/stimuli-exp2-step5/'



# obtain model predictions

for corruption_index in [12]:

    corruption =CORRUPTION_TYPES[corruption_index-1]
    

    if task =='mnist_c_original':
        print("original is used")
        x, gtx, y_hot, acc_model, objcaps_len_step, x_recon_step = \
        evaluate_model_on_mnistc_original(corruption, model, verbose=False, save_hooks=False)
        print(f'==> corruption type: {corruption}, this batch acc: {acc_model.mean().item()}')
    else:
        raise NotImplementedError

    # get model prediction
    objcaps_len_step_narrow = objcaps_len_step.narrow(dim=2,start=0, length=args.num_classes)
    # pred_model = objcaps_len_step_narrow.max(dim=-1)[1][:,-1] #torch.Size([1000, 3])


    if ACC_TYPE=='hypothesis':
        if args.time_steps==1:
            y_pred = objcaps_len_step_narrow[:,-1]
            accs = topkacc(y_pred, y_true, topk=1)
        else:
            acc_model_check, pred_model, nstep  = compute_hypothesis_based_acc(objcaps_len_step_narrow, y_hot, only_acc=False)

    elif ACC_TYPE == 'entropy':    
        if args.time_steps==1:
            y_pred = objcaps_len_step_narrow[:,-1]
            accs = topkacc(y_pred, y_true, topk=1)
        else: 
            acc_model_check, pred_model, nstep, no_stop_condition, entropy_model =compute_entropy_based_acc(objcaps_len_step_narrow, y_hot, threshold=0.6, use_cumulative = False, only_acc= False)

    assert round(acc_model.mean().item(), 4) == round(acc_model_check.float().mean().item(), 4)


    ##################
    # get cnn prediction
    ##################
    if task =='mnist_c_original':
        print("original is used")
        data_cnn, target_cnn, logsoft_cnn, pred_cnn, acc_cnn \
        =  evaluate_cnn_on_mnistc_original(corruption, cnn)
        print(f'==> corruption type: {corruption}, this batch acc: {acc_cnn.float().mean().item()}')
        
        
    #######################
    # get trials id where both model fails & disagree
    #######################
    assert (target_cnn.cpu() == y_hot.max(dim=1)[1].cpu()).all()
    bool_bothincorrect = ~(acc_model.bool())* ~(acc_cnn.bool())
    bool_diffanswer = (pred_model!= pred_cnn)
    bool_onlycnncorrect = ~(acc_model.bool())* (acc_cnn.bool())
    # idx_bothincorrect = torch.nonzero(bool_bothincorrect)
    trialid_interest = torch.nonzero(bool_bothincorrect*bool_diffanswer).flatten().tolist()
    # trialid_interest = torch.nonzero(bool_onlycnncorrect*bool_diffanswer).flatten().tolist()
    print(len(trialid_interest))
    
    #####################
    # save image (in reverted grayscale) and trialinfo
    #####################

    d = collections.defaultdict(dict)
    MAX_TRIAL = 20

    trialid_to_visualize = random.sample(trialid_interest, min(len(trialid_interest), MAX_TRIAL))

    for trialid in trialid_to_visualize:
        cnn_pred = pred_cnn[trialid].cpu().item()
        our_pred = pred_model[trialid].cpu().item()
        gt =target_cnn[trialid].cpu().item()

        # save image x8 original pixel size
        imgarray = x[trialid].numpy()
        filename = f'{corruption}_t{trialid}_g{gt}_c{cnn_pred}_o{our_pred}.png'
        save_imgarr(np.transpose(imgarray,(1,2,0)), filename=path_save+filename)

        # save trial info to dictionary
        d[filename]['id'] = trialid
        d[filename]['imgarray'] = imgarray.tolist()
        d[filename]['cnn_softmax'] = torch.exp(logsoft_cnn[trialid]).cpu().numpy().tolist()
        d[filename]['our_objlen'] = objcaps_len_step_narrow[trialid, -1].cpu().numpy().tolist()
        d[filename]['cnn_pred'] = cnn_pred
        d[filename]['our_pred'] = our_pred
        d[filename]['gt'] =  gt


    # save dictionary
    with open(path_save+ f'{corruption}.json', 'w') as fp:
        json.dump(d, fp)
    print('trialinfo saved to disk!')    
    # to open
    # with open('./stimuli/glass_blur.json') as json_file:
    #     test = json.load(json_file)


TASK: mnist_recon_low (# targets: 1, # classes: 10, # background: 0)
TIMESTEPS #: 5
ENCODER: resnet w/ None projection
...resulting primary caps #: 288, dim: 8
ROUTINGS # 3
Object #: 10, BG Capsule #: 0
DECODER: fcn, w/ None projection
...recon only one object capsule: True
...use recon mask for attention: True
...with mask type bool, threshold 0.1, apply_method match

original is used
==> corruption type: fog, this batch acc: 0.9733999967575073
original is used
==> corruption type: fog, this batch acc: 0.840999960899353
58
trialinfo saved to disk!
