In [49]:
# load required libraries & modules
%load_ext autoreload
%autoreload 2

import os
from tqdm.notebook import tqdm
import pprint
import time
import warnings
# warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
import torch

from utils import *
from loaddata import *
from visualization import *
from rrcapsnet_original import *

torch.set_grad_enabled(False)
torch.set_printoptions(sci_mode=False)

DATA_DIR = '../data'
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# DEVICE = torch.device('cpu')

BATCHSIZE = 1000

PATH_MNISTC = '../data/MNIST_C/'
CORRUPTION_TYPES = ['identity', 
         'shot_noise', 'impulse_noise','glass_blur','motion_blur',
         'shear', 'scale',  'rotate',  'brightness',  'translate',
         'stripe', 'fog','spatter','dotted_line', 'zigzag',
         'canny_edges']



ACC_TYPE = "entropy"

#################
# model load
################
def load_model(args):
    # load model
    model = RRCapsNet(args).to(args.device) 
    model.load_state_dict(torch.load(args.load_model_path))
    return model

def load_args(load_model_path, args_to_update, verbose=False):
    params_filename = os.path.dirname(load_model_path) + '/params.txt'
    assert os.path.isfile(params_filename), "No param flie exists"
    args = parse_params_wremove(params_filename, removelist = ['device']) 
    args = update_args(args, args_to_update)
    args.load_model_path = load_model_path
    if verbose:
        pprint.pprint(args.__dict__, sort_dicts=False)
    return args

def recon_loss_step(x_recon_step, x, time_steps):
    '''
    mse loss is used for reconstruction loss over all steps
    '''
    losses= []
    for t in range(time_steps):
        x_recon= x_recon_step[:,t]
        ls = nn.MSELoss(reduction='none')(x_recon, x) # 1000, 1, 28, 28
        ls = ls.flatten(start_dim=1)
        ls = ls.mean(dim=1)
        losses.append(ls)
        
    return torch.stack(losses, dim=1) # 1000 x 5

@torch.no_grad()
def evaluate_model_on_mnistc_original(corruption, model, max_batch_num=None):
    path_images = os.path.join(PATH_MNISTC, corruption, 'test_images.npy')
    path_labels = os.path.join(PATH_MNISTC, corruption, 'test_labels.npy')

    # convert to torch
    images = np.load(path_images)
    labels = np.load(path_labels)
    transform_tohot = T.Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
    images_tensorized = torch.stack([T.ToTensor()(im) for im in images])
    labels_tensorized = torch.stack([transform_tohot(label) for label in labels])
    # print(images_tensorized.shape) #torch.Size([10000, 1, 28, 28])
    # print(labels_tensorized.shape) #torch.Size([10000, 10])

    # create dataloader
    kwargs = {'num_workers': 1, 'pin_memory': True} if DEVICE == 'cuda' else {}
    dataset = TensorDataset(images_tensorized, labels_tensorized)
    dataloader = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=False, drop_last=False, **kwargs)
    
    losses_all, objcaps_len_step_all, acc_all = [],[], []
    
    model.eval()      
    
    # get input and gt
    i=0
    for data in dataloader:
        
        if max_batch_num:
            if i == max_batch_num:
                break
#         if i == max_batch_num:
#             x, y = data
#             gtx = None

        x, y = data
        gtx = None
        # evaluate and append results 
        loss, acc, objcaps_len_step, x_recon_step = evaluate(model, x, y, args, acc_type=ACC_TYPE, gtx=gtx)

        # append loss and objcaps prediction 
        losses = recon_loss_step(x_recon_step, x.to(x_recon_step.device), objcaps_len_step.shape[1])

        losses_all.append(losses)
        acc_all.append(acc)
        objcaps_len_step_all.append(objcaps_len_step)
        
        i+=1
    
    # concat and add to outputs dictionary
    losses_all = torch.cat(losses_all, dim=0)
    acc_all = torch.cat(acc_all, dim=0)
    objcaps_len_step_all = torch.cat(objcaps_len_step_all, dim=0)

    return objcaps_len_step_all, losses_all, acc_all
    


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [50]:
#############################
# arguments to change
#############################

# task and dataset
task='mnist_c_original'

train=False #train or test dataset

corruption_index =12
corruption =CORRUPTION_TYPES[corruption_index-1]
print('analysis on ', corruption)

# model and args load
print_args=False
args_to_update = {'device':DEVICE, 'batch_size':BATCHSIZE, 
                 'time_steps': 5, 'routings': 3, 'mask_threshold': 0.1}

load_model_path = './models/rrcapsnet/rrcapsnet_best.pt'
load_model_path = './results/mnist/May25_2049_recon_edge_run1/best_epoch68_acc0.9980.pt'
load_model_path= './results/mnist/May25_2131_recon_edge_run3/best_epoch70_acc0.9987.pt'
load_model_path = './results/mnist/May29_2045_blur_res4_run1/best_epoch32_acc1.0000.pt'
load_model_path = './results/mnist/Aug14_0524_lsf_res4_run2/best_epoch29_acc1.0000.pt'
load_model_path = './results/mnist/Aug14_0508_lsf_res4_run1/best_epoch30_acc1.0000.pt'

###############################
# load model and model output
###############################
args = load_args(load_model_path, args_to_update, print_args)
model = load_model(args)

# 10000 samples per corruption
print("original is used")
objcaps_len_step, losses, acc = \
evaluate_model_on_mnistc_original(corruption, model, max_batch_num=None)
print(f'==> corruption type: {corruption}, this batch acc: {acc.mean().item()}')



analysis on  fog

TASK: mnist_recon_low (# targets: 1, # classes: 10, # background: 0)
TIMESTEPS #: 5
ENCODER: resnet w/ None projection
...resulting primary caps #: 288, dim: 8
ROUTINGS # 3
Object #: 10, BG Capsule #: 0
DECODER: fcn, w/ None projection
...recon only one object capsule: True
...use recon mask for attention: True
...with mask type bool, threshold 0.1, apply_method match

original is used
==> corruption type: fog, this batch acc: 0.9733999967575073


In [None]:
#################################    
# model entropy stepwise
##################################
objcaps_len_step_narrow = objcaps_len_step.narrow(dim=2,start=0, length=args.num_classes)
# pred_model = objcaps_len_step_narrow.max(dim=-1)[1][:,-1] #torch.Size([1000, 3])
acc_model_check, pred_model, nstep, no_stop_condition, entropy_model =compute_entropy_based_acc(objcaps_len_step_narrow, y_hot, threshold=0.6, use_cumulative = False, only_acc= False)

assert round(acc_model.mean().item(), 4) == round(acc_model_check.float().mean().item(), 4)

