In [16]:
# import os
# os.environ['https_proxy'] = "http://hpc-proxy00.city.ac.uk:3128"
# # os.environ['http_proxy'] = “http://hpc-proxy00.city.ac.uk:3128”

In [17]:
import numpy as np
from numpy.random import default_rng
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import PIL
import cv2
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from torch.utils.data import DataLoader, TensorDataset
import torchvision.transforms as transforms
import torch.utils.data.dataloader as dataloader
from torchvision import datasets

In [18]:
transform_train = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
transform_test = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),     
            ])

In [19]:
filepath='.'
filedir=os.path.join(os.getcwd(),filepath,'Vegetable Images')
batch_size=128

for i in os.listdir(filedir):
    if(i=='train'):
        train_dataset=datasets.ImageFolder(
            root=os.path.join(filedir,i),
            transform=transform_train
    )
    elif(i=='test'):
        test_dataset=datasets.ImageFolder(
            root=os.path.join(filedir,i),
            transform=transform_test
    )
    elif(i=='validation'):
        valid_dataset=datasets.ImageFolder(
            root=os.path.join(filedir,i),
            transform=transform_test
    )
    else:
        raise(Exception('Unexpected error occurred.'))
train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=True)
valid_loader=DataLoader(valid_dataset,batch_size=batch_size,shuffle=True)

# Base Model

In [20]:
D1=32
D2=64
D3=128

A1=512
A2=256
num_classes=15
input_pix=224
num_neurons=int(np.floor(np.floor(np.floor(input_pix/2)/2)/2)**2*D3)
input=3
class my_nn(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            
            nn.Conv2d(input, D1, kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
        
            nn.Conv2d(D1, D2, kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            
            nn.Conv2d(D2, D3, kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            
            nn.Flatten(),
            
            
            nn.Linear(num_neurons,A1),
            nn.ReLU(),
            nn.Linear(A1, A2),
            nn.ReLU(),
            nn.Linear(A2,num_classes)
        )
    
    def forward(self, xb):
        return self.network(xb)

In [21]:
#Loading checkpoint in case of outside problems interrupting training
mydir='MODELS'
save_path=os.path.join(mydir, 'mymodel.pt')

device = ('cuda' if torch.cuda.is_available() else 'cpu')
lr = 1e-3
epochs = 100
opt=torch.optim.Adam
model=my_nn().to(device)

if not os.path.isdir(mydir):
    os.makedirs(mydir)

if os.path.isfile(save_path):
    checkpoint=torch.load(save_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    opt.load_state_dict(checkpoint['optimizer_state_dict'])
    start=checkpoint['epoch']
    best_valid=checkpoint['valid_acc']
    best_test=checkpoint['test_acc']
    print('Starting from epoch', start)
else:
    print('No Checkpoint, starting from scratch')

checkpoint=torch.load

No Checkpoint, starting from scratch


In [22]:
def train(model,train_loader,epochs,lr,opt,loss_func):
    model.train()
    losslog=[]
    optimizer=opt(model.parameters(),lr)
    for _,(x,y) in enumerate(train_loader):
        batch_x=x.to(device)
        ypred=model(batch_x)
        loss=loss_func(ypred,y.to(device))
        losslog.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    return losslog

def evaluate(model,test_loader):
    model.eval()
    acc = 0
        
    with torch.no_grad():
        for _, (x, y) in enumerate(test_loader):
            batch_x = model(x.to(device))
            acc += (batch_x.argmax(1) == y.to(device)).sum().item()
            
    #return the accuracy from the epoch     
    return acc / len(test_loader.dataset)
            

def fit(model,train_loader,test_loader,val_loader,epochs,lr=1e-3,opt=torch.optim.SGD,loss_func=F.cross_entropy):
    train_acc_log=[]
    test_acc_log=[]
    valid_acc_log=[]
    best_valid=0
    best_test=0
    for iter in range(epochs):
        #training step
        train_log=train(model,train_loader,epochs,lr,opt,loss_func)
            
        train_acc=evaluate(model,train_loader)
        test_acc=evaluate(model,test_loader)
        valid_acc=evaluate(model,test_loader)
        print("Epoch [{}], train acc: {:.2f}, test acc: {:.2f}, val acc: {:.2f}".format(iter, train_acc, test_acc, valid_acc))
        #evaluate step
        train_acc_log.append(train_acc)
        test_acc_log.append(test_acc)
        valid_acc_log.append(valid_acc)
        
        #saving models to avoid running through entire epochs
        if(valid_acc>best_valid and test_acc>=best_test):
            best_valid=valid_acc
            best_test=test_acc
            
            torch.save(
                {
                    'epochs': epochs,
                    'model_state_dict': model.state_dict(),
                    'opt_state_dict': opt.state_dict(),
                    'train_acc': train_acc,
                    'test_acc': test_acc,
                    'valid_acc': valid_acc,
                }, save_path
            )
            print('Model Saved for epoch {}'.format(iter))
    
    
    print('Training Complete')
    return train_acc,test_acc,valid_acc

In [23]:
train_acc,test_acc,valid_acc=fit(model,train_loader,test_loader,valid_loader,20,opt=torch.optim.Adam)

Epoch [0], train acc: 0.92, test acc: 0.90, val acc: 0.90
Epoch [1], train acc: 0.97, test acc: 0.95, val acc: 0.95
Epoch [2], train acc: 0.98, test acc: 0.95, val acc: 0.95


KeyboardInterrupt: 