In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms

import time
import copy
import os

In [3]:
import logging
l = logging.getLogger(__name__)

logging.basicConfig(format='[%(asctime)s][%(levelname)s]: %(message)s',
                        level=logging.INFO)

In [4]:
# Preprocessing according to the pre-trained model
# http://pytorch.org/docs/master/torchvision/models.html
preprocess = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
       mean=[0.485, 0.456, 0.406],
       std=[0.229, 0.224, 0.225]
    ),
])
    
def prepare_image(image):
    """
    Prepare image to be fed into the model.
    
    
    """
    
    # These preprocessing steps are taken from PyTorch's ImageNet example
    return preprocess(image).unsqueeze_(0)

In [8]:
ds_train = datasets.ImageFolder("./training", preprocess)

In [9]:
ds_test = datasets.ImageFolder("./testing", preprocess)

In [10]:
ds_train_ldr = torch.utils.data.DataLoader(ds_train, batch_size=1, shuffle=True)

In [11]:
ds_test_ldr = torch.utils.data.DataLoader(ds_test, batch_size=1, shuffle=True)

In [12]:
def train(model, loss_function, optimizer, num_epochs=10):
    """
    Begin training the last layer
    :param model: Model to train
    :param loss_function: Use to assess the model
    :praram optimizer:
    :param num_epochs: Number of epochs to train on.
    :return: best model
    """
    
    
    best_model = model
    best_acc = 0
    
    for e in range(num_epochs):
        l.info("EPOCH: %d/%d", e, num_epochs-1)
        
        for is_training in [True, False]:
            model.train(is_training)
            
            if is_training:
                # Train the model
                for data, target in ds_train_ldr:
                    x, y = Variable(data), Variable(target)
                    
                    optimizer.zero_grad()
                    
                    r = model(x)
                    _, preds = torch.max(r.data, 1)
#                     l.info("Predictions: %s", str(preds))
                    
                    loss = loss_function(r, y)
#                     l.info("Loss: %s", str(loss))
                    
                    loss.backward()
                    optimizer.step()

            else:
                v_loss = 0
                v_acc = 0
                
                # Validate the model
                for data, target in ds_test_ldr:
                    x, y = Variable(data), Variable(target)
                    
                    optimizer.zero_grad()
                    
                    r = model(x)
                    _, preds = torch.max(r.data, 1)
#                     l.info("Predictions: %s", str(preds))
                    
                    loss = loss_function(r, y)
#                     l.info("Loss: %s", str(loss))
                    
                    v_loss += loss.data[0]
                    v_acc += torch.sum(preds == y.data)
                
                epoch_loss = v_loss / len(ds_test)
                epoch_acc = v_acc / len(ds_test)
                
                l.info("Epoch %d, loss %f, acc %f", e, epoch_loss, epoch_acc)
                
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model = copy.deepcopy(model)
    l.info("Complete. Best acc %f", best_acc)
    return best_model

In [37]:
model_conv = torchvision.models.resnet152(pretrained=True)

In [38]:
for param in model_conv.parameters():
    param.requires_grad = True

In [39]:
# model_conv
# num_ftrs = model_conv.fc.in_features
num_ftrs = 9216
# model_conv.fc = nn.Linear(num_ftrs, 3)
model_conv.classifier = nn.Linear(num_ftrs, 3)

In [40]:
loss_function = nn.CrossEntropyLoss()

In [41]:
optimizer = optim.SGD(model_conv.classifier.parameters(), lr=0.001, momentum=0.9)

In [None]:
model_conv = train(
    model=model_conv, 
    loss_function=loss_function, 
    optimizer=optimizer, 
    num_epochs=100)

[2017-11-01 15:21:08,870][INFO]: EPOCH: 0/99
