In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

sys.path.append(os.path.abspath('../'))

In [3]:
import timm

import torch
import torchvision
import torchvision.transforms as transforms

import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint

In [8]:
from my_trainer.datasets import cifar_10

In [4]:
seed_everything(42, workers=True)

Global seed set to 42


42

In [None]:
class Model(pl.LightningModule):
    """
    Lightning model
    """
    def __init__(
        self, model_name, 
        num_classes, 
        lr=0.001, max_iter=20
    ):
        super().__init__()
        self.model = timm.create_model(
            model_name=model_name, 
            pretrained=True, 
            num_classes=num_classes
        )
        self.metric = torchmetrics.Accuracy()
        self.loss = torch.nn.CrossEntropyLoss()
        self.lr = lr
        self.max_iter = max_iter
        
    def forward(self, x):
        return self.model(x)
    
    def shared_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.metric(preds, y)
        
        return loss
    
    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True, prog_bar=True)
        self.log('train_acc', self.metric, on_epoch=True, logger=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx)
        self.log('val_loss', loss, on_step=True, on_epoch=True, logger=True, prog_bar=True)
        self.log('val_acc', self.metric, on_epoch=True, logger=True, prog_bar=True)
        
        return loss
    
    def configure_optimizers(self):
        optim = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optim, T_max=self.max_iter)
        
        return [optim], [scheduler]

In [9]:
cifar_dataset = cifar_10.CifarDataset()

In [10]:
classes = cifar_dataset.classes
len(classes)

10

In [12]:
model = Model(
    model_name="vit_tiny_patch16_224", 
    num_classes=len(classes), 
    lr=0.001, 
    max_iter=10,
)

In [17]:
import torchsummary

In [21]:
_ = torchsummary.torchsummary.summary(model, torch.rand(1, 3, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
├─VisionTransformer: 1-1                 [-1, 10]                  --
|    └─PatchEmbed: 2-1                   [-1, 196, 192]            --
|    |    └─Conv2d: 3-1                  [-1, 192, 14, 14]         147,648
|    |    └─Identity: 3-2                [-1, 196, 192]            --
|    └─Dropout: 2-2                      [-1, 197, 192]            --
|    └─Sequential: 2-3                   [-1, 197, 192]            --
|    |    └─Block: 3-3                   [-1, 197, 192]            444,864
|    |    └─Block: 3-4                   [-1, 197, 192]            444,864
|    |    └─Block: 3-5                   [-1, 197, 192]            444,864
|    |    └─Block: 3-6                   [-1, 197, 192]            444,864
|    |    └─Block: 3-7                   [-1, 197, 192]            444,864
|    |    └─Block: 3-8                   [-1, 197, 192]            444,864
|    |    └─Block: 3-9                   [-1, 197,

In [23]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='./checkpoints',
    filename='vit_tpytorch_lightning6_224-cifar10-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}'
)

In [27]:
trainer = Trainer(
    deterministic=True, 
    logger=False, 
    callbacks=[
        checkpoint_callback,
        #pl.callbacks.stochastic_weight_avg.StochasticWeightAveraging
    ], 
    # gpus=[0], # change it based on gpu or cpu availability
    max_epochs=1, 
    stochastic_weight_avg=True
)

  rank_zero_deprecation(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(
    model=model, 
    datamodule=cifar_dataset,
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified



  | Name   | Type              | Params
---------------------------------------------
0 | model  | VisionTransformer | 5.5 M 
1 | metric | Accuracy          | 0     
2 | loss   | CrossEntropyLoss  | 0     
---------------------------------------------
5.5 M     Trainable params
0         Non-trainable params
5.5 M     Total params
22.105    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Swapping scheduler `CosineAnnealingLR` for `SWALR`
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fcf89449f70>
Traceback (most recent call last):
  File "/Users/muhsin/opt/anaconda3/envs/panptic_seg/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/Users/muhsin/opt/anaconda3/envs/panptic_seg/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1445, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Users/muhsin/opt/anaconda3/envs/panptic_seg/lib/python3.8/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/Users/muhsin/opt/anaconda3/envs/panptic_seg/lib/python3.8/multiprocessing/popen_fork.py", line 44, in wait
    if not wait([self.sentinel], timeout):
  File "/Users/muhsin/opt/anaconda3/envs/panptic_seg/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout