In [28]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision import transforms
from torchvision.datasets import Food101
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl


class ImagenetTransferLearning(pl.LightningModule):
    def __init__(self, data_path, batch_size, lr):
        super().__init__()
        
        self.data_path = data_path
        self.batch_size = batch_size
        self.lr = lr
        
        # Data preparation
        dataset = Food101(data_path, transform=transforms.Compose([
            transforms.RandAugment(),
            transforms.RandomCrop(224),
            transforms.ToTensor(),
        ]), download=True)

        dataset_size = len(dataset)
        train_size = int(dataset_size * .95)
        val_size = dataset_size - train_size

        self.train_dataset, self.val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

        # Loss function
        self.loss_fn = nn.CrossEntropyLoss()
        
        # init a pretrained resnet
        backbone = models.resnet50(pretrained=True)
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)

        # use the pretrained model to classify food101
        num_target_classes = 101
        self.classifier = nn.Linear(num_filters, num_target_classes)

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        
        return x
        
    def training_step(self, batch, batch_idx):
        input, target = batch
        output = self(input)
        loss = self.loss_fn(output, target)
        
        self.log("train_loss", loss)
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        input, target = batch
        output = self(input)
        loss = self.loss_fn(output, target)
        
        self.log("val_loss", loss)
        
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)
        
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset,
                                           batch_size=self.batch_size,
                                           num_workers=8,
                                           shuffle=True)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset,
                                           batch_size=self.batch_size,
                                           num_workers=8,
                                           shuffle=False)

In [27]:
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    mode="min"
)

model = ImagenetTransferLearning("/home/alex/Data/food/", 32, 1e-3)

trainer = pl.Trainer(accelerator="gpu", callbacks=[checkpoint_callback], max_epochs=5)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type             | Params
-------------------------------------------------------
0 | loss_fn           | CrossEntropyLoss | 0     
1 | feature_extractor | Sequential       | 23.5 M
2 | classifier        | Linear           | 206 K 
-------------------------------------------------------
23.7 M    Trainable params
0         Non-trainable params
23.7 M    Total params
94.860    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

torch.Size([32, 3, 224, 224]) torch.Size([32])
torch.Size([32, 101])
torch.Size([32, 3, 224, 224]) torch.Size([32])
torch.Size([32, 101])


Training: 0it [00:00, ?it/s]

ValueError: Caught ValueError in DataLoader worker process 3.
Original Traceback (most recent call last):
  File "/home/alex/anaconda3/envs/cse6363/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/alex/anaconda3/envs/cse6363/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/alex/anaconda3/envs/cse6363/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/alex/anaconda3/envs/cse6363/lib/python3.7/site-packages/torch/utils/data/dataset.py", line 471, in __getitem__
    return self.dataset[self.indices[idx]]
  File "/home/alex/anaconda3/envs/cse6363/lib/python3.7/site-packages/torchvision/datasets/food101.py", line 77, in __getitem__
    image = self.transform(image)
  File "/home/alex/anaconda3/envs/cse6363/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 95, in __call__
    img = t(img)
  File "/home/alex/anaconda3/envs/cse6363/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/alex/anaconda3/envs/cse6363/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 676, in forward
    i, j, h, w = self.get_params(img, self.size)
  File "/home/alex/anaconda3/envs/cse6363/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 635, in get_params
    raise ValueError(f"Required crop size {(th, tw)} is larger then input image size {(h, w)}")
ValueError: Required crop size (224, 224) is larger then input image size (512, 193)
