In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision import transforms
from torchvision.datasets import CIFAR10
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl


class ImagenetTransferLearning(pl.LightningModule):
    def __init__(self, data_path, batch_size, lr):
        super().__init__()
        
        self.data_path = data_path
        self.batch_size = batch_size
        self.lr = lr
        
        # Data preparation
        dataset = CIFAR10(data_path, transform=transforms.Compose([
            transforms.RandAugment(),
            transforms.Resize(224),
            # transforms.RandomCrop(32),
            transforms.ToTensor(),
        ]), download=True)

        dataset_size = len(dataset)
        train_size = int(dataset_size * .95)
        val_size = dataset_size - train_size

        self.train_dataset, self.val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

        # Loss function
        self.loss_fn = nn.CrossEntropyLoss()
        
        # init a pretrained resnet
        backbone = models.resnet18(pretrained=True)
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)

        # use the pretrained model to classify food101
        num_target_classes = 10
        self.classifier = nn.Linear(num_filters, num_target_classes)

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        
        return x
        
    def training_step(self, batch, batch_idx):
        input, target = batch
        output = self(input)
        loss = self.loss_fn(output, target)
        
        self.log("train_loss", loss)
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        input, target = batch
        output = self(input)
        loss = self.loss_fn(output, target)
        
        self.log("val_loss", loss)
        
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)
        
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset,
                                           batch_size=self.batch_size,
                                           num_workers=8,
                                           shuffle=True)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset,
                                           batch_size=self.batch_size,
                                           num_workers=8,
                                           shuffle=False)

  Referenced from: /Users/ajd/anaconda3/envs/cse6363/lib/python3.7/site-packages/torchvision/image.so
  Reason: Incompatible library version: image.so requires version 15.0.0 or later, but libjpeg.9.dylib provides version 14.0.0
  warn(f"Failed to load image Python extension: {e}")


In [4]:
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    mode="min"
)

model = ImagenetTransferLearning("/Users/ajd/data/cifar10/", 32, 1e-3)

trainer = pl.Trainer(callbacks=[checkpoint_callback], max_epochs=5, num_sanity_val_steps=0)
trainer.fit(model)

Files already downloaded and verified


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name              | Type             | Params
-------------------------------------------------------
0 | loss_fn           | CrossEntropyLoss | 0     
1 | feature_extractor | Sequential       | 11.2 M
2 | classifier        | Linear           | 51.8 K
-------------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.913    Total estimated model params size (MB)


Epoch 0:   1%|          | 11/1564 [01:17<3:02:11,  7.04s/it, loss=3.16, v_num=9]
Epoch 0:   5%|▍         | 76/1564 [01:57<38:15,  1.54s/it, loss=1.74, v_num=10]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
