In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision import transforms
from torchvision.datasets import CIFAR10
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl


class ImagenetTransferLearning(pl.LightningModule):
    def __init__(self, data_path, batch_size, lr):
        super().__init__()
        
        self.data_path = data_path
        self.batch_size = batch_size
        self.lr = lr
        
        # Data preparation
        dataset = CIFAR10(data_path, transform=transforms.Compose([
            transforms.RandAugment(),
            transforms.Resize(224),
            # transforms.RandomCrop(32),
            transforms.ToTensor(),
        ]), download=True)

        dataset_size = len(dataset)
        train_size = int(dataset_size * .95)
        val_size = dataset_size - train_size

        self.train_dataset, self.val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

        # Loss function
        self.loss_fn = nn.CrossEntropyLoss()
        
        # init a pretrained resnet
        backbone = models.resnet18(pretrained=True)
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)

        # use the pretrained model to classify food101
        num_target_classes = 10
        self.classifier = nn.Linear(num_filters, num_target_classes)

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        
        return x
        
    def training_step(self, batch, batch_idx):
        input, target = batch
        output = self(input)
        loss = self.loss_fn(output, target)
        
        self.log("train_loss", loss)
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        input, target = batch
        output = self(input)
        loss = self.loss_fn(output, target)
        
        self.log("val_loss", loss)
        
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)
        
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset,
                                           batch_size=self.batch_size,
                                           num_workers=8,
                                           shuffle=True)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset,
                                           batch_size=self.batch_size,
                                           num_workers=8,
                                           shuffle=False)

In [4]:
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    mode="min"
)

model = ImagenetTransferLearning("data/cifar10/", 32, 1e-3)

trainer = pl.Trainer(callbacks=[checkpoint_callback], max_epochs=5, num_sanity_val_steps=0)
trainer.fit(model)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar10/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/cifar10/cifar-10-python.tar.gz to data/cifar10/


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/alex/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name              | Type             | Params
-------------------------------------------------------
0 | loss_fn           | CrossEntropyLoss | 0     
1 | feature_extractor | Sequential       | 11.2 M
2 | classifier        | Linear           | 5.1 K 
-------------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
