In [1]:
# !pip install lightning vit-pytorch torchmetrics

## Import

In [2]:
import os

import torch
import torchmetrics
from lightning import Trainer, LightningModule, LightningDataModule
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
from lightning.pytorch.tuner import Tuner
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from vit_pytorch import SimpleViT

## Hyperparameters

In [3]:
IMAGE_SIZE = (384, 288)
CLASSES = 2
EPOCHS = 100
PATCH_SIZE = 16
BATCH_SIZE = 22
LR = 0.00025
TUNING = False
BASE_DIR = os.getcwd()
# BASE_DIR = "/content/drive/MyDrive/Colab Notebooks"
ORIGIN_DATA_DIR = BASE_DIR + "/HeadPoseMixed"
DATA_DIR = BASE_DIR + "/HeadPoseMixed_output"
ORIGIN_MODEL_PATH = BASE_DIR + "/model_simple.ckpt"
MODEL_PATH = BASE_DIR + "/model_simple_output.ckpt"
LOG_DIR = BASE_DIR + "/log_simple"

## Dataset

In [4]:
class LitDataModule(LightningDataModule):
    def __init__(self, data_dir, batch_size=BATCH_SIZE, num_workers=11):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transform = transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.RandomHorizontalFlip(),
            transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.9, 1.0)),
            transforms.RandomApply([transforms.RandomAffine(degrees=0, translate=(0.1, 0.1))], p=0.5),
            transforms.ToTensor(),
        ])
        self.val_transform = transforms.Compose([
            transforms.Resize(IMAGE_SIZE),
            transforms.ToTensor(),
        ])
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def setup(self, stage: str):
        if stage == "fit" or stage is None:
            self.train_dataset = datasets.ImageFolder(
                root=self.data_dir + "/train",
                transform=self.transform
            )
            self.val_dataset = datasets.ImageFolder(
                root=self.data_dir + "/val",
                transform=self.val_transform
            )
        if stage == "test" or stage is None:
            self.test_dataset = datasets.ImageFolder(
                root=self.data_dir + "/test",
                transform=self.val_transform
            )

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

origin_data_module = LitDataModule(ORIGIN_DATA_DIR)
data_module = LitDataModule(DATA_DIR)

## Model

In [5]:
class LitModule(LightningModule):
    def __init__(self, num_classes=CLASSES, lr=LR):
        super().__init__()
        self.model = SimpleViT(
            image_size=IMAGE_SIZE,
            patch_size=PATCH_SIZE,
            channels=3,
            num_classes=num_classes,
            dim = 512,
            depth = 4,
            heads = 16,
            mlp_dim = 1024,
        )
        self.lr = lr
        self.save_hyperparameters()

        self.train_metrics = torchmetrics.MetricCollection({
            "acc": torchmetrics.Accuracy(task="multiclass", num_classes=num_classes),
        }, prefix="train_")
        self.val_metrics = self.train_metrics.clone(prefix="val_")
        self.test_metrics = self.train_metrics.clone(prefix="test_")

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.model(x)
        train_loss = torch.nn.functional.cross_entropy(y_pred, y)
        self.log("train_loss", train_loss, prog_bar=True)
        self.train_metrics.update(y_pred, y)
        self.log_dict(self.train_metrics.compute(), prog_bar=True)
        return train_loss

    def on_train_epoch_end(self):
        self.train_metrics.reset()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.model(x)
        val_loss = torch.nn.functional.cross_entropy(y_pred, y)
        self.log("val_loss", val_loss, prog_bar=True)
        self.val_metrics.update(y_pred, y)
        self.log_dict(self.val_metrics.compute(), prog_bar=True)
        return val_loss

    def on_validation_epoch_end(self):
        self.val_metrics.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.model(x)
        test_loss = torch.nn.functional.cross_entropy(y_pred, y)
        self.log("test_loss", test_loss, prog_bar=True)
        self.test_metrics.update(y_pred, y)
        self.log_dict(self.test_metrics.compute(), prog_bar=True)
        return test_loss

    def on_test_epoch_end(self):
        self.test_metrics.reset()

## Trainer

In [6]:
def get_trainer():
    return Trainer(
        callbacks=[
            EarlyStopping(monitor="val_loss", patience=4),
            ModelCheckpoint(monitor="val_loss"),
            TQDMProgressBar(leave=True),
        ],
        max_epochs=EPOCHS,
        default_root_dir=LOG_DIR
    )

## Train

### With origin data

In [7]:
model = LitModule()
trainer = get_trainer()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


#### Tuning

In [8]:
if TUNING:
    tuner = Tuner(trainer)
    tuner.scale_batch_size(
        model,
        mode="binsearch",
        max_trials=10,
        datamodule=origin_data_module,
    )
    data_module.batch_size -= 1
    print("Batch size:", data_module.batch_size)
    lr_finder = tuner.lr_find(
        model,
        datamodule=origin_data_module,
        update_attr=True
    )
    print("Learning rate:", lr_finder.suggestion())

#### Train

In [9]:
trainer.fit(model, datamodule=origin_data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model         | SimpleViT        | 13.0 M | train
1 | train_metrics | MetricCollection | 0      | train
2 | val_metrics   | MetricCollection | 0      | train
3 | test_metrics  | MetricCollection | 0      | train
-----------------------------------------------------------
13.0 M    Trainable params
0         Non-trainable params
13.0 M    Total params
51.982    Total estimated model params size (MB)
65        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [10]:
trainer.save_checkpoint(ORIGIN_MODEL_PATH)

#### Test

In [11]:
trainer.test(model, datamodule=origin_data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9604904651641846
        test_loss           0.21020059287548065
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.21020059287548065, 'test_acc': 0.9604904651641846}]

In [12]:
trainer.test(model, datamodule=data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9053906798362732
        test_loss           0.39260807633399963
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.39260807633399963, 'test_acc': 0.9053906798362732}]

### With output data

In [13]:
model = LitModule.load_from_checkpoint(ORIGIN_MODEL_PATH)
trainer = get_trainer()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


#### Train

In [14]:
trainer.fit(model, datamodule=data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model         | SimpleViT        | 13.0 M | train
1 | train_metrics | MetricCollection | 0      | train
2 | val_metrics   | MetricCollection | 0      | train
3 | test_metrics  | MetricCollection | 0      | train
-----------------------------------------------------------
13.0 M    Trainable params
0         Non-trainable params
13.0 M    Total params
51.982    Total estimated model params size (MB)
65        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [15]:
trainer.save_checkpoint(MODEL_PATH)

#### Test

In [16]:
trainer.test(model, datamodule=origin_data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9562307596206665
        test_loss           0.21375273168087006
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.21375273168087006, 'test_acc': 0.9562307596206665}]

In [17]:
trainer.test(model, datamodule=data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9520098567008972
        test_loss           0.23012612760066986
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.23012612760066986, 'test_acc': 0.9520098567008972}]