In [8]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from torchvision.models import mobilenet_v3_small
from torchvision.transforms import v2
from torch.utils.data import DataLoader, Subset
from preprocessing_pipeline_v2 import data_generator
from sklearn.metrics import top_k_accuracy_score as top_k

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# TODO: add more transforms - RandomPerspective/ RandomRotation/ FiveCrop OR RandomCrop &OR RandomResizedCrop
train_transform = v2.Compose([v2.ToImage(), 
                              v2.ToDtype(torch.float32, scale=True)
                              #v2.RandomPerspective(),
                              #v2.RandomRotation(degrees=30),
                              #v2.FiveCrop(size=(48,48)),
                              #v2.Normalize()
                              ])

eval_transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
# start small scale with only 10 classes
num_classes = 10

In [10]:
# --- SOME HYPERPARAMETERS ---
# num_workers = [0,1,2,3]
# num_epochs = 5
# learn_rate = 1e-4

In [26]:
#num_workers = 2
num_epochs = 20
learn_rate = 1e-4

In [12]:
# --- LOAD DATA ---
train_data, val_data, test_data = data_generator(train_transform, 
                                                 eval_transform, 
                                                 num_classes)

In [21]:
train_dl = DataLoader(train_data, batch_size=64,
                      shuffle=True, num_workers=2,
                      pin_memory=True)
val_dl = DataLoader(val_data, batch_size=64,
                    shuffle=True, num_workers=2,
                    pin_memory=True)
test_dl = DataLoader(test_data, batch_size=64,
                     shuffle=True, num_workers=4,
                     pin_memory=True)

In [28]:
# --- INSTANTIATE / LOAD MODEL ---
# model.load()???
mobilenet_model = mobilenet_v3_small(weights='DEFAULT')
# get number of in features from source
num_features = mobilenet_model.classifier[3].in_features
# redefine the networks final fully connected layer
mobilenet_model.classifier[3] = nn.Linear(num_features, num_classes)
# send to gpu
mobilenet_model = mobilenet_model.to(device)

# --- TRAINING ---
# TODO: add in validation data, early stopping, etc.
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=mobilenet_model.parameters(), lr=learn_rate)

# initialize training loss
training_loss = []

for epoch in range(num_epochs):
    #start_time = time.time()
    i = 0
    mobilenet_model.train()
    for images, labels in train_dl:
        # load data to gpu
        images = images.to(device)
        labels = labels.to(device)
        
        y_pred = mobilenet_model(images)
        loss = loss_fn(y_pred, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
      
        if i % 1000 == 0:
            training_loss.append(loss.item())
            print(f"Epoch: {epoch+1}/{num_epochs} --- Training Loss: {loss.item():.4f}")
        i += 1
            
    # mobilenet_model.eval()
    # y_true = torch.empty()
    # y_pred = torch.empty()
    # for (images, labels) in val_dl:
    #     images = images.to(device)
    #     labels = labels.to(device)
        
    #     y_true = torch.cat([y_true, labels], dim=0)
    #     y_pred = torch.cat([y_pred, mobilenet_model(images)], dim=0)
    # print(top_k(y_true, y_pred, k=3))
            
#mobile_train_time = time.time() - start_time
# model.save()???

Epoch: 1/20 --- Training Loss: 2.8585
Epoch: 2/20 --- Training Loss: 0.7852
Epoch: 3/20 --- Training Loss: 0.2633
Epoch: 4/20 --- Training Loss: 0.3553
Epoch: 5/20 --- Training Loss: 0.1422
Epoch: 6/20 --- Training Loss: 0.1344
Epoch: 7/20 --- Training Loss: 0.0714
Epoch: 8/20 --- Training Loss: 0.1120
Epoch: 9/20 --- Training Loss: 0.1151
Epoch: 10/20 --- Training Loss: 0.0966
Epoch: 11/20 --- Training Loss: 0.1372
Epoch: 12/20 --- Training Loss: 0.0845
Epoch: 13/20 --- Training Loss: 0.0231
Epoch: 14/20 --- Training Loss: 0.0014
Epoch: 15/20 --- Training Loss: 0.0038
Epoch: 16/20 --- Training Loss: 0.0164
Epoch: 17/20 --- Training Loss: 0.0017
Epoch: 18/20 --- Training Loss: 0.0725
Epoch: 19/20 --- Training Loss: 0.0062
Epoch: 20/20 --- Training Loss: 0.0033


In [19]:
# custom accuracy computation, optionally displays predictions
def validate(model, data, display_pred = False):
    total = 0
    correct = 0
    model.eval()
    i = 0
    for images, labels in data:
        images = images.to(device)
        labels = labels.to(device)
        
        x = model(images)
        value, pred = torch.max(x, 1)
        
        total += x.size(0)
        correct += torch.sum(pred == labels)
        
        if i % 1000 == 0 & display_pred == True:
            print(f"Pred: {x} / True: {labels}")
        i += 1
    return correct / total

In [29]:
print(f"MobileNetV3 test accuracy: {validate(mobilenet_model.to(device), data=val_dl).item():.4f}")

MobileNetV3 test accuracy: 0.9593
