Using GRPO to finetune the chat gpt2 model

# Imports

In [1]:
from importlib.metadata import version
import torch, tiktoken, time, os, tensorflow
import torch.optim as optim
import numpy as np
import argparse
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import zipfile
from pathlib import Path
import pandas as pd
from torch.nn import Module # For type hinting
from typing import Tuple, Callable # Import Tuple and Callable
from tiktoken import Encoding
from utils.previous_chapters import GPTModel, load_weights_into_gpt
# Relative import from the gpt_download.py contained in this folder
from utils.gpt_download import download_and_load_gpt2
pkgs = ["numpy", 
        "tiktoken", 
        "torch",
        # "tensorflow", # For OpenAI's pretrained weights
        "pandas"
       ]
for p in pkgs:
    print(f"{p} version: {version(p)}")

numpy version: 1.23.5
tiktoken version: 0.9.0
torch version: 2.5.1
pandas version: 2.3.1


## Dataset Class

In [2]:
def prepare_datasets(data_file_path="./sms_spam_collection/SMSSpamCollection.tsv", sep="\t", header=None, column_names=["Label", "Text"], train_frac=0.7, validation_frac=0.15, store_directory="./sms_spam_collection/data_splits"):
    """This function prepares the train, test, and validation datasets from the original data.
    Code inspired from: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb
    Args:
        data_file_path (str): Path to the original dataset.
        sep (str): The separator used in the dataset file.
        column_names (list): A list of strings representing the column names used to read the dataset file using pandas.
        train_frac (float): The percentage of the data used for training.
        validation_frac (float): The percentage of the overall data used for the validation dataset.
        store_directory (str): The parent directory path that will contain the 3 datasets (train, test, and validation) .

    Returns:
        store_directory (str): The parent directory path containing the 3 datasets.
        """
    
    print(f"'prepare_datasets' function call: Using data_file_path='{data_file_path}' to find the original dataset.\n Using store_directory='{store_directory}' for the train, test, and validation dataset parent directory")

    # Construct the full paths for the output files
    train_csv_path = os.path.join(store_directory, "train.csv")
    test_csv_path = os.path.join(store_directory, "test.csv")
    validation_csv_path = os.path.join(store_directory, "validation.csv")

    if os.path.exists(train_csv_path) and os.path.exists(test_csv_path) and os.path.exists(validation_csv_path) :
        print(f"Train, Test, and Validation datasets detected in '{store_directory}', skipping generation")
    else:
        print(f"Datasets not found in '{store_directory}' or incomplete. Generating datasets...")

        df = pd.read_csv(data_file_path, sep=sep, header=header, names=column_names)

        # Count the instances of "spam"
        num_spam = df[df["Label"] == "spam"].shape[0]
        
        # Randomly sample "ham" instances to match the number of "spam" instances
        ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
        
        # Combine ham "subset" with "spam"
        balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])

        balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})

        # Shuffle the entire DataFrame
        balanced_df = balanced_df.sample(frac=1, random_state=123).reset_index(drop=True)

        # Calculate split indices
        train_end = int(len(balanced_df) * train_frac)
        validation_end = train_end + int(len(balanced_df) * validation_frac)

        # Split the DataFrame
        train_df = balanced_df[:train_end]
        test_df = balanced_df[validation_end:]
        validation_df = balanced_df[train_end:validation_end]

        # Create the directory if it doesn't exist
        os.makedirs(store_directory, exist_ok=True) 

        train_df.to_csv(train_csv_path, index=None)
        test_df.to_csv(test_csv_path, index=None)
        validation_df.to_csv(validation_csv_path, index=None)

    print(f"'prepare_datasets' function returning: {store_directory} as parent directory.")
    
    return store_directory

In [3]:
class SpamDataset(Dataset):
    """Dataset class to turn text into tokenized inputs
    Original Code in: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb"""
    def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
        try:
            self.data = pd.read_csv(csv_file)
        except FileNotFoundError:
            raise FileNotFoundError(f"File not found: {csv_file}")

        # Pre-tokenize texts
        self.encoded_texts = [
            tokenizer.encode(text) for text in self.data["Text"]    # For each row in the text section of the pandas data frame tokenize the text string(sentence); creates list of token IDs for each example/item of the text data
        ]

        if max_length is None:
            self.max_length = self._longest_encoded_length()
        else:
            self.max_length = max_length
            # Truncate sequences if they are longer than max_length
            self.encoded_texts = [
                encoded_text[:self.max_length]
                for encoded_text in self.encoded_texts
            ]

        # Pad sequences to the longest sequence
        self.encoded_texts = [
            encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
            for encoded_text in self.encoded_texts
        ]

    def __getitem__(self, index):
        encoded = self.encoded_texts[index]
        label = self.data.iloc[index]["Label"]
        return (
            torch.tensor(encoded, dtype=torch.long),
            torch.tensor(label, dtype=torch.long)
        )

    def __len__(self):
        return len(self.data)

    def _longest_encoded_length(self):
        max_length = 0
        for encoded_text in self.encoded_texts:
            encoded_length = len(encoded_text)
            if encoded_length > max_length:
                max_length = encoded_length
        return max_length

# Data Pipeline

In [4]:
def initialize_datasets_and_dataloaders_pipeline(data_file_path=None, store_directory=None, batch_size=64, num_workers=0, pin_memory=False, drop_last=True) -> tuple[Dataset, Dataset, Dataset, DataLoader, DataLoader, DataLoader, Encoding]:
    """This pipeline does the following: calls function to prepare the train, test, and validation datasets, creates the dataloaders for each dataset, initializes the tokenizer, and returns them. Original code: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb
    Args:
        data_file_path (str): Path to the original dataset.
        sep (str): The separator used in the dataset file.
        column_names (list): A list of strings representing the column names used to read the dataset file using pandas.
        train_frac (float): The percentage of the data used for training.
        validation_frac (float): The percentage of the overall data used for the validation dataset.
        store_directory (str): The parent directory path that will contain the 3 datasets (train, test, and validation).
        batch_size (int): The dataloader's batch_size.
        num_workers (int): The dataloader's number of workers.
        pin_memory (bool): The dataloader's pin memory option.
        drop_last (bool): The dataloader's drop_last option.

    Returns: 
        train_dataset (Dataset): Dataset Class for the training dataset.
        test_dataset (Dataset): Dataset Class for the test dataset.
        validation_dataset (Dataset): Dataset Class for the validation dataset.
        train_dataloader (DataLoader): The train dataloader.
        test_dataloader (DataLoader): The test dataloader.
        validation_dataloader (DataLoader): The validation dataloader.
        tokenizer (Encoding): The tokenizer used to convert the text data into tokens.
    """
    tokenizer = tiktoken.get_encoding("gpt2")

    # Construct the arguments for prepare_datasets conditionally
    prepare_args = {}
    if data_file_path is not None:
        prepare_args['data_file_path'] = data_file_path
    if store_directory is not None:
        prepare_args['store_directory'] = store_directory

    # Prepare the 3 datasets files and return their parent directory
    store_directory = prepare_datasets(**prepare_args)

    train_dataset = SpamDataset(csv_file=os.path.join(store_directory, "train.csv"), tokenizer=tokenizer)
    test_dataset = SpamDataset(csv_file=os.path.join(store_directory, "test.csv"), tokenizer=tokenizer)
    validation_dataset = SpamDataset(csv_file=os.path.join(store_directory, "validation.csv"), tokenizer=tokenizer)

    print(f"Created Train dataset with '{len(train_dataset)}' samples")
    print(f"Created Test dataset with '{len(test_dataset)}' samples")
    print(f"Created Validation dataset with '{len(validation_dataset)}' samples")

    train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last)
    test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last)
    validation_dataloader = DataLoader(dataset=validation_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last)

    print(f"Created Train Dataloader with '{len(train_dataloader)}' batches.")
    print(f"Created Test Dataloader with '{len(test_dataloader)}' batches.")
    print(f"Created Validation dataset with '{len(validation_dataloader)}' batches.")
    print(f"Each Dataloader has a batch_size of {batch_size}")
    
    return (train_dataset, test_dataset, validation_dataset, train_dataloader, test_dataloader, validation_dataloader, tokenizer)

## Building Policies

In [5]:
def build_new_policy(base_config, chosen_model="gpt2-small (124M)", num_classes=2) -> GPTModel:
    """Build and load in the GPT2 model. Swap out the Head layer, and freeze up to the last Transformer module for transfer learning. Code Inspired from: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb
    Args:
        base_config (dict): The base configurations of the gpt2 model indicating vocab_size, context_length, drop_rate, and qkv_bias.
        chosen_model (str): The specific gpt2 model to construct.
        num_classes (int): The amount of classes in the classification task.
    Returns:
        model (GPTModel): The constructed Transformer model for classification."""
    
    model_configs = {
        "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
        "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
        "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
        "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
    }

    base_config.update(model_configs[chosen_model]) # Add the emb_dim, n_layers, and n_heads to the config

    model_size = chosen_model.split(" ")[-1].lstrip("(").rstrip(")")    # Extract the number of parameters from the chosen_model
    allowed_sizes = ("124M", "355M", "774M", "1558M")
    if model_size not in allowed_sizes:
        raise ValueError(f"Model size not in {allowed_sizes}")
    settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")

    model = GPTModel(base_config)

    load_weights_into_gpt(model, params)
    
    for param in model.parameters(): # Freeze model parameters
        param.requires_grad = False 

    # Unfreeze the last transformer block
    for param in model.trf_blocks[-1].parameters():
        param.requires_grad = True

    # Unfreeze the final layer normalizing layer
    for param in model.final_norm.parameters():
        param.requires_grad = True

    model.out_head = torch.nn.Linear(in_features=base_config["emb_dim"], out_features=num_classes) # Reconfigure the output layer for the classification task
    return model

In [6]:
def build_old_policy(base_config, chosen_model="gpt2-small (124M)", num_classes = 2) -> GPTModel:
    """Construct the GPT2 model architecture without loading the weights. Code inspired from: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb
    Args:
        base_config (dict): The base configurations of the gpt2 model indicating vocab_size, context_length, drop_rate, and qkv_bias.
        chosen_model (str): The specific gpt2 model to construct.
        num_classes (int): The amount of classes in the classification task.
    Returns:
        model (GPTModel): The constructed Transformer model for classification."""
    
    model_configs = {
        "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
        "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
        "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
        "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
    }

    base_config.update(model_configs[chosen_model]) # Add the emb_dim, n_layers, and n_heads to the config

    model_size = chosen_model.split(" ")[-1].lstrip("(").rstrip(")")
    allowed_sizes = ("124M", "355M", "774M", "1558M")
    if model_size not in allowed_sizes:
        raise ValueError(f"Model size not in {allowed_sizes}")
    model = GPTModel(base_config)

    model.out_head = torch.nn.Linear(in_features=base_config["emb_dim"], out_features=num_classes) # Reconfigure the output layer
    return model

## Utility Functions

In [7]:
def calculate_discounted_rewards(predictions, batch_labels, gamma) -> torch.tensor:
    """A general function to calculate the discounted rewards for each of the model's trajectories. For this implementation, however, use non-discounted rewards.
    Args:
        predictions (torch.tensor): A flattened 1-d tensor containing the policy's predictions.
        batch_labels (torch.tensor): A flattened 1-d tensor containing the target predictions.
        gamma (float): The amount of discount to apply to future rewards [Not Used in this Implementation].
    Returns:
        disc_rewards (torch.tensor): The 1-d tensor containing the discounted rewards of each of the model's trajectories.
        """
    disc_rewards = (predictions == batch_labels).float()    # Simple comparison to evaluate rewards for each example; output a tensor of floats
    return disc_rewards

In [8]:
def log_epoch_stats(epoch, epoch_limit, total_loss, ratio, entropy) -> None:
    print(f"=====================  [Epoch ({epoch})]  =====================")
    print("Last k_epoch stats:")
    print(f"Loss: {total_loss:.7f} | Ratio: {ratio:.7f} | Entropy Term: {entropy:.7f}")
    print(f"===========================================================")

In [9]:
def evaluate_policy(Policy: Module, dataloader: DataLoader, current_epoch: int = None, max_epochs: int=None, device: str = 'cpu') -> float:
    """
    Evaluates the policy model (greedy version) on a given dataset.
    Args:
        Policy (Module): The Policy Model.
        dataloader (DataLoader): The dataloader to evaluate with.
        current_epoch (int): The current epoch [optional].
        max_epochs (int): The maximum number of epochs [optional].
        device (str): The device that the calculations will take place on.
    Returns:
        accuracy (float): The calculated accuracy.
    """
    Policy.eval()   # Turn off dropout layers and prevent grad tracking

    # Dataset check before continuing
    if len(dataloader.dataset) == 0: # Check the underlying dataset size
        print(f"Warning: Evaluation dataset is empty. Skipping accuracy calculation.")
        return float('nan')
    
    accuracy, num_correct, num_of_samples = 0.0, 0.0, 0.0

    Softmax_lyr = torch.nn.Softmax(dim=-1)  # The layer to transform the logits to probabilities
    
    with torch.no_grad():
        for batch_inputs, batch_labels in dataloader:
            batch_inputs, batch_labels = batch_inputs.to(device), batch_labels.to(device) # Move the training data to the target device

            logits = Policy(batch_inputs)[:,-1,:]   # Get logits from model and only focus on the last iterations of each sample!!!
            # print(old_logits)
            
            classification_probabilities = Softmax_lyr(logits)
            class_predictions = torch.argmax(classification_probabilities, dim=-1).flatten()
            num_of_samples += batch_labels.size(0)
            num_correct += sum((class_predictions == batch_labels).float()).item()
    
    accuracy = num_correct/num_of_samples
    if current_epoch and max_epochs:   # If the function was called in the training loop
        print(f"===================  [Epoch ({current_epoch}/{max_epochs})]  ===================")
        print(f"Entire Validation Dataset Accuracy: {accuracy:.4f}| {num_correct} / {num_of_samples} samples")
        print(f"====================================================")

    else:   # If the function was called outside of the training loop
        print(f"===============================================")
        print(f"Entire Dataset Accuracy: {accuracy:.4f} | {num_correct} / {num_of_samples} samples")
        print(f"=====================================================")

            
    Policy.train()  # Set back to training mode 
    return accuracy

In [10]:
def simple_spam_classify_single(Policy: GPTModel, input_text: list, tokenizer: Encoding, device='cpu'):
    """Used to test the Policy with a single messages to classify it as SPAM or NOT SPAM.
    Args:
        Policy (GPTModel): The Policy that will classify the text
        input_text (list): The list containing the input texts.
        tokenizer (Encoding): The tokenizer that turns text into tokens to feed the policy.
        device (str): The device to run the Policy and input tokens to.
    """
    Policy.eval().to(device)
    Softmax_lyr = torch.nn.Softmax(dim=-1)

    tokenized_text = tokenizer.encode(input_text)
    torch_text=torch.tensor(tokenized_text).unsqueeze(0)    # turn into a tensor and add a batch dimension
    model_inputs = torch_text.to(device)
    # print(f"torch_text: {torch_text} | {torch_text.shape}")
    with torch.no_grad():
        logits = Policy(model_inputs)[:,-1,:]
        Class_probabilities = Softmax_lyr(logits)
    prediction = torch.argmax(input=Class_probabilities, dim=-1)
    # print(f"prediction: {prediction}")

    print("==================================================================")
    print(f"Classifiying the following text as [SPAM or NOT SPAM]:")
    print(f"'{input_text}'")
    print(f"Prediction ... [ => {'SPAM' if prediction.item() == 1 else 'NOT SPAM'} <= ]")
    print("==================================================================")

In [11]:
def simple_spam_classify_batch(Policy: GPTModel, input_text: list, tokenizer: Encoding, device='cpu'):
    """Used to test the Policy with a batch of messages to classify them as SPAM or NOT SPAM.
    Args:
        Policy (GPTModel): The Policy that will classify the text
        input_text (list): The list containing the input texts.
        tokenizer (Encoding): The tokenizer that turns text into tokens to feed the policy.
        device (str): The device to run the Policy and input tokens to.
    """
    Policy.eval().to(device)
    Softmax_lyr = torch.nn.Softmax(dim=-1)

    tokenized_text = [
            tokenizer.encode(text) for text in input_text    # For each row in the text section of the pandas data frame tokenize the text string(sentence); creates list of token IDs for each example/item of the text data
        ]
    
    max_length = 0
    for encoded_text in tokenized_text:
        encoded_length = len(encoded_text)
        if encoded_length > max_length:
            max_length = encoded_length

    
    torch_text=torch.tensor(tokenized_text)    # turn into a tensor and add a batch dimension
    model_inputs = torch_text.to(device)
    # print(f"torch_text: {torch_text} | {torch_text.shape}")
    with torch.no_grad():
        logits = Policy(model_inputs)[:,-1,:]
        Class_probabilities = Softmax_lyr(logits)
    predictions = torch.argmax(input=Class_probabilities, dim=-1)
    print(f"predictions: {predictions}")

    bundle = zip(input_text, predictions.item())

    for i, (text_str, pred) in enumerate(bundle):
        print("==================================================================")
        print(f"Classifiying the following text:")
        print(f"[SPAM || NOT SPAM]: \n'{text_str}'")
        print(f"Prediction ... [ => {'SPAM' if pred == 1 else 'NOT SPAM'} <= ]")
        print("==================================================================")

## Training Loop

In [12]:
def grpo_train(model_config: dict, train_dataloader: DataLoader, validation_dataloader: DataLoader, gpt2_size="gpt2-small (124M)", epochs=32, learning_rate=0.0003, batch_size=64, gamma=0.99, k_epochs=64, epsilon=0.2, beta_kl=0.01, max_grad_norm=0.5, entropy_coeff=0.05, log_iterations=10, eval_iterations=10, device="cpu", num_envs:int=None) -> GPTModel:
    """The Group-Relative Policy Optimization Training function.

    Args:
        model_config (dict): The base configurations for building the policies.
        train_dataloader (DataLoader): The dataloader for the training loop.
        validation_dataloader (DataLoader): The dataloader for the validation loop.
        gpt2_size (str): The GPT2 model and parameter choice (e.g. 'gpt2-small (124M)').
        epochs (int): The number of times the outer loop is performed.
        learning_rate (float): The hyperparameter that affects how much the model's parameters learn on each update iteration.
        batch_size (int): The number of training examples that the Old Policy gathers to perform a GRPO update [Not Used in this Implementation].
        gamma (float): The discount rate used to calculate discounted rewards.
        k_epochs (int): The number of inner iterations used to update Policy New.
        epsilon (float): The hyperparameter that affects the clipping of the 'R1_ratio'.
        beta_kl (float): The hyperparameter that adjusts how much the KL Divergence term affects the overall GRPO Loss.
        max_grad_norm (float): Used to promote numerical stability and prevent exploding gradients.
        entropy_coeff (float): The hyperparameter that adjusts how much the entropy term affects the overall GRPO Loss.
        log_iterations (int): Used to log information about the state of the New Policy.
        eval_iterations (int): Used to run an evaluation of the New Policy.
        device (str): The device that the model will be trained on.
        num_envs (int): The number of parallel training environments that will be used during training [Not Used in this Implementation].

    Returns: 
        Policy_New (GPTModel): The Trained Model in evaluation mode.
    """
    # print(f"Training Policy on {device} with {epochs} main epochs, {k_epochs} inner epochs, {learning_rate} learning rate, batch size={batch_size}, KL beta={beta_kl}, gamma={gamma}, epsilon={epsilon}, beta_kl={beta_kl}, max_grad_norm={max_grad_norm}, entropy_coeff={entropy_coeff}.")
    print(f"Training Policy on {device} with {epochs} main epochs, {k_epochs} inner epochs, {learning_rate} learning rate, KL beta={beta_kl}, epsilon={epsilon}, beta_kl={beta_kl}, max_grad_norm={max_grad_norm}, entropy_coeff={entropy_coeff}.")
    print(f"Using gpt2 size: '{gpt2_size}', logging every {log_iterations} epoch iterations, evaluating every {eval_iterations} epoch iterations.")

    Policy_New = build_new_policy(model_config, chosen_model=gpt2_size, num_classes=2).to(device)   # STEP 1 || 
    Policy_New.train()
    # Policy_New = torch.compile(Policy_New) # To reap efficiency benefits ; not working due to Triton dependency

    # STEPS 2 || For I iterations --> OMITTED
    # STEPS 3 || Initialize a reference model --> OMITTED

    optimizer = optim.Adam(params=Policy_New.parameters(), lr=learning_rate)

    Policy_Old = build_old_policy(model_config, chosen_model=gpt2_size, num_classes=2).to(device)
    Policy_Old.eval()
    # Policy_Old = torch.compile(Policy_Old)

    classifier_lyr = torch.nn.Softmax(dim=-1)   # For validation loop

    for epoch in tqdm(range(epochs), desc=f">>>>>>>>>>>>>>>>>>>>>\nMain Epoch (Outer Loop)", leave=True):     # STEP 4 || 
        # STEP 5 || Sample a batch D_b from D
        batch_inputs, batch_labels = next(iter(train_dataloader))
        batch_inputs, batch_labels = batch_inputs.to(device), batch_labels.to(device) # move the training data to the target device
        # print(f"batch_inputs shape: {batch_inputs.shape}")
        # print(f"batch_labels shape: {batch_labels.shape}")

        # STEP 6 || Update the old policy model PI old <- PI new
        Policy_Old.load_state_dict(Policy_New.state_dict())
        
        # STEP 7 || Collect a Batch of Experiences Using the Old Policy
        with torch.no_grad():
            old_logits = Policy_Old(batch_inputs)[:,-1,:]   # Get logits from model and only focus on the last iterations of each sample
            # print(old_logits)
            old_dist = torch.distributions.Categorical(logits=old_logits) # Create a distribution to sample from
            old_predictions = old_dist.sample() # Tensor of shape (batch_size,) ; list of predictions
            # print(f"old_predictions: \n{old_predictions[:10]}")
            # print(f"batch_labels True Values: \n{batch_labels[:10]}")
            old_log_probs = old_dist.log_prob(old_predictions)

        # STEP 8 || Calculate "Discounted" Rewards for completed trajectories
        discounted_rewards = calculate_discounted_rewards(old_predictions, batch_labels, gamma)    # Output is a 1-d Tensor with "discounted" rewards per each sample in batch
        # print("Calculated discounted returns")
        # STEP 9 || Calculate the Advantage for each Trajectory using normalization
        all_advantages_tensor = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-8)

        Policy_New.train()  # Prepare Policy NN for updates

        # STEP 10 || GRPO Optimization ---
        for k_epoch in tqdm(range(k_epochs), desc=f"Epoch {epoch+1}/{epochs} (Inner K-Epochs)", leave=True):
            print(f"===========================  [({k_epoch+1}/{k_epochs})]  ==========================\n")
            optimizer.zero_grad()   # Flush out all the accumulated gradients for the weights of the model-under-training!!!

            new_logits = Policy_New(batch_inputs)[:,-1,:]   # Get logits from model and only focus on the last iterations of each sample!!!
            new_dist = torch.distributions.Categorical(logits=new_logits)
            new_log_probs = new_dist.log_prob(old_predictions)  # Get the log probability of choosing the same action that the old policy took using the new distribution
            entropy = new_dist.entropy().mean() # Calculate entropy for regularization
            # print(f"Entropy of this k_epoch: {entropy}")
            
            R1_ratio = torch.exp(new_log_probs - old_log_probs)

            unclipped_surrogate = R1_ratio * all_advantages_tensor
            clipped_surrogate = torch.clamp(input=R1_ratio, min=1.0-epsilon, max=1.0+epsilon) * all_advantages_tensor
            # print(f"unclipped_surrogate: \n{unclipped_surrogate[:10]}\nclipped_surrogate: \n{clipped_surrogate[:10]}")
            
            policy_loss = -torch.min(unclipped_surrogate, clipped_surrogate).mean()

            # Calculate KL divergence per sample, then take the mean over the batch
            # Note: Reusing the calculated logits from STEP #7
            kl_div_per_sample = torch.distributions.kl.kl_divergence(p=new_dist, q=old_dist)
            kl_loss = kl_div_per_sample.mean() # Mean over the batch

            # Total Loss for GRPO
            total_loss = policy_loss + beta_kl * kl_loss - entropy_coeff * entropy
            # print(f"KL Divergence Average Loss: {kl_loss}")

            # STEP 11 || Policy Updates
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(Policy_New.parameters(), max_grad_norm)
            optimizer.step()    # Update policy parameters using gradient ascent
                

        # --- Logging and Evaluation ---
        if (epoch + 1) % log_iterations == 0:
            log_epoch_stats(epoch=epoch+1, epoch_limit=epochs, total_loss=total_loss.item(), ratio=R1_ratio.mean().item(), entropy=entropy)

        if (epoch + 1) % eval_iterations == 0:
            accuracy = evaluate_policy(Policy_New, validation_dataloader, current_epoch=epoch+1, max_epochs=epochs, device=device)
                

    Policy_New.eval()   # Change to eval mode for evaluation after training is complete

    print("Training complete.")
    return Policy_New # Return the trained policy

## Main Loop

In [13]:
BASE_CONFIG = {
    "vocab_size": 50257,     # Vocabulary size
    "context_length": 1024,  # Context length
    "drop_rate": 0.1,        # Dropout rate
    "qkv_bias": True         # Query-key-value bias
}

In [14]:
# Pretend that the argument parser will pass these arguments to the main function
args = {
    "epochs":32,
    "learning_rate":0.0003,
    "dataloader_batch_size":64,
    "dataloader_pin_memory": True,  
    "dataloader_num_workers": 0,    # Problem if I change this; slow for windows; try to modify within .py script
    "batch_size":None, # Not needed in this build/project
    "gpt2_size":'gpt2-small (124M)',
    "k_epochs":64,       # GRPO Inner-loop
    "epsilon":0.2,
    "beta_kl":0.01,
    "entropy_coeff":0.05,   # 
    "log_iterations":1,     # Log GRPO stats
    "eval_iterations":1,    # Run model through evaluation at every "x" epochs
    "gamma":None,   # Discounted Rewards
    "num_envs":None,        # Not needed in this build/project
    "use_cuda": None,
    "device":'cuda',
    "save_model":True,
    "model_output_path":'models/Spam-Classifier-GPT2-Model.pt'
}

In [15]:
train_dataset, test_dataset, validation_dataset, train_dataloader, test_dataloader, validation_dataloader, tokenizer = initialize_datasets_and_dataloaders_pipeline()

'prepare_datasets' function call: Using data_file_path='./sms_spam_collection/SMSSpamCollection.tsv' to find the original dataset.
 Using store_directory='./sms_spam_collection/data_splits' for the train, test, and validation dataset parent directory
Train, Test, and Validation datasets detected in './sms_spam_collection/data_splits', skipping generation
'prepare_datasets' function returning: ./sms_spam_collection/data_splits as parent directory.
Created Train dataset with '1045' samples
Created Test dataset with '225' samples
Created Validation dataset with '224' samples
Created Train Dataloader with '16' batches.
Created Test Dataloader with '3' batches.
Created Validation dataset with '3' batches.
Each Dataloader has a batch_size of 64


In [None]:
# function call
trained_policy = grpo_train(
        model_config=BASE_CONFIG,
        train_dataloader=train_dataloader,
        validation_dataloader=validation_dataloader,
        gpt2_size=args["gpt2_size"],
        epochs=args["epochs"],
        learning_rate=args["learning_rate"],
        batch_size=args["batch_size"], # Significantly larger batch size recommended for stability
        k_epochs=args["k_epochs"],
        epsilon=args["epsilon"],
        beta_kl=args["beta_kl"],
        entropy_coeff=args["entropy_coeff"],
        log_iterations=args["log_iterations"],
        eval_iterations=args["eval_iterations"],
        gamma=args["gamma"],
        device=args["device"],
        num_envs=args["num_envs"]
    )

Training Policy on cuda with 32 main epochs, 64 inner epochs, 0.0003 learning rate, KL beta=0.01, epsilon=0.2, beta_kl=0.01, max_grad_norm=0.5, entropy_coeff=0.05.
Using gpt2 size: 'gpt2-small (124M)', logging every 1 epoch iterations, evaluating every 10 epoch iterations.
File already exists and is up-to-date: gpt2\124M\checkpoint
File already exists and is up-to-date: gpt2\124M\encoder.json
File already exists and is up-to-date: gpt2\124M\hparams.json
File already exists and is up-to-date: gpt2\124M\model.ckpt.data-00000-of-00001
File already exists and is up-to-date: gpt2\124M\model.ckpt.index
File already exists and is up-to-date: gpt2\124M\model.ckpt.meta
File already exists and is up-to-date: gpt2\124M\vocab.bpe


>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):   0%|          | 0/32 [00:00<?, ?it/s]

















































































































































Epoch 1/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.33it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):   3%|▎         | 1/32 [00:05<02:50,  5.48s/it]



Last k_epoch stats:
Loss: -0.0657852 | Ratio: 1.2813897 | Entropy Term: 0.2360256





















































































































































Epoch 2/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.56it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):   6%|▋         | 2/32 [00:10<02:39,  5.31s/it]


Last k_epoch stats:
Loss: -0.1379615 | Ratio: 3.2132883 | Entropy Term: 0.3460789


































































































































































































Epoch 3/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.56it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):   9%|▉         | 3/32 [00:15<02:32,  5.25s/it]

Last k_epoch stats:
Loss: -0.1260138 | Ratio: 0.9926201 | Entropy Term: 0.2659352


































































































































































































Epoch 4/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.59it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  12%|█▎        | 4/32 [00:21<02:26,  5.22s/it]

Last k_epoch stats:
Loss: -0.1081421 | Ratio: 0.9841046 | Entropy Term: 0.1603588


































































































































































































Epoch 5/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.55it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  16%|█▌        | 5/32 [00:26<02:20,  5.21s/it]

Last k_epoch stats:
Loss: -0.0623029 | Ratio: 0.9962623 | Entropy Term: 0.0653414


































































































































































































Epoch 6/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.54it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  19%|█▉        | 6/32 [00:31<02:15,  5.20s/it]

Last k_epoch stats:
Loss: -0.0283791 | Ratio: 0.9966048 | Entropy Term: 0.0660047


































































































































































































Epoch 7/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.53it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  22%|██▏       | 7/32 [00:36<02:10,  5.20s/it]

Last k_epoch stats:
Loss: -0.0190297 | Ratio: 0.7374310 | Entropy Term: 0.5639080


































































































































































































Epoch 8/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.54it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  25%|██▌       | 8/32 [00:41<02:04,  5.20s/it]

Last k_epoch stats:
Loss: -0.1991711 | Ratio: 0.9379761 | Entropy Term: 0.2678085


































































































































































































Epoch 9/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.52it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  28%|██▊       | 9/32 [00:47<01:59,  5.20s/it]

Last k_epoch stats:
Loss: -0.1017490 | Ratio: 0.9691651 | Entropy Term: 0.1119034


































































































































































































Epoch 10/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.50it/s]


Last k_epoch stats:
Loss: -0.0305402 | Ratio: 0.9932793 | Entropy Term: 0.1217657


>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  31%|███▏      | 10/32 [00:52<01:55,  5.25s/it]

Entire Validation Dataset Accuracy: 0.9583| 184.0 / 192.0 samples

































































































































































































Epoch 11/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.51it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  34%|███▍      | 11/32 [00:57<01:49,  5.24s/it]


Last k_epoch stats:
Loss: -0.0490867 | Ratio: 0.9822607 | Entropy Term: 0.0619192


































































































































































































Epoch 12/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.50it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  38%|███▊      | 12/32 [01:02<01:44,  5.23s/it]

Last k_epoch stats:
Loss: -0.0218447 | Ratio: 0.6859295 | Entropy Term: 0.6244335


































































































































































































Epoch 13/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.53it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  41%|████      | 13/32 [01:07<01:39,  5.22s/it]

Last k_epoch stats:
Loss: -0.1988199 | Ratio: 1.0052629 | Entropy Term: 0.3918912


































































































































































































Epoch 14/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.51it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  44%|████▍     | 14/32 [01:13<01:33,  5.22s/it]

Last k_epoch stats:
Loss: -0.0959939 | Ratio: 1.0307188 | Entropy Term: 0.0073993


































































































































































































Epoch 15/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.52it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  47%|████▋     | 15/32 [01:18<01:28,  5.21s/it]

Last k_epoch stats:
Loss: -0.0134419 | Ratio: 0.7465652 | Entropy Term: 0.5395281


































































































































































































Epoch 16/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.55it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  50%|█████     | 16/32 [01:23<01:23,  5.20s/it]

Last k_epoch stats:
Loss: -0.1597697 | Ratio: 1.0161905 | Entropy Term: 0.1319515


































































































































































































Epoch 17/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.51it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  53%|█████▎    | 17/32 [01:28<01:18,  5.20s/it]

Last k_epoch stats:
Loss: -0.0428009 | Ratio: 1.0173582 | Entropy Term: 0.0564535

































































































































































































Epoch 18/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.48it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  56%|█████▋    | 18/32 [01:33<01:12,  5.21s/it]


Last k_epoch stats:
Loss: -0.0188069 | Ratio: 0.7118240 | Entropy Term: 0.5831895


































































































































































































Epoch 19/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.52it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  59%|█████▉    | 19/32 [01:39<01:07,  5.21s/it]

Last k_epoch stats:
Loss: -0.1635605 | Ratio: 1.0597389 | Entropy Term: 0.3056451


































































































































































































Epoch 20/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.55it/s]


Last k_epoch stats:
Loss: -0.1118433 | Ratio: 0.9781338 | Entropy Term: 0.1279985


>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  62%|██████▎   | 20/32 [01:44<01:02,  5.25s/it]

Entire Validation Dataset Accuracy: 0.9427| 181.0 / 192.0 samples

































































































































































































Epoch 21/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.49it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  66%|██████▌   | 21/32 [01:49<00:57,  5.24s/it]


Last k_epoch stats:
Loss: -0.0619240 | Ratio: 0.9573950 | Entropy Term: 0.0203928


































































































































































































Epoch 22/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.50it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  69%|██████▉   | 22/32 [01:54<00:52,  5.23s/it]

Last k_epoch stats:
Loss: -0.0102705 | Ratio: 0.8563215 | Entropy Term: 0.3004874


































































































































































































Epoch 23/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.50it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  72%|███████▏  | 23/32 [02:00<00:47,  5.22s/it]

Last k_epoch stats:
Loss: -0.0929963 | Ratio: 1.0521438 | Entropy Term: 0.2094985


































































































































































































Epoch 24/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.54it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  75%|███████▌  | 24/32 [02:05<00:41,  5.21s/it]

Last k_epoch stats:
Loss: -0.0734719 | Ratio: 1.0030096 | Entropy Term: 0.0789716


































































































































































































Epoch 25/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.49it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  78%|███████▊  | 25/32 [02:10<00:36,  5.21s/it]

Last k_epoch stats:
Loss: -0.0213952 | Ratio: 0.7188818 | Entropy Term: 0.5911974


































































































































































































Epoch 26/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.52it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  81%|████████▏ | 26/32 [02:15<00:31,  5.21s/it]

Last k_epoch stats:
Loss: -0.1930408 | Ratio: 1.0075721 | Entropy Term: 0.3437636


































































































































































































Epoch 27/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.49it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  84%|████████▍ | 27/32 [02:20<00:26,  5.21s/it]

Last k_epoch stats:
Loss: -0.0971434 | Ratio: 1.0273092 | Entropy Term: 0.1221088


































































































































































































Epoch 28/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.51it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  88%|████████▊ | 28/32 [02:26<00:20,  5.21s/it]

Last k_epoch stats:
Loss: -0.0407529 | Ratio: 0.9895287 | Entropy Term: 0.0783962


































































































































































































Epoch 29/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.51it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  91%|█████████ | 29/32 [02:31<00:15,  5.21s/it]

Last k_epoch stats:
Loss: -0.0222110 | Ratio: 0.7058790 | Entropy Term: 0.6119809


































































































































































































Epoch 30/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.55it/s]


Last k_epoch stats:
Loss: -0.1908813 | Ratio: 1.0541043 | Entropy Term: 0.4049514


>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  94%|█████████▍| 30/32 [02:36<00:10,  5.25s/it]

Entire Validation Dataset Accuracy: 0.9479| 182.0 / 192.0 samples

































































































































































































Epoch 31/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.50it/s]
>>>>>>>>>>>>>>>>>>>>>
Main Epoch (Outer Loop):  97%|█████████▋| 31/32 [02:41<00:05,  5.24s/it]


Last k_epoch stats:
Loss: -0.1182924 | Ratio: 0.9904348 | Entropy Term: 0.0205331


































































































































































































Epoch 32/32 (Inner K-Epochs): 100%|██████████| 64/64 [00:05<00:00, 12.51it/s]
>>>>>>>>>>>>>>>>>>>>>
>>>>>>>>>>>>>>>>>>>>>p): 100%|██████████| 32/32 [02:47<00:00,  5.23s/it]
Main Epoch (Outer Loop): 100%|██████████| 32/32 [02:47<00:00,  5.22s/it]

Last k_epoch stats:
Loss: -0.0106396 | Ratio: 0.8660972 | Entropy Term: 0.2922595
Training complete.





In [17]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [18]:
input_text1 = "Hey, wanna go out to watch the new fantastic four movie?"
input_text2 = "XMAS Prize draws! We are trying to contact U. Todays draw shows that you have won a £2000 prize GUARANTEED. Call 09058094565 from land line. Valid 12hrs only"

In [19]:
text_batch = [input_text1, input_text2]


In [20]:
simple_spam_classify_batch(Policy=trained_policy, input_text=text_batch, tokenizer=tokenizer, device=device)


ValueError: expected sequence of length 13 at dim 1 (got 45)

In [None]:
simple_spam_classify_single(Policy=trained_policy, input_text=input_text1, tokenizer=tokenizer, device=device)


Classifiying the following text as [SPAM or NOT SPAM]:
'Hey, wanna go out to watch the new fantastic four movie?'
Prediction ... [ => SPAM <= ]


In [None]:
simple_spam_classify_single(Policy=trained_policy, input_text=input_text2, tokenizer=tokenizer, device=device)


Classifiying the following text as [SPAM or NOT SPAM]:
'XMAS Prize draws! We are trying to contact U. Todays draw shows that you have won a £2000 prize GUARANTEED. Call 09058094565 from land line. Valid 12hrs only'
Prediction ... [ => SPAM <= ]


In [21]:
acc= evaluate_policy(Policy=trained_policy, dataloader=train_dataloader, device=device)

Entire Dataset Accuracy: 0.9580 | 981.0 / 1024.0 samples


In [None]:
def main(args) -> int:
    print("Setting up for Training")
    
    if args.device:     # Check if the user specified to use a CPU or GPU for training
        device = args.device
    else:
        if args.use_cuda:   # Check if the user wanted to use CUDA if available.
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    SAVE_LOCATION = "./model/trained_model.pth"   # Define the model path and name of the trained model weights

    BASE_CONFIG = {
        "vocab_size": 50257,     # Vocabulary size
        "context_length": 1024,  # Context length
        "drop_rate": 0.1,        # Dropout rate
        "qkv_bias": True         # Query-key-value bias
    }

    # Transfer to argparser setup
    gpt2_size="gpt2-small (124M)"
    dataloader_batch_size=64
    num_workers=0
    pin_memory=True

    # --- Data Preparation Pipeline --- 
    train_dataset, test_dataset, validation_dataset, train_dataloader, test_dataloader, validation_dataloader, tokenizer = initialize_datasets_and_dataloaders_pipeline(data_file_path="./sms_spam_collection/SMSSpamCollection.tsv", store_directory="./sms_spam_collection/data_splits", num_workers=num_workers, dataloader_batch_size=dataloader_batch_size, pin_memory=pin_memory)

    print("Beginning Training Script")
    start_time=time.time()

    trained_policy = grpo_train(
        model_config=BASE_CONFIG,
        train_dataloader=train_dataloader,
        validation_dataloader=validation_dataloader,
        gpt2_size=gpt2_size,
        epochs=args.epochs,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size, # Significantly larger batch size recommended for stability
        gamma=args.gamma,
        k_epochs=args.k_epochs,
        epsilon=args.epsilon,
        beta_kl=args.beta_kl,
        entropy_coeff=args.entropy_coeff,
        log_iterations=args.log_iterations,
        eval_iterations=args.eval_iterations,
        device=device,
        num_envs=args.num_envs
    )
    end_time=time.time()

    # --- Calculate Training Time --- 

    elapsed_time= end_time - start_time
    hrs = int(elapsed_time / 3600)
    min = int((elapsed_time % 3600) / 60)
    seconds_remaining = elapsed_time - (hrs * 3600 ) - (min * 60)

    print(f"FINISHED MODEL TRAINING. \nTRAINING TOOK: {hrs} Hours, {min} Minutes, and {seconds_remaining} Seconds")

    # --- Testing Trained Model --- 
    print("\nTesting the trained policy:")

    test_dataset_accuracy = evaluate_policy(trained_policy, test_dataloader, current_epoch=None, max_epochs=None, device=device)

    # ---  Saving Model Section  ---   

    if args.save_model:     # Check if the user wants to save the trained model weights
        if args.model_output_path:     # Check if the user specified a target save location
            SAVE_LOCATION=args.model_output_path
        
        torch.save(trained_policy.parameters(), f=SAVE_LOCATION)
        print(f"Model weights saved in: {SAVE_LOCATION}")

    print("Finished Running Script")
    return 0

In [None]:
# Example usage (assuming you have a way to call this function, e.g., in a main block)
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train and test a BlackJack PPO agent.")

    # Add arguments
    parser.add_argument('--epochs', type=int, default=32,
                        help='Number of training epochs.')
    parser.add_argument('--learning_rate', type=float, default=0.0003,
                        help='Learning rate for the optimizer.')
    parser.add_argument('--dataloader_batch_size', type=int, default=64,
                        help='Dataloader Batch sizes for train, test, validation data files.')
    parser.add_argument('--dataloader_pin_memory', action='store_false',
                        help='Use pinned memory for the dataloaders.')
    parser.add_argument('--dataloader_num_workers', type=int, default=0,
                        help='Number of workers for the dataloaders.')
    parser.add_argument('--batch_size', type=int, default=1024,
                        help='Batch size for gathering experience.')
    parser.add_argument('--gpt2_size', type=str, default="gpt2-small (124M)",
                        help='GPT2 size for model construction.')
    parser.add_argument('--k_epochs', type=int, default=128,
                        help='Number of group relative policy updates per epochs.')
    parser.add_argument('--epsilon', type=float, default=0.2,
                        help='Clipping parameter for PPO.')
    parser.add_argument('--beta_kl', type=float, default=0.01,
                        help='KL divergence coefficient (for PPO-like algorithms).')
    parser.add_argument('--entropy_coeff', type=float, default=0.05,
                        help='Entropy regularization coefficient.')
    parser.add_argument('--log_iterations', type=int, default=8,
                        help='Log training progress every N iterations.')
    parser.add_argument('--eval_iterations', type=int, default=8,
                        help='Evaluate the training model every N iterations.')
    parser.add_argument('--gamma', type=float, default=0.99,
                        help='Discount factor for rewards.')
    parser.add_argument('--num_envs', type=int, default=16,
                        help='Number of parallel environments for training.')
    parser.add_argument('--use_cuda', action='store_true',
                        help='Use CUDA if available.')
    parser.add_argument('--device', type=str, default='cpu',
                        help='Explicitly set device (e.g., "cpu, cuda:0", "cpu"). Overrides --use_cuda if specified.')
    parser.add_argument('--save_model', action='store_true',
                        help='Save the trained model weights.')
    parser.add_argument('--model_output_path', type=str, default='models/Spam-Classifier-GPT2-Model.pt',
                        help='Path to save the trained model weights.')

    # Parse the arguments
    args = parser.parse_args()

    
    ret = main(args)

    if ret == 0:
        print("Terminating program")
    else: 
        print("Main Scipt Error")

## Testing