In [1]:
from tqdm import tqdm
import torch
from mathbert_encoder import MathBERTEncoder
import retriever_cosine as rc
# from import retrieve_top_k_cosine, retrieve_sample_k_cosine
from response_sampler import sample_responses_per_demo
from reward_aggregator import compute_demo_accuracy
from icl_model_wrapper import OpenAIICLModel
from grpo_optimizer import grpo_step
from datasets import load_dataset
from dotenv import load_dotenv
import os
from transformers import get_linear_schedule_with_warmup
from importlib import reload

reload(rc)

load_dotenv()

# === Settings ===
API_KEY = os.getenv("OPENAI_API_KEY")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
K = 16
NUM_SAMPLES_PER_DEMO = 5
LEARNING_RATE =  1e-5
MAX_STEPS = 1
TEMPERATURE = 0.7

# === Init ===
encoder = MathBERTEncoder(device=DEVICE, trainable=True)
encoder.train()

icl_model = OpenAIICLModel(api_key=API_KEY, model_name="gpt-4.1-nano", temperature=TEMPERATURE)
optimizer = torch.optim.Adam(encoder.parameters(), lr=LEARNING_RATE)

gsm8k_data = load_dataset('gsm8k', 'main')['train']
gsm8k_data = gsm8k_data.select(range(20))  # slice first 200 examples

# === Training Loop ===
for step in tqdm(range(MAX_STEPS), desc="Training Steps"):
    print(f"\n=== Training Step {step+1} ===")

    for inference_index in tqdm(range(len(gsm8k_data)), desc="Examples"):
        inference_item = gsm8k_data[inference_index]
        demo_pool = [d for idx, d in enumerate(gsm8k_data) if idx != inference_index]

        Q_inf = inference_item["question"]
        A_gt = inference_item["answer"]
        demos = [(d["question"], d["answer"]) for d in demo_pool]

        q_emb = encoder.encode([Q_inf], detach=False).squeeze(0)
        demo_embs = encoder.encode([q for (q, a) in demos], detach=False)

        top_k_indices, similarities = rc.retrieve_sample_k_cosine(q_emb, demo_embs, k=min(K, len(demos)))
        selected_demos = [demos[i] for i in top_k_indices]

        print(f"\n🧠 Inference Index {inference_index}")
        print(f"🔍 Top-K Indices: {top_k_indices}")

        all_responses = sample_responses_per_demo(
            demo_tuples=selected_demos,
            Q_inf=Q_inf,
            icl_model=icl_model,
            num_samples=NUM_SAMPLES_PER_DEMO
        )

        rewards = []
        for i, responses in enumerate(all_responses):
            reward = compute_demo_accuracy(responses, A_gt)
            rewards.append(reward)
            print(f"    Demo {i} | Reward: {reward:.2f}")

        rewards = torch.tensor(rewards, dtype=torch.float32).to(DEVICE)
        selected_similarities = similarities[top_k_indices]

        loss = grpo_step(
            rewards,
            selected_similarities,
            q_emb,
            demo_embs,
            optimizer
        )

        print(f"✅ Loss: {loss:.4f}")



  from .autonotebook import tqdm as notebook_tqdm
Training Steps:   0%|          | 0/1 [00:00<?, ?it/s]


=== Training Step 1 ===





🧠 Inference Index 0
🔍 Top-K Indices: tensor([ 2, 11,  4,  7,  1,  0,  3,  6, 13, 10,  8,  9, 14, 12, 15,  5],
       device='cuda:0')
    Demo 0 | Reward: 1.00
    Demo 1 | Reward: 1.00
    Demo 2 | Reward: 1.00
    Demo 3 | Reward: 1.00
    Demo 4 | Reward: 0.20
    Demo 5 | Reward: 1.00
    Demo 6 | Reward: 0.00
    Demo 7 | Reward: 0.00
    Demo 8 | Reward: 1.00
    Demo 9 | Reward: 1.00
    Demo 10 | Reward: 0.80
    Demo 11 | Reward: 0.40
    Demo 12 | Reward: 0.40
    Demo 13 | Reward: 0.00
    Demo 14 | Reward: 1.00
    Demo 15 | Reward: 1.00




✅ Loss: 1811.1215

🧠 Inference Index 1
🔍 Top-K Indices: tensor([ 5, 11,  3,  7,  1,  0,  2,  6, 13, 10,  8,  9, 14, 12, 15,  4],
       device='cuda:0')
    Demo 0 | Reward: 1.00
    Demo 1 | Reward: 1.00
    Demo 2 | Reward: 1.00
    Demo 3 | Reward: 1.00
    Demo 4 | Reward: 1.00
    Demo 5 | Reward: 0.40
    Demo 6 | Reward: 0.80
    Demo 7 | Reward: 1.00
    Demo 8 | Reward: 0.60
    Demo 9 | Reward: 0.80
    Demo 10 | Reward: 1.00
    Demo 11 | Reward: 1.00
    Demo 12 | Reward: 1.00
    Demo 13 | Reward: 1.00
    Demo 14 | Reward: 1.00
    Demo 15 | Reward: 0.20




✅ Loss: 5175.2275

🧠 Inference Index 2
🔍 Top-K Indices: tensor([ 2,  0, 11, 10,  1,  3,  7,  6, 13, 12,  8,  9, 14, 15,  4,  5],
       device='cuda:0')
    Demo 0 | Reward: 0.00
    Demo 1 | Reward: 0.00
    Demo 2 | Reward: 0.00
    Demo 3 | Reward: 0.00
    Demo 4 | Reward: 0.00
    Demo 5 | Reward: 0.00
    Demo 6 | Reward: 0.00
    Demo 7 | Reward: 0.00
    Demo 8 | Reward: 0.00
    Demo 9 | Reward: 0.00
    Demo 10 | Reward: 0.00
    Demo 11 | Reward: 0.00
    Demo 12 | Reward: 0.00
    Demo 13 | Reward: 0.00
    Demo 14 | Reward: 0.00
    Demo 15 | Reward: 0.00




✅ Loss: 642.3002

🧠 Inference Index 3
🔍 Top-K Indices: tensor([13, 10,  3,  6,  1,  0,  2,  5, 12,  9,  7,  8, 14, 11, 15,  4],
       device='cuda:0')
    Demo 0 | Reward: 0.00
    Demo 1 | Reward: 0.00
    Demo 2 | Reward: 0.00
    Demo 3 | Reward: 0.00
    Demo 4 | Reward: 0.00
    Demo 5 | Reward: 0.00
    Demo 6 | Reward: 0.00
    Demo 7 | Reward: 0.00
    Demo 8 | Reward: 0.00
    Demo 9 | Reward: 0.00
    Demo 10 | Reward: 0.20
    Demo 11 | Reward: 0.00
    Demo 12 | Reward: 0.00
    Demo 13 | Reward: 0.00
    Demo 14 | Reward: 0.00
    Demo 15 | Reward: 0.00




✅ Loss: -3606.9187

🧠 Inference Index 4
🔍 Top-K Indices: tensor([ 8,  0, 16,  2, 10,  9,  6,  7, 12, 11, 13, 14,  5,  4,  1,  3],
       device='cuda:0')


Examples:  20%|██        | 4/20 [03:43<14:54, 55.92s/it]
Training Steps:   0%|          | 0/1 [03:43<?, ?it/s]

    Demo 0 | Reward: 0.00
    Demo 1 | Reward: 0.00
    Demo 2 | Reward: 0.00
    Demo 3 | Reward: 0.00
    Demo 4 | Reward: 0.00
    Demo 5 | Reward: 0.00
    Demo 6 | Reward: 0.00
    Demo 7 | Reward: 0.00
    Demo 8 | Reward: 0.00
    Demo 9 | Reward: 0.00
    Demo 10 | Reward: 0.40
    Demo 11 | Reward: 0.00
    Demo 12 | Reward: 0.00
    Demo 13 | Reward: 0.00
    Demo 14 | Reward: 0.20
    Demo 15 | Reward: 0.40





RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
# Save the updated MathBERT model
save_path = "./updated_mathbert"  # your save directory
encoder.model.save_pretrained(save_path)
encoder.tokenizer.save_pretrained(save_path)

# LOADING

# from transformers import BertTokenizer, BertModel

# model = BertModel.from_pretrained("./updated_mathbert")
# tokenizer = BertTokenizer.from_pretrained("./updated_mathbert")
