In [None]:
import torchvision.models as models
import torch
import argparse
import os
import shutil
import time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import numpy as np

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count



In [None]:
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()    
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    #print(pred)
    #print(target)

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

In [None]:
def validate(val_loader, model):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()
    # switch to evaluate mode
    model.eval()
    for i, (image, target) in enumerate(val_loader):
        target = target.cuda(non_blocking=True)
        image=image.cuda()
        with torch.no_grad(): 
            input_var = torch.autograd.Variable(image)
            target_var = torch.autograd.Variable(target)

        # compute output
        output = model(input_var)
        # measure accuracy
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
   
        top1.update(prec1[0], image.size(0))
        top5.update(prec5[0], image.size(0))
        

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        
        if i % 100 == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   i, len(val_loader), batch_time=batch_time,
                   top1=top1, top5=top5))
        

    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))
    return top1.avg

In [None]:
model = models.resnet50(pretrained=True)
model=model.cuda()
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])

In [None]:
# repalce the below path with the path on your machine
valdir='imagenet.data/val/' 

val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=16,num_workers=2, shuffle=False,pin_memory=True)

In [None]:
#repalce with the path to checkpoint for the pruned model
checkpoint = torch.load('resnet50_pruned_70_best.pth.tar')
state_dict=checkpoint['state_dict']

In [None]:
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)

In [None]:
print(time.ctime())
prec1 = validate(val_loader, model)
print(time.ctime())