In [129]:
import torch
import torchvision
import torchvision.transforms as transforms
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [130]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {}'.format(device))
torch.device(device)

Using cuda


device(type='cuda')

In [131]:
#training variables

imgTransformSize = 224
#range of degrees +- to rotate
imgTransformRngRot = 5

epochs = 25

modelLearnRate = 0.1
modelMomentum= 0.5
modelWeightDecay= 0.003


train_dataset_path = './datasets/FoodTrain1'
valid_dataset_path = './datasets/FoodValidate1'

In [132]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(scale=(0.6, 1.0), size=(imgTransformSize,imgTransformSize)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(imgTransformRngRot),
    transforms.ToTensor()
])

valid_transforms = transforms.Compose([
    transforms.RandomResizedCrop(size=(imgTransformSize,imgTransformSize)),
    transforms.ToTensor()
])

In [133]:
train_dataset = torchvision.datasets.ImageFolder(root = train_dataset_path, transform = train_transforms)
valid_dataset = torchvision.datasets.ImageFolder(root = valid_dataset_path, transform = valid_transforms)

In [134]:
def show_transformed_images(dataset):
    loader = torch.utils.data.DataLoader(dataset, batch_size = 6, shuffle=True)
    batch = next(iter(loader))
    images, labels = batch
    
    grid = torchvision.utils.make_grid(images, nrow=3)
    plt.figure(figsize=(11,11))
    plt.imshow(np.transpose(grid,(1,2,0)))
    print('labels: ', labels)

#show_transformed_images(train_dataset)

In [135]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 32, shuffle = True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size = 32, shuffle = False)

train_losses, valid_losses = [], []
train_accs, valid_accs = [], []


In [136]:
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim

resnet18_model = models.resnet18(weights = models.ResNet18_Weights.DEFAULT)
num_features = resnet18_model.fc.in_features
num_object_categories = 40 
resnet18_model.fc = nn.Linear(num_features, num_object_categories)
resnet18_model = resnet18_model.to(device)


resnet50_model = models.resnet50(weights = models.ResNet50_Weights.IMAGENET1K_V2)
num_features = resnet50_model.fc.in_features
num_object_categories = 40 
resnet50_model.fc = nn.Linear(num_features, num_object_categories)
resnet50_model = resnet50_model.to(device)


#The model we're actually using
usedModel = resnet50_model

from datetime import datetime

now = datetime.now()
checkpointName = now.strftime("%y%m%d%H%M")

if (usedModel == resnet50_model):
    checkpointName = 'resnet50_'+checkpointName
else:
    checkpointName = 'resnet18_'+checkpointName

In [137]:
loss_fn = nn.CrossEntropyLoss()

optimiser = optim.SGD(usedModel.parameters(), lr=modelLearnRate, momentum=modelMomentum, weight_decay=modelWeightDecay)

In [138]:
def save_checkpoint(model, epoch, optimiser, best_acc):
    state = {
        'epoch': epoch+1,
        'model': model.state_dict(),
        'best_accuracy': best_acc,
        'optimiser' : optimiser.state_dict(),
        'comments':'learning rate: {:.2f} momentum: {:.2f} Weight Decay: {:.5f}'.format(modelLearnRate, modelMomentum, modelWeightDecay)
    }

    torch.save(state, checkpointName+'.pth.tar')

In [139]:
def train_network(model, train_loader, valid_loader, criterion, optimiser, n_epochs):
    
    best_acc = 0
    
    for epoch in range(n_epochs):
        #print("Epoch number %d (epoch + 1)")
        model.train()
        epoch_loss, epoch_accuracy = 0, 0
        epoch_valid_accuracy, epoch_valid_loss = 0, 0
        totalImg = 0
        running_total = 0

        for data in train_loader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            totalImg += labels.size(0)

            optimiser.zero_grad()

            outputs = model(images)

            loss = criterion(outputs, labels)
            loss.backward()

            optimiser.step()

            acc = ((outputs.argmax(dim=1) == labels).float().mean())

            running_total += (outputs.argmax(dim=1) == labels).sum().item()

            epoch_accuracy += acc/len(train_loader)
            epoch_loss += loss/len(train_loader)

        print('Epoch: {}, train accuracy: {:.2f}%, train loss: {:.4f}'.format(epoch+1, epoch_accuracy*100, epoch_loss), end=" | ")
        train_losses.append(epoch_loss.item())
        train_accs.append(epoch_accuracy.item())

        #print("     -training set got %d out of %d images (%.3f%%)" % (running_total, totalImg, epoch_accuracy*100))
        model.eval()

        with torch.no_grad():

            for data in valid_loader:
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)

                valid_output = model(images)
                valid_loss = loss_fn(valid_output, labels)

                acc = ((valid_output.argmax(dim=1) == labels).float().mean())
                epoch_valid_accuracy += acc/len(valid_loader)
                epoch_valid_loss += valid_loss/len(valid_loader) 
                
        print('Epoch: {}, validation accuracy: {:.2f}%, valid loss: {:.4f}'.format(epoch+1, epoch_valid_accuracy*100, epoch_valid_loss))
        valid_losses.append(epoch_valid_loss.item())
        valid_losses.append(epoch_valid_accuracy.item())

        if best_acc <= epoch_valid_accuracy:
            best_acc = epoch_valid_accuracy
            save_checkpoint(model, epoch, optimiser, best_acc)

    print("training complete")

    return model


In [140]:
model = train_network(usedModel, train_loader, valid_loader, loss_fn, optimiser, epochs)

Epoch: 1, train accuracy: 3.91%, train loss: 3.6192 | Epoch: 1, validation accuracy: 9.38%, valid loss: 3.4903
Epoch: 2, train accuracy: 31.42%, train loss: 2.7123 | Epoch: 2, validation accuracy: 17.19%, valid loss: 3.2059
Epoch: 3, train accuracy: 61.89%, train loss: 1.5777 | Epoch: 3, validation accuracy: 26.56%, valid loss: 2.9146
Epoch: 4, train accuracy: 80.82%, train loss: 0.7603 | Epoch: 4, validation accuracy: 21.88%, valid loss: 2.5221
Epoch: 5, train accuracy: 94.10%, train loss: 0.3900 | Epoch: 5, validation accuracy: 42.19%, valid loss: 2.0471
Epoch: 6, train accuracy: 96.44%, train loss: 0.1779 | Epoch: 6, validation accuracy: 53.12%, valid loss: 1.8223
Epoch: 7, train accuracy: 98.44%, train loss: 0.1037 | Epoch: 7, validation accuracy: 53.12%, valid loss: 1.5842
Epoch: 8, train accuracy: 100.00%, train loss: 0.0482 | Epoch: 8, validation accuracy: 51.56%, valid loss: 1.7344
Epoch: 9, train accuracy: 100.00%, train loss: 0.0348 | Epoch: 9, validation accuracy: 48.44%, va

In [141]:
checkpoint = torch.load(checkpointName+'.pth.tar')

if usedModel == resnet18_model:
    savedModel = models.resnet18()
else:
    savedModel = models.resnet50()

num_ftrs = savedModel.fc.in_features
savedModel.fc = nn.Linear(num_ftrs, num_object_categories)
savedModel.load_state_dict(checkpoint['model'])

torch.save(savedModel, checkpointName+'.pth')