In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from lightning.pytorch.callbacks import ModelCheckpoint
import torchmetrics
import lightning as L

from data_modules.cifar10 import CIFAR10DataModule

torch.set_float32_matmul_precision('medium')


class ImagenetTransferLearning(L.LightningModule):
    def __init__(self, lr):
        super().__init__()

        self.lr = lr
        
        # Loss function
        self.loss_fn = nn.CrossEntropyLoss()
        
        # init a pretrained alexnet
        self.model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
        # self.model = models.alexnet()

        num_target_classes = 10
        # Re-initialize the linear layers of AlexNet
        self.model.classifier = nn.Sequential(
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_target_classes),
        )

        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_target_classes)

    def forward(self, x):
        # self.model.eval()
        with torch.no_grad():
            x = self.model.features(x)
            x = self.model.avgpool(x)
            features = torch.flatten(x, 1)
        x = self.model.classifier(features)
        # x = self.model(x)
        
        return x
        
    def training_step(self, batch, batch_idx):
        input, target = batch
        output = self(input)
        loss = self.loss_fn(output, target)
        
        self.log("train_loss", loss)
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        input, target = batch
        output = self(input)
        loss = self.loss_fn(output, target)

        accuracy = self.accuracy(output, target)
        self.log("val_accuracy", accuracy)
        self.log("val_loss", loss)

    def test_step(self, batch, batch_idx):
        input, target = batch
        output = self(input)
        loss = self.loss_fn(output, target)

        accuracy = self.accuracy(output, target)
        self.log("test_accuracy", accuracy)
        self.log("test_loss", loss)
        
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    mode="min"
)

early_stopping_callback = L.pytorch.callbacks.EarlyStopping(
    monitor="val_loss",
    mode="min",
    patience=5
)

model = ImagenetTransferLearning(1e-6)

cifar10 = CIFAR10DataModule("~/Data/cifar10", batch_size=512, num_workers=8)

trainer = L.Trainer(accelerator="gpu", callbacks=[checkpoint_callback, early_stopping_callback], max_epochs=-1, num_sanity_val_steps=0)
trainer.fit(model, datamodule=cifar10)

In [None]:
trainer.test(model, datamodule=cifar10)