# Vision Experiments

In [1]:
import os;
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms

import torchinfo
import lightning as L
from lightning.pytorch.loggers.wandb import WandbLogger
import torchmetrics

import wandb

import matplotlib.pyplot as plt
import numpy as np

import sys; sys.path.append('../..')
from vision_models import ViT, VAT, configure_optimizers
from utils.pl_tqdm_progbar import TQDMProgressBar

In [3]:
# batch_size = 4096
batch_size = 64

In [4]:
from torch.utils.data import DataLoader
from imagenet_data_utils import ImageNetKaggle

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

normalize = transforms.Normalize(mean=mean,std=std)
train_transform = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
val_transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])


def inv_normalize(tensor, mean=mean, std=std):
    mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
    std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
    if mean.ndim == 1:
        mean = mean.view(-1, 1, 1)
    if std.ndim == 1:
        std = std.view(-1, 1, 1)
    tensor.mul_(std).add_(mean)
    return tensor


root = '/home/ma2393/scratch/datasets/imagenet'
train_ds = ImageNetKaggle(root, "train", train_transform)
train_dataloader = DataLoader(
            train_ds,
            batch_size=batch_size, # may need to reduce this depending on your GPU 
            num_workers=8, # may need to reduce this depending on your num of CPUs and RAM
            shuffle=True,
            drop_last=True,
            pin_memory=True
        )

val_ds = ImageNetKaggle(root, "val", val_transform)
val_dataloader = DataLoader(
            val_ds,
            batch_size=batch_size, # may need to reduce this depending on your GPU 
            num_workers=8, # may need to reduce this depending on your num of CPUs and RAM
            shuffle=False,
            drop_last=True,
            pin_memory=True
        )

n_classes = 1000

In [5]:
# correct = 0
# total = 0
# with torch.no_grad():
#     for x, y in tqdm(dataloader):
#         y_pred = model(x.cuda())
#         correct += (y_pred.argmax(axis=1) == y.cuda()).sum().item()
#         total += len(y)
# print(correct / total)

In [6]:
# x,y = next(iter(train_dataloader))

# n_samples = 5
# fig, axs = plt.subplots(ncols=n_samples, figsize=(12, 4))
# for s, ax in zip(np.random.choice(len(y), n_samples), axs):
#     img = inv_normalize(x[s])
#     ax.imshow(np.transpose(img.numpy(), (1, 2, 0)))
#     ax.set_title(y[s].numpy())

## Config

In [7]:
print('cuda available: ', torch.cuda.is_available())
print('device count: ', torch.cuda.device_count())
print('current device name: ', torch.cuda.get_device_name(torch.cuda.current_device()))
print('Memory Usage:')
print('\tAllocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
print('\tReserved:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

cuda available:  True
device count:  1
current device name:  NVIDIA A100 80GB PCIe
Memory Usage:
	Allocated: 0.0 GB
	Reserved:    0.0 GB


In [8]:
device = 'cuda'
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast

dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]

# optimization hyperparams
learning_rate = 1e-3 # with baby networks can afford to go a bit higher
# max_iters = 5000
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
decay_lr = True # whether to decay the learning rate
lr_decay_iters = 5000 # make equal to max_iters usually
weight_decay = 1e-1
min_lr = 1e-4 # learning_rate / 10 usually
beta1 = 0.9
beta2 = 0.99 # make a bit bigger because number of tokens per iter is small
# warmup_iters = 100
gradient_accumulation_steps = 32 # 1 # accumulate gradients over this many steps. simulates larger batch size


## Define Pytorch Lightning Module

In [9]:
# for x in train_dataloader:
#     # print(x)
#     print(x[0].shape)
#     print(x[1].shape)
#     # print(len(x))
#     break

In [10]:
topks = 10

In [11]:
log_on_step = True

class LitVisionModel(L.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.criterion = torch.nn.functional.cross_entropy
        self.accuracy = lambda pred, y: torchmetrics.functional.accuracy(pred, y, task="multiclass", num_classes=n_classes, top_k=1, average='micro')

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = self.criterion(logits, y)
        acc = self.accuracy(logits, y)

        self.log('train/loss', loss, prog_bar=True, logger=True, on_step=log_on_step, on_epoch=True)
        self.log('train/acc', acc, prog_bar=True, logger=True, on_step=log_on_step, on_epoch=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = self.criterion(logits, y)
        acc = self.accuracy(logits, y)

        self.log(f"val/loss", loss, prog_bar=True, logger=True, add_dataloader_idx=False)
        self.log(f"val/acc", acc, prog_bar=True, logger=True, add_dataloader_idx=False)

        for k in range(1, topks):
            acc = torchmetrics.functional.accuracy(logits, y, task="multiclass", num_classes=n_classes, top_k=k, average='micro')
            self.log(f"val/top{k}_acc", acc, prog_bar=True, logger=True, add_dataloader_idx=False)


    def configure_optimizers(self):
        optimizer = configure_optimizers(self.model, weight_decay, learning_rate, (beta1, beta2), device_type=device)
        return optimizer

# endregion


## Create Model

In [12]:
# c, w, h = images.shape[1:]
c, w, h = (3, 224, 224)
image_shape = (c, w, h)

In [13]:
# model args
symbol_type = 'pos_relative'
d_model, n_layers, dff = 768, 12, None
sa, rca = 12, 0
patch_size = (16, 16)
n_patches = (w // patch_size[0]) * (h // patch_size[1])
activation = 'swiglu'
dropout_rate = 0.1
rca_type = 'disentangled_v2'
norm_first = True
bias = False
pool = 'cls'

run_name = f'sa={sa}; rca={rca}; d={d_model}; L={n_layers}; rca_type={rca_type}; symbol_type={symbol_type}'

In [14]:
n_patches

196

In [15]:
# define kwargs for symbol-retrieval module based on type
rca_kwargs = dict()
if symbol_type == 'sym_attn':
    symbol_retrieval_kwargs = dict(d_model=d_model, n_symbols=50, n_heads=4) # NOTE: n_heads, n_symbols fixed for now
elif symbol_type == 'pos_sym_retriever':
    symbol_retrieval_kwargs = dict(symbol_dim=d_model, max_length=n_patches+1)
elif symbol_type == 'pos_relative':
    symbol_retrieval_kwargs = dict(symbol_dim=d_model, max_rel_pos=n_patches+1)
    rca_kwargs['use_relative_positional_symbols'] = True # if using position-relative symbols, need to tell RCA module
elif rca != 0:
    raise ValueError(f'`symbol_type` {symbol_type} not valid')

# if rca=0, use TransformerLM
if rca == 0:
    model_args = dict(
        image_shape=image_shape, patch_size=patch_size, num_classes=n_classes, pool=pool,
        d_model=d_model, n_layers=n_layers, n_heads=sa, dff=dff, dropout_rate=dropout_rate,
        activation=activation, norm_first=norm_first, bias=bias)

    model = transformer_lm = ViT(**model_args).to(device)
# otherwise, use AbstractTransformerLM
else:
    model_args = dict(
        image_shape=image_shape, patch_size=patch_size, num_classes=n_classes, pool=pool,
        d_model=d_model, n_layers=n_layers, n_heads_sa=sa, n_heads_rca=rca, dff=dff, dropout_rate=dropout_rate,
        activation=activation, norm_first=norm_first, bias=bias, rca_type=rca_type,
        symbol_retrieval=symbol_type, symbol_retrieval_kwargs=symbol_retrieval_kwargs, rca_kwargs=rca_kwargs)

    model = abstracttransformer_lm = VAT(**model_args).to(device)

print(torchinfo.summary(
    model, input_size=(1, *image_shape),
    col_names=("input_size", "output_size", "num_params", "params_percent")))


Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Param %
ViT                                      [1, 3, 224, 224]          [1, 1000]                 152,064                     0.13%
├─Sequential: 1-1                        [1, 3, 224, 224]          [1, 196, 768]             --                             --
│    └─Rearrange: 2-1                    [1, 3, 224, 224]          [1, 196, 768]             --                             --
│    └─LayerNorm: 2-2                    [1, 196, 768]             [1, 196, 768]             1,536                       0.00%
│    └─Linear: 2-3                       [1, 196, 768]             [1, 196, 768]             590,592                     0.51%
│    └─LayerNorm: 2-4                    [1, 196, 768]             [1, 196, 768]             1,536                       0.00%
├─Dropout: 1-2                           [1, 197, 768]             [1, 197, 768]             --                

In [16]:
# unoptimized_model = model
# model = torch.compile(model)
lit_model = LitVisionModel(model)

## Train Model

In [17]:
log_to_wandb = False
n_epochs = 1
max_steps = -1
log_every_n_steps = 20
eval_interval = None

In [18]:
torch.set_float32_matmul_precision('medium')

In [19]:
# lit_model = torch.compile(lit_model)

In [20]:
if log_to_wandb:
    run = wandb.init(project=wandb_project, group=group_name, name=run_name,
        config={'group': group_name, 'num_params': num_params, **model_args})

    wandb_logger = WandbLogger(experiment=run, log_model=log_model),
else:
    wandb_logger = None

callbacks = [
    # TQDMProgressBar(refresh_rate=50)
    TQDMProgressBar(),
    L.pytorch.callbacks.ModelCheckpoint(dirpath=f'out/{run_name}', every_n_train_steps=10) #every_n_epochs=1)
]

trainer_kwargs = dict(
    max_epochs=n_epochs, enable_model_summary=True, benchmark=True, enable_checkpointing=True,
    enable_progress_bar=True, callbacks=callbacks, logger=wandb_logger,
    accumulate_grad_batches=gradient_accumulation_steps, gradient_clip_val=grad_clip,
    # log_every_n_steps=log_every_n_steps, max_steps=max_steps, val_check_interval=eval_interval) # FIXME
    log_every_n_steps=log_every_n_steps, max_steps=100)#, val_check_interval=eval_interval)

trainer = L.Trainer(
    **trainer_kwargs
    )

trainer.fit(model=lit_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
# endregion


/home/ma2393/.conda/envs/abstract_transformer/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/ma2393/.conda/envs/abstract_transformer/lib/py ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
2024-05-03 16:17:31.537963: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-05-03 16:17:32.962592: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has alre

num decayed parameter tensors: 88, with 114,756,096 parameters
num non-decayed parameter tensors: 78, with 115,432 parameters
using fused AdamW: True
Epoch 0:   1%|          | 124/20018 [00:26<1:09:39,  4.76it/s, v_num=5807, train/loss_step=10.70, train/acc_step=0.000] 

/home/ma2393/.conda/envs/abstract_transformer/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [None]:
trainer.validate(lit_model, val_dataloader)

2024-05-02 21:40:35.062466: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-05-02 21:40:36.188600: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-02 21:40:36.188649: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-02 21:40:36.214019: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-02 21:40:36.228311: I tensorflow/core/platform/cpu_feature_guar

Validation DataLoader 0: 100%|██████████| 781/781 [03:36<00:00,  3.60it/s]


[{'val/loss': 7.431337833404541,
  'val/acc': 0.0010003200732171535,
  'val/top1_acc': 0.0010003200732171535,
  'val/top2_acc': 0.0024607875384390354,
  'val/top3_acc': 0.003481114050373435,
  'val/top4_acc': 0.004061299841850996,
  'val/top5_acc': 0.005081626120954752,
  'val/top6_acc': 0.0059619080275297165,
  'val/top7_acc': 0.006862195674329996,
  'val/top8_acc': 0.007822503335773945,
  'val/top9_acc': 0.008722791448235512}]