In [1]:
import pandas as pd
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import tiktoken

#special token ids
enc = tiktoken.get_encoding('cl100k_base')
SOS_TOKEN_ID = enc.encode("<|startoftext|>", allowed_special="all")[0]
EOS_TOKEN_ID = enc.encode("<|endoftext|>", allowed_special="all")[0]
PAD_TOKEN_ID = 0

# HP
batch_size = 32
block_size = 32
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vocab_size = 100_000
n_embed = 256
n_layers = 4
learning_rate = 1e-3
max_iters = 10000
eval_interval = 100
n_heads = 4


In [2]:
df = pd.read_csv("/content/sample_data/financial_narration_data.csv")

# functions prep pairs each way,  process df into tokenised id tuples

#json to str
json_to_str_dataset = []
for _, row in df.iterrows():
  encoder_input_ids = enc.encode(row['prompt'])
  decoder_target_ids = enc.encode(row['target']) + [EOS_TOKEN_ID]
  decoder_input_ids = [SOS_TOKEN_ID] + decoder_target_ids[:-1]
  json_to_str_dataset.append((encoder_input_ids, decoder_input_ids, decoder_target_ids))

#str to json
str_to_json_dataset = []
for _, row in df.iterrows():
  encoder_input_ids = enc.encode(row['target'])
  decoder_target_ids = enc.encode(row['prompt']) + [EOS_TOKEN_ID]
  decoder_input_ids = [SOS_TOKEN_ID] + decoder_target_ids[:-1]
  str_to_json_dataset.append((encoder_input_ids, decoder_input_ids, decoder_target_ids))

#split
n_j2s = int(0.9 * len(json_to_str_dataset))
train_json_to_str = json_to_str_dataset[:n_j2s]
val_json_to_str = json_to_str_dataset[n_j2s:]

n_s2j = int(0.9 * len(str_to_json_dataset))
train_str_to_json = str_to_json_dataset[:n_s2j]
val_str_to_json = str_to_json_dataset[n_s2j:]

In [3]:
# general attention head
class AttentionHead(nn.Module):
  def __init__(self, n_embed, head_size):
    super().__init__()
    self.key = nn.Linear(n_embed, head_size, bias=False)
    self.query = nn.Linear(n_embed, head_size, bias=False)
    self.value = nn.Linear(n_embed, head_size, bias=False)
    self.dropout = nn.Dropout(0.1)
    self.head_size = head_size

  # takes separate q_input, k_input and v_input (correct for a generic head)
  def forward(self, query_input, key_input, value_input, mask=None):
    q = self.query(query_input)
    k = self.key(key_input)
    v = self.value(value_input)

    wei = q @ k.transpose(-2, -1) * (self.head_size**-0.5)

    if mask is not None:
      wei = wei.masked_fill(mask == 0, float('-inf'))
    wei = F.softmax(wei, dim=-1)
    wei = self.dropout(wei)
    out = wei @ v

    return out

# multi-head attention, for selfattention with a causal flag
class MultiHeadAttention(nn.Module):
  def __init__(self, n_embed, n_heads, block_size, is_causal):
    super().__init__()
    self.n_heads = n_heads
    self.head_size = n_embed // n_heads

    self.heads = nn.ModuleList([
        AttentionHead(n_embed, self.head_size) for _ in range(n_heads)
    ])
    self.proj = nn.Linear(n_heads * self.head_size, n_embed)
    self.dropout = nn.Dropout(0.1)
    self.is_causal = is_causal

    if self.is_causal:
      self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

  def forward(self, x, mask=None):
    B, T, C = x.shape
    local_mask = mask

    if self.is_causal:
      casual_mask = self.tril[:T, :T].to(x.device)
      if local_mask is None:
        local_mask = casual_mask
      else:
        if local_mask.ndim == 3:
            local_mask = local_mask.unsqueeze(1) #reshape to (B, 1, T, T)
        local_mask = local_mask * casual_mask

    # Pass the combined mask to the attention heads
    out = torch.cat([h(x, x, x, mask=local_mask) for h in self.heads], dim=-1)
    out = self.dropout(self.proj(out))
    return out

# MHCA for encoder-decoder attention
class MultiHeadCrossAttention(nn.Module):
  def __init__(self, n_heads, n_embed, head_size):
    super().__init__()
    self.heads = nn.ModuleList([
        AttentionHead(n_embed, head_size) for _ in range(n_heads)
    ])
    self.proj = nn.Linear(n_heads * head_size, n_embed)
    self.dropout = nn.Dropout(0.1)

  def forward(self, query_input, key_value_input, mask=None):
    local_mask = mask
    if local_mask is not None:
        # ensure mask has the correct shape for broadcasting with attention weights (B, H, T_q, T_k)
        if local_mask.ndim == 3:
            local_mask = local_mask.unsqueeze(1) #reshape to (B, 1, T_q, T_k)

    out = torch.cat([h(query_input, key_value_input, key_value_input, mask=local_mask) for h in self.heads], dim=-1)
    out = self.dropout(self.proj(out))
    return out

In [4]:
class FeedForward(nn.Module):
  def __init__(self, n_embed):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embed, 4 * n_embed),
        nn.ReLU(),
        nn.Linear(4 * n_embed, n_embed),
        nn.Dropout(0.1)
    )

  def forward(self, x):
      return self.net(x)

In [5]:
# general attention head
class AttentionHead(nn.Module):
  def __init__(self, n_embed, head_size):
    super().__init__()
    self.key = nn.Linear(n_embed, head_size, bias=False)
    self.query = nn.Linear(n_embed, head_size, bias=False)
    self.value = nn.Linear(n_embed, head_size, bias=False)
    self.dropout = nn.Dropout(0.1)
    self.head_size = head_size

  # takes separate q_input, k_input and v_input (correct for a generic head)
  def forward(self, query_input, key_input, value_input, mask=None):

    q = self.query(query_input) # (B, T_q, head_size)
    k = self.key(key_input)     # (B, T_k, head_size)
    v = self.value(value_input)   # (B, T_v, head_size)

    # compute attention scores (weights)
    # (B, T_q, head_size) @ (B, head_size, T_k) -> (B, T_q, T_k)
    k_transposed = k.transpose(-2, -1)

    wei = q @ k_transposed * (self.head_size**-0.5)

    if mask is not None:
      local_mask = mask
      if local_mask.ndim == 4 and local_mask.shape[1] == 1:
          local_mask = local_mask.squeeze(1)

      # Set masked values to a very small number instead of negative infinity
      wei = wei.masked_fill(local_mask == 0, -1e9)


    wei = F.softmax(wei, dim=-1)
    wei = self.dropout(wei)


    # attention weights to values
    # (B, T_q, T_k) @ (B, T_v, head_size) -> (B, T_q, head_size) (since T_k == T_v)
    # torch.bmm for explicit batch matrix multiplication
    wei_contiguous = wei.contiguous()
    v_contiguous = v.contiguous()
    out = torch.bmm(wei_contiguous, v_contiguous) # (B, T_q, head_size)
    return out

# multi-head attention, for selfattention with a causal flag
class MultiHeadAttention(nn.Module):
  def __init__(self, n_embed, n_heads, block_size, is_causal):
    super().__init__()
    self.n_heads = n_heads
    self.head_size = n_embed // n_heads

    self.heads = nn.ModuleList([
        AttentionHead(n_embed, self.head_size) for _ in range(n_heads)
    ])
    self.proj = nn.Linear(n_heads * self.head_size, n_embed)
    self.dropout = nn.Dropout(0.1)
    self.is_causal = is_causal

    if self.is_causal:
      self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

  def forward(self, x, mask=None):
    B, T, C = x.shape
    local_mask = mask


    if self.is_causal:
      casual_mask = self.tril[:T, :T].to(x.device)
      if local_mask is None:
        local_mask = casual_mask
      else:
        if local_mask.ndim == 3:
            local_mask = local_mask.unsqueeze(1)
        elif local_mask.ndim == 4:

             casual_mask = casual_mask.unsqueeze(1).expand(-1, self.n_heads, -1, -1)

        local_mask = local_mask * casual_mask


    out = torch.cat([h(x, x, x, mask=local_mask) for h in self.heads], dim=-1)


    out = self.dropout(self.proj(out))


    return out

# MHCA for encoder-decoder attention
class MultiHeadCrossAttention(nn.Module):
  def __init__(self, n_heads, n_embed, head_size):
    super().__init__()
    self.heads = nn.ModuleList([
        AttentionHead(n_embed, head_size) for _ in range(n_heads)
    ])
    self.proj = nn.Linear(n_heads * head_size, n_embed)
    self.dropout = nn.Dropout(0.1)

  def forward(self, query_input, key_value_input, mask=None):
    local_mask = mask

    if local_mask is not None:
        # if mask has the correct shape for broadcasting with attention weights (B, H, T_q, T_k)
        if local_mask.ndim == 3:
            local_mask = local_mask.unsqueeze(1) #reshape to (B, 1, T_q, T_k)

    out = torch.cat([h(query_input, key_value_input, key_value_input, mask=local_mask) for h in self.heads], dim=-1)


    out = self.dropout(self.proj(out))

    return out

class DecoderBlock(nn.Module):
  def __init__(self, n_embed, n_heads, block_size):
    super().__init__()
    head_size = n_embed // n_heads
    self.ln1 = nn.LayerNorm(n_embed)
    self.ln2 = nn.LayerNorm(n_embed)
    self.ln3 = nn.LayerNorm(n_embed)

    self.masked_sa = MultiHeadAttention(n_embed, n_heads, block_size, is_causal=True)
    self.cross_sa = MultiHeadCrossAttention(n_heads, n_embed, head_size)
    self.ffwd = FeedForward(n_embed)

  def forward(self, decoder_input, encoder_output, decoder_padding_mask=None, cross_attention_mask=None):
    # masked self-attention (uses decoder_padding_mask)
    # output is (B, T_dec, C)
    masked_sa_out = self.masked_sa(self.ln1(decoder_input), mask=decoder_padding_mask)
    decoder_output = decoder_input + masked_sa_out


    cross_sa_out = self.cross_sa(self.ln2(decoder_output), encoder_output, mask=cross_attention_mask)
    decoder_output = decoder_output + cross_sa_out


    # Feed-forward network
    ffwd_out = self.ffwd(self.ln3(decoder_output))
    decoder_output = decoder_output + ffwd_out


    return decoder_output

In [6]:
class EncoderBlock(nn.Module):
  def __init__(self, n_embed, n_heads, block_size):
    super().__init__()
    head_size = n_embed // n_heads
    self.ln1 = nn.LayerNorm(n_embed)
    self.ln2 = nn.LayerNorm(n_embed)
    self.self_attention = MultiHeadAttention(n_embed, n_heads, block_size, is_causal=False) #enc uses non-causal self-attention
    self.ffwd = FeedForward(n_embed)

  def forward(self, x, padding_mask=None):
    x = x + self.self_attention(self.ln1(x), mask=padding_mask)

    # Feed-forward network
    x = x + self.ffwd(self.ln2(x))
    return x

In [7]:
#remember dropout for embeddings
class Encoder(nn.Module):
  def __init__(self, vocab_size, n_embed, block_size, n_heads, n_layers):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
    self.position_embedding_table = nn.Embedding(block_size, n_embed)
    self.blocks = nn.Sequential(*[
        EncoderBlock(n_embed, n_heads, block_size) for _ in range(n_layers)
    ])
    self.ln_f = nn.LayerNorm(n_embed)
    self.dropout = nn.Dropout(0.1)
    self.block_size = block_size

  def forward(self, idx, padding_mask=None):
    B, T = idx.shape
    tok_emb = self.token_embedding_table(idx)
    pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
    x =self.dropout(tok_emb + pos_emb)

    for i, block in enumerate(self.blocks):
            x = block(x, padding_mask=padding_mask)


    x = self.ln_f(x)
    return x

In [8]:
class Decoder(nn.Module):
  def __init__(self, vocab_size, n_embed, block_size, n_heads, n_layers):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
    self.position_embedding_table = nn.Embedding(block_size, n_embed)
    self.blocks = nn.ModuleList([
        DecoderBlock(n_embed, n_heads, block_size) for _ in range(n_layers)
    ])
    self.ln_f = nn.LayerNorm(n_embed)
    self.lm_head = nn.Linear(n_embed, vocab_size)
    self.dropout = nn.Dropout(0.1)
    self.block_size = block_size

  def forward(self, decoder_input_ids, encoder_output, decoder_padding_mask=None, cross_attention_mask=None):
    B, T = decoder_input_ids.shape
    tok_emb = self.token_embedding_table(decoder_input_ids)
    pos_emb = self.position_embedding_table(torch.arange(T, device=decoder_input_ids.device))
    x = self.dropout(tok_emb + pos_emb)
    print("Decoder input embeddings shape:", x.shape)
    print("Decoder input embeddings (first 5 elements):", x.flatten()[:5])

    for i, block in enumerate(self.blocks):
      print(f"Before DecoderBlock {i} output (first 5 elements):", x.flatten()[:5])
      #pass decoder_padding_mask and cross_attention_mask with correct keyword arguments
      x = block(x, encoder_output, decoder_padding_mask=decoder_padding_mask, cross_attention_mask=cross_attention_mask)
      print(f"After DecoderBlock {i} output shape:", x.shape)
      print(f"After DecoderBlock {i} output (first 5 elements):", x.flatten()[:5])


    x = self.ln_f(x)
    print("After final LayerNorm shape:", x.shape)
    print("After final LayerNorm (first 5 elements):", x.flatten()[:5])
    logits = self.lm_head(x)
    print("After lm_head shape:", logits.shape)
    print("After lm_head (first 5 elements):", logits.flatten()[:5])
    return logits

In [9]:
class EncoderDecoderTransformer(nn.Module):
  def __init__(self, vocab_size, n_embed, block_size, n_heads, n_layers):
    super().__init__()
    self.encoder = Encoder(vocab_size, n_embed, block_size, n_heads, n_layers)
    self.decoder = Decoder(vocab_size, n_embed, block_size, n_heads, n_layers)

  def forward(self, encoder_input_ids, decoder_input_ids, targets=None,
              encoder_padding_mask=None, decoder_padding_mask=None, cross_attention_mask=None):
      print("Encoder input shape:", encoder_input_ids.shape)
      encoder_output = self.encoder(encoder_input_ids, padding_mask=encoder_padding_mask)
      print("Encoder output shape:", encoder_output.shape)
      print("Decoder input shape:", decoder_input_ids.shape)
      logits = self.decoder(decoder_input_ids, encoder_output, decoder_padding_mask=decoder_padding_mask, cross_attention_mask=cross_attention_mask)
      print("Logits shape:", logits.shape)

      loss = None
      if targets is not None:
        B, T, C = logits.shape
        logits_reshaped = logits.view(B*T, C)
        targets_reshaped = targets.view(B*T)

        # target values that are EOS_TOKEN_ID to PAD_TOKEN_ID so they are ignored
        # Ensure targets are within valid range for vocab_size
        targets_reshaped[targets_reshaped >= self.decoder.lm_head.out_features] = PAD_TOKEN_ID
        targets_reshaped[targets_reshaped == EOS_TOKEN_ID] = PAD_TOKEN_ID


        loss = F.cross_entropy(logits_reshaped, targets_reshaped, ignore_index = PAD_TOKEN_ID)
        print("Loss:", loss.item())

      return logits, loss

  def generate(self, encoder_input_ids, max_new_tokens, encoder_padding_mask=None):
    encoder_output = self.encoder(encoder_input_ids, padding_mask=encoder_padding_mask)
    # batch size
    B = encoder_input_ids.shape[0]
    decoder_input_ids = torch.full((B, 1), SOS_TOKEN_ID, dtype=torch.long, device=encoder_input_ids.device)

    for _ in range(max_new_tokens):
      decoder_input_cond = decoder_input_ids[:, -self.decoder.block_size:]

      _, T_enc, _ = encoder_output.shape
      _, T_dec_cond = decoder_input_cond.shape
      cross_attention_mask = torch.ones(B, T_dec_cond, T_enc, dtype=torch.bool, device=encoder_input_ids.device)


      logits = self.decoder(decoder_input_cond, encoder_output, cross_attention_mask=cross_attention_mask)
      logits = logits[:, -1, :]

      probs = F.softmax(logits, dim=-1)
      next_token = torch.multinomial(probs, num_samples=1)

      decoder_input_ids = torch.cat((decoder_input_ids, next_token), dim=1)

      if (next_token == EOS_TOKEN_ID).all():
          break
    return decoder_input_ids

In [10]:
def get_batch_seq2seq(split, direction='json_to_str'):
  if direction == 'json_to_str':
      data_source = train_json_to_str if split == 'train' else val_json_to_str
  else:
      data_source = train_str_to_json if split == 'train' else val_str_to_json

  batch_enc_in, batch_dec_in, batch_dec_tgt = [], [], []
  indices = torch.randint(0, len(data_source), (batch_size,))

  for i in indices:
    enc_ids, dec_in_ids, dec_tgt_ids = data_source[i.item()]
    batch_enc_in.append(torch.tensor(enc_ids, dtype=torch.long))
    batch_dec_in.append(torch.tensor(dec_in_ids, dtype=torch.long))
    batch_dec_tgt.append(torch.tensor(dec_tgt_ids, dtype=torch.long))

  # Use pad_sequence for efficient padding
  padded_enc_inputs = nn.utils.rnn.pad_sequence(batch_enc_in, batch_first=True, padding_value=PAD_TOKEN_ID)
  padded_dec_inputs = nn.utils.rnn.pad_sequence(batch_dec_in, batch_first=True, padding_value=PAD_TOKEN_ID)
  padded_dec_targets = nn.utils.rnn.pad_sequence(batch_dec_tgt, batch_first=True, padding_value=PAD_TOKEN_ID)

  #masks , 1 for real token , 0 for padding
  # (B, T_enc) -> (B, 1, T_enc) -> (B, T_enc, T_enc)
  encoder_padding_mask = (padded_enc_inputs != PAD_TOKEN_ID).unsqueeze(1)
  encoder_padding_mask = encoder_padding_mask * encoder_padding_mask.transpose(-2, -1)

  # (B, T_dec) -> (B, 1, T_dec) -> (B, T_dec, T_dec)
  decoder_padding_mask = (padded_dec_inputs != PAD_TOKEN_ID).unsqueeze(1)
  decoder_padding_mask = decoder_padding_mask * decoder_padding_mask.transpose(-2, -1)

  cross_attention_mask = (padded_dec_inputs != PAD_TOKEN_ID).unsqueeze(-1) * \
                         (padded_enc_inputs != PAD_TOKEN_ID).unsqueeze(1)

  return padded_enc_inputs.to(device), padded_dec_inputs.to(device), padded_dec_targets.to(device), \
         encoder_padding_mask.to(device), decoder_padding_mask.to(device), cross_attention_mask.to(device)

In [11]:
#two models one for each direction

model_json_to_str = EncoderDecoderTransformer(
    vocab_size=vocab_size, n_embed=n_embed, n_heads=n_heads, n_layers=n_layers, block_size=block_size
).to(device)
optimiser_json_to_str = optim.AdamW(model_json_to_str.parameters(), lr=learning_rate)


model_str_to_json = EncoderDecoderTransformer(
    vocab_size=vocab_size, n_embed=n_embed, n_heads=n_heads, n_layers=n_layers, block_size=block_size
).to(device)
optimiser_str_to_json = optim.AdamW(model_str_to_json.parameters(), lr=learning_rate)

#track best val loss for both
best_val_loss_j2s = float('inf')
best_val_loss_s2j = float('inf')

for step in range(max_iters):
  #training j2s
  model_json_to_str.train()
  enc_in_j2s, dec_in_j2s, dec_tgt_j2s, enc_mask_j2s, dec_mask_j2s, cross_mask_j2s = \
    get_batch_seq2seq('train', direction='json_to_str')

  logits_j2s, loss_j2s = model_json_to_str(enc_in_j2s, dec_in_j2s, dec_tgt_j2s, enc_mask_j2s, dec_mask_j2s, cross_mask_j2s)

  #check for NaN loss before backpropagation
  if torch.isnan(loss_j2s):
      print(f"NaN loss encountered for J2S model at step {step}. Skipping backpropagation.")
      #print gradients if they exist and are NaN
      for name, param in model_json_to_str.named_parameters():
          if param.grad is not None and torch.isnan(param.grad).any():
              print(f"  NaN gradient in {name}")
  else:
      optimiser_json_to_str.zero_grad()
      loss_j2s.backward()
      optimiser_json_to_str.step()

  #training s2j
  model_str_to_json.train()
  enc_in_s2j, dec_in_s2j, dec_tgt_s2j, enc_mask_s2j, dec_mask_s2j, cross_mask_s2j = \
    get_batch_seq2seq('train', direction='str_to_json')

  logits_s2j, loss_s2j = model_str_to_json(enc_in_s2j, dec_in_s2j, dec_tgt_s2j, enc_mask_s2j, dec_mask_s2j, cross_mask_s2j)

  #check for NaN loss before backpropagation
  if torch.isnan(loss_s2j):
      print(f"NaN loss encountered for S2J model at step {step}. Skipping backpropagation.")
      #print gradients if they exist and are NaN
      for name, param in model_str_to_json.named_parameters():
          if param.grad is not None and torch.isnan(param.grad).any():
              print(f"  NaN gradient in {name}")
  else:
      optimiser_str_to_json.zero_grad()
      loss_s2j.backward()
      optimiser_str_to_json.step()

  #evaluate and save checkpoints
  if step % eval_interval == 0:
    model_json_to_str.eval()
    model_str_to_json.eval()
    with torch.no_grad():
      #j2s validation
      val_enc_in_j2s, val_dec_in_j2s, val_dec_tgt_j2s, val_enc_mask_j2s, val_dec_mask_j2s, val_cross_mask_j2s = \
        get_batch_seq2seq('val', direction='json_to_str')

      _, val_loss_j2s = model_json_to_str(val_enc_in_j2s, val_dec_in_j2s, val_dec_tgt_j2s, val_enc_mask_j2s, val_dec_mask_j2s, val_cross_mask_j2s)

      #s2j validation
      val_enc_in_s2j, val_dec_in_s2j, val_dec_tgt_s2j, val_enc_mask_s2j, val_dec_mask_s2j, val_cross_mask_s2j = \
        get_batch_seq2seq('val', direction="str_to_json")

      _, val_loss_s2j = model_str_to_json(val_enc_in_s2j, val_dec_in_s2j, val_dec_tgt_s2j, val_enc_mask_s2j, val_dec_mask_s2j, val_cross_mask_s2j)

      print(f"[step {step}] J2S train loss: {loss_j2s.item():.4f} | J2S val loss: {val_loss_j2s.item():.4f}")
      print(f"[step {step}] S2J train loss: {loss_s2j.item():.4f} | S2J val loss: {val_loss_s2j.item():.4f}")

      #save checkpoints
      if not torch.isnan(val_loss_j2s) and val_loss_j2s.item() < best_val_loss_j2s:
        best_val_loss_j2s = val_loss_j2s.item()
        torch.save(model_json_to_str.state_dict(), 'best_model_json_to_str.pt')
        print("~~~~ Saved best J2S model ~~~~")

      if not torch.isnan(val_loss_s2j) and val_loss_s2j.item() < best_val_loss_s2j:
        best_val_loss_s2j = val_loss_s2j.item()
        torch.save(model_str_to_json.state_dict(), 'best_model_str_to_json.pt')
        print("~~~~ Saved best S2J model ~~~~")


print("~~~~ TRAINING COMPLETE ~~~~")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Loss: 0.17820023000240326
Encoder input shape: torch.Size([32, 18])
Encoder output shape: torch.Size([32, 18, 256])
Decoder input shape: torch.Size([32, 26])
Decoder input embeddings shape: torch.Size([32, 26, 256])
Decoder input embeddings (first 5 elements): tensor([-0.6779,  0.1377, -2.1049,  1.0556, -2.6279], device='cuda:0',
       grad_fn=<SliceBackward0>)
Before DecoderBlock 0 output (first 5 elements): tensor([-0.6779,  0.1377, -2.1049,  1.0556, -2.6279], device='cuda:0',
       grad_fn=<SliceBackward0>)
After DecoderBlock 0 output shape: torch.Size([32, 26, 256])
After DecoderBlock 0 output (first 5 elements): tensor([-51.2885, -61.5410, -24.1211, -23.9731,  19.7946], device='cuda:0',
       grad_fn=<SliceBackward0>)
Before DecoderBlock 1 output (first 5 elements): tensor([-51.2885, -61.5410, -24.1211, -23.9731,  19.7946], device='cuda:0',
       grad_fn=<SliceBackward0>)
After DecoderBlock 1 output shape: torch.

In [None]:
# solated test for MultiHeadAttention
B, T, C = batch_size, block_size, n_embed
dummy_input = torch.randn(B, T, C, device=device)
dummy_mask = torch.ones(B, T, T, dtype=torch.bool, device=device)

try:
    print("Testing MultiHeadAttention with is_causal=False")
    mha_test_causal_false = MultiHeadAttention(n_embed, n_heads, block_size, is_causal=False).to(device)
    output_causal_false = mha_test_causal_false(dummy_input, mask=dummy_mask)
    print(f"MultiHeadAttention (is_causal=False) output shape: {output_causal_false.shape}")
except ValueError as e:
    print(f"Error during MultiHeadAttention (is_causal=False) test: {e}")

try:
    print("\nTesting MultiHeadAttention with is_causal=True")
    mha_test_causal_true = MultiHeadAttention(n_embed, n_heads, block_size, is_causal=True).to(device)
    output_causal_true = mha_test_causal_true(dummy_input, mask=dummy_mask)
    print(f"MultiHeadAttention (is_causal=True) output shape: {output_causal_true.shape}")
except ValueError as e:
    print(f"Error during MultiHeadAttention (is_causal=True) test: {e}")

try:
    print("\nTesting AttentionHead in isolation")
    head_test = AttentionHead(n_embed, n_embed // n_heads).to(device)
    output_head_test = head_test(dummy_input, dummy_input, dummy_input, mask=dummy_mask)
    print(f"AttentionHead output shape: {output_head_test.shape}")
except ValueError as e:
     print(f"Error during AttentionHead test: {e}")

Testing MultiHeadAttention with is_causal=False
MultiHeadAttention (is_causal=False) output shape: torch.Size([32, 32, 32, 256])

Testing MultiHeadAttention with is_causal=True
MultiHeadAttention (is_causal=True) output shape: torch.Size([32, 32, 32, 256])

Testing AttentionHead in isolation
AttentionHead output shape: torch.Size([32, 32, 32, 64])
