In [3]:
import torch
import torch.nn as nn

from classifier import GolemClassifier
from backbone import GolemBackbones
from model_factories import Model, load_model
from dataloaders import get_dataloaders
from model_utils import train_epoch, eval, getBestModelParams
from plotter import plot_images, plot_metrics

In [4]:
device = torch.device(
        "mps"
        if torch.backends.mps.is_available()
        else "cuda" if torch.cuda.is_available() else "cpu"
    )

EPOCHS = 20

IsTesting = False
if IsTesting:
    data_name = "cifar10"
else:
    data_name = "cifar100"

# CIFAR-10(0) TEST RUN

In [5]:
train_loader, test_loader, classes = get_dataloaders(test_run=IsTesting)
CLASSES_NUM = len(classes)

## Optimal Model Hyperparameters for Custom Backbone Classifier

In [6]:
# gc = get_golem_model(CLASSES_NUM)
# gc.to(device)

# lr, loss, criterion, optim = getBestModelParams(gc, train_loader, device)
# print(f"Optim lr: {lr}\nLoss: {loss}\nBest optim: {optim}\nBest loss func: {criterion}")

# LR = lr

LR = 1e-4
criterion = nn.CrossEntropyLoss
optim = torch.optim.Adam

## Sceanrio 1 - Custom Backbone

In [7]:
gc = GolemClassifier(GolemBackbones.GM2, CLASSES_NUM)
gc.to(device)
criterion_gc = criterion()
optimizer_gc = optim(params=gc.parameters(), lr=LR)

model_name = "customBackbone"

In [None]:
training_losses = []
validation_losses = []

for i in range(EPOCHS):
    loss = train_epoch(gc, criterion_gc, optimizer_gc, train_loader, device)
    print(f"{i+1}/{EPOCHS}: loss={loss}")

In [None]:
loss, metrics, preds = eval(gc, criterion_gc, test_loader, device)
accuracy, precision, recall, f1 = metrics
print("Custom backbone results: ")
print(f"Avg loss {loss}\nAccu {accuracy}\nPrecision {precision}\nRecall {recall}\nF-score {f1}\n")

plot_images(test_loader, preds, classes, title=f"Classification_{data_name}_{model_name}")
plot_metrics(metrics, title=f"Metrics_{data_name}_{model_name}")

## Scenario 2 - ResNet18 Linear Probing

In [None]:
gc = load_model(CLASSES_NUM, criterion, LR, optim, train_loader, device, Model.RESNET18, True)

criterion_gc = criterion()
model_name = "resnet18"

In [None]:
loss, metrics, preds = eval(gc, criterion_gc, test_loader, device)
accuracy, precision, recall, f1 = metrics
print("ResNet18 results: ")
print(f"Avg loss {loss}\nAccu {accuracy}\nPrecision {precision}\nRecall {recall}\nF-score {f1}\n")

plot_images(test_loader, preds, classes, title=f"Classification_{data_name}_{model_name}")
plot_metrics(metrics, title=f"Metrics_{data_name}_{model_name}")

## Scenario 3 - ResNet34

In [None]:
gc = load_model(CLASSES_NUM, criterion, LR, optim, train_loader, device, Model.RESNET34, True)

criterion_gc = criterion()
optimizer_gc = optim(params=gc.parameters(), lr=LR)

model_name = "resnet34"

In [None]:
loss, metrics, preds = eval(gc, criterion_gc, test_loader, device)
accuracy, precision, recall, f1 = metrics
print("ResNet34 results: ")
print(f"Avg loss {loss}\nAccu {accuracy}\nPrecision {precision}\nRecall {recall}\nF-score {f1}\n")

plot_images(test_loader, preds, classes, title=f"Classification_{data_name}_{model_name}")
plot_metrics(metrics, title=f"Metrics_{data_name}_{model_name}")

# #TO FIX - Data for ViTbase Model
Models use images of size 224x224

In [None]:
train_loader, test_loader, classes = get_dataloaders(test_run=IsTesting, img_size=224)
CLASSES_NUM = len(classes)

## Scenario 4 - ViTbase

In [None]:
gc = load_model(CLASSES_NUM, criterion, LR, optim, train_loader, device, Model.VITBASE, True)

criterion_gc = criterion()

model_name = "vitbase"

In [None]:
loss, metrics, preds = eval(gc, criterion_gc, test_loader, device)
accuracy, precision, recall, f1 = metrics
print("ViTbase results: ")
print(f"Avg loss {loss}\nAccu {accuracy}\nPrecision {precision}\nRecall {recall}\nF-score {f1}\n")

plot_images(test_loader, preds, classes, title=f"Classification_{data_name}_{model_name}")
plot_metrics(metrics, title=f"Metrics_{data_name}_{model_name}")