In [1]:
import torch
import torchinfo
import tiktoken
import numpy as np
import math
import os
import time
from contextlib import nullcontext
from  tqdm import tqdm

import sys; sys.path.append('..')
from models import TransformerLM, AbstractTransformerLM, configure_optimizers
from train_utils import train_model

In [2]:
print('cuda available: ', torch.cuda.is_available())
print('device count: ', torch.cuda.device_count())
print('current device name: ', torch.cuda.get_device_name(torch.cuda.current_device()))
print('Memory Usage:')
print('\tAllocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
print('\tReserved:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

cuda available:  True
device count:  1
current device name:  NVIDIA A100-PCIE-40GB
Memory Usage:
	Allocated: 0.0 GB
	Reserved:    0.0 GB


## Config

In [9]:
# I/O
eval_only = False # if True, script exits right after the first eval

# system
device = 'cuda'
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast

# 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' 
compile = True

# evaluation and output
out_dir = '../out/tiny_stories'
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
eval_interval = 250 # keep frequent because we'll overfit
eval_iters = 200
log_interval = 10 # don't print too too often

# we expect to overfit on this small dataset, so only save when val improves
always_save_checkpoint = False

# wandb logging
wandb_log = False
wandb_project = 'abstract_transformer--tiny_stories'

# optimization hyperparams
learning_rate = 1e-3 # with baby networks can afford to go a bit higher
max_iters = 5000
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
decay_lr = True # whether to decay the learning rate
lr_decay_iters = 5000 # make equal to max_iters usually
weight_decay = 1e-1
min_lr = 1e-4 # learning_rate / 10 usually
beta1 = 0.9
beta2 = 0.99 # make a bit bigger because number of tokens per iter is small
warmup_iters = 100
gradient_accumulation_steps = 1 # accumulate gradients over this many steps. simulates larger batch size

# batch size and block size
batch_size = 64
block_size = 256

# DDP (distributed data parallel) training
ddp = False
master_process = True

# TODO: set up DDP for future experiments

In [79]:
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

## Data Loader

In [None]:
from datasets import load_dataset

dataset = load_dataset("roneneldan/TinyStories")



In [None]:
print(dataset['train'][0]['text'])

One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.

Lily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."

Together, they shared the needle and sewed the button on Lily's shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.


In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

In [None]:
# TODO: use tiktoken instead; much faster
dataset = dataset.map(lambda x: tokenizer(x['text'], padding=True, truncation=True, max_length=block_size+1), batched=True)

Map:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Map:   0%|          | 0/21990 [00:00<?, ? examples/s]

In [42]:
dataset.set_format(type='torch', columns=['input_ids'])

In [43]:
train_dataloader = torch.utils.data.DataLoader(dataset['train'], batch_size=batch_size)
val_dataloader = torch.utils.data.DataLoader(dataset['validation'], batch_size=batch_size)

In [59]:
def get_batch(split):
    train_batch = iter(train_dataloader)
    val_batch = iter(val_dataloader)

    if split=='train':
        text = next(train_batch)['input_ids']
    elif split=='val':
        text = next(val_batch)['input_ids']
    else:
        raise ValueError(f"`split` must be 'train' or 'val', got {split}")

    x, y = text[:, :-1], text[:, 1:]
    x = x.pin_memory().to(device, non_blocking=True)
    y = y.pin_memory().to(device, non_blocking=True)

    return x, y

## Training set up

In [79]:
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

In [80]:
@torch.no_grad()
def eval_model(model, ctx=None):

    ctx = nullcontext() if ctx is None else ctx
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[f'{split}/loss'] = losses.mean()
    model.train()
    return out

## Transformer Model

In [47]:
model_args = dict(
    vocab_size=tokenizer.vocab_size, d_model=384, n_layers=6, n_heads=6, dff=None,
    dropout_rate=0.2, activation='relu', norm_first=True, max_block_size=block_size, bias=True)
model = transformer_lm = TransformerLM(**model_args).to(device)

In [48]:
# NOTE: embedding layer and output layer account for majority of params...
torchinfo.summary(model, input_data=torch.randint(0, 10, size=(1,block_size)), device='cuda')

Layer (type:depth-idx)                        Output Shape              Param #
TransformerLM                                 [1, 1, 28996]             --
├─ModuleDict: 1-1                             --                        --
│    └─Embedding: 2-1                         [1, 256, 384]             11,134,464
│    └─Embedding: 2-2                         [256, 384]                98,304
│    └─ModuleList: 2-3                        --                        --
│    │    └─EncoderBlock: 3-1                 [1, 256, 384]             1,774,464
│    │    └─EncoderBlock: 3-2                 [1, 256, 384]             1,774,464
│    │    └─EncoderBlock: 3-3                 [1, 256, 384]             1,774,464
│    │    └─EncoderBlock: 3-4                 [1, 256, 384]             1,774,464
│    │    └─EncoderBlock: 3-5                 [1, 256, 384]             1,774,464
│    │    └─EncoderBlock: 3-6                 [1, 256, 384]             1,774,464
│    └─Linear: 2-4                       

### Training

In [49]:
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

In [50]:
# optimizer
optimizer = configure_optimizers(model, weight_decay, learning_rate, (beta1, beta2), device_type=device)

num decayed parameter tensors: 27, with 32,984,064 parameters
num non-decayed parameter tensors: 49, with 58,948 parameters
using fused AdamW: True


In [57]:
# FIXME: getting an error with compile=True; something about dynamo?
train_kwargs = dict(
    model=model, get_batch=get_batch, batch_size=batch_size, max_iters=max_iters,
    optimizer=optimizer, scaler=scaler, get_lr=get_lr, eval_model=eval_model,
    compile=False, grad_clip=0, gradient_accumulation_steps=1,
    eval_main_metric='val/loss', eval_interval=eval_interval, always_save_checkpoint=always_save_checkpoint, out_dir=out_dir,
    log_interval=10, wandb_log=True, wandb_init_kwargs=dict(project=wandb_project, name='TransformerLM'), 
    ckpt_dict=dict(model_args=model_args), track_mfu=True,
    master_process=True, ddp=False, device_type='cuda')

In [58]:
train_model(**train_kwargs)



VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112290356929103, max=1.0…

starting training loop...
step 0: train loss 10.4830, val loss 10.4951
iter 0: loss 10.5046, time 8422.64ms, mfu -100.00%
iter 10: loss 7.3179, time 46.22ms, mfu 23.26%
iter 20: loss 5.4109, time 45.99ms, mfu 23.28%
iter 30: loss 3.7952, time 45.92ms, mfu 23.29%
iter 40: loss 3.1675, time 46.45ms, mfu 23.28%
iter 50: loss 2.6550, time 46.66ms, mfu 23.25%
iter 60: loss 2.1376, time 45.90ms, mfu 23.27%
iter 70: loss 1.6392, time 46.50ms, mfu 23.25%
iter 80: loss 1.1522, time 46.37ms, mfu 23.25%
iter 90: loss 0.7892, time 47.05ms, mfu 23.21%
iter 100: loss 0.3758, time 46.51ms, mfu 23.20%
iter 110: loss 0.1616, time 46.07ms, mfu 23.21%
iter 120: loss 0.0756, time 46.52ms, mfu 23.20%
iter 130: loss 0.0477, time 46.66ms, mfu 23.19%
iter 140: loss 0.0366, time 46.59ms, mfu 23.18%
iter 150: loss 0.0314, time 46.54ms, mfu 23.17%
iter 160: loss 0.0290, time 46.56ms, mfu 23.16%
iter 170: loss 0.0260, time 45.99ms, mfu 23.18%
iter 180: loss 0.0260, time 45.93ms, mfu 23.21%
iter 190: loss 0.0245, 



VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
iter,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
lr,▁████▇▇▆▆▅▅▄▄▃▃▃▂▂▂▂▂
mfu,▁████████████████████
train/loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/loss,█▁▂▃▄▃▄▄▄▄▄▄▅▅▄▅▅▅▆▆▇

0,1
iter,5000.0
lr,0.0001
mfu,22.70976
train/loss,0.01627
val/loss,9.78307


DONE.


In [68]:
prompt = 'Once upon a time,'

prompt_idx = torch.from_numpy(np.array(tokenizer.encode(prompt))).unsqueeze(0).to(device)
sample_gen = model.generate(prompt_idx, max_new_tokens=100, temperature=1.0, top_k=None)[0]
sample_gen = tokenizer.decode(sample_gen)
print(sample_gen)

[CLS] Once upon a time, [SEP] a small house, there was a cat and a dog. They liked to play all day. One day, they saw a shiny chain on the floor. They both wanted it. The cat had an idea. She rubbed her head on the dog's leg. The dog felt happy and closed his eyes. The cat took the chain and ran away. The dog felt sad and guilty. Later, the cat felt bad. She went back to the dog and gave him the chain. They both played


## Abstract Transformer Model (Symbolic Attention)

In [82]:
model_args = dict(
    vocab_size=tokenizer.vocab_size, d_model=384, n_layers=6, n_heads_enc=4, n_heads_abs=2, dff=None,
    symbol_retrieval='sym_attn', symbol_retrieval_kwargs=dict(num_symbols=50, n_heads=4, model_dim=384), # FIXME make names consistent: d_model, model_dim
    dropout_rate=0.2, activation='relu', norm_first=True, max_block_size=256, bias=True)
model = abstracttransformer_lm = AbstractTransformerLM(**model_args).to(device)

In [83]:
torchinfo.summary(model, input_data=torch.randint(0, 10, size=(1,block_size)), device='cuda')

Layer (type:depth-idx)                        Output Shape              Param #
AbstractTransformerLM                         [1, 1, 28996]             --
├─ModuleDict: 1-1                             --                        --
│    └─Embedding: 2-1                         [1, 256, 384]             11,134,464
│    └─Embedding: 2-2                         [256, 384]                98,304
│    └─ModuleList: 2-3                        --                        --
│    │    └─AbstractEncoderBlock: 3-1         [1, 256, 384]             2,404,224
│    │    └─AbstractEncoderBlock: 3-14        --                        (recursive)
│    │    └─AbstractEncoderBlock: 3-3         --                        (recursive)
│    │    └─AbstractEncoderBlock: 3-4         [1, 256, 384]             2,404,224
│    │    └─AbstractEncoderBlock: 3-14        --                        (recursive)
│    │    └─AbstractEncoderBlock: 3-6         --                        (recursive)
│    │    └─AbstractEncoderBlock:

In [84]:
# torchinfo overcounts # of params... something to do with symbolic attention shared across layers
# this is the correct number (similar to TransformerLM)

num_params = model.get_num_params() #sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'# of params {num_params:,}')

# of params 35,792,068


In [85]:
# TODO: can we implement in a way that torchinfo can understand? i.e., without "recursive" and overcounting

### Training

In [86]:
# grad scaler
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
# optimizer
optimizer = configure_optimizers(model, weight_decay, learning_rate, (beta1, beta2), device_type=device)

num decayed parameter tensors: 42, with 35,824,128 parameters
num non-decayed parameter tensors: 62, with 66,244 parameters
using fused AdamW: True


In [87]:
train_kwargs = dict(
    model=model, get_batch=get_batch, batch_size=batch_size, max_iters=max_iters,
    optimizer=optimizer, scaler=scaler, get_lr=get_lr, eval_model=eval_model,
    compile=True, grad_clip=0, gradient_accumulation_steps=1,
    eval_main_metric='val/loss', eval_interval=eval_interval, always_save_checkpoint=always_save_checkpoint, out_dir=out_dir,
    log_interval=10, wandb_log=True, wandb_init_kwargs=dict(project=wandb_project, name='AbstractTransformerLM'), 
    ckpt_dict=dict(model_args=model_args), track_mfu=True,
    master_process=True, ddp=False, device_type='cuda')

In [88]:
train_model(**train_kwargs)



VBox(children=(Label(value='0.077 MB of 0.077 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112226033583283, max=1.0…

compiling model... done compiling.
starting training loop...
step 0: train loss 10.7940, val loss 10.8234
iter 0: loss 10.8110, time 34751.53ms, mfu -100.00%
iter 10: loss 7.0523, time 67.64ms, mfu 17.50%
iter 20: loss 5.4574, time 66.70ms, mfu 17.52%
iter 30: loss 3.8451, time 67.47ms, mfu 17.52%
iter 40: loss 3.1769, time 67.91ms, mfu 17.51%
iter 50: loss 2.6737, time 66.91ms, mfu 17.53%
iter 60: loss 2.1552, time 67.41ms, mfu 17.53%
iter 70: loss 1.6518, time 67.89ms, mfu 17.52%
iter 80: loss 1.1627, time 66.89ms, mfu 17.54%
iter 90: loss 0.7457, time 66.89ms, mfu 17.56%
iter 100: loss 0.3881, time 68.36ms, mfu 17.53%
iter 110: loss 0.1610, time 67.62ms, mfu 17.53%
iter 120: loss 0.0763, time 67.20ms, mfu 17.54%
iter 130: loss 0.0499, time 67.42ms, mfu 17.54%
iter 140: loss 0.0385, time 67.94ms, mfu 17.53%
iter 150: loss 0.0316, time 67.89ms, mfu 17.52%
iter 160: loss 0.0295, time 67.20ms, mfu 17.53%
iter 170: loss 0.0265, time 67.26ms, mfu 17.53%
iter 180: loss 0.0264, time 67.45ms



VBox(children=(Label(value='0.006 MB of 0.018 MB uploaded\r'), FloatProgress(value=0.327321272885789, max=1.0)…

0,1
iter,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
lr,▁████▇▇▆▆▅▅▄▄▃▃▃▂▂▂▂▂
mfu,▁████████████████████
train/loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/loss,█▁▂▃▃▄▃▄▄▄▄▄▅▅▅▅▆▆▆▇▇

0,1
iter,5000.0
lr,0.0001
mfu,17.12092
train/loss,0.01627
val/loss,9.8164


DONE.


In [91]:
prompt = 'Once upon a time,'

prompt_idx = torch.from_numpy(np.array(tokenizer.encode(prompt))).unsqueeze(0).to(device)
sample_gen = model.generate(prompt_idx, max_new_tokens=250, temperature=1.0, top_k=None)[0]
sample_gen = tokenizer.decode(sample_gen)
print(sample_gen)

[CLS] Once upon a time, [SEP] a big forest, there was a tiny mushroom. It was all alone. The sun was very harsh, and the mushroom did not like it. It wanted to find a friend to play with and to help it hide from the sun. One day, a little bunny came hopping by. The mushroom called out, " Hello, bunny! Will you be my friend? " The bunny looked at the mushroom and smiled. " Sure, I will be your friend. Let's play together! " The bunny and the mushroom played all day, and they were very happy. As they played, the bunny realized that the mushroom needed help to hide from the harsh sun. So, the bunny dug a hole in the ground and put the mushroom inside. Now, the mushroom was safe and cool. The mushroom and the bunny were the best of friends, and they played in the forest every day. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [P

## Abstract Transformer Model (Relational Symbolic Attention)

In [92]:
model_args = dict(
    vocab_size=tokenizer.vocab_size, d_model=384, n_layers=6, n_heads_enc=4, n_heads_abs=2, dff=None,
    symbol_retrieval='rel_sym_attn', symbol_retrieval_kwargs=dict(
        model_dim=384, rel_n_heads=4, symbolic_attn_n_heads=4,
        num_symbols=20, nbhd_delta=5, causal_nbhd=True, include_self=False,
        normalize_rels=True), # FIXME make names consistent: d_model, model_dim
    dropout_rate=0.2, activation='relu', norm_first=True, max_block_size=256, bias=True)
model = abstracttransformer_lm = AbstractTransformerLM(**model_args).to(device)

In [None]:
torchinfo.summary(model, input_data=torch.randint(0, 10, size=(1,256)), device='cuda')

Layer (type:depth-idx)                                  Output Shape              Param #
AbstractTransformerLM                                   [1, 1, 65]                --
├─ModuleDict: 1-1                                       --                        --
│    └─Embedding: 2-1                                   [1, 256, 384]             24,960
│    └─Embedding: 2-2                                   [256, 384]                98,304
│    └─ModuleList: 2-3                                  --                        --
│    │    └─AbstractEncoderBlock: 3-1                   [1, 256, 384]             2,684,928
│    │    └─AbstractEncoderBlock: 3-14                  --                        (recursive)
│    │    └─AbstractEncoderBlock: 3-3                   --                        (recursive)
│    │    └─AbstractEncoderBlock: 3-4                   [1, 256, 384]             2,684,928
│    │    └─AbstractEncoderBlock: 3-14                  --                        (recursive)
│    │    └

In [None]:
# torchinfo overcounts # of params... something to do with symbolic attention shared across layers
# this is the correct number (similar to TransformerLM)

num_params = model.get_num_params() #sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'# of params {num_params:,}')

# of params 13,824,833


In [None]:
# TODO: can we implement in a way that torchinfo can understand? i.e., without "recursive" and overcounting

### Training

In [None]:
# grad scaler
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
# optimizer
optimizer = configure_optimizers(model, weight_decay, learning_rate, (beta1, beta2), device_type=device)

num decayed parameter tensors: 45, with 13,884,672 parameters
num non-decayed parameter tensors: 65, with 38,465 parameters
using fused AdamW: True


In [None]:
train_kwargs = dict(
    model=model, get_batch=get_batch, batch_size=batch_size, max_iters=max_iters,
    optimizer=optimizer, scaler=scaler, get_lr=get_lr, eval_model=eval_model,
    compile=True, grad_clip=0, gradient_accumulation_steps=1,
    eval_main_metric='val/loss', eval_interval=eval_interval, always_save_checkpoint=always_save_checkpoint, out_dir=out_dir,
    log_interval=10, wandb_log=True, wandb_init_kwargs=dict(project=wandb_project, name='AbstractTransformerLM (Relational Symbolic Attn)'), 
    ckpt_dict=dict(model_args=model_args), track_mfu=True,
    master_process=True, ddp=False, device_type='cuda')

In [None]:
train_model(**train_kwargs)

compiling model... done compiling.
starting training loop...
step 0: train loss 4.5255, val loss 4.5268
iter 0: loss 4.5382, time 62455.47ms, mfu -100.00%


In [None]:
prompt = 'Once upon a time,'

prompt_idx = torch.from_numpy(np.array(tokenizer.encode(prompt))).unsqueeze(0).to(device)
sample_gen = model.generate(prompt_idx, max_new_tokens=100, temperature=1.0, top_k=None)[0]
sample_gen = tokenizer.decode(sample_gen)
print(sample_gen)

Romeo,--

KING HENRY VI:
He was for him to be found that Richard now.

DERBY:
My lord, I proud away not: ount take no letters
Twixt thy gracious lord begin. Who's the matter,
With some gone aloof?

BUCKINGHAM:
Say brought your not, sir?

TRANCIO:
What scorn you do very good note:
My lord, I am content my duteous bildhs away.

BAGOT:
I have subjects beats but next thy father,
Apparender thee: Take honour in our senate.

PAULINA:
This way be thy nose physician, if he be
To be a bad-: Tradam? his any conduction,
His rason was not a parturer,
Or hang denied at the king! Sir Jethop's off,
Should be halp there, says Tybalt, Queen Nedalus,
Shore the king.

QUEEN MARGARET:
Oh shall forge you ashame to Richmond!

KING HENRY VI:
For welcome, welcomemes thee from the key-roof
Lord MarsS ay, and fly will not this hoofest
Rest, and let he go forth his wither's groans?

QUEEN MARGARET:
True executions of men:
Ah, with all horses of conceit so evide.

KING HENRY VI:
Defusions, worship thee, and what 