In [12]:
import pandas as pd
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import tiktoken

#special token ids
enc = tiktoken.get_encoding('cl100k_base')
SOS_TOKEN_ID = enc.encode("<|startoftext|>", allowed_special="all")[0]
EOS_TOKEN_ID = enc.encode("<|endoftext|>", allowed_special="all")[0]
PAD_TOKEN_ID = 0

# HP
batch_size = 32
block_size = 32
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vocab_size = 100_000
n_embed = 256
n_layers = 6
learning_rate = 1e-3
max_iters = 10000
eval_interval = 100
n_heads = 4


In [13]:
df = pd.read_csv("/content/sample_data/financial_narration_data.csv")

# functions prep pairs each way,  process df into tokenised id tuples

#json to str
json_to_str_dataset = []
for _, row in df.iterrows():
  encoder_input_ids = enc.encode(row['prompt'])
  decoder_target_ids = enc.encode(row['target']) + [EOS_TOKEN_ID]
  decoder_input_ids = [SOS_TOKEN_ID] + decoder_target_ids[:-1]
  json_to_str_dataset.append((encoder_input_ids, decoder_input_ids, decoder_target_ids))

#str to json
str_to_json_dataset = []
for _, row in df.iterrows():
  encoder_input_ids = enc.encode(row['target'])
  decoder_target_ids = enc.encode(row['prompt']) + [EOS_TOKEN_ID]
  decoder_input_ids = [SOS_TOKEN_ID] + decoder_target_ids[:-1]
  json_to_str_dataset.append((encoder_input_ids, decoder_input_ids, decoder_target_ids))

#split
n_j2s = int(0.9 * len(json_to_str_dataset))
train_json_to_str = json_to_str_dataset[:n_j2s]
val_json_to_str = json_to_str_dataset[n_j2s:]

n_s2j = int(0.9 * len(str_to_json_dataset))
train_str_to_json = str_to_json_dataset[:n_s2j]
val_str_to_json = str_to_json_dataset[n_j2s:]


In [14]:
# general attention head
class AttentionHead(nn.Module):
  def __init__(self, n_embed, head_size):
    super().__init__()
    self.key = nn.Linear(n_embed, head_size, bias=False)
    self.query = nn.Linear(n_embed, head_size, bias=False)
    self.value = nn.Linear(n_embed, head_size, bias=False)
    self.dropout = nn.Dropout(0.1)
    self.head_size = head_size

  # takes separate q_input, k_input and v_input (correct for a generic head)
  def forward(self, query_input, key_input, value_input, mask=None):
    q = self.query(query_input)
    k = self.key(key_input)
    v = self.value(value_input)

    wei = q @ k.transpose(-2, -1) * (self.head_size**-0.5)

    if mask is not None:
      wei = wei.masked_fill(mask == 0, float('-inf'))
    wei = F.softmax(wei, dim=-1)
    wei = self.dropout(wei)
    out = wei @ v
    return out

# multi-head attention, for selfattention with a causal flag
class MultiHeadAttention(nn.Module):
  def __init__(self, n_embed, n_heads, block_size, is_causal):
    super().__init__()
    self.n_heads = n_heads
    self.head_size = n_embed // n_heads

    self.heads = nn.ModuleList([
        AttentionHead(n_embed, self.head_size) for _ in range(n_heads)
    ])
    self.proj = nn.Linear(n_heads * self.head_size, n_embed)
    self.dropout = nn.Dropout(0.1)
    self.is_causal = is_causal

    if self.is_causal:
      self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

  def forward(self, x, mask=None):
    B, T, C = x.shape
    local_mask = mask

    if self.is_causal:
      casual_mask = self.tril[:T, :T].view(1, 1, T, T)
      if local_mask is None:
        local_mask = casual_mask
      else:
        local_mask = local_mask * casual_mask

    out = torch.cat([h(x, x, x, mask=local_mask) for h in self.heads], dim=-1)
    out = self.dropout(self.proj(out))
    return out

# MHCA for encoder-decoder attention
class MultiHeadCrossAttention(nn.Module):
  def __init__(self, n_heads, n_embed, head_size):
    super().__init__()
    self.heads = nn.ModuleList([
        AttentionHead(n_embed, head_size) for _ in range(n_heads)
    ])
    self.proj = nn.Linear(n_heads * head_size, n_embed)
    self.dropout = nn.Dropout(0.1)

  def forward(self, query_input, key_value_input, mask=None):
    out = torch.cat([h(query_input, key_value_input, key_value_input, mask=mask) for h in self.heads], dim=-1)
    out = self.dropout(self.proj(out))
    return out

In [15]:
class FeedForward(nn.Module):
  def __init__(self, n_embed):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embed, 4 * n_embed),
        nn.ReLU(),
        nn.Linear(4 * n_embed, n_embed),
        nn.Dropout(0.1)
    )

  def forward(self, x):
      return self.net(x)

In [16]:
class EncoderBlock(nn.Module):
  def __init__(self, n_embed, n_heads, block_size):
    super().__init__()
    head_size = n_embed // n_heads
    self.ln1 = nn.LayerNorm(n_embed)
    self.ln2 = nn.LayerNorm(n_embed)
    self.sa = MultiHeadAttention(n_embed, n_heads, block_size, is_causal=False)
    self.ffwd = FeedForward(n_embed)

  def forward(self, x, mask=None):
    x = x + self.sa(self.ln1(x), mask=mask)
    x = x + self.ffwd(self.ln2(x))
    return x

class DecoderBlock(nn.Module):
  def __init__(self, n_embed, n_heads, block_size):
    super().__init__()
    head_size = n_embed // n_heads
    self.ln1 = nn.LayerNorm(n_embed)
    self.ln2 = nn.LayerNorm(n_embed)
    self.ln3 = nn.LayerNorm(n_embed)

    self.masked_sa = MultiHeadAttention(n_embed, n_heads, block_size, is_causal=True)
    self.cross_sa = MultiHeadCrossAttention(n_heads, n_embed, head_size)
    self.ffwd = FeedForward(n_embed)

  def forward(self, x, encoder_output, padding_mask=None):
    x = x + self.masked_sa(self.ln1(x), mask=padding_mask)
    x = x + self.cross_sa(self.ln2(x), encoder_output, mask=None)
    x = x + self.ffwd(self.ln3(x))
    return x

In [17]:
# remember dropout for embeddings
class Encoder(nn.Module):
  def __init__(self, vocab_size, n_embed, block_size, n_heads, n_layers):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
    self.position_embedding_table = nn.Embedding(block_size, n_embed)
    self.blocks = nn.Sequential(*[
        EncoderBlock(n_embed, n_heads, block_size) for _ in range(n_layers)
    ])
    self.ln_f = nn.LayerNorm(n_embed)
    self.dropout = nn.Dropout(0.1)
    self.block_size = block_size

  def forward(self, idx, padding_mask=None):
    B, T = idx.shape
    tok_emb = self.token_embedding_table(idx)
    pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
    x =self.dropout(tok_emb + pos_emb)

    for block in self.blocks:
        x = block(x, mask=padding_mask)

    x = self.ln_f(x)
    return x

In [18]:
class Decoder(nn.Module):
  def __init__(self, vocab_size, n_embed, block_size, n_heads, n_layers):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
    self.position_embedding_table = nn.Embedding(block_size, n_embed)
    self.blocks = nn.ModuleList([
        DecoderBlock(n_embed, n_heads, block_size) for _ in range(n_layers)
    ])
    self.ln_f = nn.LayerNorm(n_embed)
    self.lm_head = nn.Linear(n_embed, vocab_size)
    self.dropout = nn.Dropout(0.1)
    self.block_size = block_size

  def forward(self, idx, encoder_output, padding_mask=None):
    B, T = idx.shape
    tok_emb = self.token_embedding_table(idx)
    pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
    x = self.dropout(tok_emb + pos_emb)

    for block in self.blocks:
      x = block(x, encoder_output, padding_mask=padding_mask)

      x = self.ln_f(x)
      logits = self.lm_head(x)
      return logits

In [19]:
class EncoderDecoderTransformer(nn.Module):
  def __init__(self, vocab_size, n_embed, block_size, n_heads, n_layers):
    super().__init__()
    self.encoder = Encoder(vocab_size, n_embed, block_size, n_heads, n_layers)
    self.decoder = Decoder(vocab_size, n_embed, block_size, n_heads, n_layers)

  def forward(self, encoder_input_ids, decoder_input_ids, targets=None,
              encoder_padding_mask=None, decoder_padding_mask=None):
      encoder_output = self.encoder(encoder_input_ids, padding_mask=encoder_padding_mask)
      logits = self.decoder(decoder_input_ids, encoder_output, padding_mask=decoder_padding_mask)

      loss = None
      if targets is not None:
        B, T, C = logits.shape
        logits_reshaped = logits.view(B*T, C)
        targets_reshaped = targets.view(B*T)
        loss = F.cross_entropy(logits_reshaped, targets_reshaped, ignore_index = PAD_TOKEN_ID)

      return logits, loss

  def generate(self, encoder_input_ids, max_new_tokens, encoder_padding_mask=None):
    encoder_output = self.encoder(encoder_input_ids, padding_mask=encoder_padding_mask)
    # batch size
    B = encoder_input_ids[0]
    decoder_input_ids = torch.full((B, 1), SOS_TOKEN_ID, dtype=torch.long, device=encoder_input_ids.device)

    for _ in range(max_new_tokens):
      decoder_input_cond = decoder_input_ids[:, -self.decoder.block_size:]
      logits = self.decoder(decoder_input_cond, encoder_output)
      logits = logits[:, -1, :]

      probs = F.softmax(logits, dim=-1)
      next_token = torch.multinomial(probs, num_samples=1)

      decoder_input_ids = torch.cat((decoder_input_ids, next_token), dim=1)

      if (next_token == EOS_TOKEN_ID).all():
          break
    return decoder_input_ids





In [22]:
def get_batch_seq2seq(split, direction='json_to_str'):
  if direction == 'json_to_str':
      data_source = train_json_to_str if split == 'train' else val_json_to_str
  else:
      data_source = train_str_to_json if split == 'train' else val_str_to_json

  batch_enc_in, batch_dec_in, batch_dec_tgt = [], [], []
  indices = torch.randint(0, len(data_source), (batch_size,))

  for i in indices:
    enc_ids, dec_in_ids, dec_tgt_ids = data_source[i.item()]
    batch_enc_in.append(torch.tensor(enc_ids, dtype=torch.long))
    batch_dec_in.append(torch.tensor(dec_in_ids, dtype=torch.long))
    batch_dec_tgt.append(torch.tensor(dec_tgt_ids, dtype=torch.long))

  max_enc_len = min(max(len(ids) for ids in batch_enc_in), block_size)
  max_dec_len = min(max(len(ids) for ids in batch_dec_in), block_size)

  # init with pad token
  padded_enc_inputs = torch.full((batch_size, max_enc_len), PAD_TOKEN_ID, dtype=torch.long, device=device)
  padded_dec_inputs = torch.full((batch_size, max_dec_len), PAD_TOKEN_ID, dtype=torch.long, device=device)
  padded_dec_targets = torch.full((batch_size, max_dec_len), PAD_TOKEN_ID, dtype=torch.long, device=device)

  #masks , 1 for real token , 0 for padding
  encoder_padding_mask = torch.zeros(batch_size, max_enc_len, max_enc_len, dtype=torch.bool, device=device)
  decoder_padding_mask = torch.zeros(batch_size, max_dec_len, max_dec_len, dtype=torch.bool, device=device)

  for i in range(batch_size):
      current_enc_len = len(batch_enc_in[i])
      current_dec_len = len(batch_dec_in[i])

      padded_enc_inputs[i, :min(current_enc_len, max_enc_len)] = batch_enc_in[i][:min(current_enc_len, max_enc_len)]
      padded_dec_inputs[i, :min(current_dec_len, max_dec_len)] = batch_dec_in[i][:min(current_dec_len, max_dec_len)]
      padded_dec_targets[i, :min(current_dec_len, max_dec_len)] = batch_dec_tgt[i][:min(current_dec_len, max_dec_len)]

      encoder_padding_mask[i, :min(current_enc_len, max_enc_len), :min(current_enc_len, max_enc_len)] = 1
      decoder_padding_mask[i, :min(current_dec_len, max_dec_len), :min(current_dec_len, max_dec_len)] = 1

  return padded_enc_inputs, padded_dec_inputs, padded_dec_targets, encoder_padding_mask, decoder_padding_mask

In [23]:
#two models one for each direction

model_json_to_str = EncoderDecoderTransformer(
    vocab_size=vocab_size, n_embed=n_embed, n_heads=n_heads, n_layers=n_layers, block_size=block_size
).to(device)
optimiser_json_to_str = optim.AdamW(model_json_to_str.parameters(), lr=learning_rate)


model_str_to_json = EncoderDecoderTransformer(
    vocab_size=vocab_size, n_embed=n_embed, n_heads=n_heads, n_layers=n_layers, block_size=block_size
).to(device)
optimiser_str_to_json = optim.AdamW(model_str_to_json.parameters(), lr=learning_rate)

#track best val loss for both
best_val_loss_j2s = float('inf')
best_val_loss_s2j = float('inf')

for step in range(max_iters):
  #training j2s
  model_json_to_str.train()
  enc_in_j2s, dec_in_j2s, dec_tgt_j2s, enc_mask_j2s, dec_mask_j2s = \
    get_batch_seq2seq('train', direction='json_to_str')
  logits_j2s, loss_j2s = model_json_to_str(enc_in_j2s, dec_in_j2s, dec_tgt_j2s, enc_mask_j2s, dec_mask_j2s)
  optimiser_json_to_str.zero_grad()
  loss_j2s.backward()
  optimiser_json_to_str.step()

  #training s2j
  model_str_to_json.train()
  enc_in_s2j, dec_in_s2j, dec_tgt_s2j, enc_mask_s2j, dec_mask_s2j = \
    get_batch_seq2seq('train', direction='str_to_json')
  logits_s2j, loss_s2j = model_str_to_json(enc_in_s2j, dec_in_s2j, dec_tgt_s2j, enc_mask_s2j, dec_mask_s2j)
  optimiser_str_to_json.zero_grad()
  loss_s2j.backward()
  optimiser_str_to_json.step()

  #evaluate
  if step % eval_interval == 0:
    model_json_to_str.eval()
    model_str_to_json.eval()
    with torch.no_grad():
      #j2s validation
      val_enc_in_j2s, val_dec_in_j2s, val_dec_tgt_j2s, val_enc_mask_j2s, val_dec_mask_j2s = \
        get_batch_seq2seq('val', direction='json_to_str')
      _, val_loss_j2s = model_json_to_str(val_enc_in_j2s, val_dec_in_j2s, val_dec_tgt_j2s, val_enc_mask_j2s, val_dec_mask_j2s)

      #s2j validation
      val_enc_in_s2j, val_dec_in_s2j, val_dec_tgt_s2j, val_enc_mask_s2j, val_dec_mask_s2j = \
        get_batch_seq2seq('val', direction="str_to_json")
      _, val_loss_s2j = model_str_to_json(val_enc_in_s2j, val_dec_in_s2j, val_dec_tgt_s2j, val_enc_mask_s2j, val_dec_mask_s2j)

      print(f"[step {step}] J2S train loss: {loss_j2s.item():.4f} | J2S val loss: {val_loss_j2s.item():.4f}")
      print(f"[step {step}] S2J train loss: {loss_s2j.item():.4f} | S2J val loss: {val_loss_s2j.item():.4f}")

      #save checkpoints
      if val_loss_j2s.item() < best_val_loss_j2s:
        best_val_loss_j2s = val_loss_j2s.item()
        torch.save(model_json_to_str.state_dict(), 'best_model_json_to_str.pt')
        print("~~~~ Saved best J2S model ~~~~")

      if val_loss_s2j.item() < best_val_loss_s2j:
        best_val_loss_s2j = val_loss_s2j.item()
        torch.save(model_str_to_json.state_dict(), 'best_model_str_to_json.pt')
        print("~~~~ Saved best S2J model ~~~~")

print("~~~~ TRAINING COMPLETE ~~~~")

ValueError: too many values to unpack (expected 3)