In [10]:
import torch
from torch import nn
import pytorch_lightning as pl
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, TensorDataset
from torch.optim import SGD
import numpy as np
import matplotlib.pyplot as plt
import torchmetrics
import torchvision
from torchvision.transforms import Compose, ToTensor
from torch import flatten

In [18]:
class CIFAR10_mod(pl.LightningModule):
    def __init__(self, num_classes):
        super().__init__()
        self.classes = num_classes
        self.conv1 = nn.Conv2d(3, 32, 3)
        self.conv2 = nn.Conv2d(32, 32, 3)
        self.fc1 = nn.Linear(1152, 500)
        self.fc2 = nn.Linear(500, self.classes)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.relu3 = nn.ReLU()
        self.mxpool1 = nn.MaxPool2d(2)
        self.mxpool2 = nn.MaxPool2d(2)
        self.softmax = nn.LogSoftmax(dim=1)
        self.accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=self.classes)

        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.mxpool1(x)

        x = self.conv2(x)
        x = self.relu2(x)
        x = self.mxpool2(x)

        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu3(x)

        x = self.fc2(x)
        output = self.softmax(x)
        return output
    
    def configure_optimizers(self):
        LR = 1e-3
        optimizer = torch.optim.AdamW(self.parameters(),lr=LR)
        return optimizer
    
    def training_step(self, batch, batch_idx):
        x,y = batch
        out = self(x)
        loss = F.cross_entropy(out,y)
        self.log('train_loss', loss,on_step=True,on_epoch=True)
        return loss 

    def test_step(self,batch,batch_idx):
        x,y = batch
        out = self(x)
        loss = F.cross_entropy(out,y)
        out = nn.Softmax(-1)(out) 
        logits = torch.argmax(out,dim=1)
        accu = self.accuracy(logits, y)        
        self.log('test_loss', loss)
        self.log('train_acc_step', accu)
        return loss, accu
    
    def validation_step(self,batch,batch_idx):
        x,y = batch
        out = self(x)
        loss = F.cross_entropy(out,y)
        out = nn.Softmax(-1)(out) 
        logits = torch.argmax(out,dim=1)
        accu = self.accuracy(logits, y)        
        self.log('Val_loss', loss)
        self.log('Val_acc_step', accu)
        return loss, accu

In [19]:
class load_CIFAR10data(pl.LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
        transform = transforms.Compose([transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        self.train = torchvision.datasets.CIFAR10(root='./CIFAR10_data', download=True ,train=True, transform=transform)
        self.test = torchvision.datasets.CIFAR10(root='./CIFAR10_data', download=True ,train=False, transform=transform)
        print('Data loaded')

    def train_dataloader(self):
        return DataLoader(self.train, self.batch_size)    

    def test_dataloader(self):
        return DataLoader(self.test, self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.test, self.batch_size)

In [20]:
def main():
     batch_size = 10
     max_epochs = 1
     num_classes = 10
     data = load_CIFAR10data(batch_size)
     mod = CIFAR10_mod(num_classes)
     trainer = pl.Trainer(max_epochs=max_epochs)
     trainer.fit(mod, data)
     trainer.test(mod, data)

if __name__ == '__main__': main()

Files already downloaded and verified
Files already downloaded and verified


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

   | Name     | Type               | Params
-------------------------------------------------
0  | conv1    | Conv2d             | 896   
1  | conv2    | Conv2d             | 9.2 K 
2  | fc1      | Linear             | 576 K 
3  | fc2      | Linear             | 5.0 K 
4  | relu1    | ReLU               | 0     
5  | relu2    | ReLU               | 0     
6  | relu3    | ReLU               | 0     
7  | mxpool1  | MaxPool2d          | 0     
8  | mxpool2  | MaxPool2d          | 0     
9  | softmax  | LogSoftmax         | 0     
10 | accuracy | MulticlassAccuracy | 0     
-------------------------------------------------
591 K     Trainable params
0         Non-trainable params
591 K     Total params
2.367     Total estimated model params size (MB)


Data loaded
                                                                           

  rank_zero_warn(


Epoch 0: 100%|██████████| 5000/5000 [01:39<00:00, 50.23it/s, v_num=54]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 5000/5000 [01:39<00:00, 50.21it/s, v_num=54]
