In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

plt.ion()   # interactive mode
device = torch.device('cuda:0')

In [2]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import os
import torch

def normalize_transform():
    return transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])

def val_dataset(data_dir):
    val_dir = os.path.join(data_dir, 'val')
    
    val_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize_transform()
    ])
    
    val_dataset = datasets.ImageFolder(
        val_dir,
        val_transforms
    )
    
    return val_dataset

In [5]:
data_dir = '../../../../data/imagenet/'   # Change this to the directory of your imagenet dataset
val_ds = val_dataset(data_dir) 
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=16, shuffle=True, num_workers=4) 

In [8]:
model_ft = models.inception_v3(pretrained=True).to(device)   # Use any pretrained NN from pytorch
model_ft.eval();

In [9]:
correct = 0

confs = []
labels = []

with torch.no_grad():
    for i, (image, target) in enumerate(val_loader):
        image, target = image.to(device), target.to(device)
        pred = model_ft(image)
        conf = F.softmax(pred, dim=1)
        confs.append(conf)
        labels.append(target)
        correct += (torch.argmax(conf, dim=1) == target).type(torch.float).sum()
        if (i+1) % 100 == 0:
            print("Iteration %d, accuracy=%.3f" % (i, correct.cpu().item()/(i+1)/16.))

Iteration 99, accuracy=0.699
Iteration 199, accuracy=0.691
Iteration 299, accuracy=0.692
Iteration 399, accuracy=0.692
Iteration 499, accuracy=0.691
Iteration 599, accuracy=0.691
Iteration 699, accuracy=0.692
Iteration 799, accuracy=0.694
Iteration 899, accuracy=0.692
Iteration 999, accuracy=0.694
Iteration 1099, accuracy=0.695
Iteration 1199, accuracy=0.695
Iteration 1299, accuracy=0.696
Iteration 1399, accuracy=0.695
Iteration 1499, accuracy=0.697
Iteration 1599, accuracy=0.697
Iteration 1699, accuracy=0.698
Iteration 1799, accuracy=0.697
Iteration 1899, accuracy=0.698
Iteration 1999, accuracy=0.697
Iteration 2099, accuracy=0.696
Iteration 2199, accuracy=0.696
Iteration 2299, accuracy=0.696
Iteration 2399, accuracy=0.695
Iteration 2499, accuracy=0.695
Iteration 2599, accuracy=0.696
Iteration 2699, accuracy=0.696
Iteration 2799, accuracy=0.696
Iteration 2899, accuracy=0.695
Iteration 2999, accuracy=0.696
Iteration 3099, accuracy=0.695


In [10]:
confs = torch.cat(confs).cpu()
labels = torch.cat(labels).cpu()

In [11]:
results = {'val_labels': labels[:40000].cpu(), 'val_prob': confs[:40000].cpu(), 
           'test_labels': labels[40000:].cpu(), 'test_prob': confs[40000:].cpu()}
torch.save(results, 'imagenet/inception_v3.pt')   # Name this with the correct architecture name