<a href="https://colab.research.google.com/github/RaphBhz/Scenarformer/blob/main/Train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Requirements

In [1]:
!pip install tiktoken

Collecting tiktoken
  Downloading tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Downloading tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.2/1.2 MB[0m [31m41.1 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m25.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tiktoken
Successfully installed tiktoken-0.8.0


# Data

In [2]:
import random
import torch
import random
import torch
import os.path
from requests.exceptions import HTTPError

DATA_PATH = F"/content/gdrive/MyDrive/AI/Scenarformer/plays/"

def load_play_text(path):
  with open(path) as file:
    data = file.read()
    return data


def get_batch(play_files, batch_size, block_size, episode, steps_num):
  num_plays = len(play_files)
  # Select a play based on episode index
  play_idx = episode % num_plays
  play_text = load_play_text(play_files[play_idx])

  # Convert the play text to tokens
  tokens = torch.tensor(enc.encode(play_text), dtype=torch.long)

  # We want to get a specific number of batches for this episode
  for _ in range(steps_num):
      # Generate random starting indices for the batch
      start_idx = torch.randint(len(tokens) - block_size, (batch_size,))
      inputs = torch.stack([tokens[i:i+block_size] for i in start_idx])
      outputs = torch.stack([tokens[i+1:i+block_size+1] for i in start_idx])
      yield inputs, outputs


def split_files(play_dir = DATA_PATH, validation_ratio=0.1):
  # Get all .txt files in the directory
  all_files = [os.path.join(root, file)
                for root, _, files in os.walk(play_dir)
                for file in files if file.endswith('.txt')]

  # Shuffle the files
  random.shuffle(all_files)

  # Calculate the split index
  split_idx = int(len(all_files) * (1 - validation_ratio))

  # Split into training and validation
  train_files = all_files[:split_idx]
  val_files = all_files[split_idx:]

  return train_files, val_files

# Model

In [3]:
import torch
from torch import nn


class SelfAttentionHead(nn.Module):
  def __init__(self, head_size, embedding_size, block_size, dropout):
    super().__init__()
    self.key = nn.Linear(embedding_size, head_size, bias = False)
    self.query = nn.Linear(embedding_size, head_size, bias = False)
    self.value = nn.Linear(embedding_size, head_size, bias = False)
    self.droupout = nn.Dropout(dropout)
    self.register_buffer('mask', torch.tril(torch.ones(block_size, block_size)))

  def forward(self, x):
    Batch, Block, Size = x.shape

    keys = self.key(x)
    queries = self.query(x)

    weights = queries @ keys.transpose(-2, -1) * Size ** -0.5
    weights = weights.masked_fill(self.mask[:Block, :Block] == 0, float('-inf'))
    weights = nn.functional.softmax(weights, dim=-1)
    weights = self.droupout(weights)

    values = self.value(x)

    return weights @ values


class MultiHeadAttention(nn.Module):
  def __init__(self, head_size, num_heads, embedding_size, block_size, dropout):
    super().__init__()
    self.heads = nn.ModuleList(
      [SelfAttentionHead(head_size, embedding_size, block_size, dropout) for _ in range(num_heads)]
    )
    self.projection = nn.Linear(embedding_size, embedding_size)
    self.droupout = nn.Dropout(dropout)

  def forward(self, x):
    y = torch.cat([h(x) for h in self.heads], dim=-1)
    y = self.projection(y)
    y = self.droupout(y)
    return y


class FeedForwardLayer(nn.Module):
  def __init__(self, embedding_size, dropout):
    super().__init__()
    self.network = nn.Sequential(
      nn.Linear(embedding_size, 4 * embedding_size),
      nn.ReLU(),
      nn.Linear(4 * embedding_size, embedding_size),
      nn.Dropout(dropout)
    )

  def forward(self, x):
    return self.network(x)


class DecoderBlock(nn.Module):
  def __init__(self, embedding_size, block_size, num_heads, dropout = 0.2):
    super().__init__()
    head_size = embedding_size // num_heads
    self.attention_heads = MultiHeadAttention(head_size, num_heads, embedding_size, block_size, dropout)
    self.feedforward_layer = FeedForwardLayer(embedding_size, dropout)
    self.layer_norm1 = nn.LayerNorm(embedding_size)
    self.layer_norm2 = nn.LayerNorm(embedding_size)

  def forward(self, x):
    x = x + self.attention_heads(x)
    x = x + self.feedforward_layer(x)
    return x



class SimpleModel(nn.Module):
  def __init__(self, vocab_size, embedding_size, block_size, layer_num=6, attention_heads=8, dropout=0.2):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, embedding_size)
    self.position_embedding_table = nn.Embedding(block_size, embedding_size)
    self.decoder_blocks = nn.Sequential(
      *[DecoderBlock(embedding_size, block_size, num_heads=attention_heads, dropout=dropout) for _ in range(layer_num)],
      nn.LayerNorm(embedding_size)
    )
    self.linear_layer = nn.Linear(embedding_size, vocab_size)

  def forward(self, idx, targets=None):
    Batch, Block = idx.shape

    token_embeddings = self.token_embedding_table(idx)
    positional_embeddings = self.position_embedding_table(torch.arange(Block))

    embeddings = token_embeddings + positional_embeddings
    embeddings = self.decoder_blocks(embeddings)

    logits = self.linear_layer(embeddings)

    if targets is None:
      loss = None
    else:
      Batch, Block, Vocab = logits.shape
      logits = logits.view(Batch * Block, Vocab)
      targets = targets.view(Batch * Block)
      loss = nn.functional.cross_entropy(logits, targets)

    return logits, loss

  def generate(self, idx, max):
    Batch, Block = idx.shape

    for _ in range(max):
      cropped_idx = idx[:, -Block:]
      logits, loss = self(cropped_idx)
      logits = logits [:,-1,:]
      probs = nn.functional.softmax(logits, dim=-1)
      idx_next = torch.multinomial(probs, num_samples=1)
      idx = torch.cat((idx, idx_next), dim=1)

    return idx

# Training

## Parameters

In [7]:
from datetime import datetime
import tiktoken
import torch
import os.path

# Making sure the model uses GPU
torch.set_default_device('cuda')

# Part of data used for validation
VALIDATION_RATIO = 0.1

# Number of sequences to train at a time
BATCH_SIZE = 64
# Size of each sequence
BLOCK_SIZE = 128
# Size of embeddings
EMBEDDING_SIZE = 512
# Number of episodes to train for
TRAIN_EPISODES = 50
# Number of steps in each episode
TRAIN_STEPS = 500
# Number of steps in the validation
VALIDATION_STEPS = 100
# Learning rate
LEARNING_RATE = 1e-5

# Using tiktoken's r50k_base encoder
enc = tiktoken.get_encoding("r50k_base")

# Path to save the model to
def MODEL_PATH(episode, date = None):
  model_name = f'model_{datetime.today().strftime("%Y-%m-%d") if date is None else date}_{episode}.pth'
  return f"/content/gdrive/MyDrive/AI/Scenarformer/{model_name}"

In [5]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [6]:
# Splitting the data into test and validation sets
train_files, validation_files = split_files()
print(f'Number of training plays: {len(train_files)}')
print(f'Number of validation plays: {len(validation_files)}')

Number of training plays: 45
Number of validation plays: 5


In [8]:
from torch.amp import autocast, GradScaler

# Creating the model
m = SimpleModel(
  vocab_size=enc.n_vocab,
  embedding_size=EMBEDDING_SIZE,
  block_size = BLOCK_SIZE
)
m.cuda()

# Using torch's Adam optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=LEARNING_RATE)
# Using a ReduceLROnPlateau scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=4, cooldown=1)

# Initialize GradScaler for scaling gradients
scaler = GradScaler('cuda')

# Training the model
for episode in range(TRAIN_EPISODES):
  print(f'EPISODE {episode+1}')
  m.train()

  batches = get_batch(validation_files, BATCH_SIZE, BLOCK_SIZE, episode, VALIDATION_STEPS)
  for inp, out in batches:
    optimizer.zero_grad(set_to_none=True)

    with autocast('cuda'):
      logits, loss = m(inp, out)

    # Scale the loss and backpropagate
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

  # Validation
  m.eval()
  validation_loss = 0
  with torch.no_grad():
    batches = get_batch(validation_files, BATCH_SIZE, BLOCK_SIZE, episode, VALIDATION_STEPS)
    for val_inp, val_out in batches:
      with autocast('cuda'):
        val_logits, val_loss = m(val_inp, val_out)
      validation_loss += val_loss.item()

    # Compute average validation loss
    validation_loss /= VALIDATION_STEPS

  # Step the scheduler
  scheduler.step(validation_loss)

  # Save model
  if (episode + 1) % 10 == 0:
    torch.save(m.state_dict(), MODEL_PATH(episode+1))

  print(f'Loss = {loss.item():.4f} Validation Loss = {validation_loss:.4f}')


# Printing a sample of generation
print(enc.decode(m.generate(torch.zeros((1, 1), dtype=torch.long), 500)[0].tolist()))

EPISODE 1
Loss = 9.0617 Validation Loss = 8.9995
EPISODE 2
Loss = 7.9674 Validation Loss = 7.8777
EPISODE 3
Loss = 7.3190 Validation Loss = 7.2179
EPISODE 4
Loss = 6.7164 Validation Loss = 6.6344
EPISODE 5
Loss = 6.3308 Validation Loss = 6.2557
EPISODE 6
Loss = 5.8787 Validation Loss = 5.7774
EPISODE 7
Loss = 5.4076 Validation Loss = 5.3613
EPISODE 8
Loss = 5.5083 Validation Loss = 5.4260
EPISODE 9
Loss = 5.4192 Validation Loss = 5.3126
EPISODE 10
Loss = 5.1549 Validation Loss = 5.0661
EPISODE 11
Loss = 4.8744 Validation Loss = 4.7751
EPISODE 12
Loss = 4.6056 Validation Loss = 4.5894
EPISODE 13
Loss = 4.6929 Validation Loss = 4.5980
EPISODE 14
Loss = 4.8088 Validation Loss = 4.7587
EPISODE 15
Loss = 4.5545 Validation Loss = 4.5095
EPISODE 16
Loss = 4.3269 Validation Loss = 4.2734
EPISODE 17
Loss = 4.2190 Validation Loss = 4.1745
EPISODE 18
Loss = 4.1713 Validation Loss = 4.0179
EPISODE 19
Loss = 4.4696 Validation Loss = 4.4143
EPISODE 20
Loss = 4.2520 Validation Loss = 4.1769
EPISODE 2

# Loading the model

In [9]:
MODEL_NAME = MODEL_PATH('50', '2024-12-10')

model = SimpleModel(
  vocab_size=enc.n_vocab,
  embedding_size=512,
  block_size = BLOCK_SIZE
)
model.load_state_dict(torch.load(MODEL_NAME, weights_only=True))
model.eval()

SimpleModel(
  (token_embedding_table): Embedding(50257, 512)
  (position_embedding_table): Embedding(128, 512)
  (decoder_blocks): Sequential(
    (0): DecoderBlock(
      (attention_heads): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x SelfAttentionHead(
            (key): Linear(in_features=512, out_features=64, bias=False)
            (query): Linear(in_features=512, out_features=64, bias=False)
            (value): Linear(in_features=512, out_features=64, bias=False)
            (droupout): Dropout(p=0.2, inplace=False)
          )
        )
        (projection): Linear(in_features=512, out_features=512, bias=True)
        (droupout): Dropout(p=0.2, inplace=False)
      )
      (feedforward_layer): FeedForwardLayer(
        (network): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): ReLU()
          (2): Linear(in_features=2048, out_features=512, bias=True)
          (3): Dropout(p=0.2, inplace=False)
        )

In [12]:
generated_play = enc.decode(model.generate(torch.zeros((1, 1), dtype=torch.long), 5000)[0].tolist())

with open(f'/content/gdrive/MyDrive/AI/Scenarformer/generated.txt', 'w') as file:
  file.seek(0)
  file.write(generated_play)
  file.truncate()