In [1]:
import numpy as np
import torch
import wandb
import torchinfo
from contextlib import nullcontext
from  tqdm import tqdm, trange

import os
import sys; sys.path += ['../', '../..']
from train_utils import train_model
from seq2seq_models import Seq2SeqAbstractTransformer, Seq2SeqTransformer
from og_seq2seq_models import Seq2SeqAbstractorArchb

## Config

In [2]:
# I/O
eval_only = False # if True, script exits right after the first eval

# system
# device = 'cpu'
device = 'cuda'
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast

# 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' 
dtype = 'float32'
# dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' 
compile = True

# evaluation and output
out_dir = '../out/object_sorting'
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
eval_interval = 250 # keep frequent because we'll overfit
eval_iters = 200
log_interval = 10 # don't print too too often

# we expect to overfit on this small dataset, so only save when val improves
always_save_checkpoint = False

# wandb logging
wandb_log = False
wandb_project = 'abstract_transformer--object_sorting'

# optimization hyperparams
learning_rate = 1e-3 # with baby networks can afford to go a bit higher
max_iters = 5000
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
decay_lr = True # whether to decay the learning rate
lr_decay_iters = 5000 # make equal to max_iters usually
weight_decay = 1e-1
min_lr = 1e-4 # learning_rate / 10 usually
beta1 = 0.9
beta2 = 0.99 # make a bit bigger because number of tokens per iter is small
warmup_iters = 100
gradient_accumulation_steps = 1 # accumulate gradients over this many steps. simulates larger batch size

# batch size and block size
# batch_size = 64
# block_size = 256

# DDP (distributed data parallel) training
ddp = False
master_process = True

# TODO: set up DDP for future experiments

In [3]:
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# Load Data

In [4]:
data_path = 'object_sorting_datasets/task1_object_sort_dataset.npy'
data = np.load(data_path, allow_pickle=True).item()

objects, seqs, sorted_seqs, object_seqs, target, labels, start_token = tuple(
    data[key] for key in ['objects', 'seqs', 'sorted_seqs', 'object_seqs', 'target', 'labels', 'start_token'])

# convert to torch tensors
object_seqs = torch.tensor(object_seqs, dtype=ptdtype, device=device)
target = torch.tensor(target, dtype=torch.long, device=device)
labels = torch.tensor(labels, dtype=torch.long, device=device)

In [5]:
def train_val_test_split(*arrays, val_size=0.1, test_size=0.2):
    n = len(arrays[0])
    indices = np.random.permutation(n)
    val_start = int(n * (1 - val_size - test_size))
    test_start = int(n * (1 - test_size))
    train_indices = indices[:val_start]
    val_indices = indices[val_start:test_start]
    test_indices = indices[test_start:]
    return tuple(tuple(array[idx] for idx in (train_indices, val_indices, test_indices)) for array in arrays)

In [6]:
(object_seqs_train, object_seqs_val, object_seqs_test), (target_train, target_val, target_test), (labels_train, labels_val, labels_test) = train_val_test_split(
    object_seqs, target, labels, val_size=0.1, test_size=0.2)

In [7]:
print(f'training shapes: {object_seqs_train.shape}, {target_train.shape}, {labels_train.shape}')
print(f'validation shapes: {object_seqs_val.shape}, {target_val.shape}, {labels_val.shape}')
print(f'test shapes: {object_seqs_test.shape}, {target_test.shape}, {labels_test.shape}')

training shapes: torch.Size([70000, 10, 8]), torch.Size([70000, 10]), torch.Size([70000, 10])
validation shapes: torch.Size([10000, 10, 8]), torch.Size([10000, 10]), torch.Size([10000, 10])
test shapes: torch.Size([20000, 10, 8]), torch.Size([20000, 10]), torch.Size([20000, 10])


In [8]:
train_size = 1500
sample_idx = np.random.choice(object_seqs_train.shape[0], train_size)

train_ds = torch.utils.data.TensorDataset(object_seqs_train[sample_idx], target_train[sample_idx], labels_train[sample_idx])
val_ds = torch.utils.data.TensorDataset(object_seqs_val, target_val, labels_val)
test_ds = torch.utils.data.TensorDataset(object_seqs_test, target_test, labels_test)

batch_size = 128 # 512
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=batch_size)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, drop_last=True)

In [9]:
def get_train_dl(train_size, batch_size=batch_size):
    sample_idx = np.random.choice(object_seqs_train.shape[0], train_size)
    train_ds = torch.utils.data.TensorDataset(object_seqs_train[sample_idx], target_train[sample_idx], labels_train[sample_idx])
    train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size)

    return train_dl

In [10]:
import lightning as L

# TODO: add to module
# TODO: add features; e.g., logging etc

class LitSeq2SeqModel(L.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def training_step(self, batch, batch_idx):
        x, y, z = batch
        with ctx:
            logits, loss = self.model(x, y, z)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y, z = batch
        with ctx:
            logits, loss = self.model(x, y, z)
        self.log("val_loss", loss)
        tf_acc = torch.mean((torch.argmax(logits, dim=-1) == z).float())

    def test_step(self, batch, batch_idx):
        x, y, z = batch

        n, seqs_length = y.shape
        output = torch.zeros(size=(n, (seqs_length+1)), dtype=torch.int, device=device)
        output[:,0] = start_token

        for i in range(seqs_length):
            with ctx:
                predictions, _ = self.model(x, output[:, :-1], z)
            predictions = predictions[:, i, :]
            predicted_id = torch.argmax(predictions, axis=-1)
            output[:,i+1] = predicted_id

        elementwise_acc = torch.mean((output[:,1:] == z).float()).item()
        # acc_per_position = [torch.mean((output[:, i+1] == labels_test[:, i]).float()).item() for i in range(seqs_length)]
        seq_acc = torch.mean((torch.all(output[:,1:]==z, axis=1)).float()).item()

        with ctx:
            tf_pred, loss = self.model(x, y, z)
            tf_pred = torch.argmax(tf_pred, axis=-1)
        teacher_forcing_acc = torch.mean((z==tf_pred).float()).item()

        self.log("test_loss", loss)
        self.log("teacher_forcing_acc", teacher_forcing_acc)
        self.log("elementwise_acc", elementwise_acc)
        self.log("seq_acc", seq_acc)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


## Training Setup

In [11]:
# def get_lr(it):
#     # 1) linear warmup for warmup_iters steps
#     if it < warmup_iters:
#         return learning_rate * it / warmup_iters
#     # 2) if it > lr_decay_iters, return min learning rate
#     if it > lr_decay_iters:
#         return min_lr
#     # 3) in between, use cosine decay down to min learning rate
#     decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
#     assert 0 <= decay_ratio <= 1
#     coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
#     return min_lr + coeff * (learning_rate - min_lr)

def get_lr(it):
    return 0.001
@torch.no_grad()
def eval_model(model, ctx=None):

    ctx = nullcontext() if ctx is None else ctx
    out = {}
    model.eval()
    for split in ['train', 'val']:
        dl = train_dl if split == 'train' else val_dl
        max_batches = min(eval_iters, len(dl)) if eval_iters is not None else len(dl)
        losses = torch.zeros(max_batches)
        tfaccs = torch.zeros(max_batches)
        for k, batch in enumerate(dl):
            source, target, label = batch
            if eval_iters is not None and k >= max_batches:
                break
            with ctx:
                logits, loss = model(source, target, label)
            losses[k] = loss.item()
            tfaccs[k] = torch.mean((torch.argmax(logits, dim=-1) == label).float())

        out[f'{split}/loss'] = losses.mean() # FIXME loss is averaged over batch. batch sizes may be unnequal?
        out[f'{split}/tfacc'] = tfaccs.mean()
    model.train()
    return out

@torch.no_grad()
def evaluate_seq2seq_model(model, source_test, target_test, labels_test, start_token, print_=False, ctx=ctx):

    model.eval()

    n, seqs_length = target_test.shape
    output = torch.zeros(size=(n, (seqs_length+1)), dtype=torch.int, device=device)
    output[:,0] = start_token

    for i in range(seqs_length):
        with ctx:
            predictions, _ = model(source_test, output[:, :-1], labels_test)
        predictions = predictions[:, i, :]
        predicted_id = torch.argmax(predictions, axis=-1)
        output[:,i+1] = predicted_id

    elementwise_acc = torch.mean((output[:,1:] == labels_test).float()).item()
    acc_per_position = [torch.mean((output[:, i+1] == labels_test[:, i]).float()).item() for i in range(seqs_length)]
    seq_acc = torch.mean((torch.all(output[:,1:]==labels_test, axis=1)).float()).item()

    with ctx:
        tf_pred = model(source_test, target_test, labels_test)[0]
        tf_pred = torch.argmax(tf_pred, axis=-1)
    teacher_forcing_acc = torch.mean((labels_test==tf_pred).float()).item()

    if print_:
        print('element-wise accuracy: %.2f%%' % (100*elementwise_acc))
        print('full sequence accuracy: %.2f%%' % (100*seq_acc))
        print('teacher-forcing accuracy:  %.2f%%' % (100*teacher_forcing_acc))


    return_dict = {
        'elementwise_accuracy': elementwise_acc, 'full_sequence_accuracy': seq_acc,
        'teacher_forcing_accuracy': teacher_forcing_acc, 'acc_by_position': acc_per_position
        }

    return return_dict


# Modeling

In [23]:
# TODO: implement attn with relative positional embedding or RoPE
# TODO: add these to module

## Seq2Seq Transforemr

In [12]:
model_args = dict(
    input_spec=dict(type='vector', dim=8), output_spec=dict(type='token', vocab_size=10+1),
    d_model=64, out_dim=10, n_layers_enc=2, n_layers_dec=2,
    encoder_kwargs=dict(n_heads=2, dff=128, activation='relu', norm_first=True, dropout_rate=0.1, causal=False),
    decoder_kwargs=dict(n_heads=2, dff=128, activation='relu', norm_first=True, dropout_rate=0.1, causal=True),
    in_block_size=10, out_block_size=10)
seq2seqtransformer = Seq2SeqTransformer(**model_args)
torchinfo.summary(seq2seqtransformer, row_settings=["depth", "var_names"], col_names=["num_params", "params_percent", "trainable"], depth=3, col_width=20)

Layer (type (var_name):depth-idx)                                           Param #              Param %              Trainable
Seq2SeqTransformer (Seq2SeqTransformer)                                     --                        --              True
├─ModuleDict (layers): 1-1                                                  --                        --              True
│    └─Linear (source_embedder): 2-1                                        576                    0.34%              True
│    └─Embedding (target_embedder): 2-2                                     704                    0.42%              True
│    └─SinusoidalPositionalEncoding (source_pos_embedder): 2-3              --                        --              --
│    │    └─Dropout (dropout): 3-1                                          --                        --              --
│    └─SinusoidalPositionalEncoding (target_pos_embedder): 2-4              --                        --              --
│    │    └─Dropo

In [18]:
train_kwargs = dict(
    model=model, train_dl=train_dl, eval_model=eval_model, n_epochs=200,
    optimizer=optimizer, scaler=scaler, get_lr=get_lr,
    compile=True, grad_clip=0,
    eval_main_metric='val/loss',
    always_save_checkpoint=always_save_checkpoint, ckpt_dict=dict(model_args=model_args), out_dir=out_dir,
    wandb_log=False, wandb_init_kwargs=dict(project=wandb_project, name='Transformer'), track_mfu=True,
    ddp=False, device_type='cuda')

NameError: name 'model' is not defined

In [None]:
train_dl.batch_size

512

In [None]:
train_model(**train_kwargs)

compiling model... done compiling.
starting training loop...


KeyboardInterrupt: 

In [None]:
evaluate_seq2seq_model(model, source_test, target_test, labels_test, start_token, print_=True, ctx=ctx)

element-wise accuracy: 99.84%
full sequence accuracy: 99.22%
teacher-forcing accuracy:  99.92%


{'elementwise_accuracy': 0.9984375238418579,
 'full_sequence_accuracy': 0.9921875,
 'teacher_forcing_accuracy': 0.999218761920929,
 'acc_by_position': [1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  0.9921875,
  0.9921875,
  1.0,
  1.0]}

# Abstract Transformer

In [12]:
import sys; sys.path += ['..', '../..']
from seq2seq_models import Seq2SeqAbstractTransformer, configure_optimizers

In [13]:
model_args = dict(
    input_spec=dict(type='vector', dim=8), output_spec=dict(type='token', vocab_size=10+1),
    symbol_retrieval='positional_symbols', symbol_retrieval_kwargs=dict(symbol_dim=64, max_symbols=10),
    d_model=64, out_dim=10, n_layers_enc=2, n_layers_dec=2,
    encoder_kwargs=dict(n_heads_enc=2, n_heads_abs=2, dff=128, activation='relu', norm_first=False, dropout_rate=0.1, causal=False, rel_mask_diag=False),
    decoder_kwargs=dict(n_heads_enc=2, n_heads_abs=0, n_heads_cross=2, dff=128, activation='relu', norm_first=False, dropout_rate=0.1, causal=True, rel_mask_diag=False),
    in_block_size=10, out_block_size=10)
seq2seqabstransformer = Seq2SeqAbstractTransformer(**model_args)#.to(device)
torchinfo.summary(seq2seqabstransformer, row_settings=["depth", "var_names"], col_names=["num_params", "params_percent", "trainable"], depth=3, col_width=20)



Layer (type (var_name):depth-idx)                                           Param #              Param %              Trainable
Seq2SeqAbstractTransformer (Seq2SeqAbstractTransformer)                     --                    -0.32%              True
├─PositionalSymbolRetriever (symbol_retriever): 1-1                         --                        --              True
│    └─Embedding (symbol_library): 2-1                                      640                    0.32%              True
├─ModuleDict (layers): 1-2                                                  --                    -0.32%              True
│    └─Linear (source_embedder): 2-2                                        576                    0.29%              True
│    └─Embedding (target_embedder): 2-3                                     704                    0.36%              True
│    └─SinusoidalPositionalEncoding (source_pos_embedder): 2-4              --                        --              --
│    │    └─D

In [14]:
model = seq2seqabstransformer

In [15]:
lit_model = LitSeq2SeqModel(model)

In [16]:
from lightning.pytorch.callbacks import RichProgressBar, TQDMProgressBar

In [17]:
trainer = L.Trainer(
    max_epochs=500, enable_checkpointing=False, logger=False, enable_model_summary=True, precision='64-true',
    # callbacks=[RichProgressBar()]
    )
trainer.fit(model=lit_model, train_dataloaders=train_dl)#, val_dataloaders=val_dl)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/awni/miniconda3/envs/abstract_transformer/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
You are using a CUDA device ('NVIDIA GeForce RTX 4070 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | Seq2SeqAbstractTransformer | 194 K 
-----------------------------------------------------
194 K     T

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=500` reached.


In [19]:
trainer.test(lit_model, test_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 0.044189517713503185,
  'teacher_forcing_acc': 0.9870042174290388,
  'elementwise_acc': 0.9724309016496707,
  'seq_acc': 0.878155048076923}]

In [None]:
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-7)
# optimizer = configure_optimizers(model, weight_decay, learning_rate, (beta1, beta2), device_type=device)

train_kwargs = dict(
    model=model, train_dl=train_dl, eval_model=eval_model, n_epochs=500,
    optimizer=optimizer, scaler=scaler, get_lr=get_lr,
    compile=True, grad_clip=0,
    eval_main_metric='val/loss',
    always_save_checkpoint=always_save_checkpoint, ckpt_dict=dict(model_args=model_args), out_dir=out_dir,
    wandb_log=False, wandb_init_kwargs=dict(project=wandb_project, name='AbstractTransformer'), track_mfu=True,
    ddp=False, device_type='cuda')

train_model(**train_kwargs)

In [None]:
source_test, target_test, labels_test = next(iter(test_dl))
evaluate_seq2seq_model(model, source_test, target_test, labels_test, start_token, print_=True, ctx=ctx)

# Original Abstractor Architecture (b)

In [11]:
import sys; sys.path += ['..', '../..']
from og_seq2seq_models import Seq2SeqAbstractorArchb

In [12]:
model_args = dict(
    input_spec=dict(type='vector', dim=8), output_spec=dict(type='token', vocab_size=10+1),
    symbol_retrieval='positional_symbols', symbol_retrieval_kwargs=dict(symbol_dim=64, max_symbols=10),
    d_model=64, out_dim=10, n_layers_enc=2, n_layers_dec=2,
    encoder_kwargs=dict(n_heads=2, dff=128, activation='relu', norm_first=False, dropout_rate=0.1, causal=False),
    abstractor_kwargs=dict(
        n_layers=2, d_model=64, n_heads=2, dff=128, activation='relu', norm_first=False, dropout_rate=0.1, symbol_retriever_type='symbolic_attention', symbol_add_pos_embedding=False, symbol_retriever_kwargs=dict(model_dim=64, n_heads=2, num_symbols=10), max_len=10),
    decoder_kwargs=dict(n_heads=2, dff=128, activation='relu', norm_first=True, dropout_rate=0.1, causal=True),
    in_block_size=10, out_block_size=10)
model = Seq2SeqAbstractorArchb(**model_args)#.to(device)
torchinfo.summary(model, row_settings=["depth", "var_names"], col_names=["num_params", "params_percent", "trainable"], depth=3, col_width=20)

Layer (type (var_name):depth-idx)                                                Param #              Param %              Trainable
Seq2SeqAbstractorArchb (Seq2SeqAbstractorArchb)                                  --                        --              True
├─PositionalSymbolRetriever (symbol_retriever): 1-1                              --                        --              True
│    └─Embedding (symbol_library): 2-1                                           640                    0.26%              True
├─ModuleDict (layers): 1-2                                                       --                        --              True
│    └─Linear (source_embedder): 2-2                                             576                    0.24%              True
│    └─Embedding (target_embedder): 2-3                                          704                    0.29%              True
│    └─SinusoidalPositionalEncoding (source_pos_embedder): 2-4                   --                

In [13]:
lit_model = LitSeq2SeqModel(model)

In [14]:
trainer = L.Trainer(
    max_epochs=500, enable_checkpointing=False, logger=False, enable_model_summary=True, precision='64-true',
    # callbacks=[RichProgressBar()]
    )
trainer.fit(model=lit_model, train_dataloaders=train_dl)#, val_dataloaders=val_dl)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/awni/miniconda3/envs/abstract_transformer/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
You are using a CUDA device ('NVIDIA GeForce RTX 4070 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                   | Params
-------------------------------------------------
0 | model | Seq2SeqAbstractorArchb | 242 K 
-------------------------------------------------
242 K     Trainable params


Training: |          | 0/? [00:00<?, ?it/s]

In [15]:
trainer.test(lit_model, test_dl)

NameError: name 'trainer' is not defined

In [59]:
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-7)
# optimizer = configure_optimizers(model, weight_decay, learning_rate, (beta1, beta2), device_type=device)

In [60]:
train_kwargs = dict(
    model=model, train_dl=train_dl, eval_model=eval_model, n_epochs=500,
    optimizer=optimizer, scaler=scaler, get_lr=get_lr,
    compile=True, grad_clip=0,
    eval_main_metric='val/loss',
    always_save_checkpoint=always_save_checkpoint, ckpt_dict=dict(model_args=model_args), out_dir=out_dir,
    wandb_log=False, wandb_init_kwargs=dict(project=wandb_project, name='Abstractor (arch. b)'), track_mfu=True,
    ddp=False, device_type='cuda')

In [61]:
train_model(**train_kwargs) # set norm_first = False in this one

compiling model... done compiling.
starting training loop...
epoch: 0, step: 6 train/loss: 2.2730, train/tfacc: 0.1290, val/loss: 2.2817, val/tfacc: 0.1252, 🤖
saving checkpoint to ../out/object_sorting
epoch 0: loss 2.2860, time 50486.99ms, mfu -100.00%
epoch: 1, step: 12 train/loss: 2.2274, train/tfacc: 0.1634, val/loss: 2.2376, val/tfacc: 0.1484, 🤖
saving checkpoint to ../out/object_sorting
epoch 1: loss 2.2389, time 1289.85ms, mfu -100.00%
epoch: 2, step: 18 train/loss: 2.1491, train/tfacc: 0.2076, val/loss: 2.1670, val/tfacc: 0.1924, 🤖
saving checkpoint to ../out/object_sorting
epoch 2: loss 2.1712, time 1438.48ms, mfu -100.00%
epoch: 3, step: 24 train/loss: 2.0488, train/tfacc: 0.2443, val/loss: 2.0739, val/tfacc: 0.2266, 🤖
saving checkpoint to ../out/object_sorting
epoch 3: loss 2.0807, time 1551.41ms, mfu -100.00%
epoch: 4, step: 30 train/loss: 1.9562, train/tfacc: 0.2705, val/loss: 1.9832, val/tfacc: 0.2520, 🤖
saving checkpoint to ../out/object_sorting
epoch 4: loss 1.9946, tim

In [62]:
source_test, target_test, labels_test = next(iter(test_dl))

In [63]:
evaluate_seq2seq_model(model, source_test, target_test, labels_test, start_token, print_=True, ctx=ctx)

element-wise accuracy: 85.31%
full sequence accuracy: 42.97%
teacher-forcing accuracy:  92.73%


{'elementwise_accuracy': 0.8531250357627869,
 'full_sequence_accuracy': 0.4296875,
 'teacher_forcing_accuracy': 0.9273437857627869,
 'acc_by_position': [0.9765625,
  0.921875,
  0.8828125,
  0.8359375,
  0.796875,
  0.7578125,
  0.7109375,
  0.78125,
  0.8984375,
  0.96875]}

## Learning Curves

In [23]:
wandb_project_name = 'abstract_transformer--object_sorting'
num_trials = 1
train_sizes = [250, 500, 1000, 1500, 2000, 2500, 3000]
start_trial = 0
n_epochs = 10

In [2]:
ee, ea, de, da = 0, 2, 2, 0

In [3]:
def create_abstransformer_model(ee, ea, de, da):

    model_args = dict(
        input_spec=dict(type='vector', dim=8), output_spec=dict(type='token', vocab_size=10+1),
        symbol_retrieval='positional_symbols', symbol_retrieval_kwargs=dict(symbol_dim=64, max_symbols=10),
        d_model=64, out_dim=10, n_layers_enc=2, n_layers_dec=2,
        encoder_kwargs=dict(n_heads_enc=ee, n_heads_abs=ea, dff=128, activation='relu', norm_first=True, dropout_rate=0.1, causal=False, rel_mask_diag=False),
        decoder_kwargs=dict(n_heads_enc=de, n_heads_abs=da, n_heads_cross=2, dff=128, activation='relu', norm_first=True, dropout_rate=0.1, causal=True, rel_mask_diag=False),
        in_block_size=10, out_block_size=10)
    seq2seqabstransformer = Seq2SeqAbstractTransformer(**model_args)#.to(device)
    return seq2seqabstransformer


In [4]:
def create_model():
    return create_abstransformer_model(ee, ea, de, da)

In [5]:
def evaluate_learning_curves(
    create_model,
    wandb_project_name, group_name,
    train_sizes=train_sizes, num_trials=num_trials):

    for train_size in tqdm(train_sizes, desc='train size'):

        for trial in trange(start_trial, start_trial + num_trials, desc='trial', leave=False):
            # run = wandb.init(project=wandb_project_name, group=group_name, name=f'train size = {train_size}; trial = {trial}',
            #                 config={'train size': train_size, 'trial': trial, 'group': group_name})
            # TODO: add model args to config?

            model = create_model().to(device)

            scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
            # optimizer
            optimizer = torch.optim.Adam(model.parameters())
            # optimizer = configure_optimizers(model, weight_decay, learning_rate, (beta1, beta2), device_type=device)

            train_dl = get_train_dl(train_size)

            # TODO: make training loop support pre-initiated wanbd runs
            train_kwargs = dict(
                model=model, train_dl=train_dl, eval_model=eval_model, n_epochs=n_epochs,
                optimizer=optimizer, scaler=scaler, get_lr=get_lr,
                compile=True, grad_clip=0,
                eval_main_metric='val/loss',
                always_save_checkpoint=always_save_checkpoint,
                # ckpt_dict=dict(model_args=model_args), 
                out_dir=out_dir,
                wandb_log=False, wandb_init_kwargs=dict(project=wandb_project, group=group_name, name=f'{group_name}--trial={trial}'),
                track_mfu=True,
                ddp=False, device_type='cuda')
            train_utils.train_model(**train_kwargs)

            source_test, target_test, labels_test = test_ds.tensors
            eval_dict = evaluate_seq2seq_model(model, source_test, target_test, labels_test, start_token, print_=True, ctx=ctx)

            # wandb.log(eval_dict)
            # wandb.finish(quiet=True)

            del model

# endregion


NameError: name 'train_sizes' is not defined

In [28]:
evaluate_learning_curves(
    create_model, wandb_project_name='abstract_transformer--object_sorting', group_name='abstract_transformer_',
    train_sizes=train_sizes, num_trials=num_trials)

train size:   0%|          | 0/7 [00:00<?, ?it/s]

compiling model... done compiling.
starting training loop...
epoch: 0, step: 2 train/loss: 2.4196, train/tfacc: 0.1109, val/loss: 2.4156, val/tfacc: 0.1127, 🤖
saving checkpoint to ../out/object_sorting
epoch 0: loss 2.4996, time 31263.68ms, mfu -100.00%
epoch: 1, step: 4 train/loss: 2.3761, train/tfacc: 0.1036, val/loss: 2.3727, val/tfacc: 0.1108, 🤖
saving checkpoint to ../out/object_sorting
epoch 1: loss 2.3970, time 1199.03ms, mfu -100.00%
epoch: 2, step: 6 train/loss: 2.2999, train/tfacc: 0.1197, val/loss: 2.2974, val/tfacc: 0.1235, 🤖
saving checkpoint to ../out/object_sorting
epoch 2: loss 2.3223, time 1316.90ms, mfu -100.00%
epoch: 3, step: 8 train/loss: 2.2637, train/tfacc: 0.1348, val/loss: 2.2621, val/tfacc: 0.1352, 🤖
saving checkpoint to ../out/object_sorting
epoch 3: loss 2.2540, time 1215.06ms, mfu -100.00%
epoch: 4, step: 10 train/loss: 2.2459, train/tfacc: 0.1503, val/loss: 2.2451, val/tfacc: 0.1489, 🤖
saving checkpoint to ../out/object_sorting
epoch 4: loss 2.2274, time 1

train size:  14%|█▍        | 1/7 [00:48<04:49, 48.24s/it]

element-wise accuracy: 9.84%
full sequence accuracy: 0.00%
teacher-forcing accuracy:  18.95%




compiling model... done compiling.
starting training loop...
epoch: 0, step: 4 train/loss: 2.3567, train/tfacc: 0.1200, val/loss: 2.3605, val/tfacc: 0.1170, 🤖
saving checkpoint to ../out/object_sorting
epoch 0: loss 2.3820, time 31078.46ms, mfu -100.00%
epoch: 1, step: 8 train/loss: 2.2654, train/tfacc: 0.1387, val/loss: 2.2674, val/tfacc: 0.1385, 🤖
saving checkpoint to ../out/object_sorting
epoch 1: loss 2.2617, time 1609.61ms, mfu -100.00%
epoch: 2, step: 12 train/loss: 2.2119, train/tfacc: 0.1634, val/loss: 2.2120, val/tfacc: 0.1590, 🤖
saving checkpoint to ../out/object_sorting
epoch 2: loss 2.2001, time 1601.12ms, mfu -100.00%
epoch: 3, step: 16 train/loss: 2.1700, train/tfacc: 0.1665, val/loss: 2.1690, val/tfacc: 0.1681, 🤖
saving checkpoint to ../out/object_sorting
epoch 3: loss 2.1578, time 2182.98ms, mfu -100.00%
epoch: 4, step: 20 train/loss: 2.1237, train/tfacc: 0.1955, val/loss: 2.1229, val/tfacc: 0.1962, 🤖
saving checkpoint to ../out/object_sorting
epoch 4: loss 2.1035, time

train size:  29%|██▊       | 2/7 [01:42<04:18, 51.77s/it]

element-wise accuracy: 10.29%
full sequence accuracy: 0.00%
teacher-forcing accuracy:  28.17%




compiling model... done compiling.
starting training loop...
epoch: 0, step: 8 train/loss: 2.2773, train/tfacc: 0.1335, val/loss: 2.2818, val/tfacc: 0.1264, 🤖
saving checkpoint to ../out/object_sorting
epoch 0: loss 2.2982, time 27287.14ms, mfu -100.00%
epoch: 1, step: 16 train/loss: 2.1743, train/tfacc: 0.1809, val/loss: 2.1779, val/tfacc: 0.1779, 🤖
saving checkpoint to ../out/object_sorting
epoch 1: loss 2.1800, time 1463.14ms, mfu -100.00%
epoch: 2, step: 24 train/loss: 2.0826, train/tfacc: 0.2285, val/loss: 2.0851, val/tfacc: 0.2272, 🤖
saving checkpoint to ../out/object_sorting
epoch 2: loss 2.0898, time 1454.96ms, mfu -100.00%
epoch: 3, step: 32 train/loss: 1.9839, train/tfacc: 0.2645, val/loss: 1.9865, val/tfacc: 0.2620, 🤖
saving checkpoint to ../out/object_sorting
epoch 3: loss 2.0029, time 1364.19ms, mfu -100.00%
epoch: 4, step: 40 train/loss: 1.8838, train/tfacc: 0.2777, val/loss: 1.8878, val/tfacc: 0.2775, 🤖
saving checkpoint to ../out/object_sorting
epoch 4: loss 1.9006, tim

train size:  43%|████▎     | 3/7 [02:28<03:16, 49.13s/it]

element-wise accuracy: 14.34%
full sequence accuracy: 0.00%
teacher-forcing accuracy:  35.10%




compiling model... done compiling.
starting training loop...
epoch: 0, step: 12 train/loss: 2.2288, train/tfacc: 0.1542, val/loss: 2.2273, val/tfacc: 0.1583, 🤖
saving checkpoint to ../out/object_sorting
epoch 0: loss 2.2477, time 26540.96ms, mfu -100.00%
epoch: 1, step: 24 train/loss: 2.1060, train/tfacc: 0.2095, val/loss: 2.1080, val/tfacc: 0.2033, 🤖
saving checkpoint to ../out/object_sorting
epoch 1: loss 2.1008, time 1467.04ms, mfu -100.00%
epoch: 2, step: 36 train/loss: 1.9700, train/tfacc: 0.2723, val/loss: 1.9729, val/tfacc: 0.2620, 🤖
saving checkpoint to ../out/object_sorting
epoch 2: loss 1.9841, time 1491.36ms, mfu -100.00%
epoch: 3, step: 48 train/loss: 1.8269, train/tfacc: 0.2885, val/loss: 1.8300, val/tfacc: 0.2827, 🤖
saving checkpoint to ../out/object_sorting
epoch 3: loss 1.8582, time 1519.06ms, mfu -100.00%
epoch: 4, step: 60 train/loss: 1.7072, train/tfacc: 0.2975, val/loss: 1.7121, val/tfacc: 0.2943, 🤖
saving checkpoint to ../out/object_sorting
epoch 4: loss 1.7388, ti

train size:  57%|█████▋    | 4/7 [03:14<02:24, 48.03s/it]

element-wise accuracy: 20.20%
full sequence accuracy: 0.00%
teacher-forcing accuracy:  44.40%




compiling model... done compiling.
starting training loop...
epoch: 0, step: 16 train/loss: 2.1509, train/tfacc: 0.1898, val/loss: 2.1518, val/tfacc: 0.1861, 🤖
saving checkpoint to ../out/object_sorting
epoch 0: loss 2.1634, time 26247.43ms, mfu -100.00%
epoch: 1, step: 32 train/loss: 1.9588, train/tfacc: 0.2586, val/loss: 1.9622, val/tfacc: 0.2526, 🤖
saving checkpoint to ../out/object_sorting
epoch 1: loss 1.9982, time 1570.01ms, mfu -100.00%
epoch: 2, step: 48 train/loss: 1.8069, train/tfacc: 0.2842, val/loss: 1.8116, val/tfacc: 0.2796, 🤖
saving checkpoint to ../out/object_sorting
epoch 2: loss 1.8532, time 1669.52ms, mfu -100.00%
epoch: 3, step: 64 train/loss: 1.6832, train/tfacc: 0.3164, val/loss: 1.6848, val/tfacc: 0.3133, 🤖
saving checkpoint to ../out/object_sorting
epoch 3: loss 1.7336, time 1492.06ms, mfu -100.00%
epoch: 4, step: 80 train/loss: 1.5763, train/tfacc: 0.3599, val/loss: 1.5780, val/tfacc: 0.3598, 🤖
saving checkpoint to ../out/object_sorting
epoch 4: loss 1.6322, ti

train size:  71%|███████▏  | 5/7 [04:01<01:35, 47.61s/it]

element-wise accuracy: 38.68%
full sequence accuracy: 0.38%
teacher-forcing accuracy:  64.36%




compiling model... done compiling.
starting training loop...
epoch: 0, step: 20 train/loss: 2.1053, train/tfacc: 0.2101, val/loss: 2.1051, val/tfacc: 0.2109, 🤖
saving checkpoint to ../out/object_sorting
epoch 0: loss 2.1382, time 26348.18ms, mfu -100.00%
epoch: 1, step: 40 train/loss: 1.8719, train/tfacc: 0.2753, val/loss: 1.8725, val/tfacc: 0.2781, 🤖
saving checkpoint to ../out/object_sorting
epoch 1: loss 1.9096, time 1763.87ms, mfu -100.00%
epoch: 2, step: 60 train/loss: 1.6740, train/tfacc: 0.3191, val/loss: 1.6756, val/tfacc: 0.3196, 🤖
saving checkpoint to ../out/object_sorting
epoch 2: loss 1.7244, time 1842.89ms, mfu -100.00%
epoch: 3, step: 80 train/loss: 1.5187, train/tfacc: 0.3747, val/loss: 1.5236, val/tfacc: 0.3684, 🤖
saving checkpoint to ../out/object_sorting
epoch 3: loss 1.5904, time 1875.92ms, mfu -100.00%
epoch: 4, step: 100 train/loss: 1.3967, train/tfacc: 0.4431, val/loss: 1.3959, val/tfacc: 0.4427, 🤖
saving checkpoint to ../out/object_sorting
epoch 4: loss 1.4544, t

train size:  86%|████████▌ | 6/7 [04:49<00:47, 47.77s/it]

element-wise accuracy: 49.62%
full sequence accuracy: 1.81%
teacher-forcing accuracy:  71.83%




compiling model... done compiling.
starting training loop...
epoch: 0, step: 24 train/loss: 2.0660, train/tfacc: 0.2462, val/loss: 2.0685, val/tfacc: 0.2451, 🤖
saving checkpoint to ../out/object_sorting
epoch 0: loss 2.0755, time 28061.02ms, mfu -100.00%
epoch: 1, step: 48 train/loss: 1.7881, train/tfacc: 0.2956, val/loss: 1.7908, val/tfacc: 0.2945, 🤖
saving checkpoint to ../out/object_sorting
epoch 1: loss 1.8357, time 1992.52ms, mfu -100.00%
epoch: 2, step: 72 train/loss: 1.5935, train/tfacc: 0.3345, val/loss: 1.5926, val/tfacc: 0.3334, 🤖
saving checkpoint to ../out/object_sorting
epoch 2: loss 1.6511, time 2048.97ms, mfu -100.00%
epoch: 3, step: 96 train/loss: 1.4810, train/tfacc: 0.3809, val/loss: 1.4815, val/tfacc: 0.3765, 🤖
saving checkpoint to ../out/object_sorting
epoch 3: loss 1.5479, time 1653.12ms, mfu -100.00%
epoch: 4, step: 120 train/loss: 1.3636, train/tfacc: 0.4442, val/loss: 1.3630, val/tfacc: 0.4438, 🤖
saving checkpoint to ../out/object_sorting
epoch 4: loss 1.4169, t

train size: 100%|██████████| 7/7 [05:40<00:00, 48.67s/it]

element-wise accuracy: 56.42%
full sequence accuracy: 4.30%
teacher-forcing accuracy:  76.18%





: 