In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from pytorch_lightning.callbacks import ModelCheckpoint
import torchmetrics
import pytorch_lightning as pl

from data_modules.cifar10 import CIFAR10DataModule


class ImagenetTransferLearning(pl.LightningModule):
    def __init__(self, lr):
        super().__init__()

        self.lr = lr
        
        # Loss function
        self.loss_fn = nn.CrossEntropyLoss()
        
        # init a pretrained alexnet
        # self.model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
        self.model = models.alexnet()

        num_target_classes = 10
        # Re-initialize the linear layers of AlexNet
        self.model.classifier = nn.Sequential(
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_target_classes),
        )

        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_target_classes)

    def forward(self, x):
        # self.model.eval()
        # with torch.no_grad():
        #     x = self.model.features(x)
        #     x = self.model.avgpool(x)
        #     features = torch.flatten(x, 1)
        # x = self.model.classifier(features)
        x = self.model(x)
        
        return x
        
    def training_step(self, batch, batch_idx):
        input, target = batch
        output = self(input)
        loss = self.loss_fn(output, target)
        
        self.log("train_loss", loss)
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        input, target = batch
        output = self(input)
        loss = self.loss_fn(output, target)

        accuracy = self.accuracy(output, target)
        self.log("val_accuracy", accuracy)
        self.log("val_loss", loss)

    def test_step(self, batch, batch_idx):
        input, target = batch
        output = self(input)
        loss = self.loss_fn(output, target)

        accuracy = self.accuracy(output, target)
        self.log("test_accuracy", accuracy)
        self.log("test_loss", loss)
        
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

In [2]:
from pytorch_lightning.tuner import Tuner

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    mode="min"
)

early_stopping_callback = pl.callbacks.EarlyStopping(
    monitor="val_loss",
    mode="min",
    patience=5
)

model = ImagenetTransferLearning(1e-6)

cifar10 = CIFAR10DataModule("~/Data/cifar10", batch_size=512, num_workers=8)

trainer = pl.Trainer(accelerator="gpu", callbacks=[checkpoint_callback, early_stopping_callback], max_epochs=-1, num_sanity_val_steps=0)
# tuner = Tuner(trainer)
# tuner.lr_find(model, cifar10)
trainer.fit(model, datamodule=cifar10)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type               | Params
------------------------------------------------
0 | loss_fn  | CrossEntropyLoss   | 0     
1 | model    | AlexNet            | 57.0 M
2 | accuracy | MulticlassAccuracy | 0     
------------------------------------------------
57.0 M    Trainable params
0         Non-trainable params
57.0 M    Total params
228.179   Total estimated model params size (MB)


Epoch 113: 100%|██████████| 88/88 [00:27<00:00,  3.17it/s, v_num=5]


In [3]:
trainer.test(model, datamodule=cifar10)

Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 20/20 [00:03<00:00,  6.44it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.8751999735832214
        test_loss           0.3754540681838989
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_accuracy': 0.8751999735832214, 'test_loss': 0.3754540681838989}]