In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import wandb

import torch
import torchinfo
import tiktoken
import numpy as np
import math
import os
import time
from contextlib import nullcontext
from  tqdm import tqdm

import lightning as L
from lightning.pytorch.loggers.wandb import WandbLogger
import torchmetrics

import sys; sys.path.append('../..')
from language_models import TransformerLM, AbstractTransformerLM, configure_optimizers
from train_utils import train_model
from utils.pl_tqdm_progbar import TQDMProgressBar

In [2]:
print('cuda available: ', torch.cuda.is_available())
print('device count: ', torch.cuda.device_count())
print('current device name: ', torch.cuda.get_device_name(torch.cuda.current_device()))
print('Memory Usage:')
print('\tAllocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
print('\tReserved:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

cuda available:  True
device count:  1
current device name:  NVIDIA A100 80GB PCIe
Memory Usage:
	Allocated: 0.0 GB
	Reserved:    0.0 GB


## Config

In [28]:
eval_interval = 500 # keep frequent because we'll overfit
max_steps = 5000
n_epochs = 1
log_every_n_steps = 10
log_model = True

In [4]:
# I/O
eval_only = False # if True, script exits right after the first eval

# system
device = 'cuda'
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast

# 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' 
compile = True

# # evaluation and output
# out_dir = '../out/tiny_stories'
# if not os.path.exists(out_dir):
#     os.makedirs(out_dir)
# eval_interval = 250 # keep frequent because we'll overfit
# eval_iters = 200
# log_interval = 10 # don't print too too often

# # we expect to overfit on this small dataset, so only save when val improves
# always_save_checkpoint = False

# wandb logging
wandb_log = False
wandb_project = 'abstract_transformer--tiny_stories'

# optimization hyperparams
learning_rate = 1e-3 # with baby networks can afford to go a bit higher
max_iters = 5000
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
decay_lr = True # whether to decay the learning rate
lr_decay_iters = 5000 # make equal to max_iters usually
weight_decay = 1e-1
min_lr = 1e-4 # learning_rate / 10 usually
beta1 = 0.9
beta2 = 0.99 # make a bit bigger because number of tokens per iter is small
# warmup_iters = 100
gradient_accumulation_steps = 1 # accumulate gradients over this many steps. simulates larger batch size

# batch size and block size
batch_size = 64
block_size = 256

# DDP (distributed data parallel) training
# ddp = False
# master_process = True

# TODO: set up DDP for future experiments

In [5]:
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

## Data Loader

In [6]:
from datasets import load_dataset

dataset = load_dataset("roneneldan/TinyStories")



In [7]:
print(dataset['train'][0]['text'])

One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.

Lily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."

Together, they shared the needle and sewed the button on Lily's shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.


In [8]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

In [9]:
# TODO: use tiktoken instead?; much faster
dataset = dataset.map(lambda x: tokenizer(x['text'], padding=True, truncation=True, max_length=block_size+1), batched=True)

In [10]:
dataset.set_format(type='torch', columns=['input_ids'])

In [11]:
print("EXAMPLES")

for x in dataset['train']['input_ids'][:5]:
    print(tokenizer.decode(x))
    print('-'*100)
    print()

EXAMPLES
[CLS] One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt. Lily went to her mom and said, " Mom, I found this needle. Can you share it with me and sew my shirt? " Her mom smiled and said, " Yes, Lily, we can share the needle and fix your shirt. " Together, they shared the needle and sewed the button on Lily's shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

In [12]:
train_dataloader = torch.utils.data.DataLoader(dataset['train'], batch_size=batch_size, pin_memory=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(dataset['validation'], batch_size=batch_size, pin_memory=True, num_workers=4)

In [13]:
# TODO: need to handle padding token? ignore in loss/perplexity/etc?

## Train with Pytorch Lightning

In [14]:
log_on_step = True

class LitLanguageModel(L.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def training_step(self, batch, batch_idx):
        text = batch['input_ids']
        x, y = text[:, :-1], text[:, 1:]

        # with ctx:
        logits, loss = self.model(x, y)
        perplexity = torchmetrics.functional.text.perplexity(logits, y, ignore_index=tokenizer.pad_token_id)

        self.log('train/loss', loss, prog_bar=True, logger=True, on_step=log_on_step, on_epoch=True)
        self.log('train/perplexity', perplexity, prog_bar=True, logger=True, on_step=log_on_step, on_epoch=True)

        return loss

    def validation_step(self, batch, batch_idx):
        text = batch['input_ids']
        x, y = text[:, :-1], text[:, 1:]
        # with ctx:
        logits, loss = self.model(x, y)

        perplexity = torchmetrics.functional.text.perplexity(logits, y, ignore_index=tokenizer.pad_token_id)

        self.log(f"val/loss", loss, prog_bar=True, logger=True, add_dataloader_idx=False)
        self.log(f'val/perplexity', perplexity, prog_bar=True, logger=True, add_dataloader_idx=False)

    def test_step(self, batch, batch_idx):
        text = batch['input_ids']
        x, y = text[:, :-1], text[:, 1:]
        # with ctx:
        logits, loss = self.model(x, y)

        perplexity = torchmetrics.functional.text.perplexity(logits, y, ignore_index=tokenizer.pad_token_id)

        self.log(f"test/loss", loss, prog_bar=True, logger=True, add_dataloader_idx=False)
        self.log(f'test/perplexity', perplexity, prog_bar=True, logger=True, add_dataloader_idx=False)

    def configure_optimizers(self):
        optimizer = configure_optimizers(model, weight_decay, learning_rate, (beta1, beta2), device_type=device)
        return optimizer

In [15]:
model_args = dict(
    vocab_size=tokenizer.vocab_size, d_model=768, n_layers=2, n_heads=12, dff=None, pos_enc_type='RoPE',
    dropout_rate=0.2, activation='relu', norm_first=True, max_block_size=block_size, bias=True)
model = transformer_lm = TransformerLM(**model_args).to(device)
# NOTE: embedding layer/output layer account for much of the params... (but they use weight sharing)
torchinfo.summary(model, device='cuda') # input_data=[torch.randint(0, 10, size=(1,block_size))]*2,

Layer (type:depth-idx)                        Param #
TransformerLM                                 --
├─ModuleDict: 1-1                             --
│    └─Embedding: 2-1                         22,268,928
│    │    └─Linear: 3-1                       22,297,924
│    └─Dropout: 2-2                           --
│    └─ModuleList: 2-3                        --
│    │    └─EncoderBlock: 3-2                 7,085,568
│    │    └─EncoderBlock: 3-3                 7,085,568
│    └─Linear: 2-4                            (recursive)
Total params: 58,737,988
Trainable params: 58,737,988
Non-trainable params: 0

In [17]:
lit_model = LitLanguageModel(model)

In [18]:
len(train_dataloader)

33121

In [18]:
group_name = None
run_name = 'TransformerLM (TEST)'
run = wandb.init(project=wandb_project, group=group_name, name=run_name,
    config={'group': group_name, **model_args})

wandb_logger = WandbLogger(experiment=run, log_model=False) # name=run_name, project=wandb_project,
# wandb_logger.watch(model, log_graph=False)
# wandb_logger = None
callbacks = [
    TQDMProgressBar(refresh_rate=50)
    # LitProgressBar()
    # L.pytorch.callbacks.TQDMProgressBar(refresh_rate=50)
    # L.pytorch.callbacks.RichProgressBar()
]

trainer_kwargs = dict(
    max_epochs=n_epochs, enable_checkpointing=False, enable_model_summary=True,
    enable_progress_bar=True, callbacks=callbacks, logger=wandb_logger,
    accumulate_grad_batches=gradient_accumulation_steps, benchmark=True, gradient_clip_val=grad_clip,
    log_every_n_steps=log_every_n_steps, max_steps=max_steps, val_check_interval=eval_interval)

trainer = L.Trainer(
    **trainer_kwargs
    )
trainer.fit(model=lit_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mawni00[0m. Use [1m`wandb login --relogin`[0m to force relogin


/home/ma2393/.conda/envs/abstract_transformer/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/ma2393/.conda/envs/abstract_transformer/lib/py ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A40') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type          | Params
----------------------------------------
0 | model | Tra

num decayed parameter tensors: 15, with 58,890,240 parameters
num non-decayed parameter tensors: 15, with 44,356 parameters
using fused AdamW: True
Epoch 0:  15%|█▌        | 5000/33121 [29:08<2:43:52,  2.86it/s, v_num=ycs5, train/loss_step=1.150, train/perplexity_step=5.210, val/loss=1.480, val/perplexity=7.620, train/loss_epoch=1.810, train/perplexity_epoch=35.70]

`Trainer.fit` stopped: `max_steps=5000` reached.


Epoch 0:  15%|█▌        | 5000/33121 [29:08<2:43:52,  2.86it/s, v_num=ycs5, train/loss_step=1.150, train/perplexity_step=5.210, val/loss=1.480, val/perplexity=7.620, train/loss_epoch=1.810, train/perplexity_epoch=35.70]


In [None]:
prompts = [
    'Once upon a time,',
    'There once was a girl named ',
    'On a rainy day,'
    'Emma is a curious person.'
]

def generate_from_prompt(model, prompt, max_new_tokens=100, temperature=1.0, top_k=None, tokenizer=tokenizer):
    prompt_idx = torch.from_numpy(np.array(tokenizer.encode(prompt))).unsqueeze(0)#.to(device)
    prompt_idx = prompt_idx[:, :-1] # remove final token because it is [SEP]
    sample_gen = model.generate(prompt_idx, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)[0]
    sample_gen = tokenizer.decode(sample_gen)
    return sample_gen

for prompt in prompts:
    print(f"PROMPT: {prompt}")
    print("GENERATED TEXT:")
    sample_gen = generate_from_prompt(model, prompt)
    print(sample_gen)
    print('-'*100)

## Abstract Transformer Model (Symbolic Attention; RoPE; Disentangled RCA)

In [19]:
model_args = dict(
    vocab_size=tokenizer.vocab_size, d_model=384, n_layers=6, n_heads_sa=4, n_heads_rca=2, dff=None, rca_disentangled=True,
    symbol_retrieval='sym_attn', symbol_retrieval_kwargs=dict(num_symbols=50, n_heads=4, model_dim=384), # FIXME make names consistent: d_model, model_dim
    dropout_rate=0.2, activation='relu', norm_first=True, max_block_size=256, bias=True, pos_enc_type='RoPE')
model = abstracttransformer_lm = AbstractTransformerLM(**model_args).to(device)

In [20]:
torchinfo.summary(model, input_data=torch.randint(0, 10, size=(1,block_size)), device='cuda')

Layer (type:depth-idx)                                                 Output Shape              Param #
AbstractTransformerLM                                                  [1, 1, 28996]             --
├─ModuleDict: 1-1                                                      --                        --
│    └─Embedding: 2-1                                                  [1, 256, 384]             22,297,924
│    └─ModuleList: 2-2                                                 --                        --
│    │    └─AbstractEncoderBlock: 3-1                                  [1, 256, 384]             1,991,936
│    │    └─AbstractEncoderBlock: 3-14                                 --                        (recursive)
│    │    └─AbstractEncoderBlock: 3-3                                  --                        (recursive)
│    │    └─AbstractEncoderBlock: 3-4                                  [1, 256, 384]             1,991,936
│    │    └─AbstractEncoderBlock: 3-14                 

In [21]:
# torchinfo overcounts # of params... something to do with symbolic attention shared across layers
# this is the correct number (similar to TransformerLM)

num_params = model.get_num_params() #sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'# of params {num_params:,}')

AttributeError: 'ModuleDict' object has no attribute 'positional_embedder'

In [26]:
# TODO: can we implement in a way that torchinfo can understand? i.e., without "recursive" and overcounting

In [22]:
lit_model = LitLanguageModel(model)

In [23]:
len(train_dataloader)

33121

In [24]:
group_name = None
run_name = 'AbstractTransformerLM (TEST)'
run = wandb.init(project=wandb_project, group=group_name, name=run_name,
    config={'group': group_name, **model_args})

wandb_logger = WandbLogger(experiment=run, log_model=False) # name=run_name, project=wandb_project,
# wandb_logger.watch(model, log_graph=False)
# wandb_logger = None
callbacks = [
    TQDMProgressBar(refresh_rate=50)
    # LitProgressBar()
    # L.pytorch.callbacks.TQDMProgressBar(refresh_rate=50)
    # L.pytorch.callbacks.RichProgressBar()
]

trainer_kwargs = dict(
    max_epochs=n_epochs, enable_checkpointing=False, enable_model_summary=True,
    enable_progress_bar=True, callbacks=callbacks, logger=wandb_logger,
    accumulate_grad_batches=gradient_accumulation_steps, benchmark=True, gradient_clip_val=grad_clip,
    log_every_n_steps=log_every_n_steps, max_steps=max_steps, val_check_interval=eval_interval)

trainer = L.Trainer(
    **trainer_kwargs
    )
trainer.fit(model=lit_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mawni00[0m. Use [1m`wandb login --relogin`[0m to force relogin


/home/ma2393/.conda/envs/abstract_transformer/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/ma2393/.conda/envs/abstract_transformer/lib/py ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A40') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                  | Params
------------------------------------------------

num decayed parameter tensors: 77, with 33,268,224 parameters
num non-decayed parameter tensors: 38, with 50,116 parameters
using fused AdamW: True
Epoch 0:   3%|▎         | 1150/33121 [06:10<2:51:47,  3.10it/s, v_num=0258, train/loss_step=1.420, train/perplexity_step=8.040, val/loss=1.770, val/perplexity=11.20]

In [None]:
prompts = [
    'Once upon a time,',
    'There once was a girl named ',
    'On a rainy day,'
    'Emma is a curious person.'
]

def generate_from_prompt(model, prompt, max_new_tokens=100, temperature=1.0, top_k=None, tokenizer=tokenizer):
    prompt_idx = torch.from_numpy(np.array(tokenizer.encode(prompt))).unsqueeze(0)#.to(device)
    prompt_idx = prompt_idx[:, :-1] # remove final token because it is [SEP]
    sample_gen = model.generate(prompt_idx, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)[0]
    sample_gen = tokenizer.decode(sample_gen)
    return sample_gen

for prompt in prompts:
    print(f"PROMPT: {prompt}")
    print("GENERATED TEXT:")
    sample_gen = generate_from_prompt(model, prompt)
    print(sample_gen)
    print('-'*100)

## Abstract Transformer Model (Position-Relative Symbols; RoPE; Disentangled RCA)

In [16]:
model_args = dict(
    vocab_size=tokenizer.vocab_size, d_model=384, n_layers=6, n_heads_sa=4, n_heads_rca=2, dff=None, rca_disentangled=True,
    symbol_retrieval='pos_relative', symbol_retrieval_kwargs=dict(symbol_dim=384, max_rel_pos=block_size), rca_kwargs=dict(use_relative_positional_symbols=True),
    dropout_rate=0.2, activation='relu', norm_first=True, max_block_size=256, bias=True, pos_enc_type='RoPE')
model = abstracttransformer_lm = AbstractTransformerLM(**model_args).to(device)

In [17]:
torchinfo.summary(model, input_data=torch.randint(0, 10, size=(1,block_size)), device='cuda')

Layer (type:depth-idx)                                                 Output Shape              Param #
AbstractTransformerLM                                                  [1, 1, 28996]             --
├─ModuleDict: 1-1                                                      --                        --
│    └─Embedding: 2-1                                                  [1, 256, 384]             22,297,924
│    └─ModuleList: 2-2                                                 --                        --
│    │    └─AbstractEncoderBlock: 3-1                                  [1, 256, 384]             2,002,688
│    │    └─AbstractEncoderBlock: 3-14                                 --                        (recursive)
│    │    └─AbstractEncoderBlock: 3-3                                  --                        (recursive)
│    │    └─AbstractEncoderBlock: 3-4                                  [1, 256, 384]             2,002,688
│    │    └─AbstractEncoderBlock: 3-14                 

In [18]:
# torchinfo overcounts # of params... something to do with symbolic attention shared across layers
# this is the correct number (similar to TransformerLM)

num_params = model.get_num_params() #sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'# of params {num_params:,}')

# of params 33,329,092


In [None]:
# TODO: can we implement in a way that torchinfo can understand? i.e., without "recursive" and overcounting
# maybe symbol retrieval should be done outside AbstractBlock!

In [19]:
lit_model = LitLanguageModel(model)

In [20]:
len(train_dataloader)

33121

In [21]:
group_name = None
run_name = 'AbstractTransformerLM (TEST)'
run = wandb.init(project=wandb_project, group=group_name, name=run_name,
    config={'group': group_name, **model_args})

wandb_logger = WandbLogger(experiment=run, log_model=log_model) # name=run_name, project=wandb_project,
# wandb_logger.watch(model, log_graph=False)
# wandb_logger = None
callbacks = [
    TQDMProgressBar(refresh_rate=50)
    # LitProgressBar()
    # L.pytorch.callbacks.TQDMProgressBar(refresh_rate=50)
    # L.pytorch.callbacks.RichProgressBar()
]

trainer_kwargs = dict(
    max_epochs=n_epochs, enable_checkpointing=False, enable_model_summary=True,
    enable_progress_bar=True, callbacks=callbacks, logger=wandb_logger,
    accumulate_grad_batches=gradient_accumulation_steps, benchmark=True, gradient_clip_val=grad_clip,
    log_every_n_steps=log_every_n_steps, max_steps=max_steps, val_check_interval=eval_interval)

trainer = L.Trainer(
    **trainer_kwargs
    )
trainer.fit(model=lit_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mawni00[0m. Use [1m`wandb login --relogin`[0m to force relogin


/home/ma2393/.conda/envs/abstract_transformer/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/ma2393/.conda/envs/abstract_transformer/lib/py ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                  | Params
-------------------------------------

num decayed parameter tensors: 75, with 33,279,360 parameters
num non-decayed parameter tensors: 37, with 49,732 parameters
using fused AdamW: True
Epoch 0:  15%|█▌        | 5000/33121 [22:07<2:04:24,  3.77it/s, v_num=d4cn, train/loss_step=0.458, train/perplexity_step=1.930, val/loss=0.599, val/perplexity=2.230, train/loss_epoch=1.130, train/perplexity_epoch=86.00]

`Trainer.fit` stopped: `max_steps=5000` reached.


Epoch 0:  15%|█▌        | 5000/33121 [22:07<2:04:24,  3.77it/s, v_num=d4cn, train/loss_step=0.458, train/perplexity_step=1.930, val/loss=0.599, val/perplexity=2.230, train/loss_epoch=1.130, train/perplexity_epoch=86.00]


In [29]:
prompts = [
    'Once upon a time,',
    'There once was a girl named ',
    'On a rainy day,',
    'Emma is a curious person.',
    '',
    '',
    '',
]

def generate_from_prompt(model, prompt, max_new_tokens=100, temperature=1.0, top_k=None, tokenizer=tokenizer):
    prompt_idx = torch.from_numpy(np.array(tokenizer.encode(prompt))).unsqueeze(0)#.to(device)
    prompt_idx = prompt_idx[:, :-1] # remove final token because it is [SEP]
    sample_gen = model.generate(prompt_idx, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)[0]
    sample_gen = tokenizer.decode(sample_gen)
    return sample_gen

print()
print('='*100)
print("GENERATING SAMPLES")
samples = []
for prompt in prompts:
    print(f"PROMPT: {prompt}")
    print("GENERATED TEXT:")
    sample_gen = generate_from_prompt(model, prompt)
    print(sample_gen)
    print('-'*100)
    print()
    samples.append(sample_gen)


GENERATING SAMPLES
PROMPT: Once upon a time,
GENERATED TEXT:
[CLS] Once upon a time, there was a little boy who had a powerful and strong robot. The powerful shark was a good, powerful ship, the powerful ships. They imagined they were playing and laughing. After a while, the powerful ship was tired, but he was so tired. He gently guessed to the powerful shark, and the powerful shark was able to get away from the powerful ship. He was just as fast as he could, but then he heard the wave coming from the powerful water. He hopped all back to his
----------------------------------------------------------------------------------------------------

PROMPT: There once was a girl named 
GENERATED TEXT:
[CLS] There once was a girl named Lily in the park. Her mummy warn her she always used the Making would be arranged away and she would stay safe and sound. Lily learned that sometimes we have a promise engine sometimes a good time and it makes us it! They left the park to be important and tidy.

In [30]:
samples_table = [[p, g] for p, g in zip(prompts, samples)]
samples_table = wandb.Table(columns=["Prompt", "Generated Sample"], data = samples_table)
run.log({"Generated Samples": samples_table})

In [31]:
wandb.finish()



VBox(children=(Label(value='0.103 MB of 0.152 MB uploaded\r'), FloatProgress(value=0.6817309567701223, max=1.0…

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss_epoch,▁
train/loss_step,▇▆██▅▆▆▄▃▃▂▃▂▂▂▃▂▃▂▂▂▂▃▃▁▃▃▃▃▂▂▃▂▁▂▁▁▂▁▂
train/perplexity_epoch,▁
train/perplexity_step,█▄▃▃▃▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val/loss,█▃▂▂▂▁▁▁▁▁
val/perplexity,█▂▁▁▁▁▁▁▁▁

0,1
epoch,0.0
train/loss_epoch,1.13302
train/loss_step,0.4578
train/perplexity_epoch,85.96976
train/perplexity_step,1.93338
trainer/global_step,4999.0
val/loss,0.59906
val/perplexity,2.23264


### Training

In [None]:
# grad scaler
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
# optimizer
optimizer = configure_optimizers(model, weight_decay, learning_rate, (beta1, beta2), device_type=device)

num decayed parameter tensors: 66, with 32,776,704 parameters
num non-decayed parameter tensors: 38, with 50,116 parameters
using fused AdamW: True


In [None]:
train_kwargs = dict(
    model=model, get_batch=get_batch, batch_size=batch_size, max_iters=max_iters,
    optimizer=optimizer, scaler=scaler, get_lr=get_lr, eval_model=eval_model,
    compile=True, grad_clip=0, gradient_accumulation_steps=1,
    eval_main_metric='val/loss', eval_interval=eval_interval, always_save_checkpoint=always_save_checkpoint, out_dir=out_dir,
    log_interval=10, wandb_log=True, wandb_init_kwargs=dict(project=wandb_project, name='AbstractTransformerLM'), 
    ckpt_dict=dict(model_args=model_args), track_mfu=True,
    master_process=True, ddp=False, device_type='cuda')

NameError: name 'get_batch' is not defined

In [None]:
train_model(**train_kwargs)

NameError: name 'train_kwargs' is not defined

In [None]:
prompt = 'Once upon a time,'

prompt_idx = torch.from_numpy(np.array(tokenizer.encode(prompt))).unsqueeze(0).to(device)
sample_gen = model.generate(prompt_idx, max_new_tokens=250, temperature=1.0, top_k=None)[0]
sample_gen = tokenizer.decode(sample_gen)
print(sample_gen)

[CLS] Once upon a time, [SEP] a big forest, there was a tiny mushroom. It was all alone. The sun was very harsh, and the mushroom did not like it. It wanted to find a friend to play with and to help it hide from the sun. One day, a little bunny came hopping by. The mushroom called out, " Hello, bunny! Will you be my friend? " The bunny looked at the mushroom and smiled. " Sure, I will be your friend. Let's play together! " The bunny and the mushroom played all day, and they were very happy. As they played, the bunny realized that the mushroom needed help to hide from the harsh sun. So, the bunny dug a hole in the ground and put the mushroom inside. Now, the mushroom was safe and cool. The mushroom and the bunny were the best of friends, and they played in the forest every day. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [P

## Abstract Transformer Model (Relational Symbolic Attention)

In [None]:
model_args = dict(
    vocab_size=tokenizer.vocab_size, d_model=384, n_layers=6, n_heads_enc=4, n_heads_abs=2, dff=None,
    symbol_retrieval='rel_sym_attn', symbol_retrieval_kwargs=dict(
        model_dim=384, rel_n_heads=4, symbolic_attn_n_heads=4,
        num_symbols=20, nbhd_delta=5, causal_nbhd=True, include_self=False,
        normalize_rels=True), # FIXME make names consistent: d_model, model_dim
    dropout_rate=0.2, activation='relu', norm_first=True, max_block_size=256, bias=True)
model = abstracttransformer_lm = AbstractTransformerLM(**model_args).to(device)

TypeError: AbstractTransformerLM.__init__() got an unexpected keyword argument 'n_heads_enc'

In [None]:
torchinfo.summary(model, input_data=torch.randint(0, 10, size=(1,256)), device='cuda')

Layer (type:depth-idx)                                  Output Shape              Param #
AbstractTransformerLM                                   [1, 1, 65]                --
├─ModuleDict: 1-1                                       --                        --
│    └─Embedding: 2-1                                   [1, 256, 384]             24,960
│    └─Embedding: 2-2                                   [256, 384]                98,304
│    └─ModuleList: 2-3                                  --                        --
│    │    └─AbstractEncoderBlock: 3-1                   [1, 256, 384]             2,684,928
│    │    └─AbstractEncoderBlock: 3-14                  --                        (recursive)
│    │    └─AbstractEncoderBlock: 3-3                   --                        (recursive)
│    │    └─AbstractEncoderBlock: 3-4                   [1, 256, 384]             2,684,928
│    │    └─AbstractEncoderBlock: 3-14                  --                        (recursive)
│    │    └

In [None]:
# torchinfo overcounts # of params... something to do with symbolic attention shared across layers
# this is the correct number (similar to TransformerLM)

num_params = model.get_num_params() #sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'# of params {num_params:,}')

# of params 13,824,833


In [None]:
# TODO: can we implement in a way that torchinfo can understand? i.e., without "recursive" and overcounting

### Training

In [None]:
# grad scaler
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
# optimizer
optimizer = configure_optimizers(model, weight_decay, learning_rate, (beta1, beta2), device_type=device)

num decayed parameter tensors: 45, with 13,884,672 parameters
num non-decayed parameter tensors: 65, with 38,465 parameters
using fused AdamW: True


In [None]:
train_kwargs = dict(
    model=model, get_batch=get_batch, batch_size=batch_size, max_iters=max_iters,
    optimizer=optimizer, scaler=scaler, get_lr=get_lr, eval_model=eval_model,
    compile=True, grad_clip=0, gradient_accumulation_steps=1,
    eval_main_metric='val/loss', eval_interval=eval_interval, always_save_checkpoint=always_save_checkpoint, out_dir=out_dir,
    log_interval=10, wandb_log=True, wandb_init_kwargs=dict(project=wandb_project, name='AbstractTransformerLM (Relational Symbolic Attn)'), 
    ckpt_dict=dict(model_args=model_args), track_mfu=True,
    master_process=True, ddp=False, device_type='cuda')

In [None]:
train_model(**train_kwargs)

compiling model... done compiling.
starting training loop...
step 0: train loss 4.5255, val loss 4.5268
iter 0: loss 4.5382, time 62455.47ms, mfu -100.00%


In [None]:
prompt = 'Once upon a time,'

prompt_idx = torch.from_numpy(np.array(tokenizer.encode(prompt))).unsqueeze(0).to(device)
sample_gen = model.generate(prompt_idx, max_new_tokens=100, temperature=1.0, top_k=None)[0]
sample_gen = tokenizer.decode(sample_gen)
print(sample_gen)

Romeo,--

KING HENRY VI:
He was for him to be found that Richard now.

DERBY:
My lord, I proud away not: ount take no letters
Twixt thy gracious lord begin. Who's the matter,
With some gone aloof?

BUCKINGHAM:
Say brought your not, sir?

TRANCIO:
What scorn you do very good note:
My lord, I am content my duteous bildhs away.

BAGOT:
I have subjects beats but next thy father,
Apparender thee: Take honour in our senate.

PAULINA:
This way be thy nose physician, if he be
To be a bad-: Tradam? his any conduction,
His rason was not a parturer,
Or hang denied at the king! Sir Jethop's off,
Should be halp there, says Tybalt, Queen Nedalus,
Shore the king.

QUEEN MARGARET:
Oh shall forge you ashame to Richmond!

KING HENRY VI:
For welcome, welcomemes thee from the key-roof
Lord MarsS ay, and fly will not this hoofest
Rest, and let he go forth his wither's groans?

QUEEN MARGARET:
True executions of men:
Ah, with all horses of conceit so evide.

KING HENRY VI:
Defusions, worship thee, and what 