# Vision Experiments

In [1]:
import os;
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [2]:
import torch

import torchinfo
import lightning as L
from lightning.pytorch.loggers.wandb import WandbLogger
import torchmetrics

import wandb

import matplotlib.pyplot as plt
import numpy as np

from relational_games_data_utils import RelationalGamesDataset

import sys; sys.path.append('../..')
from vision_models import ViT, VAT, configure_optimizers
from utils.pl_tqdm_progbar import TQDMProgressBar

In [3]:
data_path = '../../data/relational_games'
task = '1task_match_patt'
batch_size = 512

train_split = 'pentos'

train_ds = RelationalGamesDataset(data_path, task, train_split)
train_dataloader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

val_ds_dict = dict()
val_dls = []
val_splits = ('hexos', 'stripes')
for val_split in val_splits:
    ds = RelationalGamesDataset(data_path, task, val_split)
    dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
    val_ds_dict[val_split] = ds
    val_dls.append(dl)

In [4]:
for x, y in train_dataloader:
    print(x.shape)
    print(x.dtype)
    print(y.dtype)
    print(y.shape)
    break

torch.Size([512, 3, 36, 36])
torch.float32
torch.int64
torch.Size([512])


## Config

In [5]:
print('cuda available: ', torch.cuda.is_available())
print('device count: ', torch.cuda.device_count())
print('current device name: ', torch.cuda.get_device_name(torch.cuda.current_device()))
print('Memory Usage:')
print('\tAllocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
print('\tReserved:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

cuda available:  True
device count:  1
current device name:  NVIDIA H100 80GB HBM3
Memory Usage:
	Allocated: 0.0 GB
	Reserved:    0.0 GB


In [6]:
device = 'cuda'
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast

dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]

# optimization hyperparams
learning_rate = 5e-4 # with baby networks can afford to go a bit higher #NOTE: this was useful for match_pattern
# max_iters = 5000
grad_clip = 0.0 # 1.0 # clip gradients at this value, or disable if == 0.0
# decay_lr = True # whether to decay the learning rate
# lr_decay_iters = 5000 # make equal to max_iters usually
weight_decay = None # 1e-1 # NOTE: maybe need this?
# min_lr = 1e-4 # learning_rate / 10 usually
beta1 = 0.9
beta2 = 0.99 # make a bit bigger because number of tokens per iter is small
# warmup_iters = 100
gradient_accumulation_steps = 1 # 32 # 1 # accumulate gradients over this many steps. simulates larger batch size


## Define Pytorch Lightning Module

In [7]:
# for x in train_dataloader:
#     # print(x)
#     print(x[0].shape)
#     print(x[1].shape)
#     # print(len(x))
#     break

In [8]:
log_on_step = True

class LitVisionModel(L.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.criterion = torch.nn.functional.cross_entropy
        self.accuracy = lambda pred, y: torchmetrics.functional.accuracy(pred, y, task="multiclass", num_classes=n_classes, top_k=1, average='micro')

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = self.criterion(logits, y)
        acc = self.accuracy(logits, y)

        self.log('train/loss', loss, prog_bar=True, logger=True, on_step=log_on_step, on_epoch=True)
        self.log('train/acc', acc, prog_bar=True, logger=True, on_step=log_on_step, on_epoch=True)

        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        x, y = batch
        logits = self.model(x)
        loss = self.criterion(logits, y)
        acc = self.accuracy(logits, y)

        self.log(f"val/loss_{val_splits[dataloader_idx]}", loss, prog_bar=True, logger=True, add_dataloader_idx=False)
        self.log(f"val/acc_{val_splits[dataloader_idx]}", acc, prog_bar=True, logger=True, add_dataloader_idx=False)


    def configure_optimizers(self):
        # optimizer = configure_optimizers(self.model, weight_decay, learning_rate, (beta1, beta2), device_type=device)
        # optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate, betas=(beta1, beta2))
        optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate, betas=(beta1, beta2))
        return optimizer

# endregion


## Create Model

In [9]:
# c, w, h = images.shape[1:]
c, w, h = (3, 36, 36)
image_shape = (c, w, h)
n_classes = 2

In [10]:
# # THIS WORKED FOR MATCH-PATTERN
# # model args
# symbol_type = 'positional_symbols'
# d_model, n_layers, dff = 128, 1, 256
# sa, rca = 4, 4
# patch_size = (12, 12)
# n_patches = (w // patch_size[0]) * (h // patch_size[1])
# activation = 'gelu'
# dropout_rate = 0.1
# rca_type = 'relational_attention'
# norm_first = False
# bias = False
# pool = 'mean'

# run_name = f'sa={sa}; rca={rca}; d={d_model}; L={n_layers}; rca_type={rca_type}; symbol_type={symbol_type}'

In [11]:
# model args
symbol_type = 'positional_symbols'
d_model, n_layers, dff = 128, 2, 256
# d_model, n_layers, dff = 32, 2, 64
# sa, rca = 4, 4
# sa, rca = 2, 2
# sa, rca = 8, 0
# sa, rca = 4, 0
sa, rca = 4, 0
patch_size = (12, 12)
n_patches = (w // patch_size[0]) * (h // patch_size[1])
activation = 'swiglu'
dropout_rate = 0.1
# rca_type = 'relational_attention'
rca_type = 'rca'
norm_first = True
bias = False
pool = 'mean'
norm_type = 'layernorm'

run_name = f'sa={sa}; rca={rca}; d={d_model}; L={n_layers}; rca_type={rca_type}; symbol_type={symbol_type}'

In [12]:
n_patches

9

In [13]:
# define kwargs for symbol-retrieval module based on type
rca_kwargs = dict()
if symbol_type == 'symbolic_attention':
    symbol_retrieval_kwargs = dict(d_model=d_model, n_symbols=50, n_heads=4) # NOTE: n_heads, n_symbols fixed for now
elif symbol_type == 'positional_symbols':
    symbol_retrieval_kwargs = dict(symbol_dim=d_model, max_length=n_patches+1)
elif symbol_type == 'position_relative':
    symbol_retrieval_kwargs = dict(symbol_dim=d_model, max_rel_pos=n_patches+1)
    rca_kwargs['use_relative_positional_symbols'] = True # if using position-relative symbols, need to tell RCA module
elif rca != 0:
    raise ValueError(f'`symbol_type` {symbol_type} not valid')

# if rca=0, use TransformerLM
if rca == 0:
    model_args = dict(
        image_shape=image_shape, patch_size=patch_size, num_classes=n_classes, pool=pool,
        d_model=d_model, n_layers=n_layers, n_heads=sa, dff=dff, dropout_rate=dropout_rate, norm_type=norm_type,
        activation=activation, norm_first=norm_first, bias=bias)

    model = transformer_lm = ViT(**model_args).to(device)
# otherwise, use AbstractTransformerLM
else:
    model_args = dict(
        image_shape=image_shape, patch_size=patch_size, num_classes=n_classes, pool=pool,
        d_model=d_model, n_layers=n_layers, n_heads_sa=sa, n_heads_rca=rca, dff=dff, dropout_rate=dropout_rate, norm_type=norm_type,
        activation=activation, norm_first=norm_first, bias=bias, rca_type=rca_type,
        symbol_retrieval=symbol_type, symbol_retrieval_kwargs=symbol_retrieval_kwargs, rca_kwargs=rca_kwargs)

    model = abstracttransformer_lm = VAT(**model_args).to(device)

print(torchinfo.summary(
    model, input_size=(1, *image_shape),
    col_names=("input_size", "output_size", "num_params", "params_percent")))


Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Param %
ViT                                      [1, 3, 36, 36]            [1, 2]                    1,408                       0.36%
├─Sequential: 1-1                        [1, 3, 36, 36]            [1, 9, 128]               --                             --
│    └─Rearrange: 2-1                    [1, 3, 36, 36]            [1, 9, 432]               --                             --
│    └─LayerNorm: 2-2                    [1, 9, 432]               [1, 9, 432]               864                         0.22%
│    └─Linear: 2-3                       [1, 9, 432]               [1, 9, 128]               55,424                     14.32%
│    └─LayerNorm: 2-4                    [1, 9, 128]               [1, 9, 128]               256                         0.07%
├─Dropout: 1-2                           [1, 10, 128]              [1, 10, 128]              --                

In [14]:
unoptimized_model = model
model = torch.compile(model)
lit_model = LitVisionModel(model)

## Train Model

In [15]:
log_to_wandb = False
n_epochs = 25
max_steps = -1
log_every_n_steps = 20
eval_interval = None

In [16]:
torch.set_float32_matmul_precision('medium')

In [17]:
if log_to_wandb:
    run = wandb.init(project=wandb_project, group=group_name, name=run_name,
        config={'group': group_name, 'num_params': num_params, **model_args})

    wandb_logger = WandbLogger(experiment=run, log_model=log_model),
else:
    wandb_logger = None

callbacks = [
    # TQDMProgressBar(refresh_rate=50)
    TQDMProgressBar(),
    # L.pytorch.callbacks.ModelCheckpoint(dirpath=f'out/{run_name}', every_n_train_steps=10) #every_n_epochs=1)
]

trainer_kwargs = dict(
    max_epochs=n_epochs, enable_model_summary=True, benchmark=True, enable_checkpointing=True,
    enable_progress_bar=True, callbacks=callbacks, logger=wandb_logger,
    accumulate_grad_batches=gradient_accumulation_steps, gradient_clip_val=grad_clip,
    # log_every_n_steps=log_every_n_steps, max_steps=max_steps, val_check_interval=eval_interval) # FIXME
    log_every_n_steps=log_every_n_steps, max_steps=max_steps)#, val_check_interval=eval_interval)

trainer = L.Trainer(
    **trainer_kwargs
    )

trainer.fit(model=lit_model, train_dataloaders=train_dataloader, val_dataloaders=val_dls)
# endregion


/home/ma2393/.conda/envs/abstract_transformer/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/ma2393/.conda/envs/abstract_transformer/lib/py ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
2024-05-04 21:12:08.009134: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-05-04 21:12:08.036637: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has alre

Epoch 24: 100%|██████████| 489/489 [00:03<00:00, 127.29it/s, v_num=5898, train/loss_step=0.693, train/acc_step=0.521, val/loss_hexos=0.693, val/acc_hexos=0.497, val/loss_stripes=0.693, val/acc_stripes=0.496, train/loss_epoch=0.693, train/acc_epoch=0.506]

`Trainer.fit` stopped: `max_epochs=25` reached.


Epoch 24: 100%|██████████| 489/489 [00:03<00:00, 126.67it/s, v_num=5898, train/loss_step=0.693, train/acc_step=0.521, val/loss_hexos=0.693, val/acc_hexos=0.497, val/loss_stripes=0.693, val/acc_stripes=0.496, train/loss_epoch=0.693, train/acc_epoch=0.506]


In [None]:
# # L=1, sa=0, rca=4, layernorm, d_model, n_layers, dff = 64, 1, 128, mean pooling
# Epoch 24: 100%|██████████| 977/977 [00:06<00:00, 154.98it/s, v_num=5898, train/loss_step=0.432, train/acc_step=0.799, val/loss_hexos=0.553, val/acc_hexos=0.807, val/loss_stripes=1.190, val/acc_stripes=0.569, train/loss_epoch=0.437, train/acc_epoch=0.805]

In [None]:
# L=1, sa=4, rca=0, layernorm, d_model, n_layers, dff = 64, 1, 128, mean pooling
# Epoch 24: 100%|██████████| 977/977 [00:05<00:00, 165.00it/s, v_num=5898, train/loss_step=0.509, train/acc_step=0.771, val/loss_hexos=0.648, val/acc_hexos=0.713, val/loss_stripes=0.958, val/acc_stripes=0.582, train/loss_epoch=0.565, train/acc_epoch=0.714]

In [None]:
# L=1, sa=2, rca=2, layernorm, d_model, n_layers, dff = 64, 1, 128, mean pooling
# Epoch 24: 100%|██████████| 977/977 [00:06<00:00, 145.32it/s, v_num=5898, train/loss_step=0.695, train/acc_step=0.465, val/loss_hexos=0.693, val/acc_hexos=0.502, val/loss_stripes=0.693, val/acc_stripes=0.497, train/loss_epoch=0.693, train/acc_epoch=0.508]

In [None]:
# L=1, sa=2, rca=2, layernorm, d_model, n_layers, dff = 64, 1, 128, mean pooling; sym_attn
# Epoch 24: 100%|██████████| 977/977 [00:06<00:00, 141.41it/s, v_num=5898, train/loss_step=0.548, train/acc_step=0.729, val/loss_hexos=0.579, val/acc_hexos=0.728, val/loss_stripes=0.757, val/acc_stripes=0.575, train/loss_epoch=0.568, train/acc_epoch=0.707]

In [None]:
# L=2, sa=2, rca=2, layernorm, d_model, n_layers, dff = 64, 2, 128, mean pooling; sym_attn; batch_size = 512 (from 256)
# Epoch 24: 100%|██████████| 489/489 [00:04<00:00, 101.71it/s, v_num=5898, train/loss_step=0.255, train/acc_step=0.910, val/loss_hexos=0.484, val/acc_hexos=0.866, val/loss_stripes=1.560, val/acc_stripes=0.643, train/loss_epoch=0.192, train/acc_epoch=0.931]

In [None]:
# L=2, sa=2, rca=2, layernorm, d_model, n_layers, dff = 64, 2, 128, mean pooling; pos_sym_retriever; batch_size = 512 (from 256)
# Epoch 24: 100%|██████████| 489/489 [00:04<00:00, 103.85it/s, v_num=5898, train/loss_step=0.188, train/acc_step=0.931, val/loss_hexos=0.325, val/acc_hexos=0.903, val/loss_stripes=1.050, val/acc_stripes=0.762, train/loss_epoch=0.191, train/acc_epoch=0.931]

In [None]:
# L=2, sa=0, rca=4, layernorm, d_model, n_layers, dff = 64, 2, 128, mean pooling; pos_sym_retriever; batch_size = 512 (from 256)
# Epoch 24: 100%|██████████| 489/489 [00:04<00:00, 110.84it/s, v_num=5898, train/loss_step=0.317, train/acc_step=0.896, val/loss_hexos=0.511, val/acc_hexos=0.851, val/loss_stripes=1.280, val/acc_stripes=0.691, train/loss_epoch=0.238, train/acc_epoch=0.914]

In [None]:
# L=2, sa=2, rca=2, layernorm, d_model, n_layers, activation=gelu; dff = 64, 2, 128, mean pooling; pos_sym_retriever; batch_size = 512 (from 256)
# Epoch 24: 100%|██████████| 489/489 [00:04<00:00, 106.02it/s, v_num=5898, train/loss_step=0.459, train/acc_step=0.785, val/loss_hexos=0.538, val/acc_hexos=0.766, val/loss_stripes=0.800, val/acc_stripes=0.688, train/loss_epoch=0.469, train/acc_epoch=0.775]

In [None]:
# L=2, sa=2, rca=2, layernorm, d_model, n_layers, activation=swiglu; dff = 64, 2, 64*4, mean pooling; pos_sym_retriever; batch_size = 512 (from 256)
# Epoch 24: 100%|██████████| 489/489 [00:04<00:00, 106.59it/s, v_num=5898, train/loss_step=0.168, train/acc_step=0.917, val/loss_hexos=0.665, val/acc_hexos=0.841, val/loss_stripes=2.550, val/acc_stripes=0.546, train/loss_epoch=0.201, train/acc_epoch=0.928]
# hexos better but stripes worse

In [None]:
# L=2, sa=0, rca=4, layernorm, d_model, n_layers, activation=swiglu; dff = 64, 2, 64*4, mean pooling; pos_sym_retriever; batch_size = 512 (from 256)
# Epoch 24: 100%|██████████| 489/489 [00:04<00:00, 115.16it/s, v_num=5898, train/loss_step=0.185, train/acc_step=0.917, val/loss_hexos=0.539, val/acc_hexos=0.861, val/loss_stripes=1.650, val/acc_stripes=0.642, train/loss_epoch=0.186, train/acc_epoch=0.933]

In [None]:
# L=2, sa=0, rca=4, layernorm, d_model, n_layers, activation=swiglu; dff = 32, 2, 32*4, mean pooling; pos_sym_retriever; batch_size = 512 (from 256)
# Epoch 24: 100%|██████████| 489/489 [00:04<00:00, 115.21it/s, v_num=5898, train/loss_step=0.377, train/acc_step=0.826, val/loss_hexos=0.527, val/acc_hexos=0.820, val/loss_stripes=0.858, val/acc_stripes=0.712, train/loss_epoch=0.410, train/acc_epoch=0.820]

In [None]:
# L=2, sa=0, rca=4, layernorm, d_model, n_layers, activation=swiglu; dff = 32, 2, 64, mean pooling; pos_sym_retriever; batch_size = 512 (from 256)
# Epoch 24: 100%|██████████| 489/489 [00:04<00:00, 114.35it/s, v_num=5898, train/loss_step=0.296, train/acc_step=0.903, val/loss_hexos=0.543, val/acc_hexos=0.830, val/loss_stripes=1.220, val/acc_stripes=0.627, train/loss_epoch=0.272, train/acc_epoch=0.898]

In [None]:
# L=2, sa=2, rca=2, layernorm, d_model, n_layers, activation=swiglu; dff = 32, 2, 64, mean pooling; pos_sym_retriever; batch_size = 512 (from 256)
# Epoch 24: 100%|██████████| 489/489 [00:05<00:00, 93.00it/s, v_num=5898, train/loss_step=0.700, train/acc_step=0.458, val/loss_hexos=0.691, val/acc_hexos=0.527, val/loss_stripes=0.693, val/acc_stripes=0.510, train/loss_epoch=0.690, train/acc_epoch=0.527]

In [None]:
# L=1, sa=2, rca=2, layernorm, d_model, n_layers, activation=swiglu; dff = 32, 2, 64, mean pooling; pos_sym_retriever; batch_size = 512 (from 256)
# Epoch 24: 100%|██████████| 489/489 [00:05<00:00, 90.42it/s, v_num=5898, train/loss_step=0.471, train/acc_step=0.778, val/loss_hexos=0.538, val/acc_hexos=0.757, val/loss_stripes=0.744, val/acc_stripes=0.686, train/loss_epoch=0.474, train/acc_epoch=0.772]

In [None]:
# L=1, sa=4, rca=0, layernorm, d_model, n_layers, activation=swiglu; dff = 32, 2, 64, mean pooling; pos_sym_retriever; batch_size = 512 (from 256)
# Epoch 24: 100%|██████████| 489/489 [00:03<00:00, 125.21it/s, v_num=5898, train/loss_step=0.493, train/acc_step=0.778, val/loss_hexos=0.513, val/acc_hexos=0.789, val/loss_stripes=0.661, val/acc_stripes=0.701, train/loss_epoch=0.551, train/acc_epoch=0.736]

In [None]:
# L=1, sa=0, rca=4, layernorm, d_model, n_layers, activation=swiglu; dff = 32, 2, 64, mean pooling; pos_sym_retriever; batch_size = 512 (from 256)
# Epoch 24: 100%|██████████| 489/489 [00:05<00:00, 97.05it/s, v_num=5898, train/loss_step=0.284, train/acc_step=0.917, val/loss_hexos=0.597, val/acc_hexos=0.826, val/loss_stripes=1.460, val/acc_stripes=0.632, train/loss_epoch=0.292, train/acc_epoch=0.890]

In [None]:
# L=1, sa=0, rca=4, norm_type=none, d_model, n_layers, activation=swiglu; dff = 32, 2, 64, mean pooling; pos_sym_retriever; batch_size = 512 (from 256)
# Epoch 24: 100%|██████████| 489/489 [00:04<00:00, 115.33it/s, v_num=5898, train/loss_step=0.354, train/acc_step=0.833, val/loss_hexos=0.346, val/acc_hexos=0.857, val/loss_stripes=0.748, val/acc_stripes=0.680, train/loss_epoch=0.327, train/acc_epoch=0.872]

In [None]:
# L=1, sa=0, rca=4, norm_type=none, d_model, n_layers, activation=swiglu; dff = 32, 2, 64, mean pooling; pos_sym_retriever; batch_size = 512 (from 256)
# Epoch 24: 100%|██████████| 489/489 [00:04<00:00, 112.48it/s, v_num=5898, train/loss_step=0.549, train/acc_step=0.750, val/loss_hexos=0.531, val/acc_hexos=0.761, val/loss_stripes=0.594, val/acc_stripes=0.738, train/loss_epoch=0.482, train/acc_epoch=0.769]

In [None]:
# L=1, sa=0, rca=4, norm_type=none, d_model, n_layers, activation=swiglu; dff = 128, 2, 256, mean pooling; pos_sym_retriever; batch_size = 512 (from 256)
# Epoch 24: 100%|██████████| 489/489 [00:04<00:00, 119.37it/s, v_num=5898, train/loss_step=0.0976, train/acc_step=0.965, val/loss_hexos=0.213, val/acc_hexos=0.941, val/loss_stripes=0.751, val/acc_stripes=0.803, train/loss_epoch=0.171, train/acc_epoch=0.941]