In [56]:
import torch
from torch import nn
import pytorch_lightning as pl
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import SGD
import numpy as np
import matplotlib.pyplot as plt
import torchmetrics
import torchvision
from torchvision.transforms import Compose, ToTensor, Normalize, RandomHorizontalFlip, RandomCrop
from torch import flatten

In [57]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(self.expansion * planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


In [62]:
class ResNet(pl.LightningModule):
    def __init__(self, block, num_blocks, num_classes, lr):
        super().__init__()
        self.classes = num_classes
        self.in_planes = 64
        self.lr = lr

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

        self.accuracy = torchmetrics.classification.Accuracy(task="multiclass", 
                                                             num_classes=self.classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
        
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        output = self.linear(out)
        return output
    
    def configure_optimizers(self):
        optimizer = SGD(self.parameters(), lr=self.lr, momentum=0.9, weight_decay=5e-4,)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, self.lr, epochs=self.trainer.max_epochs, 
                                                        steps_per_epoch = 45000 // self.trainer.datamodule.batch_size )
        return [optimizer], [scheduler]
    
    def training_step(self, batch, batch_idx):
        x,y = batch
        out = self(x)
        loss = F.cross_entropy(out,y)
        self.log('train_loss', loss,on_step=True,on_epoch=True)
        return loss 
    
    def evaluate(self, batch, stage=None):
        x,y = batch
        out = self(x)
        loss = F.cross_entropy(out,y)
        out = nn.Softmax(-1)(out) 
        logits = torch.argmax(out,dim=1)
        acc = self.accuracy(logits, y)        

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)
        # return loss, acc


    def test_step(self,batch,batch_idx):
        self.evaluate(batch, stage='test')
    
    def validation_step(self,batch,batch_idx):
        self.evaluate(batch, stage='val')

In [63]:
def ResNet18(lr=0.05):
    num_classes = 10
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes, lr=lr)

In [64]:
class load_CIFAR10data(pl.LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

        self.train_transform = Compose([ToTensor(), 
                                Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                RandomHorizontalFlip(),
                                RandomCrop(32, padding=4, padding_mode='reflect')])
        self.test_transform = Compose([ToTensor(), 
                                Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        self.train = torchvision.datasets.CIFAR10(root='./CIFAR10_data', download=True,
                                                train=True, transform=self.train_transform)
        self.test = torchvision.datasets.CIFAR10(root='./CIFAR10_data', download=True, 
                                                train=False, transform=self.test_transform)
        print('Data loaded')

    def train_dataloader(self):
        return DataLoader(self.train, self.batch_size)    

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=100)
    
    def val_dataloader(self):
        return DataLoader(self.test, batch_size=100)

In [65]:
def main():
     batch_size = 10
     max_epochs = 1
     data = load_CIFAR10data(batch_size)
     mod = ResNet18()
     trainer = pl.Trainer(max_epochs=max_epochs)
     trainer.fit(mod, data)
     trainer.test(mod, data)

if __name__ == '__main__': main()

Files already downloaded and verified
Files already downloaded and verified


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name     | Type               | Params
------------------------------------------------
0 | conv1    | Conv2d             | 1.7 K 
1 | bn1      | BatchNorm2d        | 128   
2 | layer1   | Sequential         | 147 K 
3 | layer2   | Sequential         | 525 K 
4 | layer3   | Sequential         | 2.1 M 
5 | layer4   | Sequential         | 8.4 M 
6 | linear   | Linear             | 5.1 K 
7 | accuracy | MulticlassAccuracy | 0     
------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.696    Total estimated model params size (MB)


Data loaded
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0:   1%|▏         | 71/5000 [00:20<23:32,  3.49it/s, v_num=62]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 100/100 [02:07<00:00,  1.27s/it]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Runningstage.testing metric      DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.23960000276565552
        test_loss           3.2692244052886963
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
