In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {}'.format(device))
torch.device(device)

Using cuda


device(type='cuda')

In [3]:
#training variables

imgTransformSize = 224
#range of degrees +- to rotate
imgTransformRngRot = 5

epochs = 25

modelLearnRate = 0.1
modelMomentum=0.5
modelWeightDecay=0.003


train_dataset_path = './datasets/FoodTrain1'
valid_dataset_path = './datasets/FoodValidate1'

In [4]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(scale=(0.6, 1.0), size=(imgTransformSize,imgTransformSize)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(imgTransformRngRot),
    transforms.ToTensor()
])

valid_transforms = transforms.Compose([
    transforms.RandomResizedCrop(size=(imgTransformSize,imgTransformSize)),
    transforms.ToTensor()
])

In [5]:
train_dataset = torchvision.datasets.ImageFolder(root = train_dataset_path, transform = train_transforms)
valid_dataset = torchvision.datasets.ImageFolder(root = valid_dataset_path, transform = valid_transforms)

In [6]:
def show_transformed_images(dataset):
    loader = torch.utils.data.DataLoader(dataset, batch_size = 6, shuffle=True)
    batch = next(iter(loader))
    images, labels = batch
    
    grid = torchvision.utils.make_grid(images, nrow=3)
    plt.figure(figsize=(11,11))
    plt.imshow(np.transpose(grid,(1,2,0)))
    print('labels: ', labels)

#show_transformed_images(train_dataset)

In [7]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 32, shuffle = True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size = 32, shuffle = False)

train_losses, valid_losses = [], []
train_accs, valid_accs = [], []


In [8]:
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim

resnet18_model = models.resnet18(weights = None)
num_features = resnet18_model.fc.in_features
num_object_categories = 40 
resnet18_model.fc = nn.Linear(num_features, num_object_categories)
resnet18_model = resnet18_model.to(device)


resnet50_model = models.resnet50(weights = models.ResNet50_Weights.IMAGENET1K_V2)
num_features = resnet50_model.fc.in_features
num_object_categories = 40 
resnet50_model.fc = nn.Linear(num_features, num_object_categories)
resnet50_model = resnet50_model.to(device)


#The model we're actually using
usedModel = resnet50_model

In [9]:
loss_fn = nn.CrossEntropyLoss()

optimiser = optim.SGD(usedModel.parameters(), lr=modelLearnRate, momentum=modelMomentum, weight_decay=modelWeightDecay)

In [10]:
def save_checkpoint(model, epoch, optimiser, best_acc):
    state = {
        'epoch': epoch+1,
        'model': model.state_dict(),
        'best_accuracy': optimiser.state_dict(),
        'comments':'test'
    }
    torch.save(state, 'model_best_checkpoint.pth.tar')

In [11]:
def train_network(model, train_loader, valid_loader, criterion, optimiser, n_epochs):
    
    best_acc = 0
    
    for epoch in range(n_epochs):
        #print("Epoch number %d (epoch + 1)")
        model.train()
        epoch_loss, epoch_accuracy = 0, 0
        epoch_valid_accuracy, epoch_valid_loss = 0, 0
        totalImg = 0
        running_total = 0

        for data in train_loader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            totalImg += labels.size(0)

            optimiser.zero_grad()

            outputs = model(images)

            loss = criterion(outputs, labels)
            loss.backward()

            optimiser.step()

            acc = ((outputs.argmax(dim=1) == labels).float().mean())

            running_total += (outputs.argmax(dim=1) == labels).sum().item()

            epoch_accuracy += acc/len(train_loader)
            epoch_loss += loss/len(train_loader)

        print('Epoch: {}, train accuracy: {:.2f}%, train loss: {:.4f}'.format(epoch+1, epoch_accuracy*100, epoch_loss))
        train_losses.append(epoch_loss.item())
        train_accs.append(epoch_accuracy.item())

        #print("     -training set got %d out of %d images (%.3f%%)" % (running_total, totalImg, epoch_accuracy*100))
        model.eval()

        with torch.no_grad():

            for data in valid_loader:
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)

                valid_output = model(images)
                valid_loss = loss_fn(valid_output, labels)

                acc = ((valid_output.argmax(dim=1) == labels).float().mean())
                epoch_valid_accuracy += acc/len(valid_loader)
                epoch_valid_loss += valid_loss/len(valid_loader) 
                
        print('Epoch: {}, validation accuracy: {:.2f}%, valid loss: {:.4f}'.format(epoch+1, epoch_valid_accuracy*100, epoch_valid_loss))
        valid_losses.append(epoch_valid_loss.item())
        valid_losses.append(epoch_valid_accuracy.item())

        if best_acc < epoch_valid_accuracy:
            best_acc = epoch_valid_accuracy
            save_checkpoint(model, epoch, optimiser, best_acc)

    print("training complete")

    return model


In [12]:
model = train_network(usedModel, train_loader, valid_loader, loss_fn, optimiser, epochs)

Epoch: 1, train accuracy: 12.67%, train loss: 3.5286
Epoch: 1, validation accuracy: 29.69%, valid loss: 3.1037
Epoch: 2, train accuracy: 57.79%, train loss: 2.2838
Epoch: 2, validation accuracy: 35.94%, valid loss: 2.4698
Epoch: 3, train accuracy: 77.88%, train loss: 1.0892
Epoch: 3, validation accuracy: 78.12%, valid loss: 1.3339
Epoch: 4, train accuracy: 95.00%, train loss: 0.3720
Epoch: 4, validation accuracy: 90.62%, valid loss: 0.8064
Epoch: 5, train accuracy: 98.75%, train loss: 0.1626
Epoch: 5, validation accuracy: 95.31%, valid loss: 0.6737
Epoch: 6, train accuracy: 100.00%, train loss: 0.0808
Epoch: 6, validation accuracy: 96.88%, valid loss: 0.3282
Epoch: 7, train accuracy: 98.67%, train loss: 0.0725
Epoch: 7, validation accuracy: 98.44%, valid loss: 0.4138
Epoch: 8, train accuracy: 100.00%, train loss: 0.0710
Epoch: 8, validation accuracy: 92.19%, valid loss: 0.3469
Epoch: 9, train accuracy: 100.00%, train loss: 0.0187
Epoch: 9, validation accuracy: 98.44%, valid loss: 0.103

In [13]:
#fig, (ax1, ax2) = plt.subplots(2, figsize=(12, 8), sharex=True)
#ax1.plot(train_losses, color='b', label='train')
#ax1.plot(valid_losses, color='g', label='valid')
#ax1.set_ylabel("Loss")
#ax1.legend()
#ax2.plot(train_accs, color='b', label='train')
#ax2.plot(valid_accs, color='g', label='valid')
#ax2.set_ylabel("Accuracy")
#ax2.set_xlabel("Epoch")
#ax2.legend()