Using GRPO to finetune the chat gpt2 model

# Imports

In [1]:
from importlib.metadata import version
import torch, tiktoken, time, os, tensorflow
import torch.optim as optim
import numpy as np
import argparse
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import zipfile
from pathlib import Path
import pandas as pd

pkgs = ["numpy", 
        "tiktoken", 
        "torch",
        # "tensorflow", # For OpenAI's pretrained weights
        "pandas"
       ]
for p in pkgs:
    print(f"{p} version: {version(p)}")

numpy version: 1.23.5
tiktoken version: 0.9.0
torch version: 2.5.1
pandas version: 2.3.1


In [2]:
!python --version

Python 3.10.8


In [3]:

from utils.previous_chapters import generate_text_simple, text_to_token_ids, token_ids_to_text,GPTModel, create_dataloader_v1, load_weights_into_gpt
# Relative import from the gpt_download.py contained in this folder
from utils.gpt_download import download_and_load_gpt2

## Dataset Class

In [4]:
def prepare_datasets(data_file_path, sep="\t", header=None, column_names=["Label", "Text"], train_frac=0.7, validation_frac=0.15, store_directory="./"):
    df = pd.read_csv(data_file_path, sep=sep, header=header, names=column_names)

    # Count the instances of "spam"
    num_spam = df[df["Label"] == "spam"].shape[0]
    
    # Randomly sample "ham" instances to match the number of "spam" instances
    ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
    
    # Combine ham "subset" with "spam"
    balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])

    balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})

    # Shuffle the entire DataFrame
    balanced_df = balanced_df.sample(frac=1, random_state=123).reset_index(drop=True)

    # Calculate split indices
    train_end = int(len(balanced_df) * train_frac)
    validation_end = train_end + int(len(balanced_df) * validation_frac)

    # Split the DataFrame
    train_df = balanced_df[:train_end]
    validation_df = balanced_df[train_end:validation_end]
    test_df = balanced_df[validation_end:]

    train_df.to_csv(store_directory+"/train.csv", index=None)
    validation_df.to_csv(store_directory+"/validation.csv", index=None)
    test_df.to_csv(store_directory+"/test.csv", index=None)

In [5]:
class SpamDataset(Dataset):
    def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
        try:
            self.data = pd.read_csv(csv_file)
        except FileNotFoundError:
            raise FileNotFoundError(f"File not found: {csv_file}")

        # Pre-tokenize texts
        self.encoded_texts = [
            tokenizer.encode(text) for text in self.data["Text"]    # For each row in the text section of the pandas data frame tokenize the text string(sentence); creates list of token IDs for each example/item of the text data
        ]

        if max_length is None:
            self.max_length = self._longest_encoded_length()
        else:
            self.max_length = max_length
            # Truncate sequences if they are longer than max_length
            self.encoded_texts = [
                encoded_text[:self.max_length]
                for encoded_text in self.encoded_texts
            ]

        # Pad sequences to the longest sequence
        self.encoded_texts = [
            encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
            for encoded_text in self.encoded_texts
        ]

    def __getitem__(self, index):
        encoded = self.encoded_texts[index]
        label = self.data.iloc[index]["Label"]
        return (
            torch.tensor(encoded, dtype=torch.long),
            torch.tensor(label, dtype=torch.long)
        )

    def __len__(self):
        return len(self.data)

    def _longest_encoded_length(self):
        max_length = 0
        for encoded_text in self.encoded_texts:
            encoded_length = len(encoded_text)
            if encoded_length > max_length:
                max_length = encoded_length
        return max_length

In [10]:
tokenizer = tiktoken.get_encoding("gpt2")

In [33]:
train_dataset = SpamDataset(csv_file="./sms_spam_collection/data_splits/train.csv", tokenizer=tokenizer)
test_dataset = SpamDataset(csv_file="./sms_spam_collection/data_splits/test.csv", tokenizer=tokenizer)
validation_dataset = SpamDataset(csv_file="./sms_spam_collection/data_splits/validation.csv", tokenizer=tokenizer)

## Create Dataloaders

In [34]:
batch_size=64
num_workers=0
pin_memory=False

In [None]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, drop_last=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, drop_last=True)
validation_dataloader = DataLoader(dataset=validation_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, drop_last=True)

## Building Policies

In [6]:
def build_new_policy(base_config, chosen_model="gpt2-small (124M)", num_classes = 2) -> GPTModel:
    """Build and load in the GPT2 model. Swap out the Head layer, and freeze up to the last Transformer module for transfer learning."""
    model_configs = {
        "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
        "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
        "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
        "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
    }

    base_config.update(model_configs[chosen_model]) # add the emb_dim, n_layers, and n_heads to the config

    model_size = chosen_model.split(" ")[-1].lstrip("(").rstrip(")")
    allowed_sizes = ("124M", "355M", "774M", "1558M")
    if model_size not in allowed_sizes:
        raise ValueError(f"Model size not in {allowed_sizes}")
    settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")

    model = GPTModel(base_config)

    load_weights_into_gpt(model, params)
    
    for param in model.parameters(): # freeze model parameters
        param.requires_grad = False 

    # Unfreeze the last transformer block
    for param in model.trf_blocks[-1].parameters():
        param.requires_grad = True

    # Unfreeze the final layer normalizing layer
    for param in model.final_norm.parameters():
        param.requires_grad = True

    model.out_head = torch.nn.Linear(in_features=base_config["emb_dim"], out_features=num_classes) # reconfigure the output layer
    return model

In [7]:
def build_old_policy(base_config, chosen_model="gpt2-small (124M)", num_classes = 2) -> GPTModel:
    """Construct just the model without loading the weights"""
    model_configs = {
        "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
        "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
        "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
        "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
    }

    base_config.update(model_configs[chosen_model]) # add the emb_dim, n_layers, and n_heads to the config

    model_size = chosen_model.split(" ")[-1].lstrip("(").rstrip(")")
    allowed_sizes = ("124M", "355M", "774M", "1558M")
    if model_size not in allowed_sizes:
        raise ValueError(f"Model size not in {allowed_sizes}")
    model = GPTModel(base_config)

    model.out_head = torch.nn.Linear(in_features=base_config["emb_dim"], out_features=num_classes) # reconfigure the output layer
    return model

In [23]:
model = build_old_policy(BASE_CONFIG)

In [17]:
model

<function __main__.build_old_policy(base_config, chosen_model='gpt2-small (124M)', num_classes=2) -> utils.previous_chapters.GPTModel>

In [36]:
for batch_inputs, batch_labels in validation_dataloader:
    pass

In [39]:
batch_inputs.shape

torch.Size([64, 92])

In [40]:
batch_labels.shape

torch.Size([64])

## Utility Functions

In [8]:
def calculate_discounted_rewards(predictions, batch_labels) -> torch.tensor:
    """For this implementation, use non-discounted rewards"""
    disc_rewards = (predictions == batch_labels).float()    # Simple comparison to evaluate rewards for each example; output a tensor of floats
    return disc_rewards

## Training Loop

In [None]:
def grpo_train(model_config, train_dataloader, validation_dataloader, gpt_size="gpt2-small (124M)", epochs=50, learning_rate=0.0001, batch_size=64, gamma=0.99, k_epochs=64, epsilon=0.2, beta_kl=0.01, max_grad_norm=0.5, entropy_coeff=0.01, log_iterations=10, eval_iterations=10, device="cpu", num_envs=None):
    print(f"Training Policy on {device} with {epochs} main epochs, {k_epochs} inner epochs, {learning_rate} learning rate, batch size {batch_size}, KL beta {beta_kl}.")
    print(f"Using gpt2 size:{gpt_size} , logging every {log_iterations} iterations, evaluating every {eval_iterations} iterations.")


    Policy_New = build_new_policy(model_config, chosen_model=gpt_size, num_classes=2).to(device)   # STEP 3 || 
    Policy_New.train()

    optimizer = optim.Adam(params=Policy_New.parameters(), lr=learning_rate)

    Policy_Old = build_old_policy(model_config, chosen_model=gpt_size, num_classes=2).to(device)
    Policy_Old.eval()
    
    classifier_lyr = torch.nn.Softmax(dim=-1)   # For validation loop

    for epoch in tqdm(range(epochs), desc=f"Main Epoch (Outer Loop)", leave=False):     # STEP 4 || 
        # STEP 5 || Sample a batch D_b from D --> OMITTED 
        # STEP 6 || Update the old policy model PI old <- PI new
        Policy_Old.load_state_dict(Policy_New.state_dict())
        print("loaded Policy Old Weights")
        # --- STEP 7 Collect a Batch of Experiences Using the Old Policy---
        # for batch_inputs, batch_labels in train_dataloader:
        #     pass
        batch_inputs, batch_labels = next(iter(train_dataloader))
        
        batch_inputs, batch_labels = batch_inputs.to(device), batch_labels.to(device) # move the training data to the target device
        print("Transferred Data")
        with torch.no_grad():
            old_logits = Policy_Old(batch_inputs)[:,-1,:]   # Get logits from model and only focus on the last iterations of each sample
            print(old_logits)
            old_dist = torch.distributions.Categorical(logits=old_logits)
            old_predictions = old_dist.sample() # Tensor of shape [1] ; list of predictions
            print(f"old_predictions: {old_predictions}")
            old_log_probs = old_dist.log_prob(old_predictions)

        # STEP 8 || Calculate "Discounted" Rewards for completed trajectories
        discounted_rewards = calculate_discounted_rewards(old_predictions, batch_labels)    # Tensor with "discounted" rewards per each sample in batch
        print("Calculated discounted returns")
        # STEP 9 || Calculate the Advantage for each Trajectory using normalization
        all_advantages_tensor = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-8)

        Policy_New.train()  # Prepare NN for updates

        # --- STEP 10 || GRPO Optimization ---
        for k_epoch in tqdm(range(k_epochs), desc=f"Epoch {epoch+1}/{epochs} (Inner K-Epochs)", leave=True):
            print("Entered GRPO Optimization loop")
            optimizer.zero_grad()   # Flush out all the accumulated gradients for the weights of the model-under-training

            new_logits = Policy_New(batch_inputs)[:,-1,:]   # Get logits from model and only focus on the last iterations of each sample
            new_dist = torch.distributions.Categorical(logits=new_logits)
            new_log_probs = new_dist.log_prob(old_predictions)  # Get the log probability of choosing the same action that the old policy took using the new distribution
            entropy = new_dist.entropy().mean() # Calculate entropy for regularization

            R1_ratio = torch.exp(new_log_probs - old_log_probs)

            unclipped_surrogate = R1_ratio * all_advantages_tensor
            clipped_surrogate = torch.clamp(input=R1_ratio, min=1.0-epsilon, max=1.0+epsilon) * all_advantages_tensor

            policy_loss = -torch.min(unclipped_surrogate, clipped_surrogate).mean()

            # --- KL Divergence Calculation ---
            # Create distributions for old policies using the trajectory states
            # with torch.no_grad():
            #     old_logits = Policy_Old(all_states_tensor)
            # old_dist = torch.distributions.Categorical(logits=old_logits)
            # INSTEAD, just reusing the calculated logits from STEP #7

            # Calculate KL divergence per sample, then take the mean over the batch
            kl_div_per_sample = torch.distributions.kl.kl_divergence(p=new_dist, q=old_dist)
            kl_loss = kl_div_per_sample.mean() # Mean over the batch

            # Total Loss for GRPO
            total_loss = policy_loss + beta_kl * kl_loss - entropy_coeff * entropy

            # STEP 11 || Policy Updates
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(Policy_New.parameters(), max_grad_norm)
            optimizer.step()    # Update policy parameters using gradient ascent

            # --- 4. Logging and Evaluation ---
        if (epoch + 1) % log_iterations == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss.item():.4f}, Ratio: {R1_ratio.mean().item():.5f}, Entropy Term: {entropy:.5f}")

        if (epoch + 1) % eval_iterations == 0:
            Policy_New.eval()   # Turn off dropout layers and prevent grad tracking
            accuracy = 0.0
            num_correct = 0.0
            num_of_samples = 0.0
            with torch.no_grad():
                for batch_inputs, batch_labels in validation_dataloader:
                    batch_labels, predictions = batch_labels.to(device), predictions.to(device) # move the training data to the target device
                    logits = Policy_New(batch_inputs)[:,-1,:]
                    classifications = classifier_lyr(logits)
                    class_predictions = torch.argmax(classifications, dim=-1).flatten()
                    num_of_samples += batch_labels.size(0)
                    num_correct += sum((class_predictions == batch_labels).float()).item()
                accuracy = num_correct/num_of_samples
                print(f"Epoch {epoch+1} | Entire Validation Dataset Accuracy: {accuracy:.4f}")
            

    Policy_New.eval()   # Change to eval mode for evaluation after training is complete

    print("Training complete.")
    return Policy_New # Return the trained policy

## Main Loop

In [11]:
BASE_CONFIG = {
    "vocab_size": 50257,     # Vocabulary size
    "context_length": 1024,  # Context length
    "drop_rate": 0.1,        # Dropout rate
    "qkv_bias": True         # Query-key-value bias
}

In [19]:
# Pretend that the argument parser will pass these arguments to the main function
args = {
    "epochs":10,
    "learning_rate":0.0003,
    "dataloader_batch_size":64,
    "dataloader_pin_memory": True,
    "dataloader_num_workers": 0,
    "batch_size":1024, # Significantly larger batch size recommended for stability
    "gpt_size":'gpt2-small (124M)',
    "k_epochs":2,
    "epsilon":0.2,
    "beta_kl":0.01,
    "entropy_coeff":0.001,
    "log_iterations":5,
    "gamma":None,   # Discounted Rewards
    "device":'cuda',
    "num_envs":None,
    "save_model":True,
    "model_output_path":'models/first.pt'
}

In [14]:
prepare_datasets(data_file_path="./sms_spam_collection/SMSSpamCollection.tsv", store_directory="./sms_spam_collection/data_splits")

In [15]:
train_dataset = SpamDataset(csv_file="./sms_spam_collection/data_splits/train.csv", tokenizer=tokenizer)
test_dataset = SpamDataset(csv_file="./sms_spam_collection/data_splits/test.csv", tokenizer=tokenizer)
validation_dataset = SpamDataset(csv_file="./sms_spam_collection/data_splits/validation.csv", tokenizer=tokenizer)

In [16]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=args["batch_size"], num_workers=args["dataloader_num_workers"], pin_memory=args["dataloader_pin_memory"], drop_last=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=args["batch_size"], num_workers=args["dataloader_num_workers"], pin_memory=args["dataloader_pin_memory"], drop_last=True)
validation_dataloader = DataLoader(dataset=validation_dataset, batch_size=args["batch_size"], num_workers=args["dataloader_num_workers"], pin_memory=args["dataloader_pin_memory"], drop_last=True)

In [20]:
# function call
trained_policy = grpo_train(
        model_config=BASE_CONFIG,
        train_dataloader=train_dataloader,
        validation_dataloader=validation_dataloader,
        gpt_size=args["gpt_size"],
        epochs=args["epochs"],
        learning_rate=args["learning_rate"],
        batch_size=args["batch_size"], # Significantly larger batch size recommended for stability
        k_epochs=args["k_epochs"],
        epsilon=args["epsilon"],
        beta_kl=args["beta_kl"],
        entropy_coeff=args["entropy_coeff"],
        log_iterations=args["log_iterations"],
        gamma=args["gamma"],
        device=args["device"],
        num_envs=args["num_envs"]
    )

File already exists and is up-to-date: gpt2\124M\checkpoint
File already exists and is up-to-date: gpt2\124M\encoder.json
File already exists and is up-to-date: gpt2\124M\hparams.json
File already exists and is up-to-date: gpt2\124M\model.ckpt.data-00000-of-00001
File already exists and is up-to-date: gpt2\124M\model.ckpt.index
File already exists and is up-to-date: gpt2\124M\model.ckpt.meta
File already exists and is up-to-date: gpt2\124M\vocab.bpe


Main Epoch (Outer Loop):   0%|          | 0/10 [00:00<?, ?it/s]

loaded Policy Old Weights
Transferred Data
tensor([[-3.8413,  5.7205],
        [-3.7407,  5.6240],
        [-3.7503,  5.6361],
        ...,
        [-3.7944,  5.7532],
        [-3.6782,  5.6643],
        [-3.6839,  5.6124]], device='cuda:0')
old_predictions: tensor([1, 1, 1,  ..., 1, 1, 1], device='cuda:0')
Calculated discounted returns




Entered GRPO Optimization loop




Entered GRPO Optimization loop


Epoch 1/10 (Inner K-Epochs): 100%|██████████| 2/2 [00:38<00:00, 19.21s/it]
Main Epoch (Outer Loop):  10%|█         | 1/10 [00:51<07:46, 51.79s/it]

loaded Policy Old Weights
Transferred Data
tensor([[-4.4633,  7.0328],
        [-4.3885,  6.9651],
        [-4.3911,  6.9824],
        ...,
        [-4.4387,  7.0710],
        [-4.3349,  6.9577],
        [-4.3283,  6.9639]], device='cuda:0')
old_predictions: tensor([1, 1, 1,  ..., 1, 1, 1], device='cuda:0')
Calculated discounted returns




Entered GRPO Optimization loop




Entered GRPO Optimization loop


Epoch 2/10 (Inner K-Epochs): 100%|██████████| 2/2 [00:55<00:00, 27.81s/it]
Main Epoch (Outer Loop):  20%|██        | 2/10 [02:26<10:16, 77.12s/it]

loaded Policy Old Weights
Transferred Data
tensor([[-5.2405,  7.8761],
        [-5.1815,  7.8221],
        [-5.1804,  7.8407],
        ...,
        [-5.2304,  7.9218],
        [-5.1291,  7.8077],
        [-5.1223,  7.8254]], device='cuda:0')
old_predictions: tensor([1, 1, 1,  ..., 1, 1, 1], device='cuda:0')
Calculated discounted returns




Entered GRPO Optimization loop


Epoch 3/10 (Inner K-Epochs):   0%|          | 0/2 [00:25<?, ?it/s]
                                                                       

KeyboardInterrupt: 

In [None]:
def main(args):
    print("Setting up for Training")
    
    if args.device:     # Check if the user specified to use a CPU or GPU for training
        device = args.device
    else:
        if args.use_cuda:   # Check if the user wanted to use CUDA if available.
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    BASE_CONFIG = {
        "vocab_size": 50257,     # Vocabulary size
        "context_length": 1024,  # Context length
        "drop_rate": 0.1,        # Dropout rate
        "qkv_bias": True         # Query-key-value bias
    }

    # Transfer to argparser setup
    gpt_size="gpt2-small (124M)"
    dataloader_batch_size=64
    num_workers=0
    pin_memory=True
    tokenizer=tiktoken.get_encoding("gpt2")

    print("Creating Datasets using train, test, and validation files.")

    prepare_datasets(data_file_path="./sms_spam_collection/SMSSpamCollection.tsv", store_directory="./sms_spam_collection/data_splits")

    train_dataset = SpamDataset(csv_file="./sms_spam_collection/data_splits/train.csv", tokenizer=tokenizer)
    test_dataset = SpamDataset(csv_file="./sms_spam_collection/data_splits/test.csv", tokenizer=tokenizer)
    validation_dataset = SpamDataset(csv_file="./sms_spam_collection/data_splits/validation.csv", tokenizer=tokenizer)
    
    train_dataloader = DataLoader(dataset=train_dataset, batch_size=dataloader_batch_size, num_workers=num_workers, pin_memory=pin_memory, drop_last=True)
    test_dataloader = DataLoader(dataset=test_dataset, batch_size=dataloader_batch_size, num_workers=num_workers, pin_memory=pin_memory, drop_last=True)
    validation_dataloader = DataLoader(dataset=validation_dataset, batch_size=dataloader_batch_size, num_workers=num_workers, pin_memory=pin_memory, drop_last=True)

    print("Beginning Training Script")
    start_time=time.time()

    trained_policy = grpo_train(
        model_config=BASE_CONFIG,
        train_dataloader=train_dataloader,
        validation_dataloader=validation_dataloader,
        gpt_size=gpt_size,
        epochs=args.epochs,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size, # Significantly larger batch size recommended for stability
        k_epochs=args.k_epochs,
        epsilon=args.epsilon,
        beta_kl=args.beta_kl,
        entropy_coeff=args.entropy_coeff,
        log_iterations=args.log_iterations,
        gamma=args.gamma,
        device=device,
        num_envs=args.num_envs
    )
    end_time=time.time()

    elapsed_time= end_time - start_time
    hrs = int(elapsed_time / 3600)
    min = int((elapsed_time % 3600) / 60)
    seconds_remaining = elapsed_time - (hrs * 3600 ) - (min * 60)
    print(f"FINISHED MODEL TRAINING. \nTRAINING TOOK: {hrs} Hours, {min} Minutes, and {seconds_remaining} Seconds")


    print("\nTesting the trained policy:")

    classification_lyr = torch.nn.Softmax(dim=-1)
    trained_policy.eval()   # Turn off dropout layers and prevent grad tracking
    accuracy = 0.0
    num_correct = 0.0
    num_of_samples = 0.0

    with torch.no_grad():
        for batch_inputs, batch_labels in test_dataloader:
            batch_labels, predictions = batch_labels.to(device), predictions.to(device) # move the training data to the target device
            logits = trained_policy(batch_inputs)[:,-1,:]
            classifications = classification_lyr(logits)
            class_predictions = torch.argmax(classifications, dim=-1).flatten()
            num_of_samples += batch_labels.size(0)
            num_correct += sum((class_predictions == batch_labels).float()).item()
        accuracy = num_correct/num_of_samples
        print(f" Entire test Dataset Accuracy: {accuracy:.4f} |  {num_correct} corrct/ {num_of_samples} samples")


    #---------------  !!!  ---------------
    SAVE_LOCATION = "./model/trained_model.pth"   # Define the model path and name of the trained model weights

    if args.save_model:     # Check if the user wants to save the trained model weights
        if args.model_output_path:     # Check if the user specified a target save location
            SAVE_LOCATION=args.model_output_path
        
        torch.save(trained_policy.parameters(), f=SAVE_LOCATION)
        print(f"Model weights saved in: {SAVE_LOCATION}")

    print("Finished Running Script")

In [None]:
# Example usage (assuming you have a way to call this function, e.g., in a main block)
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train and test a BlackJack PPO agent.")

    # Add arguments
    parser.add_argument('--epochs', type=int, default=2000,
                        help='Number of training epochs.')
    parser.add_argument('--learning_rate', type=float, default=0.0003,
                        help='Learning rate for the optimizer.')
    parser.add_argument('--dataloader_batch_size', type=int, default=64,
                        help='Dataloader Batch sizes for train, test, validation data files.')
    parser.add_argument('--batch_size', type=int, default=1024,
                        help='Batch size for training.')
    parser.add_argument('--gpt2_size', type=str, default="gpt2-small (124M)",
                        help='GPT2 size for model construction.')
    parser.add_argument('--k_epochs', type=int, default=128,
                        help='Number of policy update epochs per trajectory collection.')
    parser.add_argument('--epsilon', type=float, default=0.2,
                        help='Clipping parameter for PPO.')
    parser.add_argument('--beta_kl', type=float, default=0.01,
                        help='KL divergence coefficient (for PPO-like algorithms).')
    parser.add_argument('--entropy_coeff', type=float, default=0.001,
                        help='Entropy regularization coefficient.')
    parser.add_argument('--log_iterations', type=int, default=100,
                        help='Log training progress every N iterations.')
    parser.add_argument('--gamma', type=float, default=0.99,
                        help='Discount factor for rewards.')
    parser.add_argument('--num_envs', type=int, default=16,
                        help='Number of parallel environments for training.')
    parser.add_argument('--use_cuda', action='store_true',
                        help='Use CUDA if available.')
    parser.add_argument('--device', type=str, default='cpu',
                        help='Explicitly set device (e.g., "cpu, cuda:0", "cpu"). Overrides --use_cuda if specified.')
    parser.add_argument('--save_model', action='store_true',
                        help='Save the trained model weights.')
    parser.add_argument('--model_output_path', type=str, default='blackjack_policy_model.pth',
                        help='Path to save the trained model weights.')

    # Parse the arguments
    args = parser.parse_args()

    
    main(args)

## Testing

In [35]:
grpo_train(BASE_CONFIG)

File already exists and is up-to-date: gpt2\124M\checkpoint
File already exists and is up-to-date: gpt2\124M\encoder.json
File already exists and is up-to-date: gpt2\124M\hparams.json
File already exists and is up-to-date: gpt2\124M\model.ckpt.data-00000-of-00001
File already exists and is up-to-date: gpt2\124M\model.ckpt.index
File already exists and is up-to-date: gpt2\124M\model.ckpt.meta
File already exists and is up-to-date: gpt2\124M\vocab.bpe


                                                               

In [42]:
CHOOSE_MODEL = "gpt2-small (124M)"
INPUT_PROMPT = "Every effort moves"

BASE_CONFIG = {
    "vocab_size": 50257,     # Vocabulary size
    "context_length": 1024,  # Context length
    "drop_rate": 0.0,        # Dropout rate
    "qkv_bias": True         # Query-key-value bias
}

model_configs = {
    "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}

BASE_CONFIG.update(model_configs[CHOOSE_MODEL]) # add the emb_dim, n_layers, and n_heads to the config

In [8]:
tokenizer = tiktoken.get_encoding("gpt2")

In [43]:
model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
model = GPTModel(BASE_CONFIG)

load_weights_into_gpt(model, params)
model.eval()

checkpoint: 100%|██████████| 77.0/77.0 [00:00<00:00, 5.26kiB/s]
encoder.json: 100%|██████████| 1.04M/1.04M [00:00<00:00, 1.93MiB/s]
hparams.json: 100%|██████████| 90.0/90.0 [00:00<00:00, 8.73kiB/s]
model.ckpt.data-00000-of-00001: 100%|██████████| 498M/498M [00:31<00:00, 15.6MiB/s] 
model.ckpt.index: 100%|██████████| 5.21k/5.21k [00:00<00:00, 541kiB/s]
model.ckpt.meta: 100%|██████████| 471k/471k [00:00<00:00, 987kiB/s] 
vocab.bpe: 100%|██████████| 456k/456k [00:00<00:00, 1.27MiB/s]


GPTModel(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.0, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768, out_features=768, bias=True)
        (W_key): Linear(in_features=768, out_features=768, bias=True)
        (W_value): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_resid): Dropout(p=0.0, inplace=False)
    )
    (1): TransformerBlock(
      (att): MultiHeadAttention(
        (W_query): Linear(in_features=768,

In [22]:
BASE_CONFIG = {
    "vocab_size": 50257,     # Vocabulary size
    "context_length": 1024,  # Context length
    "drop_rate": 0.1,        # Dropout rate
    "qkv_bias": True         # Query-key-value bias
}