In [34]:
'''Train CIFAR10 with PyTorch.'''
import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision

import pytorch_lightning as pl

from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoader

from torchmetrics import Accuracy

from torchvision import transforms
from torchvision.datasets import CIFAR10

In [35]:
class CIFAR10DataModule(pl.LightningDataModule):
    '''organize the data pipeline from accessing the data to loading it using PyTorch dataloaders'''


    def __init__(self, batch_size, data_dir: str = './data'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform =  transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        self.num_classes = 10
        self.classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

    def prepare_data(self):
        # download the CIFAR-10 dataset
        torchvision.datasets.CIFAR10(self.data_dir, train=True, download=True)
        torchvision.datasets.CIFAR10(self.data_dir, train=False, download=True)

    # PyTorch dataset instances
    def setup(self, stage=None):

        if stage == 'fit' or stage is None:
            cifar_full = torchvision.datasets.CIFAR10(self.data_dir, train=True, download=True, transform=self.transform)
            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])

        if stage == 'test' or stage is None:
            self.cifar_test =  torchvision.datasets.CIFAR10(self.data_dir, train=False, download=True, transform=self.transform)

    # dataloaders
    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=self.batch_size)

In [36]:
class CIFARLitModel(pl.LightningModule):
    '''model architecture, training, testing and validation loops'''
    def __init__(self, pretrained, learning_rate=3e-4):
        super().__init__()

        # log hyperparameters
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        
        self.model  = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', weights='IMAGENET1K_V1' if pretrained else 'DEFAULT')
        # or any of these variants
        # model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=True)
        # model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
        # model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet101', pretrained=True)
        # model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet152', pretrained=True)
        self.model.fc = nn.Linear(in_features=512, out_features=10, bias=True)
        
        self.accuracy = Accuracy(task="multiclass", num_classes=10)

    def forward(self, x):
         return self.model(x)

    # train loop
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)

        # metric
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        return loss

    # validation loop
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)

        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    # test loop
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)

        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss

    # optimizers
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [37]:
model = CIFARLitModel(pretrained = True, learning_rate=1e-4 / 2)
# instantiate classes
dm = CIFAR10DataModule(batch_size=48)
dm.prepare_data()
dm.setup()
# Initialize Callbacks
from pathlib import Path

checkpoint_callback = pl.callbacks.ModelCheckpoint()
early_stop_callback = pl.callbacks.EarlyStopping(monitor="val_acc", patience=3, verbose=False, mode="max")
trainer = pl.Trainer(max_epochs=10,
                     callbacks=[checkpoint_callback, early_stop_callback],
                    )
# Train the model
trainer.fit(model, dm, )
# Evaluate the model
trainer.test(dataloaders=dm.test_dataloader())

Using cache found in C:\Users\Matyiko/.cache\torch\hub\pytorch_vision_v0.10.0


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type               | Params
------------------------------------------------
0 | model    | ResNet             | 11.2 M
1 | accuracy | MulticlassAccuracy | 0     
------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Epoch 6: 100%|██████████| 938/938 [01:17<00:00, 12.10it/s, v_num=7, val_loss=0.204, val_acc=0.946]




Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Restoring states from the checkpoint path at c:\Users\Matyiko\Documents\Egyetem\Msc_2_felev\Melytanulas\ImageClassification\lightning_logs\version_7\checkpoints\epoch=6-step=6566.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at c:\Users\Matyiko\Documents\Egyetem\Msc_2_felev\Melytanulas\ImageClassification\lightning_logs\version_7\checkpoints\epoch=6-step=6566.ckpt
c:\Users\Matyiko\Documents\Egyetem\Msc_2_felev\Melytanulas\ImageClassification\myenv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 209/209 [00:11<00:00, 18.24it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9416000247001648
        test_loss           0.21263504028320312
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.21263504028320312, 'test_acc': 0.9416000247001648}]

In [38]:
modelNotPretrained = CIFARLitModel(pretrained = False, learning_rate=1e-4 / 2)

# instantiate classes
dm = CIFAR10DataModule(batch_size=48)
dm.prepare_data()
dm.setup()
# Initialize Callbacks
from pathlib import Path

checkpoint_callback = pl.callbacks.ModelCheckpoint()
early_stop_callback = pl.callbacks.EarlyStopping(monitor="val_acc", patience=3, verbose=False, mode="max")
trainer = pl.Trainer(max_epochs=10,
                     callbacks=[checkpoint_callback, early_stop_callback],
                    )
# Train the model
trainer.fit(modelNotPretrained, dm, )
# Evaluate the model
trainer.test(dataloaders=dm.test_dataloader())


Using cache found in C:\Users\Matyiko/.cache\torch\hub\pytorch_vision_v0.10.0


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type               | Params
------------------------------------------------
0 | model    | ResNet             | 11.2 M
1 | accuracy | MulticlassAccuracy | 0     
------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Epoch 5: 100%|██████████| 938/938 [01:17<00:00, 12.15it/s, v_num=8, val_loss=0.175, val_acc=0.946]
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Restoring states from the checkpoint path at c:\Users\Matyiko\Documents\Egyetem\Msc_2_felev\Melytanulas\ImageClassification\lightning_logs\version_8\checkpoints\epoch=5-step=5628.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at c:\Users\Matyiko\Documents\Egyetem\Msc_2_felev\Melytanulas\ImageClassification\lightning_logs\version_8\checkpoints\epoch=5-step=5628.ckpt


Testing DataLoader 0: 100%|██████████| 209/209 [00:11<00:00, 18.56it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9427000284194946
        test_loss           0.20132392644882202
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.20132392644882202, 'test_acc': 0.9427000284194946}]