In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import os, sys
import PIL
from PIL import Image
from getimagenetclasses import get_label


Loss is 1 and accuracy is 1/2 = 0.5


In [2]:
class ImageNetDataset(Dataset):
    def __init__(self, path, use_crop=True, crop_size=224, batch_size=50, max_iter=250):
        self.path = path
        self.totensor = transforms.ToTensor()
        self.images = [i for i in os.listdir(path+'/imagespart/') if i[0] != '.']
        self.crop_size = crop_size
        self.batch_size = batch_size
        self.max_iter = max_iter
        self.use_crop = use_crop
        
    def load_image(self, filename):
        img_name = os.path.splitext(filename)[0]
        img = Image.open(self.path+'/imagespart/'+filename)
        width, height = img.size
        ratio = width/height
        if width >= height:
            tup = (int(ratio*280), 280)
        else:
            tup = (280, int(280/ratio))
        img = img.resize(tup, PIL.Image.BICUBIC)
        label = get_label(img_name)
        img_tensor = self.totensor(img)
        if img_tensor.size(0) == 1:
            img_tensor = img_tensor.expand(3, -1, -1)
        if self.use_crop:
            return self.five_crops(img_tensor), torch.full((5,), label, dtype=torch.long)
        else:
            return self.crop(img_tensor, 0), torch.full((1,), label, dtype=torch.long)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        return self.load_image(self.images[index])
    
    def crop(self, img_tensor, loc=0):
        # loc: 0 = center, 1 = top left, 2 = top right, 3 = btm left, 4 = btm right
        H, W = img_tensor.size(1), img_tensor.size(2)
        crop_size = self.crop_size
        if loc==0:
            Hs, Ws = int(H/2-crop_size/2), int(W/2-crop_size/2)
            return img_tensor[:,Hs:Hs+crop_size,Ws:Ws+crop_size]
        elif loc==1:
            return img_tensor[:, :crop_size, :crop_size]
        elif loc==2:
            return img_tensor[:, :crop_size, -crop_size:]
        elif loc==3:
            return img_tensor[:, -crop_size:, :crop_size]
        elif loc==4:
            return img_tensor[:, -crop_size:, -crop_size:]
        return 
    
    def five_crops(self, img_tensor):
        tensors = []
        for i in range(5):
            cropped = self.crop(img_tensor, i).unsqueeze(dim=0)
            tensors.append(cropped)
        return torch.cat(tensors, dim=0)
    

def test(model, device, test_loader, nsample, use_crop=True):
    model.eval()
    test_loss = 0
    match = 0
    with torch.no_grad():
        for data, target in test_loader:
            if use_crop:
                c, w, h = data.size(2), data.size(3), data.size(4)
                data = data.view(-1, c, w, h)
            target = target.view(-1)
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            match += pred.eq(target.view_as(pred)).sum().item()
    if use_crop:
        nsample *= 5
    test_loss /= len(test_loader.dataset)
    print ("Loss is {} and accuracy is {}/{} = {}".format(test_loss, match, nsample, (match / nsample)))

### Resnet 18 with Five Crop

In [3]:
import torchvision.models as models
cwd = os.getcwd()
dataset_path = cwd
dataset = ImageNetDataset(dataset_path)
nsample=250

valsampler = SubsetRandomSampler(list(range(nsample)))
val_loader = torch.utils.data.DataLoader(dataset, batch_size=20, sampler=valsampler)

resnet18 = models.resnet18(pretrained=True)
device = 'cpu'
test(resnet18, device, val_loader, nsample)

Loss is -4.877224951171875 and accuracy is 453/1250 = 0.3624


### Resnet 18 without Five Crop

In [4]:
use_crop=False

dataset = ImageNetDataset(dataset_path, use_crop=use_crop)
nsample=250

valsampler = SubsetRandomSampler(list(range(nsample)))
val_loader = torch.utils.data.DataLoader(dataset, batch_size=20, sampler=valsampler)

resnet18 = models.resnet18(pretrained=True)
device = 'cpu'
test(resnet18, device, val_loader, nsample, use_crop)

Loss is -1.0572685119628906 and accuracy is 108/250 = 0.432
