In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import models, transforms
import time
from IPython.display import display, clear_output
import pdb 

In [2]:
batch_size = 50

In [3]:
chan_mean = [0.485, 0.456, 0.406]
chan_std = [0.229, 0.224, 0.225]

# training data
train_data_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(chan_mean, chan_std)
])

train_set = torchvision.datasets.CIFAR10(root='./data',
                                        train=True,
                                        download=True,
                                        transform=train_data_transform)
train_loader = torch.utils.data.DataLoader(train_set,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=2)

Using downloaded and verified file: ./data/cifar-10-python.tar.gz


In [4]:
val_data_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(chan_mean, chan_std)
])

val_set = torchvision.datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=train_data_transform)
val_order = torch.utils.data.DataLoader(val_set,
                                        batch_size=batch_size,
                                        shuffle=False,
                                        num_workers=2)

Files already downloaded and verified


In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
def train_model(model, loss_function, optimizer, data_loader, epochs):
    model.train()
    
    current_loss = 0.0
    current_acc = 0
    
    # iterate over training data
    for i, (inputs, labels) in enumerate(data_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # status
        clear_output(wait=True)
        
        
        
        # zero the parameter gradients
        optimizer.zero_grad()
        
        with torch.set_grad_enabled(True):
            # forward
            outputs = model(inputs)
            _, predictions = torch.max(outputs,1)
            loss = loss_function(outputs, labels)
            
            # backward
            loss.backward()
            optimizer.step()
            
        # statistics
        current_loss += loss.item() * inputs.size(0)
        current_acc += torch.sum(predictions == labels.data)
        
        print(str(epochs[0]) + '/' + str(epochs[1]))
        print('batch: ' + str(i+1) + '/' + str(len(data_loader)) + 
             ' [' + '='*int((i+1)/(len(data_loader)/20)) +
              '>' + ' '*(20 - int((i+1)/(len(data_loader)/20))) +
              ']')
        print('Loss: %.4g   Accuracy: %.4g' % ((current_loss / ((i+1)*batch_size)),
                                               (current_acc.double() / ((i+1)*batch_size))))
         
    total_loss = current_loss / len(data_loader.dataset)
    total_acc = current_acc.double() / len(data_loader.dataset)
    
    print('Train Loss: {:.4f}; Accuracy: {:.4f}'.format(total_loss,total_acc))

In [7]:
def test_model(model, loss_function, data_loader):
    model.eval()
    
    current_loss = 0.0
    current_acc = 0
    
    # iterate over validation data
    for i, (inputs, labels) in enumerate(data_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        with torch.set_grad_enabled(False): #invece che True
            # forward
            outputs = model(inputs)
            _, predictions = torch.max(outputs,1)
            loss = loss_function(outputs, labels)
            
        # statistics
        current_loss += loss.item() * inputs.size(0)
        current_acc += torch.sum(predictions == labels.data)
        
    total_loss = current_loss / len(data_loader.dataset)
    total_acc = current_acc.double() / len(data_loader.dataset)
    
    print('Test Loss: {:.4f}; Accuracy: {:.4f}'.format(total_loss,total_acc))

In [8]:
def tl_feature_extractor(epochs=3):
    model = torchvision.models.resnet18(pretrained=True)
    
    # exclude existing param from backward pass 
    for param in model.parameters():
        param.requires_grad = False
        
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 10)
    
    model = model.to(device)
    
    loss_function = nn.CrossEntropyLoss()
    
    # optimize only param of the last layer
    optimizer = optim.Adam(model.fc.parameters())
    
    # train
    for epoch in range(epochs):
        #print('Epoch {}/{}'. format(epoch+1, epochs))
        
        train_model(model, loss_function, optimizer, train_loader,
                   [epoch+1, epochs])
        test_model(model, loss_function, val_order)

In [9]:
def tl_fine_tuning(epochs=3):
    model = models.resnet18(pretrained=True)
    
    # replace last layer    
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 10)
    
    model = model.to(device)
    
    loss_function = nn.CrossEntropyLoss()
    
    # optimize only param of the last layer
    optimizer = optim.Adam(model.parameters())
    
    # train
    for epoch in range(epochs):
        #print('Epoch {}/{}'. format(epoch+1, epochs))
        
        train_model(model, loss_function, optimizer, train_loader,
                   [epoch+1, epochs])
        test_model(model, loss_function, val_order)

In [10]:
print(train_set)

Dataset CIFAR10
    Number of datapoints: 50000
    Split: train
    Root Location: ./data
    Transforms (if any): Compose(
                             Resize(size=224, interpolation=PIL.Image.BILINEAR)
                             RandomHorizontalFlip(p=0.5)
                             RandomVerticalFlip(p=0.5)
                             ToTensor()
                             Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                         )
    Target Transforms (if any): None


In [None]:
tl_feature_extractor()

1/3
batch: 4/1000 [>                    ]
Loss: 2.484   Accuracy: 0.08
