In [12]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
from losses import SupConLoss, HingeLoss
from model import Encoder, LinearClassifier, CNN
from train import train

In [14]:
import wandb
import umap
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from typing import Literal
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

In [15]:
def get_device():
    if torch.cuda.is_available():
        # NVIDIA GPU
        device = torch.device("cuda")
        print(f"Using CUDA: {torch.cuda.get_device_name(0)}")
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        # Apple Silicon GPU (MPS)
        device = torch.device("mps")
        print("Using MPS (Apple Silicon GPU)")
    else:
        # Fallback to CPU
        device = torch.device("cpu")
        print("Using CPU")
    return device

In [16]:
DEVICE = get_device()
BATCH_SIZE = 64
EPOCHS = 20
PROJ_DIM = 128
MODEL_FILENAME = "custom_model.pt"
TYPE_OF_LOSS:Literal["crossentropy", "hinge"] = "hinge"  # loss used for the encoder

Using MPS (Apple Silicon GPU)


In [17]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]) # TODO: add transformations/augmentations?

dataset = datasets.CIFAR10(root="data", train=True, download=True, transform=transform)

train_ratio, validation_ratio = 0.8, 0.2
dataset_size = len(dataset)
train_size = int(train_ratio * dataset_size)
validation_size = dataset_size - train_size

train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

test_data = datasets.CIFAR10(root="data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [18]:
sup_con_loss = SupConLoss()
encoder = Encoder(in_channels=3, proj_dim=PROJ_DIM).to(DEVICE)
encoder_optimizer = optim.Adam(encoder.parameters(), lr=0.01)

[34m[1mwandb[0m: [wandb.login()] Loaded credentials for https://api.wandb.ai from /Users/lorenzocusin/.netrc.
[34m[1mwandb[0m: Currently logged in as: [33mlorenzocusin02[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
wandb.init(
    project="Cnn-Verification",
    name="Encoder - SupConLearning",
    config={
        "learning_rate": 0.01,
        "epochs": 40,
        "batch_size": 512,
        "projection_dimension": 128
    }
)

In [None]:
encoder = train(
    encoder,
    train_loader,
    validation_loader,
    encoder_optimizer,
    sup_con_loss,
    EPOCHS,
    DEVICE,
    compute_accuracy=False,
    wandb_logging=True
)

In [None]:
# show embedding distribution
encoder.eval()

all_embeddings = []
all_labels = []

N_ITERATIONS = 500 // BATCH_SIZE

with torch.no_grad():
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        embeddings = encoder(images)

        all_embeddings.append(embeddings.cpu())
        all_labels.append(labels.cpu())
        
        if i >= N_ITERATIONS:
            break

embeddings = torch.cat(all_embeddings, dim=0).numpy()
labels = torch.cat(all_labels, dim=0).numpy()

# umap computation
umap_reducer = umap.UMAP(
    n_components=2,
    n_neighbors=15,
    min_dist=0.1,
    metric="euclidean",
    random_state=42
)

embeddings_2d = umap_reducer.fit_transform(embeddings)

# plotting
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
scatter = plt.scatter(
    embeddings_2d[:, 0],
    embeddings_2d[:, 1],
    c=labels,
    cmap="tab10",
    s=5
)
plt.title("UMAP of Embeddings")
plt.colorbar(scatter, ticks=range(10))


In [None]:

wandb.log({
    "sample_image": wandb.Image(plt, caption="Embedding distribution")
})

In [19]:
wandb.finish()

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


In [None]:
hinge_loss = HingeLoss(margin=1)
cross_entropy_loss = nn.CrossEntropyLoss()
classifier = LinearClassifier(in_dim=PROJ_DIM, num_classes=10).to(DEVICE)
classifier_optimizer = optim.Adam(classifier.parameters(), lr=0.01)

In [None]:
def execute_classifier(images:torch.Tensor, labels:torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    encoder.eval()
    with torch.no_grad():
        embeddings = encoder(images)
    return embeddings, labels

classifier = train(
    classifier,
    train_loader,
    validation_loader,
    classifier_optimizer,
    cross_entropy_loss if TYPE_OF_LOSS == "crossentropy" else hinge_loss,
    EPOCHS,
    DEVICE,
    middleware=execute_classifier,
    wandb_logging=True
)

In [None]:
model = CNN.import_from(encoder, classifier)
torch.save(model.state_dict(), MODEL_FILENAME)