# Vision Transformer

## Lightning Model

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from vision_transformer import VisionTransformer
import lightning as L

class LitViT(L.LightningModule):
    def __init__(self, num_classes):
        super().__init__()

        self.model = VisionTransformer((224, 224), (16, 16), num_classes, 512, 4, 8, 1024)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)
        self.log('val_loss', loss)
        _, predicted = torch.max(outputs, 1)
        correct = (predicted == labels).sum().item()
        self.log('val_accuracy', correct / len(labels))

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)
    
model_class = LitViT

## CIFAR 10

In [2]:
from pipeline import CIFAR10_EXP

# Instantiate the model
model = model_class(num_classes=10)

# Train using PyTorch Lightning Trainer
CIFAR10_EXP(model)

Seed set to 42


Files already downloaded and verified
Files already downloaded and verified


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
c:\Users\benoi\AppData\Local\pypoetry\Cache\virtualenvs\projet-8inf974-qm87_-2b-py3.9\lib\site-packages\lightning\pytorch\trainer\connectors\logger_connector\logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
c:\Users\benoi\AppData\Local\pypoetry\Cache\virtualenvs\projet-8inf974-qm87_-2b-py3.9\lib\site-packages\lightning\pytorch\loops\utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.

  | Name    

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\benoi\AppData\Local\pypoetry\Cache\virtualenvs\projet-8inf974-qm87_-2b-py3.9\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


                                                                           

c:\Users\benoi\AppData\Local\pypoetry\Cache\virtualenvs\projet-8inf974-qm87_-2b-py3.9\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 0:   1%|▏         | 5/391 [01:05<1:24:02,  0.08it/s, v_num=18]

c:\Users\benoi\AppData\Local\pypoetry\Cache\virtualenvs\projet-8inf974-qm87_-2b-py3.9\lib\site-packages\lightning\pytorch\trainer\call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
