### Step 1: Install necesscary packages

In [1]:
!pip install matplotlib torch numpy transformers datasets tiktoken wandb tqdm --quiet

### Step 2: Package imports and configuration

- Tokeniser appears to not have !

In [2]:
import sys
import os
sys.path.append(os.path.abspath("..")) 
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import pickle
from model import GPT, GPTConfig
import random
from tqdm import tqdm
import time
import json
import matplotlib.pyplot as plt
from transformers import GPT2TokenizerFast, GPT2Tokenizer

# Hyperparameters for training and generation
beta = 0.5
if torch.backends.mps.is_available():
    # For Macbook
    device = torch.device("mps")
elif torch.cuda.is_available():
    # For Windows
    device = 'cuda'  
else:
    device = torch.device("cpu")

# Base learning rate - For controlling how much the weights are updated. Smaller is better to prevent overshoot but will take longer
base_lr = 1e-4
# Number of times model will see the training dataset. Not too many to prevent overfitting
epochs = 10
# Number of samples before updating the model's weights
batch_size = 128
# Max number of tokens in each input sequence, i.e. the length the model can see at once
max_length = 64
num_samples = 1
# Max number of tokens to generate in answer
max_new_tokens = 200
# Controls randomness in generation of tokens. Lower values (close to 0) make output more deterministic
temperature = 0.8
# Samples from top k most likely tokens to generate next
top_k = 200

# Loads a pickled dictionary, which is a saved dict
with open("../sft/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
def encode(s): 
    return [stoi[c] for c in s]
def decode(l): 
    return ''.join([itos[i] for i in l])

# Verify device used
print(f"Using: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Using: mps


### Step 3: Define helper functions

In [3]:
"""
Used for scoring various sequences of tokens generated by the GPT to decide which sequence should be returned
"""
def compute_logprob(input_ids):
    # All tokens except for the last token is used as input
    inputs = input_ids[:, :-1]
    # All tokens except for the first, i.e. the next target token that should be predicted
    targets = input_ids[:, 1:]
    # Runs the model to get the logits (unnormalized scores) for each possible next token at each position
    logits, _ = gpt(inputs, full_seq=True)

    # Reshape both the logits and targets to the correct shape, i.e. (batch_size * seq_length)
    B, T, V = logits.size()
    logits_flat = logits.reshape(-1, V)
    targets_flat = targets.reshape(-1)
    # Calculates the cross-entropy loss (negative log-probability) for each token
    loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=0, reduction='none')
    loss = loss.reshape(B, T)
    # Creates a mask to ignore padding tokens
    attention_mask = (targets != 0).float()
    loss = (loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
    # Returns the negative loss, which is the average log-probability per sequence (higher is better, since log-probs are negative)
    return -loss 

"""
Ensures that every sequence is exactly max_length tokens long, either by truncating longer sequences or padding shorter ones with zeros (the padding token)

Either keep the last max_length tokens if length exceeds.
Or pads with 0s to reach the end
"""
def pad_or_truncate(seq, max_length):
    return seq[-max_length:] if len(seq) > max_length else seq + [0] * (max_length - len(seq))

"""
Prepares batches of positive and negative examples for training or evaluation, yielding them as tensors ready for the model
"""
def get_batches(lines, batch_size):
    # Randomise the order of data for each epoch
    random.shuffle(lines)
    #for l in lines:
    #    print(l[1])
    for i in range(0, len(lines), batch_size):
        batch = lines[i:i+batch_size]
        # Skips the last batch if the size is too small to make sure all batches have the same size
        if len(batch) < batch_size:
            continue
        # Encode the negative and positive strings and pad/truncate to maxLenght
        neg_inputs = [pad_or_truncate(encode(p['negative'] + '\n\n\n\n'), max_length) for p in batch]
        pos_inputs = [pad_or_truncate(encode(p['positive'] + '\n\n\n\n'), max_length) for p in batch]
        # Converts the lists of token IDs into PyTorch tensors on the correct device.
        neg_tensor = torch.tensor(neg_inputs, dtype=torch.long, device=device)
        pos_tensor = torch.tensor(pos_inputs, dtype=torch.long, device=device)
        # Yields a tuple (neg_tensor, pos_tensor) for each batch, as tensors are used for model input
        yield neg_tensor, pos_tensor

### Step 4: Load the pretrained NanoGPT model

In [4]:
# Load the pre-trained model, i.e. the checkpoint, which has all the model weights
ckpt = torch.load("../sft/gpt.pt", map_location=device)
# Recreate the model configuration with the saved details
gptconf = GPTConfig(**ckpt['model_args'])
# Create a new instance of the model using the pre-trained configs
gpt = GPT(gptconf)

# Clean up any keys that might have unwanted prefixes while saving
state_dict = ckpt['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

# Load the model weights from the checkpoint to the new model instance
gpt.load_state_dict(state_dict)
# Set it to training mode after ensuring it runs on GPU if it exists
gpt.to(device).train()

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(74, 348)
    (wpe): Embedding(256, 348)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=348, out_features=1044, bias=False)
          (c_proj): Linear(in_features=348, out_features=348, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=348, out_features=1392, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1392, out_features=348, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=348, out_features=74, bias=False)
)

### Step 5: Load Data (**students are required to complete this part!**)

- Currently I generated 5000 simple arithmetic like the ones provided, and 5000 more 2 step arithmetic like 2*(5+3)=?
- Would need to generate more until at least 100k in 1:1 ratio
- Then maybe 50k more for ones that uses brackets like (x+5)*2=30,x=?. Need about 50k to maintain the ratio

In [5]:
# Load my training dataset
with open("./pos_neg_pairs.json") as file:
    lines = json.load(file)

# Filter out characters like ! that have not been encoded by the trained model
num_clean_datapoints = 0
for line in lines:
    positive_datapoint = line["positive"]
    negative_datapoint = line["negative"]

    # Clean both datapoints
    cleaned_positive_datapoint = ''.join(char for char in positive_datapoint if char in stoi)
    cleaned_negative_datapoint = ''.join(char for char in negative_datapoint if char in stoi)

    # Check if cleaned
    if cleaned_positive_datapoint != positive_datapoint or negative_datapoint != cleaned_negative_datapoint:
        num_clean_datapoints += 1
    
    # Update with cleaned
    line["positive"] = cleaned_positive_datapoint
    line["negative"] = cleaned_negative_datapoint

print(f"Number of Training Datapoints:  {len(lines)}")
print(f"Number of Cleaned Datapoints: {num_clean_datapoints}")

Number of Training Datapoints:  500000
Number of Cleaned Datapoints: 500000


### Step 6: Build the optimizer and scheduler (**students are required to complete this part!**)

In [6]:
# Recommended by prof to use the AdamW optimizer
## Apply a weight decay which acts as a regulariser to reduce overfitting
adamw_optimizer = torch.optim.AdamW(gpt.parameters(), lr=base_lr, weight_decay=0.1)

## Scheduler helps adjust the learning rate of the optimiser while training. Idea is to have learning rate be high and decrease slowly
## Cosine is pretty good cos the graph is high then low
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(adamw_optimizer, T_max=epochs)

### Step 7: Begin training (**students are required to complete this part!**)

In [7]:
# Track the best model only
best_loss = float('inf')

# Lines refer to the number of lines of training data we have so we know how much steps/batches is req to go thru dataset once
total_steps = len(lines) // batch_size

# Loop over the dataset for different epochs
for epoch in range(epochs):
    # Calculate the average loss
    epoch_loss_sum = 0
    batch_count = 0
    # Retrieve the batches of positive and negative tensors from the helper function we saw earlier
    pbar = tqdm(get_batches(lines, batch_size))

    for step, (neg_tensor,pos_tensor) in enumerate(pbar):
        # 1. Zero the accumulated gradients from previous batch of training or epoch of training. Do once at the start of every batch
        adamw_optimizer.zero_grad()
        # 2. Calculate the actual loss using the loss function in helper. We use DPO
        ## First calculate the log probs using the helper
        positive_log_prob = compute_logprob(pos_tensor)
        negative_log_prob = compute_logprob(neg_tensor)
        ## Assign higher weights to positive completions in this case, so the model knows what to do.
        dpo_loss = -F.logsigmoid((positive_log_prob - negative_log_prob) / beta).mean() - positive_log_prob.mean() * 0.1
        # 3. Backward Propagation, i.e. compute gradient of loss w.r.t model params
        dpo_loss.backward()
        # 4. Use optimiser to update the model params using new gradients to reduce dpo_loss
        adamw_optimizer.step()
        # 5. Update the progress bar so we can see how the training is going
        pbar.set_description(f"Epoch {epoch+1} Step {step+1} Loss {dpo_loss.item():.4f}")
        # 6. Ensures the scheduler decays the learning rate after every epoch
        scheduler.step()

        # 7. Track the loss for each batch
        epoch_loss_sum += dpo_loss.item()
        batch_count += 1

    # Calculate average loss for the epoch
    avg_epoch_loss = epoch_loss_sum / batch_count
    print(f"Epoch {epoch+1} average loss: {avg_epoch_loss:.4f}")

    # Update the checkpoint only if the model is improving
    if avg_epoch_loss < best_loss:
        best_loss = avg_epoch_loss

        # Specify where to save our trained model
        ckpt_path = f"./dpo.pt"
        # Saves the model's weights and configuration after each epoch, so you can resume training or use the model later
        torch.save({
            "model_state_dict": gpt.state_dict(),
            "model_args": ckpt['model_args'],
        }, ckpt_path)
        print(f"Saved checkpoint to {ckpt_path}")
    else:
        # Dont save checkpoint to prevent overfitting, but continue training to see if it will improve
        print(f"Loss did not improve. Best is still: {best_loss:.4f}")

Epoch 1 Step 3906 Loss 0.0319: : 3906it [40:41,  1.60it/s]


Epoch 1 average loss: 0.0665
Saved checkpoint to ./dpo.pt


Epoch 2 Step 3906 Loss 0.0263: : 3906it [40:26,  1.61it/s]


Epoch 2 average loss: 0.0290
Saved checkpoint to ./dpo.pt


Epoch 3 Step 3906 Loss 0.0228: : 3906it [40:05,  1.62it/s]


Epoch 3 average loss: 0.0245
Saved checkpoint to ./dpo.pt


Epoch 4 Step 3906 Loss 0.0214: : 3906it [39:49,  1.63it/s]


Epoch 4 average loss: 0.0220
Saved checkpoint to ./dpo.pt


Epoch 5 Step 3906 Loss 0.0214: : 3906it [39:55,  1.63it/s]


Epoch 5 average loss: 0.0212
Saved checkpoint to ./dpo.pt


Epoch 6 Step 3906 Loss 0.0206: : 3906it [40:25,  1.61it/s]


Epoch 6 average loss: 0.0207
Saved checkpoint to ./dpo.pt


Epoch 7 Step 3906 Loss 0.0198: : 3906it [40:22,  1.61it/s]


Epoch 7 average loss: 0.0202
Saved checkpoint to ./dpo.pt


Epoch 8 Step 3906 Loss 0.0200: : 3906it [39:56,  1.63it/s]


Epoch 8 average loss: 0.0199
Saved checkpoint to ./dpo.pt


Epoch 9 Step 3906 Loss 0.0194: : 3906it [40:10,  1.62it/s]


Epoch 9 average loss: 0.0197
Saved checkpoint to ./dpo.pt


Epoch 10 Step 3906 Loss 0.0192: : 3906it [40:05,  1.62it/s]

Epoch 10 average loss: 0.0194
Saved checkpoint to ./dpo.pt





# Can continue training with the trained model here ( train from 101 to 1000 ig)

- The loss never stopped decreasing, so perhaps we can increase the number of epochs, since it have not overfitted on the train set yet. Even though we are using 10k only for now.
- can increase epoch since the dpo loss did not plateau or increase
- can consider increase batch size to 128

### Step 8: Begin testing (**students are required to complete this part!**)

In [8]:
# Load the fine-tuned model that we trained earlier above
ckpt_path = "./dpo.pt"
checkpoint = torch.load(ckpt_path, map_location=device)

# Load the saved model args
gptconf = GPTConfig(**checkpoint['model_args'])
gpt = GPT(gptconf).to(device)

# Load the model saved weights
state_dict = checkpoint['model_state_dict']

# Clean the data keys in dict again. Remember the keys might get affected whenever we save a checkpoint
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
gpt.load_state_dict(state_dict)

<All keys matched successfully>

In [9]:
# Test - Set the Model to Eval mode for deterministic testing behaviur
gpt.eval()
# Load the test set
# test_set = ["1*1=?"]
test_set = ["17+19=?", "3*17=?", "72/4=?", "72-x=34,x=?", "x*11=44,x=?", "54+25=?"]

# Disable gradient computation since we are not training the model anymore
with torch.no_grad():
    for prompt in test_set: 
        # Encode the string and convert it to a tensor to be used as input
        test_tensor = torch.tensor(encode(prompt), dtype=torch.long, device=device).unsqueeze(0)
        # Use the test tensors to generate an answer
        answer, _ = gpt.generate(test_tensor, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)
        # Decode ans back for human understanding
        output = decode(answer[0].tolist())
        print(f"Prompt: {prompt}\nOutput: {output}\n")

Prompt: 17+19=?
Output: 17+19=? The answer is 36 because 17+19 equals 36.

Prompt: 3*17=?
Output: 3*17=? The answer is 51 because 3*17 equals 51.

Prompt: 72/4=?
Output: 72/4=? The answer is 18 because 72/4 equals 18.

Prompt: 72-x=34,x=?
Output: 72-x=34,x=? answer is -2 because 72-94 equals -2.

Prompt: x*11=44,x=?
Output: x*11=44,x=? answer is 4 because solving for x gives 4.

Prompt: 54+25=?
Output: 54+25=? The answer is 79 because 54+25 equals 79.



- epochs is good enuf, can try with larger dataset and no fractions or decimals