In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim 
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import time

%load_ext autoreload
%autoreload 2

from caltech256 import Caltech256

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
    'test': transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
       transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
}

In [9]:
# data_dir = 'data/256_ObjectCategories'
data_dir = '/datasets/Caltech256/256_ObjectCategories'
caltech256_train = Caltech256(data_dir, data_transforms['train'], train=True)
caltech256_test = Caltech256(data_dir, data_transforms['test'], train=False)

In [10]:
vgg16 = models.vgg16_bn(pretrained=True)

In [11]:
vgg16.classifier = nn.Sequential(nn.Linear(in_features=25088, out_features=4096),
                                nn.ReLU(),
                                nn.Dropout(p=0.5),
                                nn.Linear(in_features=4096, out_features=4096),
                                nn.ReLU(),
                                nn.Dropout(p=0.5),
                                nn.Linear(in_features=4096, out_features=257))

In [12]:
for param in vgg16.features.parameters():
    param.requires_grad = False

In [16]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vgg16.classifier.parameters())
scheduler = lr_scheduler.StepLR(optimizer, step_size=30)

vgg16 = nn.DataParallel(vgg16)
vgg16 = vgg16.cuda()

In [17]:
dataloader = DataLoader(caltech256_train, batch_size=4)
dataiter = iter(dataloader)
image, label = dataiter.next()
print(image.size())
print(label.size())

torch.Size([4, 3, 224, 224])
torch.Size([4])


In [18]:
def train_model(model, dataset, criterion, optimizer, scheduler, num_epochs, batch_size):
    start_time = time.time()
    model.train(True)
    dataset_size = dataset.__len__()
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    for epoch in range(num_epochs):
        scheduler.step()
        running_loss = 0.
        batch_cnt = 0
        
        for data in dataloader:
            inputs, labels = data
            inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
            
            optimizer.zero_grad()
            outputs = model(inputs)
            
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss
            
            batch_cnt += 1
            if batch_cnt % 150 == 0:
                print('Training completed [%d, %d]' % (epoch, batch_cnt))
            
            
        epoch_loss = running_loss / dataset_size
        print('%d epoch loss: %f' % (epoch, running_loss))
        
    model.train(False)
    time_elapsed = time.time() - start_time
    print('Training comple in %dm, %ds' % (time_elapsed//60, time_elapsed%60))
    return model

In [19]:
model_tf = train_model(vgg16, caltech256_train, criterion, optimizer, scheduler, num_epochs=5, batch_size=16)

KeyboardInterrupt: 

In [35]:
test_dataloader = DataLoader(caltech256_test, batch_size=16)
correct_cnt = 0
cnt = 0
for data in test_dataloader:
    inputs, labels = data
    inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
    outputs = vgg16(inputs)
    _, preds = torch.max(outputs, 1)
    correct_cnt += torch.sum(preds.data == labels.data)
    
acc = correct_cnt / caltech256_test.__len__()
print('Test Set Accuracy: %f' % (acc*100))

Test Set Accuracy: 6.738281
