In [1]:
import random
import string
from tqdm.notebook import tqdm  # For progress bars
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer
from torch.nn.utils.rnn import pad_sequence
from collections import Counter
import numpy as np
import pandas as pd
from datetime import datetime

import os

In [2]:
num_blocks = 50
N = 5 #text string in sequence
L = 2 #query string
N_options=[5, 8]
num_sequences_train = 1000000
num_sequences_test = 1000
sequence_len = (num_blocks *N) + (num_blocks-1)
def contains_letters(main_str, target_str, case_sensitive=False):
    if not case_sensitive:
        main_str = main_str.lower()
        target_str = target_str.lower()
    
    main_counts = Counter(main_str)
    target_counts = Counter(target_str)
    
    return all(main_counts[k] >= target_counts[k] for k in target_counts)

def generate_sequence(num_blocks,N, N_options, L):
    # Randomly choose block lengths (N) and generate blocks
    if N is None:
        N = random.choice(N_options)
    blocks = []
    target_block_idx = random.randint(0, num_blocks - 1)
    target_block = ''.join(random.choices(string.ascii_lowercase, k=N))

    while True:
        target_block = ''.join(random.choices(string.ascii_lowercase, k=N))
        unique_block_letters = list(set(target_block))
        if len(unique_block_letters) >3:  # Ensure we get at least 3 unique letters
            break
    # Generate 2 unique query letters that MUST appear in the target block
    while True:
        query_letters = random.sample(target_block, min(L, len(target_block)))
        # print(target_block)
        unique_letters = list(set(query_letters))
        if len(unique_letters) == L:  # Ensure we get exactly 2 letters
            break

    for i,_ in enumerate(range(num_blocks)):
        if i == target_block_idx:
            blocks.append(target_block)
        else:
            while True:
                block = ''.join(random.choices(string.ascii_lowercase, k=N))
                if not contains_letters(block, ''.join(query_letters)): 
                    break
            blocks.append(block)


    
    # Join blocks into sequence and append query
    sequence = '.'.join(blocks) + '#' + ''.join(query_letters)
    return sequence, target_block_idx





def generate_dataset(num_sequences, filepath, force=False):
    if os.path.exists(filepath) and not force:
        print(f"{filepath} already exists. Skipping generation.")
        return
    
    print(f"Generating {num_sequences} sequences to {filepath}...")
    with open(filepath, 'w') as f:
        for _ in tqdm(range(num_sequences), desc=f"Generating {filepath}"):
            seq, target = generate_sequence(num_blocks,N, N_options, L)
            f.write(f"{seq}|{target}\n")




# Only generate files if they don't exist
path_folder = "/kaggle/input/mta-toy-task-dataset/"
generate_dataset(num_sequences_train, f'{path_folder}mta_train.txt')
generate_dataset(num_sequences_test, f'{path_folder}mta_test.txt')
# Verify a sample
if not os.path.exists(f'{path_folder}mta_train.txt'):
    sample_seq, target_idx = generate_sequence()
    print(f"Sample Sequence:\n{sample_seq}")
    print(f"Target Block: {sample_seq.split('.')[target_idx]}")
    print(f"Query Letters: {sample_seq.split('#')[-1]}")
else:
    print("Dataset already exists. Loading a sample...")
    with open(f'{path_folder}mta_train.txt') as f:
        first_line = f.readline().strip()
        seq, target = first_line.split('|')
        print(f"Sample Sequence:\n{seq}")
        print(f"Target Block: {seq.split('.')[int(target)]}")
        print(f"Query Letters: {seq.split('#')[-1]}")



/kaggle/input/mta-toy-task-dataset/mta_train.txt already exists. Skipping generation.
/kaggle/input/mta-toy-task-dataset/mta_test.txt already exists. Skipping generation.
Dataset already exists. Loading a sample...
Sample Sequence:
olbud.xdyxs.cudki.lhwzg.pwhlo.rljxi.ozgfm.ijhhx.gowsd.gamrs.tstxh.gjxcx.qkwht.awtzm.jhctg.aqyxx.bxpyl.pxdks.ryciw.rarbd.lghyj.gdbcm.aeyen.fxdcl.bbhpj.gqnxc.tuhgc.bnieb.krypx.bmiev.yrphk.damen.yqifs.wmpdo.fhkun.dslla.ncfie.buysa.dddbq.srvth.gqugr.owgip.hieqz.zydvi.cfwde.ogbrg.zyerl.ynrca.mdwsf.qphya#tm
Target Block: awtzm
Query Letters: tm


In [3]:
blocks_part, query_letters = seq.split('#')
blocks = blocks_part.split('.')

blocks_part

'olbud.xdyxs.cudki.lhwzg.pwhlo.rljxi.ozgfm.ijhhx.gowsd.gamrs.tstxh.gjxcx.qkwht.awtzm.jhctg.aqyxx.bxpyl.pxdks.ryciw.rarbd.lghyj.gdbcm.aeyen.fxdcl.bbhpj.gqnxc.tuhgc.bnieb.krypx.bmiev.yrphk.damen.yqifs.wmpdo.fhkun.dslla.ncfie.buysa.dddbq.srvth.gqugr.owgip.hieqz.zydvi.cfwde.ogbrg.zyerl.ynrca.mdwsf.qphya'

In [4]:
import string

# use only lowercase letters as the character set
chars = list(string.ascii_lowercase)
chars.append(".")
vocab_size = len(chars)

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }

encode = lambda s: [stoi[c] for c in s if c in stoi]  # ignore characters not in vocab
decode = lambda l: ''.join([itos[i] for i in l])

vocab_size


27

In [5]:
torch.tensor(encode(blocks_part))

tensor([14, 11,  1, 20,  3, 26, 23,  3, 24, 23, 18, 26,  2, 20,  3, 10,  8, 26,
        11,  7, 22, 25,  6, 26, 15, 22,  7, 11, 14, 26, 17, 11,  9, 23,  8, 26,
        14, 25,  6,  5, 12, 26,  8,  9,  7,  7, 23, 26,  6, 14, 22, 18,  3, 26,
         6,  0, 12, 17, 18, 26, 19, 18, 19, 23,  7, 26,  6,  9, 23,  2, 23, 26,
        16, 10, 22,  7, 19, 26,  0, 22, 19, 25, 12, 26,  9,  7,  2, 19,  6, 26,
         0, 16, 24, 23, 23, 26,  1, 23, 15, 24, 11, 26, 15, 23,  3, 10, 18, 26,
        17, 24,  2,  8, 22, 26, 17,  0, 17,  1,  3, 26, 11,  6,  7, 24,  9, 26,
         6,  3,  1,  2, 12, 26,  0,  4, 24,  4, 13, 26,  5, 23,  3,  2, 11, 26,
         1,  1,  7, 15,  9, 26,  6, 16, 13, 23,  2, 26, 19, 20,  7,  6,  2, 26,
         1, 13,  8,  4,  1, 26, 10, 17, 24, 15, 23, 26,  1, 12,  8,  4, 21, 26,
        24, 17, 15,  7, 10, 26,  3,  0, 12,  4, 13, 26, 24, 16,  8,  5, 18, 26,
        22, 12, 15,  3, 14, 26,  5,  7, 10, 20, 13, 26,  3, 18, 11, 11,  0, 26,
        13,  2,  5,  8,  4, 26,  1, 20, 

In [6]:
torch.tensor(encode(query_letters))

tensor([19, 12])

In [7]:
torch.tensor(encode(blocks[int(target)]))[1]

tensor(22)

In [8]:
class MTADataset(Dataset):
    def __init__(self, file_path, variant = "all" ):#, tokenizer, max_length=256):
        # self.tokenizer = tokenizer
        # self.max_length = max_length
        self.variant = variant
        with open(file_path, 'r') as f:
            self.sequences = [line.strip() for line in f.readlines()]
    
    def __len__(self):
        return len(self.sequences)
    def __getitem__(self, idx):
        sequence, target_block_idx = self.sequences[idx].split('|')
        blocks_part, query_letters = sequence.split('#')
        blocks = blocks_part.split('.')
        
        # Find target block (contains all query letters)
        target_block = torch.tensor(encode(blocks[int(target_block_idx)]))
        if self.variant == "all":
            target = target_block
        elif self.variant == "first":
            target = target_block[0]
        elif self.variant == "last":
            target = target_block[-1]
        elif self.variant == "position":
            target = torch.tensor(int(target_block_idx))

        else:
            raise f"{self.variant} is not recognized as a supported variant type"
            
        
        # # Tokenize blocks and query SEPARATELY
        blocks_tokens = torch.tensor(encode(blocks_part))
        query_tokens = torch.tensor(encode(query_letters))

        
        return {
            'blocks': blocks_tokens,
            'query': query_tokens,
            'target': target
        }






In [9]:
# i = 0
# for batch in train_dataloader:
#     blocks = batch['blocks']
#     queries = batch['query']
#     targets = batch['target']
#     print(f"blocks: {blocks.shape}")
#     print(f"queries: {queries.shape}")
#     print(f"targets: {targets.shape}")
#     i+=1
#     if i == 5:
#         break
    


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler


class MultiTokenAttention(nn.Module):
    def __init__(self,n_heads, head_size):
        self.n_heads = n_heads
        self.head_size = head_size
        super().__init__()
        # Projections for blocks (keys/values) and queries (letters)
        self.block_proj = nn.Linear(n_embd, n_embd * 2)  # K, V
        self.query_proj = nn.Linear(n_embd, n_embd)      # Q
        
        # Key-Query Convolution
        self.conv2d = nn.Conv2d(
            in_channels=n_heads,
            out_channels=n_heads,
            kernel_size=(c_q, c_k),
            padding= "same",#((c_q-1)//2, (c_k-1)//2),
            # groups=n_heads
        )
        
        # Head Mixing
        self.head_conv = nn.Conv2d(
            in_channels=n_heads,
            out_channels=n_heads,
            padding = "same",# ((c_h-1)//2, (c_h-1)//2),
            kernel_size=c_h,
            groups=n_heads // c_h
        )
        
        # Output
        self.group_norm = nn.GroupNorm(n_heads, n_embd)
        self.gate = nn.Linear(n_embd, n_heads)
        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(n_embd, n_embd)
        # self.register_buffer('tril', torch.tril(torch.ones(L, block_size)))


    def forward(self, blocks, queries):
        """
        Inputs:
            blocks:  (batch, block_len, n_embd)  # All blocks concatenated
            queries: (batch, query_len, n_embd)  # Query letters
        """
        # batch_size = blocks.size(0)
        B,T,C = blocks.shape
        _,Q_len,_ = queries.shape
        
        # 1. Project blocks to K/V and queries to Q
        K, V = self.block_proj(blocks).chunk(2, dim=-1)  # Each (batch, block_len, n_embd)
        Q = self.query_proj(queries)                     # (batch, query_len, n_embd)

        
        # 2. Split into heads
        Q = Q.view(B, -1, self.n_heads, self.head_size).transpose(1, 2)  # [B, h, q_len, d]
        K = K.view(B, -1, self.n_heads, self.head_size).transpose(1, 2)  # [B, h, b_len, d]
        V = V.view(B, -1, self.n_heads, self.head_size).transpose(1, 2)  # [B, h, b_len, d]

        
        # 3. Attention scores between queries and blocks
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) *(self.head_size **-0.5)  # [B, h, q_len, b_len]
        # Ensure q_len >= c_q (e.g., pad queries if needed)


        # 4. Key-Query Convolution (mixes attention across nearby queries/blocks)
        attn_scores = self.conv2d(attn_scores)  # [B, h, q_len, b_len]


        
        # attn_scores = attn_scores.masked_fill(self.tril[:Q_len, :T] == 0, float('-inf'))    
        # attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))        


        
        # 5. Softmax + Head Mixing
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.head_conv(attn_weights)  # [B, h, q_len, b_len]

        
        # 6. Weighted sum of block values
        output = torch.matmul(attn_weights, V)  # [B, h, q_len, d]
        output = output.transpose(1, 2).reshape(B, -1, self.n_heads * self.head_size)
        # output = self.head_conv(output)  # [B, q_len, C]

        
        # 7. GroupNorm + Gating
        output = output.transpose(1, 2)  # [B, C, seq_len]
        output = self.group_norm(output)
        output = output.transpose(1, 2)  # Back to [B, seq_len, C]
        gates = torch.sigmoid(self.gate(output))  # [B, q_len, h]
        gates = gates.unsqueeze(-1).expand(-1, -1, -1, self.head_size).reshape_as(output)
        output = output * gates
        output = self.dropout(self.proj(output))
        
        return output  # (batch, query_len, n_embd)

In [11]:


# hyperparameters
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = sequence_len#32 # what is the maximum context length for predictions?
max_iters = 20000
eval_interval = 100
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 256#//4
n_head = 2
n_layer = 4
dropout = 0.1
N = 5
c_q=2
c_k=2*N -1
c_h=2

# ------------
SEED = 1337
torch.manual_seed(SEED)
# For maximum reproducibility
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(False) 
# 1. Set environment variables for CuBLAS determinism
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # or ':16:8'
os.environ['PYTHONHASHSEED'] = str(SEED)

# 2. Set Python, NumPy and PyTorch seeds
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# 3. Configure PyTorch for deterministic operations
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


# Add these near your hyperparameters
checkpoint_dir = "./checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_interval = 1000  # Save every 1000 iterations
best_test_loss = float('inf')

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(SEED)

# Add this function to save checkpoints
def save_checkpoint(iteration, model, optimizer, scheduler, best=False):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    checkpoint = {
        'iteration': iteration,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'loss': best_test_loss,
        'timestamp': timestamp
    }
    
    filename = f"checkpoint_{iteration}.pt" if not best else "best_model.pt"
    torch.save(checkpoint, os.path.join(checkpoint_dir, filename))
    print(f"Saved checkpoint at iteration {iteration}")

# Add this function to load checkpoints
def load_checkpoint(path, model, optimizer, scheduler):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if scheduler and checkpoint['scheduler_state_dict']:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    return checkpoint['iteration']
train_dataset = MTADataset(f"{path_folder}mta_train.txt", variant="all")#, tokenizer)
# Modified DataLoader with deterministic shuffling
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    worker_init_fn=seed_worker,
    generator=g,
    num_workers=4,
    persistent_workers=True
)
test_dataset = MTADataset(f"{path_folder}mta_test.txt", variant="all")#, tokenizer)
# Modified DataLoader with deterministic shuffling
test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True,
    worker_init_fn=seed_worker,
    generator=g,
    num_workers=4,
    persistent_workers=True
)
  
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'test']:
        losses = torch.zeros(eval_iters)
        if split == 'test':
            dataloader = test_dataloader
        else:
            dataloader = train_dataloader
        dataloader_iter = iter(dataloader)
        for k in range(eval_iters):
            try:
                batch = next(dataloader_iter)
                # process the batch
            except StopIteration:
                # Reinitialize the iterator if you reach the end
                dataloader_iter = iter(dataloader)
                batch = next(dataloader_iter)
            blocks = batch['blocks'].to(device)
            queries = batch['query'].to(device)
            targets = batch['target'].to(device)
            logits, loss = model(blocks, queries, targets)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiTokenAttention(n_head, head_size)
        self.ffwd_blocks = FeedFoward(n_embd)
        self.ffwd_queries = FeedFoward(n_embd)

        self.blocks_proj  = nn.Linear(n_embd, n_embd) 
        self.ln_blocks = nn.RMSNorm(n_embd)
        self.ln_queries = nn.RMSNorm(n_embd)

        self.ln2 = nn.RMSNorm(n_embd)

    def forward(self, blocks_pos_emb, queries_emb):
        blocks_pos_emb = self.ln_blocks(blocks_pos_emb)
        queries_emb = self.ln_queries(queries_emb)
        attn_out = self.sa(blocks_pos_emb, queries_emb)
        blocks_pos_emb_attn = self.blocks_proj(attn_out.mean(dim=1))
        blocks_pos_emb = blocks_pos_emb + blocks_pos_emb_attn.unsqueeze(1)
        queries_emb = queries_emb + attn_out
        queries_emb = queries_emb + self.ffwd_queries(self.ln2(queries_emb))
        blocks_pos_emb = blocks_pos_emb + self.ffwd_blocks(self.ln2(blocks_pos_emb))

        return blocks_pos_emb,queries_emb

# super simple bigram model
class MTAAttentionRetrievalModel(nn.Module):

    def __init__(self, variant= "all"):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(sequence_len, n_embd)
        self.blocks = nn.ModuleList([Block(n_embd, n_head) for _ in range(n_layer)]) #nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.RMSNorm(n_embd) # final layer norm
        self.variant = variant

        if self.variant == "position":
            self.lm_head = nn.Linear(L*n_embd, num_blocks)
        elif self.variant == "all":
            self.lm_head = nn.Sequential(
                nn.Linear(n_embd*L, n_embd * 2),
                nn.ReLU(),
                nn.Linear(n_embd * 2, vocab_size * N)
            )
        elif self.variant in ['first', 'last']:
            self.lm_head = nn.Linear(n_embd*L, vocab_size)
        else:
            raise f"{self.variant} is not recognized as a supported variant type"

    def forward(self, blocks, queries, targets=None):
        B, T = blocks.shape

        # idx and targets are both (B,T) tensor of integers
        blocks_emb = self.token_embedding_table(blocks) # (B,T,C)
        queries_emb = self.token_embedding_table(queries) # (B,T,C)

        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        blocks_pos_emb = blocks_emb + pos_emb # (B,T,C)
        for block in self.blocks:
            blocks_pos_emb, queries_emb = block(blocks_pos_emb, queries_emb)
        # x = self.blocks(blocks_pos_emb,queries_emb) # (B,T,C)
        x = self.ln_f(queries_emb) # (B,T,C)
        # print(x.shape)
        logits = self.lm_head(x.view(B,-1)) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, C = logits.shape
            if self.variant == "all":
                # logits = logits.view(-1, vocab_size)
                
                logits = logits.view(B*N, vocab_size)
                targets = targets.view(B*N)
            else:
                # logits = logits.view(B, vocab_size)
                
                logits = logits.view(B, vocab_size)
                targets = targets.view(B)                
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    


model = MTAAttentionRetrievalModel(variant="all")
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate,weight_decay=0.01)
# optimizer = torch.optim.Lion(model.parameters(), lr=learning_rate)#,weight_decay=0.01)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=0.1, 
    patience=1000,  # Number of eval intervals to wait before reducing LR
    verbose=True
)

start_scheduler_at_iter = 10000  
resume_from_checkpoint = False

dataloader = train_dataloader
dataloader_iter = iter(dataloader)
if resume_from_checkpoint:
    start_iter = load_checkpoint("checkpoints/best_model.pt", model, optimizer, scheduler)
else:
    start_iter = 0
for i in tqdm(range(start_iter, max_iters)):

    if i % eval_interval == 0 or i == max_iters - 1:
        losses = estimate_loss()
        print(f"step {i}: train loss {losses['train']:.4f}, test loss {losses['test']:.4f}, lr {optimizer.param_groups[0]['lr']:.2e}")
        # Save best model
        if losses['test'] < best_test_loss:
            best_test_loss = losses['test']
            save_checkpoint(i, model, optimizer, scheduler, best=True)
    
    # Regular checkpoint saving
    if i % checkpoint_interval == 0 and i > 0:
        save_checkpoint(i, model, optimizer, scheduler)
        

    try:
        batch = next(dataloader_iter)
        # process the batch
    except StopIteration:
        # Reinitialize the iterator if you reach the end
        dataloader_iter = iter(dataloader)
        batch = next(dataloader_iter)
    blocks = batch['blocks'].to(device)
    queries = batch['query'].to(device)
    targets = batch['target'].to(device)

    # evaluate the loss
    logits, loss = model(blocks, queries, targets)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    # Only step the scheduler after start_scheduler_at_iter
    if i >= start_scheduler_at_iter:
        scheduler.step(loss)



5.943551 M parameters




  0%|          | 0/20000 [00:00<?, ?it/s]

  return F.conv2d(


step 0: train loss 3.3232, test loss 3.3180, lr 3.00e-04
Saved checkpoint at iteration 0
step 100: train loss 2.8587, test loss 2.8616, lr 3.00e-04
Saved checkpoint at iteration 100
step 200: train loss 2.8232, test loss 2.8169, lr 3.00e-04
Saved checkpoint at iteration 200
step 300: train loss 2.8164, test loss 2.8103, lr 3.00e-04
Saved checkpoint at iteration 300
step 400: train loss 2.8074, test loss 2.8060, lr 3.00e-04
Saved checkpoint at iteration 400
step 500: train loss 2.8133, test loss 2.8087, lr 3.00e-04
step 600: train loss 2.8036, test loss 2.7988, lr 3.00e-04
Saved checkpoint at iteration 600
step 700: train loss 2.8001, test loss 2.7957, lr 3.00e-04
Saved checkpoint at iteration 700
step 800: train loss 2.7984, test loss 2.7974, lr 3.00e-04
step 900: train loss 2.8000, test loss 2.7918, lr 3.00e-04
Saved checkpoint at iteration 900
step 1000: train loss 2.7973, test loss 2.7941, lr 3.00e-04
Saved checkpoint at iteration 1000
step 1100: train loss 2.7929, test loss 2.7906,

In [12]:
checkpoint = torch.load("checkpoints/best_model.pt")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

MTAAttentionRetrievalModel(
  (token_embedding_table): Embedding(27, 256)
  (position_embedding_table): Embedding(299, 256)
  (blocks): ModuleList(
    (0-3): 4 x Block(
      (sa): MultiTokenAttention(
        (block_proj): Linear(in_features=256, out_features=512, bias=True)
        (query_proj): Linear(in_features=256, out_features=256, bias=True)
        (conv2d): Conv2d(2, 2, kernel_size=(2, 9), stride=(1, 1), padding=same)
        (head_conv): Conv2d(2, 2, kernel_size=(2, 2), stride=(1, 1), padding=same)
        (group_norm): GroupNorm(2, 256, eps=1e-05, affine=True)
        (gate): Linear(in_features=256, out_features=2, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (proj): Linear(in_features=256, out_features=256, bias=True)
      )
      (ffwd_blocks): FeedFoward(
        (net): Sequential(
          (0): Linear(in_features=256, out_features=1024, bias=True)
          (1): ReLU()
          (2): Linear(in_features=1024, out_features=256, bias=True)
       

In [13]:

test_dataset = MTADataset(f"{path_folder}mta_test.txt", variant="all")#, tokenizer)
# Modified DataLoader with deterministic shuffling
test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    worker_init_fn=seed_worker,
    generator=g,
    num_workers=4,
    persistent_workers=True
)

full_output = np.array([])
full_targets = np.array([])
for batch in tqdm(test_dataloader):
    blocks = batch['blocks'].to(device)
    queries = batch['query'].to(device)
    targets = batch['target'].to(device)
    
    # evaluate the loss
    logits, loss = model(blocks, queries, targets=None)
    output = torch.argmax(logits.view(-1, N, vocab_size), dim=-1)
    
    output = np.array([decode(i.tolist()) for i in output])
    full_output = np.append(full_output, output)
    targets = np.array([decode(i.tolist()) for i in targets])
    full_targets = np.append(full_targets, targets)

  0%|          | 0/16 [00:00<?, ?it/s]

In [14]:
acc = sum((full_output == full_targets).astype(int))/ len(targets)
print(f"The Accuracy of the MTA Model on the test set is {acc:.4f}%")

The Accuracy of the MTA Model on the test set is 17.3250%


In [15]:
#[int(a == b) for a, b in zip(output, targets)]


In [16]:
class StandardMultiAttention(nn.Module):
    def __init__(self,n_heads, head_size):
        self.n_heads = n_heads
        self.head_size = head_size
        super().__init__()
        # Projections for blocks (keys/values) and queries (letters)
        self.block_proj = nn.Linear(n_embd, n_embd * 2)  # K, V
        self.query_proj = nn.Linear(n_embd, n_embd)      # Q
        
        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(n_embd, n_embd)
        self.register_buffer('tril', torch.tril(torch.ones(L, block_size)))


    def forward(self, blocks, queries):
        """
        Inputs:
            blocks:  (batch, block_len, n_embd)  # All blocks concatenated
            queries: (batch, query_len, n_embd)  # Query letters
        """
        # batch_size = blocks.size(0)
        B,T,C = blocks.shape
        _,Q_len,_ = queries.shape
        
        # 1. Project blocks to K/V and queries to Q
        K, V = self.block_proj(blocks).chunk(2, dim=-1)  # Each (batch, block_len, n_embd)
        Q = self.query_proj(queries)                     # (batch, query_len, n_embd)

        
        # 2. Split into heads
        Q = Q.view(B, -1, self.n_heads, self.head_size).transpose(1, 2)  # [B, h, q_len, d]
        K = K.view(B, -1, self.n_heads, self.head_size).transpose(1, 2)  # [B, h, b_len, d]
        V = V.view(B, -1, self.n_heads, self.head_size).transpose(1, 2)  # [B, h, b_len, d]

        
        # 3. Attention scores between queries and blocks
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) *(self.head_size **-0.5)  # [B, h, q_len, b_len]

        attn_scores = attn_scores.masked_fill(self.tril[:Q_len, :T] == 0, float('-inf'))    
        # attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))        


        
        # 5. Softmax + Head Mixing
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        # 6. Weighted sum of block values
        output = torch.matmul(attn_weights, V)  # [B, h, q_len, d]
        output = output.transpose(1, 2).reshape(B, -1, self.n_heads * self.head_size)
        output = self.dropout(self.proj(output))
        
        return output  # (batch, query_len, n_embd)

In [17]:

# hyperparameters
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = sequence_len#32 # what is the maximum context length for predictions?
max_iters = 20000
eval_interval = 100
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 256#//4
n_head = 2
n_layer = 4
dropout = 0.1
N = 5
c_q=2
c_k=2*N -1
c_h=2

# ------------
SEED = 1337
torch.manual_seed(SEED)
# For maximum reproducibility
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(False) 
# 1. Set environment variables for CuBLAS determinism
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # or ':16:8'
os.environ['PYTHONHASHSEED'] = str(SEED)

# 2. Set Python, NumPy and PyTorch seeds
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# 3. Configure PyTorch for deterministic operations
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


# Add these near your hyperparameters
checkpoint_dir = "./checkpoints_stan"
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_interval = 1000  # Save every 1000 iterations
best_test_loss = float('inf')

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(SEED)

# Add this function to save checkpoints
def save_checkpoint(iteration, model, optimizer, scheduler, best=False):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    checkpoint = {
        'iteration': iteration,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'loss': best_test_loss,
        'timestamp': timestamp
    }
    
    filename = f"checkpoint_{iteration}.pt" if not best else "best_model.pt"
    torch.save(checkpoint, os.path.join(checkpoint_dir, filename))
    print(f"Saved checkpoint at iteration {iteration}")

# Add this function to load checkpoints
def load_checkpoint(path, model, optimizer, scheduler):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if scheduler and checkpoint['scheduler_state_dict']:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    return checkpoint['iteration']
train_dataset = MTADataset(f"{path_folder}mta_train.txt", variant="all")#, tokenizer)
# Modified DataLoader with deterministic shuffling
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    worker_init_fn=seed_worker,
    generator=g,
    num_workers=4,
    persistent_workers=True
)
test_dataset = MTADataset(f"{path_folder}mta_test.txt", variant="all")#, tokenizer)
# Modified DataLoader with deterministic shuffling
test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True,
    worker_init_fn=seed_worker,
    generator=g,
    num_workers=4,
    persistent_workers=True
)
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'test']:
        losses = torch.zeros(eval_iters)
        if split == 'test':
            dataloader = test_dataloader
        else:
            dataloader = train_dataloader
        dataloader_iter = iter(dataloader)
        for k in range(eval_iters):
            try:
                batch = next(dataloader_iter)
                # process the batch
            except StopIteration:
                # Reinitialize the iterator if you reach the end
                dataloader_iter = iter(dataloader)
                batch = next(dataloader_iter)
            blocks = batch['blocks'].to(device)
            queries = batch['query'].to(device)
            targets = batch['target'].to(device)
            logits, loss = model(blocks, queries, targets)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = StandardMultiAttention(n_head, head_size)
        self.ffwd_blocks = FeedFoward(n_embd)
        self.ffwd_queries = FeedFoward(n_embd)

        self.blocks_proj  = nn.Linear(n_embd, n_embd) 
        self.ln_blocks = nn.LayerNorm(n_embd)
        self.ln_queries = nn.LayerNorm(n_embd)

        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, blocks_pos_emb, queries_emb):
        blocks_pos_emb = self.ln_blocks(blocks_pos_emb)
        queries_emb = self.ln_queries(queries_emb)
        attn_out = self.sa(blocks_pos_emb, queries_emb)
        blocks_pos_emb_attn = self.blocks_proj(attn_out.mean(dim=1))
        blocks_pos_emb = blocks_pos_emb + blocks_pos_emb_attn.unsqueeze(1)
        queries_emb = queries_emb + attn_out
        queries_emb = queries_emb + self.ffwd_queries(self.ln2(queries_emb))
        blocks_pos_emb = blocks_pos_emb + self.ffwd_blocks(self.ln2(blocks_pos_emb))

        return blocks_pos_emb,queries_emb

# super simple bigram model
class StandardAttentionRetrievalModel(nn.Module):

    def __init__(self,variant="all"):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.ModuleList([Block(n_embd, n_head) for _ in range(n_layer)]) #nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.variant = variant

        if self.variant == "position":
            self.lm_head = nn.Linear(L*n_embd, num_blocks)
        elif self.variant == "all":
            self.lm_head = nn.Sequential(
                nn.Linear(n_embd*L, n_embd * 2),
                nn.ReLU(),
                nn.Linear(n_embd * 2, vocab_size * N)
            )
        elif self.variant in ['first', 'last']:
            self.lm_head = nn.Linear(n_embd*L, vocab_size)
        else:
            raise f"{self.variant} is not recognized as a supported variant type"

    def forward(self, blocks, queries, targets=None):
        B, T = blocks.shape

        # idx and targets are both (B,T) tensor of integers
        blocks_emb = self.token_embedding_table(blocks) # (B,T,C)
        queries_emb = self.token_embedding_table(queries) # (B,T,C)

        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        blocks_pos_emb = blocks_emb + pos_emb # (B,T,C)
        for block in self.blocks:
            blocks_pos_emb, queries_emb = block(blocks_pos_emb, queries_emb)
        # x = self.blocks(blocks_pos_emb,queries_emb) # (B,T,C)
        x = self.ln_f(queries_emb) # (B,T,C)
        # print(x.shape)
        logits = self.lm_head(x.view(B,-1)) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, C = logits.shape
            if self.variant == "all":
                # logits = logits.view(-1, vocab_size)
                
                logits = logits.view(B*N, vocab_size)
                targets = targets.view(B*N)
            else:
                # logits = logits.view(B, vocab_size)
                
                logits = logits.view(B, vocab_size)
                targets = targets.view(B)                
            loss = F.cross_entropy(logits, targets)

        return logits, loss

model = StandardAttentionRetrievalModel(variant="all")
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=0.1, 
    patience=1000,  # Number of eval intervals to wait before reducing LR
    verbose=True
)
# scheduler = lr_scheduler.LambdaLR
start_scheduler_at_iter = 10000  
resume_from_checkpoint = False

dataloader = train_dataloader
dataloader_iter = iter(dataloader)
if resume_from_checkpoint:
    start_iter = load_checkpoint("checkpoints_stan/best_model.pt", model, optimizer, scheduler)
else:
    start_iter = 0
for i in tqdm(range(start_iter, max_iters)):

    if i % eval_interval == 0 or i == max_iters - 1:
        losses = estimate_loss()
        print(f"step {i}: train loss {losses['train']:.4f}, test loss {losses['test']:.4f}, lr {optimizer.param_groups[0]['lr']:.2e}")
        # Save best model
        if losses['test'] < best_test_loss:
            best_test_loss = losses['test']
            save_checkpoint(i, model, optimizer, scheduler, best=True)
    
    # Regular checkpoint saving
    if i % checkpoint_interval == 0 and i > 0:
        save_checkpoint(i, model, optimizer, scheduler)
        

    try:
        batch = next(dataloader_iter)
        # process the batch
    except StopIteration:
        # Reinitialize the iterator if you reach the end
        dataloader_iter = iter(dataloader)
        batch = next(dataloader_iter)
    blocks = batch['blocks'].to(device)
    queries = batch['query'].to(device)
    targets = batch['target'].to(device)

    # evaluate the loss
    logits, loss = model(blocks, queries, targets)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    # Only step the scheduler after start_scheduler_at_iter
    if i >= start_scheduler_at_iter:
        scheduler.step(loss)

5.942407 M parameters


  0%|          | 0/20000 [00:00<?, ?it/s]

step 0: train loss 3.3188, test loss 3.3167, lr 3.00e-04
Saved checkpoint at iteration 0
step 100: train loss 2.8819, test loss 2.8825, lr 3.00e-04
Saved checkpoint at iteration 100
step 200: train loss 2.8322, test loss 2.8283, lr 3.00e-04
Saved checkpoint at iteration 200
step 300: train loss 2.8184, test loss 2.8155, lr 3.00e-04
Saved checkpoint at iteration 300
step 400: train loss 2.8121, test loss 2.8116, lr 3.00e-04
Saved checkpoint at iteration 400
step 500: train loss 2.8157, test loss 2.8120, lr 3.00e-04
step 600: train loss 2.8075, test loss 2.8027, lr 3.00e-04
Saved checkpoint at iteration 600
step 700: train loss 2.8020, test loss 2.7945, lr 3.00e-04
Saved checkpoint at iteration 700
step 800: train loss 2.7998, test loss 2.7955, lr 3.00e-04
step 900: train loss 2.8009, test loss 2.7962, lr 3.00e-04
step 1000: train loss 2.8000, test loss 2.7950, lr 3.00e-04
Saved checkpoint at iteration 1000
step 1100: train loss 2.7928, test loss 2.7904, lr 3.00e-04
Saved checkpoint at i

In [18]:
checkpoint = torch.load("checkpoints_stan/best_model.pt")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

StandardAttentionRetrievalModel(
  (token_embedding_table): Embedding(27, 256)
  (position_embedding_table): Embedding(299, 256)
  (blocks): ModuleList(
    (0-3): 4 x Block(
      (sa): StandardMultiAttention(
        (block_proj): Linear(in_features=256, out_features=512, bias=True)
        (query_proj): Linear(in_features=256, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (proj): Linear(in_features=256, out_features=256, bias=True)
      )
      (ffwd_blocks): FeedFoward(
        (net): Sequential(
          (0): Linear(in_features=256, out_features=1024, bias=True)
          (1): ReLU()
          (2): Linear(in_features=1024, out_features=256, bias=True)
          (3): Dropout(p=0.1, inplace=False)
        )
      )
      (ffwd_queries): FeedFoward(
        (net): Sequential(
          (0): Linear(in_features=256, out_features=1024, bias=True)
          (1): ReLU()
          (2): Linear(in_features=1024, out_features=256, bias=True)
         

In [19]:

test_dataset = MTADataset(f"{path_folder}mta_test.txt", variant="all")#, tokenizer)
# Modified DataLoader with deterministic shuffling
test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    worker_init_fn=seed_worker,
    generator=g,
    num_workers=4,
    persistent_workers=True
)

full_output = np.array([])
full_targets = np.array([])
for batch in tqdm(test_dataloader):
    blocks = batch['blocks'].to(device)
    queries = batch['query'].to(device)
    targets = batch['target'].to(device)
    
    # evaluate the loss
    logits, loss = model(blocks, queries, targets=None)
    output = torch.argmax(logits.view(-1, N, vocab_size), dim=-1)
    
    output = np.array([decode(i.tolist()) for i in output])
    full_output = np.append(full_output, output)
    targets = np.array([decode(i.tolist()) for i in targets])
    full_targets = np.append(full_targets, targets)

  0%|          | 0/16 [00:00<?, ?it/s]

In [1]:
acc = sum((full_output == full_targets).astype(int))/ len(targets)
print(f"The Accuracy of the Standard Attention Model on the test set is {acc:.4f}%")

The Accuracy of the Standard Attention Model on the test set is 0.0000%
