# ResNet Demo
This notebook shows how to train or test the model. Also included are methods to display model outputs and attention mechanimsms 

In [1]:
# load required libraries & modules
import os
from tqdm.notebook import tqdm
import pprint
import time

import torch

from utils import *
# from ocra import *
from loaddata import *

import warnings
# warnings.filterwarnings('ignore')

# Instantiate Model for Training

In [2]:
############################################
# set hyperparams using param file 
############################################
# param file can be 1) multimnist_params.txt or 2) multimnist_cluttered_params.txt
# params_filename = 'multimnist_cluttered_params.txt'
# params_filename = 'multisvhn_params.txt'
params_filename = 'resnet_svrt_task1_params.txt'
args = parse_params(params_filename)

# if you have a checkpoint to restore, specify restore file (in the orginal param file or here)
# args.restore_file =  'results/multimnist_cluttered/Aug28_4014__step7_1/state_checkpoint.pt'
if args.restore_file: 
    # if you want to pick up from save-point, reload param files
    print("param file will be reloaded from your save point folder")
    path_savepoint = os.path.dirname(args.restore_file)
    params_filename = path_savepoint + '/params.txt'  
    assert os.path.isfile(params_filename), "No param flie exists"
    
    # remove the arguments that cannot be translated into literal
    removelist = ['device'] 
    args = parse_params_wremove(params_filename, removelist) 
    
    # reassign path_savepoint to restorefile
    args.restore_file = path_savepoint + '/state_checkpoint.pt'
            
    
# setup output directory where log folder should be created 
if not os.path.isdir(args.output_dir):
    os.makedirs(args.output_dir)

# set device
# args.device = torch.device('cuda:{}'.format(args.cuda) if torch.cuda.is_available() and args.cuda is not None else 'cpu')
args.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

pprint.pprint(args.__dict__, sort_dicts=False)

{'task': 'svrt_task1',
 'num_classes': 2,
 'num_targets': 1,
 'image_dims': (1, 64, 64),
 'cat_dup': False,
 'n_epochs': 50,
 'lr': 0.001,
 'train_batch_size': 128,
 'test_batch_size': 128,
 'cuda': 0,
 'data_dir': './data/',
 'output_dir': './results/svrt_task1/',
 'restore_file': None,
 'save_checkpoint': True,
 'record_gradnorm': False,
 'record_attn_hooks': False,
 'validate_after_howmany_epochs': 1,
 'best_val_acc': 0,
 'verbose': True,
 'device': device(type='cuda', index=0)}


In [3]:
import sys
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils.data import DataLoader
import time
import math

from utils import *
from loaddata import *


def loss_fn(y_pred, y_true, args, writer=None, step=None):
    """
    Args
        y_true -- groundtruth y in n-hot format 
        y_pred -- model predictions; (normalized) cumulative length of class capsules 
    """
    #######   Classification Margin Loss      #######
    
    if len(y_true.shape) == 3: 
        y_true_1d = torch.clamp(torch.sum(y_true, dim=1), max=1)
    else:
        y_true_1d = y_true
         
#     # classification error: margin error for class capsules -- allows for both GT and prediction to have any number for any class 
#     m_neg = 0.1 # margin loss allowed for negative case (for absent digits)
#     lam_abs = 0.5 # down-weighting loss for absent digits (prevent the initial learning from shrinking the lengths of the class capsules    
#     L_present =  torch.clamp(y_true_1d, min=0., max=1.) * torch.clamp((y_true_1d-m_neg) - y_pred, min=0.) ** 2   
# #     L_present =  y_true_1d * torch.clamp((y_true_1d-m_neg) - y_pred, min=0.) ** 2   # not clamped version
#     L_absent = lam_abs * torch.clamp(1 - y_true_1d, min=0.) * torch.clamp(y_pred-m_neg, min=0.) ** 2
#     L_margin = (L_present+L_absent).sum(dim=1).mean()
    y_true_1d = y_true_1d.argmax(-1)
    loss_fn = nn.CrossEntropyLoss()
    L_margin = loss_fn(y_pred, y_true_1d)
    return L_margin 


# --------------------
# Train and Test
# --------------------

def get_topkacc(y_pred: torch.Tensor, y_true:  torch.Tensor, topk=1):
    """
    Get indices of topk model predictions and gather how many of them are correct (in percentage).
    e.g., if 1 correct out of top2 prediction --> 0.5
    
    Input: torch tensor
        - y_pred should be a vector of prediction score 
        - y_true should be in multi-hot encoding format (one or zero; can't deal with duplicates)

    Return: 
        - a vector of accuracy from each image --> [n_images,]
        - average acc
    """
    n_images = y_pred.size(0)
    topk_indices = y_pred.data.topk(topk, sorted=True)[1] 
    accs = torch.gather(y_true, dim=1, index=topk_indices).sum(dim=1)/topk
    average_acc = accs.cpu().sum().item()/n_images

    return average_acc, accs

def get_exactmatch(y_pred_hot: torch.Tensor, y_true: torch.Tensor):
    """
    See if y_pred and y_true matches exactly
    e.g., if match acc=1, not match acc=0
    
    Input: torch tensor 
        - both y_pred and y_true should be in the same format
        e.g., if y_true is multi-hot, then y_pred should be made in multi-hot as well
    Return: 
        - a vector of accuracy from each image --> [n_images,]
        - average acc
    """
    n_images = y_pred_hot.size(0)
    accs = (y_pred_hot == y_true).all(dim=1).float()
    average_acc = accs.cpu().sum().item()/n_images
    
    return average_acc, accs

def cal_accs(y_pred_nar, y_true, args):
    """ calculate accuracy
    1) when args.cat_dup == True --> use criterion value to indicate duplicate classification
    2) when args.cat_dup == False --> just topk based accuracy
    exact match: if one of digits incorrect  --> 0
    partial match: if one of digits incorrect --> 0.5
    
    return:
        - accuracy = accuracy sum over the whole batch
        - accs = a list of correct score for each image 
    """
    n_targets = args.num_targets
    
    
    if args.task=='multisvhn':  # when there are two targets with duplicates allowed
    
        col_11 = torch.zeros((y_true.size(0),5, 1)).to(args.device)
        col_11[:,1:5, 0] = torch.flip(torch.squeeze(y_true[:,0,10:args.num_classes]), (1,)) 
        y_true = torch.cat( (y_true[:,:,0:10], col_11), dim=2)


        readout_logits = readout_logits.view(-1, 5,11)
        out, pred_digits = torch.max(readout_logits,dim=2)
        out, true_digits = torch.max(y_true,dim=2)
        
        partial_accs = torch.sum(1*(pred_digits == true_digits), dim=1) / float(args.num_targets)        
        partial_accuracy = partial_accs.cpu().sum().item() 
        exact_accs = (partial_accs == 1)
        exact_accuracy =  exact_accs.cpu().sum().item() 
        
        y_pred_hot = torch.sum(readout_logits, dim=1)
        
        #exact_accuracy = torch.sum(1* (torch.sum(1*(pred_digits == true_digits), dim=1) == 5)).type('torch.FloatTensor').cpu()  # 
        
    
    elif args.cat_dup == True:  # when there are two targets with duplicates allowed
        # get bool indices for an image with duplicates
        dup_t = n_targets - 0.2# 1.8 for target=2; criterion value indicating duplicates (two targets are from the same category)
        bool_above = (y_pred_nar >= dup_t).any(dim=1)

        # when any of predictions are above dup_t, apply the following

        y_pred_dup = (y_pred_nar == torch.max(y_pred_nar, dim= 1)[0].reshape(-1,1)) # get y_pred for duplicates; the largest value --> mark as 1, e.g., y_pred  [2.1, 0.2, 0.1] --> [1, 0, 0] 
        # maxid = y_pred_nar.argmax(dim=1)
        # y_pred_dup = torch.zeros(y_pred_nar.shape).scatter(1, maxid.unsqueeze(dim=1), 1.0)
        dupaccs = torch.sum(y_true*y_pred_dup, dim=1)/2.0 # compare with y_true and get acc, 
        # e.g, when y_pred_dup = [1, 0, 0], acc=1 when y_true = [2,0,0] and and acc=0.5 when y_true = [1,1,0]

        # when none of predictions are above dup_t, apply the following
        y_true_clip = torch.clip(y_true,0,1) # y_true [2,0] --> [1,0]
        y_pred_nodup = (y_pred_nar >= y_pred_nar.topk(n_targets)[0][:,n_targets-1:n_targets]) # bool indices of predictions higher/equal than n_target highest value  y_pred = [1.0,0.9,0.1] --> [1, 1, 0]
        nodupaccs = torch.sum(y_true_clip*y_pred_nodup, dim=1)/float(n_targets)  # compare with y_true and get acc, 
        # e.g., when y_pred = [1.0,0.9,0.1], acc=1, when y_true = [1,1,0] and acc=0.5, when y_true = [2,0,0]

        # combine match for both no duplicates and yes duplicates

        y_pred_hot = n_targets*(bool_above.reshape(-1,1)*y_pred_dup)+ (~bool_above).reshape(-1,1)*y_pred_nodup
        partial_accs = bool_above*dupaccs + (~bool_above)*nodupaccs
        exact_accs = (partial_accs >= 1)
        partial_accuracy = partial_accs.cpu().sum().item() # if gt = 1, 2 and pred = 2, 3 --> 50 % acc
        exact_accuracy =  exact_accs.cpu().sum().item()  # if gt = 1, 2 and pred = 2, 3 --> 0 % acc
        
                    
    elif args.cat_dup == False: # when no duplicates
        y_pred_hot = (y_pred_nar >= y_pred_nar.topk(n_targets)[0][:,n_targets-1:n_targets]) # bool indices of predictions higher/equal than n_target highest value y_pred = [1.0,0.9,0.1] --> [1, 1, 0]
        partial_accs = torch.sum(y_pred_hot*y_true, dim=1)/float(n_targets) 
        partial_accuracy = partial_accs.cpu().sum().item() 
        exact_accs = (partial_accs >= 1)
        exact_accuracy =  exact_accs.cpu().sum().item() 
        
    return y_pred_hot, partial_accuracy, partial_accs, exact_accuracy, exact_accs  
            
    #     #####################
    #     # alternative version: using groundtruth to know whether duplicates trial or not
    #     """
    #     for multimnist, acc should be the same as the version above, acc was ~94% after 3 epoch
    #     for cluttered task, acc was 80% after 3 epoch (higher than acc from the version above; ~72-75%%, and 80% was reached around 10 epoch)
    #     """
    #     # get bool indices for an image with duplicates
    #     n_targets = args.num_targets
    #     bool_duplicate = (y_true==n_targets).any(dim=1) 

    #     # when no duplicates in the image, apply topk predictions
    #     _, top2accs = get_topkacc(y_pred_nar, y_true, topk=2)

    #     # when yes duplicates in the image, apply exact match 
    #     dup_t = 1.85 
    #     y_pred_dup = n_targets*(y_pred_nar >= dup_t) # if prediction > dup_t, we consider it as prediction for duplicates, e.g, prediction [1.9, 1.5, 0.5] -> [2, 0, 0]
    #     _, matchaccs = get_exactmatch(y_pred_dup, y_true)

    #     # combine topk (for no duplicates) + exactmatch (for yes duplicates) and get total sum
    #     combaccs = (~bool_duplicate)*top2accs + (bool_duplicate)*matchaccs
    #     accuracy = combaccs.cpu().sum().item()
    #     #######################

    
@torch.no_grad()
def evaluate(model, x, y_true, loss_fn, args, epoch=None):
    """
    Run model prediction on testing dataset and compute loss/acc 
    
    Args
        model -- trained model to be evaluated with no_grad()
        recon_mask -- A mask that is the normalized sum of all the read operatoin to focus the erorr reconstructio
        x -- input
        y_true -- input y in n hot format 
        y_pred -- (normalized) cumulative length of the class capsules 
    """

    # evaluate
    model.eval()
    
    # load testing dataset on device
    x = torch.cat((x,x,x), dim=1).to(args.device)
    
#     x = x.view(-1, self.C, self.H, self.W)
#     x = x.view(x.shape[0], -1).float().to(args.device)
    
    if args.task == 'mnist_ctrv':
        x = 1.0 - x
                
    y_true = y_true.to(args.device)
    yloss_w = torch.ones((len(x),args.num_classes)).to(args.device)
    # run model with testing data and get predictions
    y_pred  = model(x)
        
#     # Do not normalize the sum of object caps lengths if multiple items can be from the same category
#     if (not args.cat_dup) and (args.task != 'mnist_ctrv') and (torch.min(torch.max(y_pred, dim=1, keepdim=True)[0]) != 0):
#         y_pred = y_pred / torch.max(y_pred, dim=1, keepdim=True)[0] #self.num_objectcaps #objectcaps.norm(dim=-1)   
        
    # compute accuracy sum over whole batch

    _, partial_accuracy, _ , exact_accuracy, _ = cal_accs(y_pred, y_true, args)
    
    # compute loss    
    loss = loss_fn(y_pred, y_true, args)        


    return loss, exact_accuracy, partial_accuracy, y_pred


def test(model, dataloader, args):
    """
    for each batch:
        - evaluate loss & acc ('evaluate')
    log average loss & acc  
    """   
    test_loss = 0
    test_L_recon = 0
    test_L_margin = 0
    test_acc_partial = 0
    test_acc_exact = 0
    
    # load batch data
    for x, y in dataloader:
        
        # if one target and y is not in one-hot format, convert it to one-hot encoding
#         if args.num_targets == 1:
#             if len(y.shape) < 2: 
#                 y = y.type(torch.int64)
#                 y = torch.zeros(y.size(0), args.num_classes).scatter_(1, y.view(-1, 1), 1.)  
        
        # evaluate
        batch_loss, batch_acc_exact, batch_acc_partial, y_pred =  evaluate(model, x, y, loss_fn, args)

        # aggregate loss and acc
        test_loss += args.test_batch_size * batch_loss
        test_acc_partial += batch_acc_partial
        test_acc_exact += batch_acc_exact

    # get average loss and acc
    test_loss /= len(dataloader.dataset)
    test_acc_partial /= (len(dataloader.dataset))
    test_acc_exact /= (len(dataloader.dataset))
    return test_loss, test_acc_partial, test_acc_exact



def train_epoch(model, train_dataloader, loss_fn, optimizer, epoch, writer, args):
    """
    for each batch:
        - forward pass  
        - compute loss
        - param update
    log average train_loss  
    """    
    model.train() 
    with tqdm(total=len(train_dataloader), desc='epoch {} of {}'.format(epoch, args.n_epochs)) as pbar:
#     time.sleep(0.1)        
        training_loss = 0.0
        
        # load batch from dataloader 
        for i, (x, y, yloss_w) in enumerate(train_dataloader):
            global_step = (epoch-1) * len(train_dataloader) + i + 1 #global batch number
            
            # if one target and y is not in one-hot format, convert it to one-hot encoding
#             if args.num_targets == 1:
#                 if len(y.shape) < 2: 
#                     y = y.type(torch.int64)
#                     y = torch.zeros(y.size(0), args.num_classes).scatter_(1, y.view(-1, 1), 1.) 
            
            # load dataset on device
            x = torch.cat((x,x,x), dim=1).to(args.device)
            
            if args.task == 'mnist_ctrv':
                x = 1.0 - x
             
            y = y.type(torch.long).to(args.device)
            yloss_w = yloss_w.to(args.device)
            # forward pass
            y_pred = model(x)
    
            # compute loss for this batch and append it to training loss
            loss = loss_fn(y_pred, y, args, writer, global_step)
            
            training_loss += loss.data #* x.size(0) 
            
            # zero out previous gradients and backward pass
            optimizer.zero_grad()
            loss.backward()
            
            # record grad norm and clip to prevent exploding gradients
            if args.record_gradnorm:
                grad_norm = 0
                for name, p in model.named_parameters():
                    grad_norm += p.grad.norm().item() if p.grad is not None else 0
                writer.add_scalar('grad_norm', grad_norm, global_step)
            nn.utils.clip_grad_norm_(model.parameters(), 10)

            # update param
            optimizer.step()

            # end of each batch, update tqdm tracking
            pbar.set_postfix(batch_loss='{:.3f}'.format(loss.item()))
            pbar.update()
    
    # logging training info to tensorboard writer
    train_loss = training_loss / len(train_dataloader.dataset)
    writer.add_scalar('Train/Loss', train_loss, epoch)
    
    return train_loss
    
def train_and_evaluate(model, train_dataloader, val_dataloader, loss_fn, optimizer, writer, args):
    """
    for each epoch:
        - train the model, update param, and log the training loss ('train_epoch')
        - save checkpoint
        - compute and log average val loss/acc and
        - save best model
        
    """
    start_epoch = 1

    if args.restore_file:
        print('Restoring parameters from {}'.format(args.restore_file))
        start_epoch = load_checkpoint(args.restore_file, [model], [optimizer], map_location=args.device.type)
        args.n_epochs += start_epoch
        print('Resuming training from epoch {}'.format(start_epoch))

    for epoch in range(start_epoch, args.n_epochs+1):
        
        # train epoch
        train_loss = train_epoch(model, train_dataloader, loss_fn, optimizer, epoch, writer, args)

        # save checkpoint 
        if args.save_checkpoint:
            save_checkpoint({'epoch': epoch,
                             'model_state_dicts': [model.state_dict()],
                             'optimizer_state_dicts': [optimizer.state_dict()]}, 
                            checkpoint=args.log_dir,
                            quiet=True)
        
        # compute validation loss and acc
        if (epoch) % args.validate_after_howmany_epochs == 0:
            val_loss, val_acc_partial, val_acc_exact = test(model, val_dataloader, args)
            
            # logging validation info to tensorboard writer
            writer.add_scalar('Val/Loss', val_loss, epoch)
            writer.add_scalar('Val/Accuracy_partial', val_acc_partial, epoch)
            writer.add_scalar('Val/Accuracy_exact', val_acc_exact, epoch)
                        
            if args.verbose:
                print("==> Epoch %02d: train_loss=%.5f, val_loss=%.5f, val_acc_partial=%.4f,  val_acc_exact=%.4f" \
                  % (epoch, train_loss, val_loss, val_acc_partial, val_acc_exact))
               
            # update best validation acc and save best model to output dir
            if (val_acc_exact > args.best_val_acc):  
                args.best_val_acc = val_acc_exact
                torch.save(model.state_dict(), args.log_dir +'/best_model_epoch%d_acc%.4f.pt'% (epoch, val_acc_exact))  #output_dir
                print("the model with best val_acc (%.4f) was saved to disk" % val_acc_exact)

        # for experiments, abort the local mimima trials
        if (epoch) % 100 == 0:
            if hasattr(args, 'abort_if_valacc_below'):
                if (args.best_val_acc < args.abort_if_valacc_below) or math.isnan(val_acc_exact):
                    status = f'===== EXPERIMENT ABORTED: val_acc_exact is {val_acc_exact} at epoch {epoch} (Criterion is {args.abort_if_valacc_below}) ===='
                    writer.add_text('Status', status, epoch)
                    print(status)
                    sys.exit()
                else:
                    status = '==== EXPERIMENT CONTINUE ===='
                    writer.add_text('Status', status, epoch)
                    print(status)


In [4]:
############################################
# instantiate model 
############################################
import torchvision
from torchvision.models import resnet50 


class ResNet_model(nn.Module):
    """Model modified.
    The architecture of our model is the same as standard DenseNet121
    except the classifier layer which has an additional sigmoid function.
    """
    def __init__(self, args):
        super(ResNet_model, self).__init__()
        
        self.num_classes = args.num_classes
        
        model = torchvision.models.resnet18(pretrained=True)
        
#         model.layer2[0].conv1.stride = (1,1)
#         model.layer2[0].downsample[0].stride = (1,1)

#         model.layer3[0].conv2.stride = (1,1)
#         model.layer3[0].downsample[0].stride = (1,1)

#         model.layer4[0].conv1.stride = (1,1)
#         model.layer4[0].downsample[0].stride = (1,1)

        self.backbone = model 
        
        self.class_readout = nn.Linear(1000, self.num_classes)
        
#         features = nn.Sequential(*(list(densenet121.children())[0]))
#         self.backbone = nn.Sequential(*(list(features.children())[0:11]))
        
         #Fix the parameters of the feature extractor:
#         for param in self.backbone.parameters():
#             param.requires_grad = False

#         self.densenet121.features.denseblock4.denselayer16.conv2.output
        
#         num_ftrs = self.densenet121.classifier.in_features
#         self.densenet121.classifier = nn.Sequential(
#             nn.Linear(num_ftrs, out_size),
#             nn.Sigmoid()
#         )

    def forward(self, x):
        x = self.backbone(x)
        y_pred = self.class_readout(x)
        
        return y_pred
    
    
# ---------------------------------------------------------------    

    
    
    
#model = ResNet_model(args).to(args.device) 
model = ResNet_model(args).to(args.device) 

# set up model, optimizer, and hooks for monitoring
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999))
# if args.record_attn_hooks:
#     record_forward_backward_attn_hooks(model)

# print model info
count_parameters(model)

+---------------------------------------+------------+
|                Modules                | Parameters |
+---------------------------------------+------------+
|         backbone.conv1.weight         |    9408    |
|          backbone.bn1.weight          |     64     |
|           backbone.bn1.bias           |     64     |
|     backbone.layer1.0.conv1.weight    |   36864    |
|      backbone.layer1.0.bn1.weight     |     64     |
|       backbone.layer1.0.bn1.bias      |     64     |
|     backbone.layer1.0.conv2.weight    |   36864    |
|      backbone.layer1.0.bn2.weight     |     64     |
|       backbone.layer1.0.bn2.bias      |     64     |
|     backbone.layer1.1.conv1.weight    |   36864    |
|      backbone.layer1.1.bn1.weight     |     64     |
|       backbone.layer1.1.bn1.bias      |     64     |
|     backbone.layer1.1.conv2.weight    |   36864    |
|      backbone.layer1.1.bn2.weight     |     64     |
|       backbone.layer1.1.bn2.bias      |     64     |
|     back

11691514

## start training

In [5]:
###########################
# model training...
##########################

DO_TRAIN = True  # false if you want to skip this cell
COMMENT = 'test'

if DO_TRAIN:
    # load dataloader 
    train_dataloader, val_dataloader = fetch_dataloader(args, args.train_batch_size, train=True, train_val_split='train-val')

    # set writer for tensorboard
    writer, current_log_path = set_writer(log_path = args.output_dir if args.restore_file is None else os.path.dirname(args.restore_file),
                        comment = COMMENT, 
                        restore = args.restore_file is not None) 

    args.log_dir = current_log_path

    # save used param info to writer and logging directory
    writer.add_text('Params', pprint.pformat(args.__dict__))
    
    with open(os.path.join(args.log_dir, 'params.txt'), 'w') as f:
        pprint.pprint(args.__dict__, f, sort_dicts=False)

    # start training
    print('Start training with args set above...')
    train_and_evaluate(model, train_dataloader, val_dataloader, loss_fn, optimizer, writer, args)

    writer.close()

Start training with args set above...


epoch 1 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 01: train_loss=0.00462, val_loss=0.31400, val_acc_partial=0.8693,  val_acc_exact=0.8693
the model with best val_acc (0.8693) was saved to disk


epoch 2 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 02: train_loss=0.00156, val_loss=0.14430, val_acc_partial=0.9488,  val_acc_exact=0.9488
the model with best val_acc (0.9488) was saved to disk


epoch 3 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 03: train_loss=0.00082, val_loss=0.11038, val_acc_partial=0.9602,  val_acc_exact=0.9602
the model with best val_acc (0.9602) was saved to disk


epoch 4 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 04: train_loss=0.00215, val_loss=0.70169, val_acc_partial=0.5018,  val_acc_exact=0.5018


epoch 5 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 05: train_loss=0.00277, val_loss=0.11907, val_acc_partial=0.9591,  val_acc_exact=0.9591


epoch 6 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 06: train_loss=0.00048, val_loss=0.09095, val_acc_partial=0.9709,  val_acc_exact=0.9709
the model with best val_acc (0.9709) was saved to disk


epoch 7 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 07: train_loss=0.00028, val_loss=0.08606, val_acc_partial=0.9773,  val_acc_exact=0.9773
the model with best val_acc (0.9773) was saved to disk


epoch 8 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 08: train_loss=0.00044, val_loss=0.09459, val_acc_partial=0.9746,  val_acc_exact=0.9746


epoch 9 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 09: train_loss=0.00019, val_loss=0.07969, val_acc_partial=0.9790,  val_acc_exact=0.9790
the model with best val_acc (0.9790) was saved to disk


epoch 10 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 10: train_loss=0.00014, val_loss=0.11362, val_acc_partial=0.9732,  val_acc_exact=0.9732


epoch 11 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 11: train_loss=0.00014, val_loss=0.11858, val_acc_partial=0.9734,  val_acc_exact=0.9734


epoch 12 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 12: train_loss=0.00014, val_loss=0.08921, val_acc_partial=0.9794,  val_acc_exact=0.9794
the model with best val_acc (0.9794) was saved to disk


epoch 13 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 13: train_loss=0.00013, val_loss=0.07771, val_acc_partial=0.9797,  val_acc_exact=0.9797
the model with best val_acc (0.9797) was saved to disk


epoch 14 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 14: train_loss=0.00058, val_loss=0.14953, val_acc_partial=0.9479,  val_acc_exact=0.9479


epoch 15 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 15: train_loss=0.00024, val_loss=0.07801, val_acc_partial=0.9797,  val_acc_exact=0.9797


epoch 16 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 16: train_loss=0.00032, val_loss=0.07350, val_acc_partial=0.9788,  val_acc_exact=0.9788


epoch 17 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 17: train_loss=0.00009, val_loss=0.07354, val_acc_partial=0.9832,  val_acc_exact=0.9832
the model with best val_acc (0.9832) was saved to disk


epoch 18 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 18: train_loss=0.00008, val_loss=0.09073, val_acc_partial=0.9825,  val_acc_exact=0.9825


epoch 19 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 19: train_loss=0.00008, val_loss=0.06307, val_acc_partial=0.9832,  val_acc_exact=0.9832


epoch 20 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 20: train_loss=0.00783, val_loss=0.69956, val_acc_partial=0.5034,  val_acc_exact=0.5034


epoch 21 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 21: train_loss=0.00542, val_loss=0.63812, val_acc_partial=0.6473,  val_acc_exact=0.6473


epoch 22 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 22: train_loss=0.00098, val_loss=0.06705, val_acc_partial=0.9805,  val_acc_exact=0.9805


epoch 23 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 23: train_loss=0.00012, val_loss=0.07366, val_acc_partial=0.9831,  val_acc_exact=0.9831


epoch 24 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 24: train_loss=0.00008, val_loss=0.06236, val_acc_partial=0.9825,  val_acc_exact=0.9825


epoch 25 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 25: train_loss=0.00006, val_loss=0.07051, val_acc_partial=0.9837,  val_acc_exact=0.9837
the model with best val_acc (0.9837) was saved to disk


epoch 26 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 26: train_loss=0.00007, val_loss=0.06297, val_acc_partial=0.9850,  val_acc_exact=0.9850
the model with best val_acc (0.9850) was saved to disk


epoch 27 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 27: train_loss=0.00007, val_loss=0.05961, val_acc_partial=0.9855,  val_acc_exact=0.9855
the model with best val_acc (0.9855) was saved to disk


epoch 28 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 28: train_loss=0.00008, val_loss=0.06525, val_acc_partial=0.9846,  val_acc_exact=0.9846


epoch 29 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 29: train_loss=0.00007, val_loss=0.06373, val_acc_partial=0.9842,  val_acc_exact=0.9842


epoch 30 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 30: train_loss=0.00006, val_loss=0.06917, val_acc_partial=0.9834,  val_acc_exact=0.9834


epoch 31 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 31: train_loss=0.00007, val_loss=0.07629, val_acc_partial=0.9814,  val_acc_exact=0.9814


epoch 32 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 32: train_loss=0.00007, val_loss=0.06491, val_acc_partial=0.9826,  val_acc_exact=0.9826


epoch 33 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 33: train_loss=0.00007, val_loss=0.07412, val_acc_partial=0.9842,  val_acc_exact=0.9842


epoch 34 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 34: train_loss=0.00005, val_loss=0.07233, val_acc_partial=0.9846,  val_acc_exact=0.9846


epoch 35 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 35: train_loss=0.00005, val_loss=0.06247, val_acc_partial=0.9848,  val_acc_exact=0.9848


epoch 36 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 36: train_loss=0.00005, val_loss=0.06992, val_acc_partial=0.9856,  val_acc_exact=0.9856
the model with best val_acc (0.9856) was saved to disk


epoch 37 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 37: train_loss=0.00006, val_loss=0.06870, val_acc_partial=0.9831,  val_acc_exact=0.9831


epoch 38 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 38: train_loss=0.00005, val_loss=0.06712, val_acc_partial=0.9862,  val_acc_exact=0.9862
the model with best val_acc (0.9862) was saved to disk


epoch 39 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 39: train_loss=0.00004, val_loss=0.06672, val_acc_partial=0.9865,  val_acc_exact=0.9865
the model with best val_acc (0.9865) was saved to disk


epoch 40 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 40: train_loss=0.00005, val_loss=0.07902, val_acc_partial=0.9826,  val_acc_exact=0.9826


epoch 41 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 41: train_loss=0.00006, val_loss=0.07698, val_acc_partial=0.9804,  val_acc_exact=0.9804


epoch 42 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 42: train_loss=0.00005, val_loss=0.07572, val_acc_partial=0.9849,  val_acc_exact=0.9849


epoch 43 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 43: train_loss=0.00004, val_loss=0.05586, val_acc_partial=0.9841,  val_acc_exact=0.9841


epoch 44 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 44: train_loss=0.00004, val_loss=0.08233, val_acc_partial=0.9846,  val_acc_exact=0.9846


epoch 45 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 45: train_loss=0.00005, val_loss=0.06796, val_acc_partial=0.9860,  val_acc_exact=0.9860


epoch 46 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 46: train_loss=0.00003, val_loss=0.06631, val_acc_partial=0.9858,  val_acc_exact=0.9858


epoch 47 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 47: train_loss=0.00004, val_loss=0.08526, val_acc_partial=0.9866,  val_acc_exact=0.9866
the model with best val_acc (0.9866) was saved to disk


epoch 48 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 48: train_loss=0.00004, val_loss=0.06453, val_acc_partial=0.9862,  val_acc_exact=0.9862


epoch 49 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 49: train_loss=0.00210, val_loss=0.06778, val_acc_partial=0.9831,  val_acc_exact=0.9831


epoch 50 of 50:   0%|          | 0/422 [00:00<?, ?it/s]

==> Epoch 50: train_loss=0.00005, val_loss=0.07529, val_acc_partial=0.9849,  val_acc_exact=0.9849


# Test pretrained models

In [None]:
# load required libraries & modules
import os
from tqdm.notebook import tqdm
import pprint
import time

import torch

from utils import *
# from ocra import *
from loaddata import *

import warnings
# warnings.filterwarnings('ignore')

In [6]:
###########################
# model testing...
##########################

# testing parameters
DO_TEST = True # false if you want to skip this cell
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# give path for the model to load, or None if you want to continue with the current trained model
# load_model_path = "./pretrained/multimnist_cluttered/run1.pt"
load_model_path = None
# load_model_path = './results/svrt_task1/Nov27_5848_test/best_model_epoch70_acc0.9899.pt'
if DO_TEST:
    if load_model_path: 
        # if path is given, load a saved model (make sure that you loaded right model args)
        print("param file will be loaded from your saved model folder")
        params_filename = os.path.dirname(load_model_path) + '/params.txt'
        assert os.path.isfile(params_filename), "No param flie exists"

        # remove the arguments that caanot be translated into literal
        removelist = ['device'] 
        args = parse_params_wremove(params_filename, removelist) 
        args.device = device 
        
        # print params
        pprint.pprint(args.__dict__, sort_dicts=False)
        
        # load model
        model = ResNet_model(args).to(args.device) 
        model.load_state_dict(torch.load(load_model_path,map_location=args.device))
        print('model loaded.')
        
        # get test results
        model.eval()
        test_dataloader = fetch_dataloader(args, args.test_batch_size, train=False)
        test_loss, test_acc_partial, test_acc_exact = test(model, test_dataloader, args)
        print("==> Epoch %02d: test_loss=%.5f, test_acc_partial=%.4f, test_acc_exact=%.4f"
              % (args.n_epochs, test_loss, test_acc_partial, test_acc_exact))

    else: # use the current trained model to get results
        model.eval()
        test_dataloader = fetch_dataloader(args, args.test_batch_size, train=False)
        test_loss, test_acc_partial, test_acc_exact = test(model, test_dataloader, args)
        print("==> Epoch %02d: test_loss=%.5f, test_acc_partial=%.4f, test_acc_exact=%.4f"
              % (args.n_epochs, test_loss, test_acc_partial, test_acc_exact))

==> Epoch 50: test_loss=0.07836, test_acc_partial=0.9844, test_acc_exact=0.9844
