In [None]:
#---- Import necessary libraries ----
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

import os
import time

In [None]:
#---- Set up training data directories ----
base_dir = "data/"


# Define image transforms to convert images to tensor and normalize
image_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])

# Use ImageFolder object to simplify data loading
train_dataset = ImageFolder(os.path.join(base_dir, 'train/'), transform=image_transform)
test_dataset = ImageFolder(os.path.join(base_dir, 'test/'), transform=image_transform)

BATCH_SIZE = 16

# Create data loader objects
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
#---- Define CNN model for multiclass classification ----
NUM_OUTPUT_CLASSES = 5

model = nn.Sequential()

# First block of convolution -> ReLU -> MaxPooling
model.add_module('conv1', nn.Conv2d(in_channels=3, out_channels = 8, kernel_size=3, padding=0))
model.add_module('relu1', nn.ReLU())
model.add_module('pool1', nn.MaxPool2d(kernel_size=4))

# Second block of convolution -> ReLU -> MaxPooling




# Third block of convolution -> ReLU -> MaxPooling




# Flatten layer to transform feature maps into 1D vector


# Output layer: must set output size to number of classes


# Display model iformation
model
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of trainable model parameters: ", num_params)

In [None]:
#---- Set compute device and send model, set up optimizer and loss function ----
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
print(device)
model = model.to(device)

# Cross Entropy is the go-to loss function for multiclass classification



In [None]:
def training_loop(model, dataloader, val_dataloader, opt, criterion, num_epochs):
    total_time = 0.0
    
    # ###### TRAINING ######
    model.train()
    # Iterate through number of training epochs
    for n in range(num_epochs):
        # Initialize values for running statistics
        epoch_train_loss = 0.0
        epoch_train_correct = 0
        counter = 0
        start = time.time()
        # Iterate through training data set
        for data, target in dataloader:
            # Send input and targets to compute device
            data = data.to(device)
            target = target.to(device)
            # Must zero gradient every training step
            opt.zero_grad()
            # Forward pass
            outputs = model(data)
            # Compute loss
            loss = criterion(outputs, target)
            epoch_train_loss += loss.item()
            # Record accuracy
            _, preds = torch.max(outputs.data, 1)
            epoch_train_correct += (preds == target).sum().item()
            # Backpropogate
            loss.backward()
            opt.step()
            
            counter += 1
        # Compute and display statistics
        train_loss = epoch_train_loss / counter
        correct = epoch_train_correct / len(dataloader.dataset) * 100
        
        print(f'Epoch: {n} | Training loss: {train_loss:.3f} | Train accuracy: {correct:3f}%')

        # ###### VALIDATION ######
        
        epoch_train_loss = 0.0
        epoch_train_correct = 0
        counter = 0
        
        model.eval()
        for data, target in val_dataloader:
            with torch.no_grad():
                data = data.to(device)
                target = target.to(device)

                outputs = model(data)

                loss = criterion(outputs, target)
                epoch_train_loss += loss.item()

                _, preds = torch.max(outputs.data, 1)
                epoch_train_correct += (preds == target).sum().item()
                
                counter += 1
            
        train_loss = epoch_train_loss / counter
        correct = epoch_train_correct / len(val_dataloader.dataset) * 100
        
        print(f'Validation loss: {train_loss:.3f} | Validation accuracy: {correct:3f}%')
        epoch_time = time.time() - start
        print("Epoch training time: ", epoch_time, " seconds")
        total_time += epoch_time
        print("#" * 40)
    
    return total_time

In [None]:
#---- Run training loop ----
EPOCHS = 5

print("Total training time: ", train_time, " seconds") 

In [None]:
#---- Export our trained model to an ONNX file ----
import torch.onnx
model.eval()
# ONNX stores our model by giving dummy data as an input and recording the operations
# Dummy data
x = torch.randn((1,3,480,640), requires_grad=True)
x = x.to(device)
# Forward pass
torch_out = model(x)
# Save model to onnx
torch.onnx.export(model, x, 'onnx_classifier.onnx',
                 export_params=True,
                 opset_version=10,
                 do_constant_folding=True,
                 input_names=['input'],
                 output_names=['output'],
                 dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})
                                