In [None]:
import pandas as pd
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import tiktoken

#special token ids
enc = tiktoken.get_encoding('cl100k_base')
SOS_TOKEN_ID = enc.encode("<|startoftext|>", allowed_special="all")[0]
EOS_TOKEN_ID = enc.encode("<|endoftext|>", allowed_special="all")[0]
PAD_TOKEN_ID = 0

# HP
batch_size = 32
block_size = 32
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vocab_size = 100_000
n_embed = 256
n_layers = 6
learning_rate = 1e-3
max_iters = 10000
eval_interval = 100


In [None]:
df = pd.read_csv("/content/sample_data/financial_narration_data.csv")

# functions prep pairs each way,  process df into tokenised id tuples

#json to str
json_to_str_dataset = []
for _, row in df.iterrows():
  encoder_input_ids = enc.encode(row['prompt'])
  decoder_target_ids = enc.encode(row['target']) + [EOS_TOKEN_ID]
  decoder_input_ids = [SOS_TOKEN_ID] + decoder_target_ids[:-1]
  json_to_str_dataset.append((encoder_input_ids, decoder_input_ids, decoder_target_ids))

#str to json
str_to_json_dataset = []
for _, row in df.iterrows():
  encoder_input_ids = enc.encode(row['target'])
  decoder_target_ids = enc.encode(row['prompt']) + [EOS_TOKEN_ID]
  decoder_input_ids = [SOS_TOKEN_ID] + decoder_target_ids[:-1]
  json_to_str_dataset.append((encoder_input_ids, decoder_input_ids, decoder_target_ids))

#split
n_j2s = int(0.9 * len(json_to_str_dataset))
train_json_to_str = json_to_str_dataset[:n_j2s]
val_json_to_str = json_to_str_dataset[n_j2s:]

n_s2j = int(0.9 * len(str_to_json_dataset))
train_str_to_json = str_to_json_dataset[:n_s2j]
val_str_to_json = str_to_json_dataset[n_j2s:]


In [None]:
# general attention head
class AttentionHead(nn.Module):
  def __init__(self, n_embed, head_size):
    super().__init__()
    self.key = nn.Linear(n_embed, head_size, block_size, bias=False)
    self.query = nn.Linear(n_embed, head_size, block_size, bias=False)
    self.value = nn.Linear(n_embed, head_size, block_size, bias=False)
    self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
    self.dropout = nn.Dropout(0.1)
    self.head_size = head_size

  # takes seperateq q_input, k_input and v_input
  def forward(self, query_input, key_input, value_input, mask=None):
    q = self.query(query_input)
    k = self.key(key_input)
    v = self.value(value_input)

    wei = q @ k.transpose(-2. -1) / self.head_size**0.5

    if mask is not None:
      wei = wei.masked_fill(mask == 0, float('-inf'))

    wei = F.softmax(wei, dim=-1)
    wei = self.dropout(wei)
    out = wei @ v
    return out

# multi-head attention, for selfattention with a casual flag
class MultiHeadAttention(nn.Module):
  def __init__(self, n_heads,n_embed, head_size, block_size, is_casual=True):
    super().__init__()
    self.heads = nn.ModuleList([
        AttentionHead(n_embed, head_size) for _ in range(n_heads)
    ])
    self.proj = nn.Linear(n_heads * head_size, n_embed)
    self.dropout = nn.Dropout(0.1)
    self.is_casual = is_casual
    if self.is_casual:
      self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

  def forward(self, x, mask=None):
    B, T, C = x.shape
    local_mask = mask #external

    if self.is_casual:
      casual_mask = self.tril[:T, :T].view(1, T, T)
      if local_mask is None:
        local_mask = casual_mask
      else:
        local_mask = local_mask * casual_mask

    out = torch.cat([h(x, x, x, mask=local_mask) for h in self.heads], dim=-1)
    out = self.dropout(self.proj(out))
    return out

# MHCA for encoder-decoder attention
  class MultiHeadCrossAttention(nn.Module):
    def __init__(self, n_heads, n_embed, head_size):
      super().__init__()
      self.heads = nn.ModuleList([
          AttentionHead(n_embed, head_size) for _ in range(n_heads)
      ])
      self.proj = nn.Linear(n_heads * head_size, n_embed)
      self.dropout = nn.Dropout(0.1)

    def forward(self, query_input, key_value_input, mask=None):
      out = torch.cat([h(query_input, key_value_input, key_value_input, mask=mask) for h in self.heads], dim=-1)
      out = self.dropout(self.proj(out))
      return out