In [None]:
%%capture

!pip install unsloth
!pip install vllm
!pip install cairosvg

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import sys

sys.path.append('/content/drive/MyDrive/dev-files')

In [None]:

%%capture

import re
import itertools
import math
import time
import random
from collections import Counter
from pprint import pprint
import json

import numpy as np
import pandas as pd
from numpy.random import choice, randint
from IPython.display import HTML, display, clear_output
import matplotlib.pyplot as plt
import ipywidgets as widgets

# Utilities for plotting
from symbolic_utilities import progress, compute_global_limits_smc, plot_mh_trace_upto, plot_state_2d
# MHMC sampler
from symbolic_utilities import propose_tree, get_coordinates, \
    mh_sampler, smc_sampler, define_bs_DSL, define_lt_DSL, enumerate_full_sentences

from neural_utilities import extract_xml_answer, extract_xml_reasoning, produce_tasks, get_data

from neural_utilities import print_func, lt_correctness_reward_func, \
    xmlcount_reward_func, soft_format_reward_func, strict_format_reward_func, cfg_reward_func, lt_correctness_reward_func, \
    direct_cfg_reward_func, direct_lt_correctness_reward_func, direct_conciseness_reward_func

from symbolic_utilities import \
    ltgrammar, lt_nonterminals, lt_terminals, lt_eval_dict, \
    bsgrammar, bs_nonterminals, bs_terminals, bs_eval_dict

from evaluation import generate_answers, select_best, get_accuracy

# NOTE: PatchFastRL needs to run **before** the imports below
from unsloth import FastLanguageModel, is_bfloat16_supported, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)

import torch, gc
from torch import tensor
from datasets import load_dataset, Dataset, DatasetDict
from transformers import EarlyStoppingCallback, TextStreamer, TrainingArguments
from peft import AutoPeftModelForCausalLM
from trl import SFTTrainer, GRPOConfig, GRPOTrainer, SFTConfig
from unsloth.chat_templates import get_chat_template
from vllm import SamplingParams

from dotenv import load_dotenv, find_dotenv
import os
from openai import OpenAI
from tqdm import tqdm

print(torch.cuda.get_device_name(0))

In [None]:
max_seq_length = 1024 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower

lt_system_prompt = ""

# get all sentences up to depth 5
sentences_pool = []
for i, sent in enumerate(enumerate_full_sentences('T', ltgrammar, max_depth=6)):
    sentences_pool.append(sent)
    if i==500000:
        break

print(sentences_pool[:5])

data = get_data(
    ltgrammar,
    lt_system_prompt,
    eval_dict=lt_eval_dict,
    n_tasks=100000,
    sentences_pool=sentences_pool
)

# the 'sentence' is what we want the model to output for a given input.
data = data.map(lambda x: {
    'completion': [{'content': x['sentence'], 'role': 'assistant'}],
})

# 90% train, 10% test + validation
train_testvalid = data.train_test_split(train_size=2**16, test_size=2* 2**7)
# Split the 10% test + valid in half test, half valid
test_valid = train_testvalid['test'].train_test_split(test_size=0.5)
# gather everyone if you want to have a single DatasetDict
data = DatasetDict({
    'train': train_testvalid['train'],
    'test': test_valid['test'],
    'valid': test_valid['train']})

data

sentence: target program\
e.g. filter_(and_(or_(gt(2),gt(2)),or_(even,even)))

examples:
[[[5, 5, 8], [8]], [[3, 4], [4]], ...] - shows input and output pairs

task:
human-readable version of the examples which could be directly fed into the prompt

prompt:
the prompt containing `task` in a LLM friendly manner

completions: the sentence according to the task

In [None]:
d = data['train'][0]['prompt']
# data['train'][0]['completion']

In [None]:
c = Counter()
for s in data['train']['sentence']:
   c.update({k:s.count(k) for k in lt_terminals})

c

In [None]:
data['test'].to_pandas().head()

## start building model here

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="Qwen/Qwen2.5-0.5B-Instruct",
    max_seq_length=max_seq_length,
    load_in_4bit=True,
    # fast_inference=True,
    max_lora_rank=lora_rank,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_alpha=lora_rank,
    use_gradient_checkpointing = "unsloth",
    random_state=3407,
)

In [None]:
inputs = tokenizer.apply_chat_template(
    d,
    add_generation_prompt=True,
    tokenize=True,
    return_tensors="pt",
).to("cuda")

In [None]:
FastLanguageModel.for_inference(model)
text_streamer = TextStreamer(tokenizer)
_ = model.generate(
    inputs,
    streamer=text_streamer,
    max_new_tokens=512,
)

In [None]:
%%capture
FastLanguageModel.for_training(model)

In [None]:
# Format function
def format_dataset(examples):
    texts = []
    for prompt_msgs, completion_msgs in zip(examples["prompt"], examples["completion"]):
        # Combine prompt and completion messages
        messages = prompt_msgs + completion_msgs

        # Filter out empty system messages
        messages = [msg for msg in messages if msg["content"].strip()]

        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
        texts.append(text)
    return {"text": texts}

# Apply formatting to create the "text" column
formatted_train = data['train'].map(format_dataset, batched=True, remove_columns=data['train'].column_names)
formatted_test = data['test'].map(format_dataset, batched=True, remove_columns=data['test'].column_names)

In [None]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=formatted_train,
    eval_dataset=formatted_test,
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=True,
    args=SFTConfig(
        learning_rate=3e-5,
        lr_scheduler_type="linear",
        per_device_train_batch_size=8,
        gradient_accumulation_steps=2,
        num_train_epochs=0.4,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.05,
        warmup_steps=10,
        output_dir="lt_SFT_noreasoning",
        seed=0,
        save_steps=100,
        fp16_full_eval=True,
        per_device_eval_batch_size=2,
        eval_accumulation_steps=4,
        eval_strategy="steps",
        eval_steps=10,
    ),
)

In [None]:
trainer.train()

In [None]:
trainer.save_model('finetuned_lt')

# from google.colab import drive
# drive.mount('/content/drive')

# !cp -r /content/finetuned_lt/ /content/drive/MyDrive/

In [None]:
%%capture
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "/content/drive/MyDrive/dev-files/finetuned_lt", # "finetuned_lt",
    max_seq_length=max_seq_length,
    load_in_4bit=True,
    # Enable vLLM fast inference
    # fast_inference = True,
)

In [None]:
tokenizer.padding_side = 'left'

prompts = [
    tokenizer.apply_chat_template(p, add_generation_prompt=True, tokenize=False)
    for p in data['test']['prompt']
]

In [None]:
%%capture
FastLanguageModel.for_inference(model)

In [None]:
answers = generate_answers(model, tokenizer, data['test']['prompt'])

In [None]:
get_accuracy(answers, data['test'])

## now let's start using RL!

In [None]:
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        # print_func,
        # direct_cfg_reward_func,
        direct_lt_correctness_reward_func,
        # direct_conciseness_reward_func
    ],
    args=GRPOConfig(
        # use vLLM for fast inference! (it raises an error)
        # use_vllm = True,
        learning_rate = 5e-6,
        adam_beta1 = 0.9,
        adam_beta2 = 0.99,
        weight_decay = 0.1,
        warmup_ratio = 0.1,
        lr_scheduler_type = "cosine",
        optim = "adamw_8bit",
        logging_steps = 1,
        bf16 = is_bfloat16_supported(),
        fp16 = not is_bfloat16_supported(),
        per_device_train_batch_size = 1,
        # Increase to 4 for smoother training
        gradient_accumulation_steps = 1,
        # Decrease if out of memory
        num_generations = 8,
        max_prompt_length = 256,
        max_completion_length = 64,
        # Set to 1 for a full training run
        num_train_epochs = 1,
        max_steps = 10000,
        save_steps = 500,
        output_dir="/content/drive/MyDrive/dev-files/grpo_checkpoints",
        max_grad_norm = 0.1,
    ),
    train_dataset=data['train'],
)

In [None]:
trainer.train(
    resume_from_checkpoint=True
)

In [None]:
trainer.save_model('rl_finetuned')

!cp -r /content/rl_finetuned/ /content/drive/MyDrive/dev-files

In [None]:
df_history = pd.DataFrame(trainer.state.log_history)

In [None]:
smoothed_rewards = df_history['rewards/direct_lt_correctness_reward_func'].rolling(window=100).mean()

# Plotting the raw reward and the trend (smoothed reward)
plt.figure(figsize=(10, 5))
plt.plot(df_history.index, smoothed_rewards, label="Trend (Moving Average)", color="red", linewidth=2)
plt.scatter(df_history.index, df_history['rewards/direct_lt_correctness_reward_func'], s=1)
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.ylim(0,3)
plt.legend()
plt.show()

In [None]:
%%capture
model.for_inference()

In [None]:
answers = generate_answers(model, tokenizer, data['test']['prompt'])

In [None]:
get_accuracy(answers, data['test'])

In [None]:
answers = generate_answers(
    model, tokenizer,
    data['test']['prompt'],
    examples=data['test']['examples'],  # pass the I/O examples
    eval_dict=lt_eval_dict,
    batch_size=4,
    num_return_sequences=32,
    do_sample=True,
    temperature=0.7
)

## Testing eval performance on different model checkpoints

Final straw to see what's happening with the reward signal ðŸ˜§

In [None]:
# Get all checkpoints
checkpoint_dir = "/content/drive/MyDrive/dev-files/grpo_checkpoints"
checkpoints = sorted(
    [d for d in os.listdir(checkpoint_dir) if d.startswith("checkpoint-")],
    key=lambda x: int(x.split("-")[1])
)

print(f"Found {len(checkpoints)} checkpoints: {checkpoints}")

results = {}

for ckpt in tqdm(checkpoints):
    ckpt_path = os.path.join(checkpoint_dir, ckpt)
    step = int(ckpt.split("-")[1])

    # Load model from checkpoint
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=ckpt_path,
        max_seq_length=max_seq_length,
        load_in_4bit=True,
    )
    tokenizer.padding_side = 'left'
    FastLanguageModel.for_inference(model)

    # Generate answers on test set - MAYBE NEED TO INCREASE THIS NUMBER!!
    test_subset = data['test'].select(range(64))
    answers = generate_answers(model, tokenizer, test_subset['prompt'])

    # Compute accuracy (using your existing function)
    acc = get_accuracy(answers, test_subset)
    results[step] = acc
    print(f"Checkpoint {step}: {acc:.4f}")

    # Clear memory
    del model
    torch.cuda.empty_cache()
    gc.collect()

# Plot the learning curve
steps, accs = zip(*sorted(results.items()))
plt.figure(figsize=(10, 5))
plt.plot(steps, accs, 'o-', markersize=4)
plt.xlabel("Training Steps")
plt.ylabel("Test Accuracy")
plt.title("GRPO Checkpoint Evaluation")
plt.grid(True, alpha=0.3)
plt.savefig("/content/drive/MyDrive/dev-files/checkpoint_eval_curve.png", dpi=150)
plt.show()

# Save results
import json
with open("/content/drive/MyDrive/dev-files/checkpoint_results.json", "w") as f:
    json.dump(results, f, indent=2)