In [1]:
import os
from multiprocessing import cpu_count

import lightning as L
import matplotlib
import matplotlib.pyplot as plt
import matplotlib_inline.backend_inline
import seaborn as sns
import torch
import torchvision
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint

from task import train_model
from src.utils import get_datasets

plt.set_cmap("cividis")
%matplotlib inline
matplotlib_inline.backend_inline.set_matplotlib_formats("svg", "pdf")  # For export
matplotlib.rcParams["lines.linewidth"] = 2.0
sns.reset_orig()


# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = os.environ.get("PATH_DATASETS", "data/")
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/VisionTransformers/")

# Setting the seed
L.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Global seed set to 42


<Figure size 640x480 with 0 Axes>

In [3]:
world_size = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1

In [4]:
world_size

1

In [None]:
# _ = !nproc
# tpu_cores = tpu_cores if 'tpu_cores' in vars() else 0
num_cpus = cpu_count()
num_gpus = torch.cuda.device_count()
device = torch.device('cuda') if num_gpus else 'cpu'

print(f'Device: {device}')
print(f'CPUs: {num_cpus}')
print(f'GPUs: {num_gpus}')
# print(f'TPUs: {tpu_cores}')

In [None]:
if device == 'gpu':
    num_workers = num_gpus
else:
    num_workers = num_cpus

In [None]:
num_workers

In [None]:
model_kwargs={
    "embed_dim": 256,
    "hidden_dim": 512,
    "num_heads": 8,
    "num_layers": 6,
    "patch_size": 4,
    "num_channels": 3,
    "num_patches": 64,
    "num_classes": 10,
    "dropout": 0.2,
}

trainer_kwargs={
    "default_root_dir": os.path.join(CHECKPOINT_PATH, "ViT"),
    "accelerator": "auto",
    "devices": 1,
    "max_epochs": 180,
    "callbacks": [
        ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
        LearningRateMonitor("epoch"),
    ],
}

loader_kwargs={
    "dataset_path": DATASET_PATH,
    "batch_size": 128,
    "num_workers": num_workers
}

lr=3e-4

In [None]:
_, val_set, _ = get_datasets(DATASET_PATH)

In [None]:
# Visualize some examples
NUM_IMAGES = 4
CIFAR_images = torch.stack([val_set[idx][0] for idx in range(NUM_IMAGES)], dim=0)
img_grid = torchvision.utils.make_grid(CIFAR_images, nrow=4, normalize=True, pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(8, 8))
plt.title("Image examples of the CIFAR10 dataset")
plt.imshow(img_grid)
plt.axis("off")
plt.show()
plt.close()

In [None]:
train_model(model_kwargs, trainer_kwargs, loader_kwargs, lr=3e-4,)