In [1]:
# Import necessary libraries
# Basic Python libraries for various operations
import random
import copy
import re
import os
import sys
import numpy as np
import wandb
from dotenv import load_dotenv
from DGXutils import GetLowestGPU
from tqdm.auto import tqdm

# PyTorch and related libraries for deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

# Hugging Face libraries for transformer models
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

sys.path.append('../')

# custom
from utils import preprocess as pp
from utils.graph_llm import GraphLLM
from utils.llm import LLM
from utils.multiplex import Multiplex
from utils.textualize import *
from utils.bio_graphs import BiologicalDataset
from utils.evaluate import eval_funcs
from utils.config import parse_args_llama
from utils.ckpt import _save_checkpoint, _reload_best_model
from utils.collate import collate_fn
from utils.seed import seed_everything
from utils.lr_schedule import adjust_learning_rate

def set_random_seed(seed: int = 42):
    """
    Set the random seed for reproducibility across Python, NumPy, and PyTorch.

    Args:
        seed (int): The seed value to use for random number generation.

    Returns:
        None

    Explanation:
        1. Sets seed for Python's built-in random module for basic random operations.
        2. Sets seed for NumPy, ensuring consistent random number generation in array operations.
        3. Sets seed for PyTorch CPU operations.
        4. If CUDA is available, sets seed for all GPU devices.
        5. Configures cuDNN to ensure deterministic behavior:
           - Sets deterministic flag to True, ensuring reproducible results.
           - Disables benchmarking to prevent algorithm selection based on hardware.

    Note:
        Setting deterministic behavior may impact performance but ensures consistent results
        across multiple runs, which is crucial for debugging and research.
    """
    # Set the seed for Python's built-in random module
    random.seed(seed)
    # Set the seed for NumPy
    np.random.seed(seed)
    # Set the seed for PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # Ensure deterministic behavior in cuDNN (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Call the function to set random seed for reproducibility
set_random_seed(42)

load_dotenv()
os.environ["WANDB_API_KEY"] = os.getenv("WANDB_API_KEY")
os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT")
os.environ["WANDB_ENTITY"] = os.getenv("WANDB_ENTITY")

# set visible devices to gpus 0-3
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3,4,5,6,7"

# Setup 

In [22]:
T = 512

# get dataset to see what we're working with
path = "../data/subgraphs/all/"
dataset = BiologicalDataset(path)
loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

In [3]:
# load model to see what we're working with
model = GraphLLM(max_txt_len=T,
                max_new_tokens=200,
                llm_model_path='meta-llama/Meta-Llama-3.1-8B-Instruct',
                llm_frozen=False, # set frozen to false so we can train with RL
                fsdp=False, 
                )

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Training with Lora


# Extraction + Evaluation

In [5]:
def extract_answer(text):
    """
    Extract answer from the model output.

    Args:
        text (str): The model output text.
    
    Returns:
        str: The extracted answer.
    """
    
    # extract answer from prediction
    ans = ''.join(re.findall(r"<answer>(.*?)</answer>", text)[-1]) 
    ans = ans.lower() 
    return ans

In [6]:
def evaluate_model(model, batch):
    """
    Evaluate the model on a set of examples provided by a PyTorch DataLoader.

    Args:
        model (GraphLLM): The model to evaluate.
        dataloader (torch.utils.data.DataLoader): A DataLoader yielding evaluation batches.
            Each batch is expected to be a dictionary with keys such as 'id', 'question', 
            'scope', 'label', 'desc', and 'graph'. The values for 'label', 'desc', and 'question'
            should be lists (or tensors in the case of labels) of the same batch size.

    Returns:
        float: The accuracy of the model on the evaluation examples.
    
    References:
        - PyTorch DataLoader documentation: https://pytorch.org/docs/stable/data.html
        - Accelerate library for device placement and distributed inference: https://huggingface.co/docs/accelerate 
    """
    model.eval()
    correct = 0

    batch_size = len(batch["label"])
    print("\n" + "=" * 50)
    print(f"EVALUATION ON {batch_size} EXAMPLES")
    print("=" * 50)

    # Perform model inference on the whole batch with no gradient computation.
    with torch.no_grad():
        outputs = model.inference(batch)

    # Assume outputs["pred"] is a list or tensor of predictions of length equal to batch_size.
    for i in range(batch_size):
        # Extract the predicted answer for this example.
        predicted = extract_answer(outputs["pred"][i])
        expected = batch["label"][i]
        is_correct = (predicted == expected)
        if is_correct:
            correct += 1

        # Print details for this example.
        print("\nPrompt:")
        print(batch["desc"][i] + ' ' + batch["question"][i])
        print("\nExpected Answer:")
        print(expected)
        print("\nExtracted Answer:")
        print(predicted)
        print("\nFull Generated Response:")
        # If outputs["pred"] is a tensor or list of strings, print accordingly.
        print(outputs["pred"][i])
        print("\nCorrect:", "✓" if is_correct else "✗")
        print("-" * 50)

    accuracy = (correct / batch_size) * 100
    print(f"\nAccuracy: {accuracy:.2f}% ({correct}/{batch_size})")
    print("=" * 50)

    # Switch model back to training mode after evaluation.
    model.train()
    return accuracy

# Reward Functions

In [7]:
# function to reward formatting
def reward_format(gt, pred):
    """
    if the answer is in the correct format, reward 1.25, else reward -1
    """
    
    # answer format
    pattern = r"^<think>.*?</think><answer>.*?</answer>$"

    return 1.25 if re.match(pattern, pred, re.DOTALL | re.VERBOSE) else -1

# define reward function for node connectivity
def reward_correct_yn(gt, pred) -> int: 
    """
    given a yes/no answer and ground truth, return 1 if correct, -1 if incorrect
    """

    # extract answer from prediction
    ans = ''.join(re.findall(r"<answer>(.*?)</answer>", pred)) 
    ans = ans.lower() 

    # if the model produced an answer, compare it to the ground truth - return 1 if correct, -1 if incorrect
    if ans == gt:
        return 1
    else:
        return -1
    
def combined_reward(gt, pred):
    """
    combined reward function for yes/no questions and answer formatting
    """
    return reward_correct_yn(gt, pred) + reward_format(gt, pred)

# GRPO Train Functions

In [26]:
def selective_log_softmax(logits, input_ids):
    """
    Computes log probabilities for specific tokens in the vocabulary.

    Args:
        logits (torch.Tensor): The raw logits output from the model.
        input_ids (torch.Tensor): The token IDs for which we want the log probabilities.

    Returns:
        torch.Tensor: Log probabilities of the selected tokens.

    Explanation:
        1. Applies log softmax to convert logits to log probabilities over the vocabulary.
        2. Uses gather to extract only the log probabilities corresponding to the input_ids.
        3. Removes the extra dimension to match the original shape of input_ids.
    """
    log_probs = nn.functional.log_softmax(logits, dim=-1)
    return log_probs.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

def compute_log_probs(model, batch, logits_to_keep):
    """
    Computes the log probabilities for a batch of tokens.

    Args:
        model: The language model.
        input_ids (torch.Tensor): Token IDs for input sequences.
        attention_mask (torch.Tensor): Attention mask for input sequences.
        logits_to_keep (int): Number of tokens to keep from the end of the sequence.

    Returns:
        torch.Tensor: Log probabilities of the selected tokens.

    Explanation:
        1. Gets logits from the model for the input sequence.
        2. Selects logits for all tokens except the last one (as we predict next tokens).
        3. Selects only the last 'logits_to_keep' tokens from both logits and input_ids.
        4. Computes log probabilities for these tokens using selective_log_softmax.
    """
    _, out = model(batch)
    logits = out.logits[:, -1, :]
    input_ids = input_ids[:, -logits_to_keep:]
    logits = logits[:, -logits_to_keep:, :]
    return selective_log_softmax(logits, input_ids)

def create_completion_mask(completion_ids, eos_token_id):
    """
    Creates a mask for completion tokens that excludes tokens after the EOS token.

    Args:
        completion_ids (torch.Tensor): Token IDs of the generated completions.
        eos_token_id (int): The ID of the end-of-sequence token.

    Returns:
        torch.Tensor: A binary mask with 1s for valid tokens and 0s after the EOS token.

    Explanation:
        1. Identifies positions where EOS tokens occur in each sequence.
        2. Finds the index of the first EOS token in each sequence.
        3. Creates a mask where positions before and including the first EOS are 1, others are 0.
        4. If no EOS token is found in a sequence, all positions are set to 1.
    """
    is_eos = completion_ids == eos_token_id
    eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device)
    mask_exists = is_eos.any(dim=1)
    eos_idx[mask_exists] = is_eos.int().argmax(dim=1)[mask_exists]
    sequence_indices = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1)
    return (sequence_indices <= eos_idx.unsqueeze(1)).int()

def generate_completions(model, batch, num_generations=4, max_completion_length=32):
    """
    Generates multiple completions for each prompt.

    Args:
        model: The language model.
        tokenizer: The tokenizer for encoding and decoding text.
        prompts (list): List of text prompts.
        num_generations (int): Number of completions to generate per prompt.
        max_completion_length (int): Maximum number of tokens to generate.

    Returns:
        tuple: Containing prompt IDs, prompt mask, completion IDs, and completion mask.

    Explanation:
        1. Encodes the prompts and moves them to the appropriate device.
        2. Repeats each prompt num_generations times to generate multiple completions.
        3. Generates completions using the model with specified parameters.
        4. Extracts the completion IDs (excluding the prompt tokens).
        5. Creates a mask for the completions using create_completion_mask.
    """

    # tokenize prompt inputs
    prompt_inputs = [batch["desc"][i] + batch["question"][i] for i in range(len(batch["desc"]))]
    inputs = model.tokenizer(prompt_inputs, return_tensors="pt", padding=True, padding_side="left")
    prompt_ids = inputs["input_ids"]
    prompt_mask = inputs["attention_mask"]
    print(f"Input batch size: {prompt_ids.size(0)}")

    prompt_length = prompt_ids.size(1)
    prompt_ids = prompt_ids.repeat_interleave(num_generations, dim=0)
    prompt_mask = prompt_mask.repeat_interleave(num_generations, dim=0)

    outputs = model.inference(batch, num_generations=num_generations)

    print(f"Output batch size: {outputs['out_ids'].size(0)}")
    completion_ids = outputs["out_ids"][:, prompt_length:]
    completion_mask = create_completion_mask(completion_ids, model.tokenizer.eos_token_id)
    return prompt_ids, prompt_mask, completion_ids, completion_mask

In [24]:
batch = next(iter(loader))
batch

{'id': [58449],
 'question': ['Is there an edge between nodes 239 and 120?'],
 'scope': ['all'],
 'label': ['no'],
 'desc': ['A question with a yes/no answer is provided along with a graph. Answer the question based on the graph. Provide reasoning inside of <think></think> tags and the answer inside of <answer></answer> tags.'],
 'graph': DataBatch(x=[250, 1024], edge_index=[2, 2234], num_nodes=250, batch=[250], ptr=[2])}

In [27]:
prompt_ids, prompt_mask, completion_ids, completion_mask = generate_completions(model, batch, num_generations=4, max_completion_length=32)


Input batch size: 1
Output batch size: 4
