In [1]:
import os

import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchmetrics.functional import accuracy

import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger

In [2]:
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
NUM_WORKERS = int(os.cpu_count() / 2)

In [3]:
# set up transformations
train_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.RandomCrop(32, padding=4),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])
    ]
)

test_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])
    ]
)

In [4]:
# setup data
cifar_train = torchvision.datasets.CIFAR10("./", train=True, download=True, transform=train_transforms)
cifar_val = torchvision.datasets.CIFAR10("./", train=True, download=True, transform=test_transforms)

train_size = int(0.8 * len(cifar_train))
val_size = len(cifar_train)-train_size
indices = torch.randperm(len(cifar_train))

# split dataset into training and validation set
seed = torch.Generator().manual_seed(42)
train_dataset = torch.utils.data.Subset(cifar_train, indices[:-val_size])
val_dataset = torch.utils.data.Subset(cifar_val, indices[-val_size:])

print(len(train_dataset))
print(len(val_dataset))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, persistent_workers=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, persistent_workers=True)

cifar_test = torchvision.datasets.CIFAR10("./", train=False, download=True, transform=test_transforms)
test_loader = torch.utils.data.DataLoader(cifar_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, persistent_workers=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:33<00:00, 5047151.39it/s] 


Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified
40000
10000
Files already downloaded and verified


In [5]:
model = nn.Sequential(nn.Conv2d(3, 6, 5), 
                      nn.Tanh(), 
                      nn.MaxPool2d(2, 2), 
                      nn.Conv2d(6, 16, 5), 
                      nn.Tanh(), 
                      nn.Flatten(), 
                      nn.Linear(1600, 120), 
                      nn.Tanh(), 
                      nn.Linear(120, 10))

In [6]:
class LitConvNet(L.LightningModule):
    def __init__(self, model, num_classes, lr = 0.0001):
        super().__init__()
        self.save_hyperparameters(ignore=['model'])
        self.num_classes = num_classes
        self.model = model
    
    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        images, labels = batch
        
        # Forward pass
        outputs = self.model(images)
        loss = F.cross_entropy(outputs, labels)
        
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        
        if batch_idx % 10 == 0:
            grid = torchvision.utils.make_grid(images[:10],10)
            self.logger.experiment.add_image("cifar_images", grid, 0, self.global_step)
        return loss
    
    def evaluate(self, batch, stage=None):
        images, labels = batch
        
        # Forward pass
        outputs = self.model(images)
        loss = F.cross_entropy(outputs, labels)
        
        acc = accuracy(outputs, labels, task="multiclass", num_classes=self.num_classes)

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)
    
    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")
    
    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
    
        


In [7]:
convNet = LitConvNet(model, lr=0.0001, num_classes=10)
#logger = CSVLogger("logs", name="my_logs")
logger = TensorBoardLogger("tb_logs", name="cifar_model")
trainer = L.Trainer(max_epochs=30, 
                    callbacks=[EarlyStopping(monitor="val_loss", mode="min")],
                    #accelerator="auto",
                    #devices=1 if torch.cuda.is_available() else None,
                    logger=logger,
                    fast_dev_run=False)

# train the model
trainer.fit(convNet, train_dataloaders=train_loader, val_dataloaders=val_loader)

# test the model
trainer.test(convNet, dataloaders=test_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: tb_logs\cifar_model

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 196 K 
-------------------------------------
196 K     Trainable params
0         Non-trainable params
196 K     Total params
0.785     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.5888000130653381
        test_loss           1.1605582237243652
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 1.1605582237243652, 'test_acc': 0.5888000130653381}]