In [None]:
%load_ext autoreload
%autoreload 2

from losses import SupConLoss, HingeLoss
from model import Encoder, LinearClassifier, CNN
from train import train

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from typing import Literal
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

In [None]:
def get_device():
    if torch.cuda.is_available():
        # NVIDIA GPU
        device = torch.device("cuda")
        print(f"Using CUDA: {torch.cuda.get_device_name(0)}")
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        # Apple Silicon GPU (MPS)
        device = torch.device("mps")
        print("Using MPS (Apple Silicon GPU)")
    else:
        # Fallback to CPU
        device = torch.device("cpu")
        print("Using CPU")
    return device

In [None]:
DEVICE = get_device()
BATCH_SIZE = 64
EPOCHS = 10
PROJ_DIM = 128
MODEL_FILENAME = "custom_model.pt"
TYPE_OF_LOSS:Literal["supcon", "hinge"] = "supcon"  # loss used for the encoder

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]) # TODO: add transformations/augmentations?

dataset = datasets.CIFAR10(root="data", train=True, download=True, transform=transform)

train_ratio, validation_ratio = 0.8, 0.2
dataset_size = len(dataset)
train_size = int(train_ratio * dataset_size)
validation_size = dataset_size - train_size

train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

test_data = datasets.CIFAR10(root="data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
cross_entropy_loss = nn.CrossEntropyLoss()
sup_con_loss = SupConLoss()
hinge_loss = HingeLoss(margin=1)

In [None]:
encoder = Encoder(in_channels=3, proj_dim=PROJ_DIM)
encoder_optimizer = optim.Adam(encoder.parameters(), lr=0.01)

In [None]:
encoder = train(
    encoder,
    train_loader,
    validation_loader,
    encoder_optimizer,
    sup_con_loss if TYPE_OF_LOSS == "supconv" else hinge_loss,
    EPOCHS,
    DEVICE
)

In [None]:
classifier = LinearClassifier(in_dim=PROJ_DIM, num_classes=10)
classifier_optimizer = optim.Adam(classifier.parameters(), lr=0.01)

In [None]:
def execute_classifier(images:torch.Tensor, labels:torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    encoder.eval()
    with torch.no_grad():
        embeddings = encoder(images)
    return embeddings, labels

encoder = train(
    classifier,
    train_loader,
    validation_loader,
    encoder_optimizer,
    sup_con_loss,
    EPOCHS,
    'cpu',
    middleware=execute_classifier
)

In [None]:
model = CNN.import_from(encoder, classifier)
torch.save(model, MODEL_FILENAME)