In [None]:
from tqdm import tqdm
import torch
from mathbert_encoder import MathBERTEncoder
import retriever_cosine as rc
# from import retrieve_top_k_cosine, retrieve_sample_k_cosine
from response_sampler import sample_responses_per_demo
from reward_aggregator import compute_demo_accuracy
from icl_model_wrapper import OpenAIICLModel
from grpo_optimizer import grpo_step
from datasets import load_dataset
from dotenv import load_dotenv
import os
from transformers import get_linear_schedule_with_warmup
from importlib import reload

reload(rc)

load_dotenv()

# === Settings ===
API_KEY = os.getenv("OPENAI_API_KEY")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
K = 16
NUM_SAMPLES_PER_DEMO = 5
LEARNING_RATE =  1e-5
MAX_STEPS = 1
TEMPERATURE = 0.7

# === Init ===
encoder = MathBERTEncoder(device=DEVICE, trainable=True)
encoder.train()

icl_model = OpenAIICLModel(api_key=API_KEY, model_name="gpt-4.1-nano", temperature=TEMPERATURE)
optimizer = torch.optim.Adam(encoder.parameters(), lr=LEARNING_RATE)

gsm8k_data = load_dataset('gsm8k', 'main')['train']
gsm8k_data = gsm8k_data.select(range(20))  # slice first 200 examples

# === Training Loop ===
for step in tqdm(range(MAX_STEPS), desc="Training Steps"):
    print(f"\n=== Training Step {step+1} ===")

    for inference_index in tqdm(range(len(gsm8k_data)), desc="Examples"):
        inference_item = gsm8k_data[inference_index]
        demo_pool = [d for idx, d in enumerate(gsm8k_data) if idx != inference_index]

        Q_inf = inference_item["question"]
        A_gt = inference_item["answer"]
        demos = [(d["question"], d["answer"]) for d in demo_pool]

        q_emb = encoder.encode([Q_inf], detach=False).squeeze(0)
        demo_embs = encoder.encode([q for (q, a) in demos], detach=False)

        top_k_indices, similarities = rc.retrieve_sample_k_cosine(q_emb, demo_embs, k=min(K, len(demos)))
        selected_demos = [demos[i] for i in top_k_indices]

        print(f"\n🧠 Inference Index {inference_index}")
        print(f"🔍 Top-K Indices: {top_k_indices}")

        all_responses = sample_responses_per_demo(
            demo_tuples=selected_demos,
            Q_inf=Q_inf,
            icl_model=icl_model,
            num_samples=NUM_SAMPLES_PER_DEMO
        )

        rewards = []
        for i, responses in enumerate(all_responses):
            reward = compute_demo_accuracy(responses, A_gt)
            rewards.append(reward)
            print(f"    Demo {i} | Reward: {reward:.2f}")

        rewards = torch.tensor(rewards, dtype=torch.float32).to(DEVICE)
        print('similarities:', similarities.shape)      # expected (N,) or (batch, N)
        print('idx_set:',     top_k_indices.shape)      # expected (k,) or (batch, k)
        print('idx_set max:', top_k_indices.max())

        loss = grpo_step(
            rewards,
            similarities,
            q_emb,
            demo_embs,
            optimizer
        )

        print(f"✅ Loss: {loss:.4f}")



Training Steps:   0%|          | 0/1 [00:00<?, ?it/s]


=== Training Step 1 ===





🧠 Inference Index 0
🔍 Top-K Indices: tensor([ 2, 14,  3,  5,  1,  4, 11, 15,  8, 16, 10,  6, 13, 18, 17,  9])
---------------0-----------
Q: Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?
A: Maila read 12 x 2 = <<12*2=24>>24 pages today.
So she was able to read a total of 12 + 24 = <<12+24=36>>36 pages since yesterday.
There are 120 - 36 = <<120-36=84>>84 pages left to be read.
Since she wants to read half of the remaining pages tomorrow, then she should read 84/2 = <<84/2=42>>42 pages.
#### 42

Q: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
A:
48 + 48/2 = 48 + 24 = 72
---------------0-----------
Q: James creates a media empire.  He creates a movie for $2000.  Each DVD cost $6 to make.  He sells it for 2.

Examples:   0%|          | 0/20 [00:36<?, ?it/s]
Training Steps:   0%|          | 0/1 [00:36<?, ?it/s]

    Demo 0 | Reward: 0.60
    Demo 1 | Reward: 0.20
    Demo 2 | Reward: 0.00
    Demo 3 | Reward: 1.00
    Demo 4 | Reward: 0.20
    Demo 5 | Reward: 1.00
    Demo 6 | Reward: 1.00
    Demo 7 | Reward: 1.00
    Demo 8 | Reward: 0.80
    Demo 9 | Reward: 1.00
    Demo 10 | Reward: 0.60
    Demo 11 | Reward: 0.00
    Demo 12 | Reward: 0.80
    Demo 13 | Reward: 0.60
    Demo 14 | Reward: 0.80
    Demo 15 | Reward: 0.60
similarities: torch.Size([16])
idx_set: torch.Size([16])
idx_set max: tensor(18)





IndexError: index 16 is out of bounds for dimension 0 with size 16

In [None]:
# Save the updated MathBERT model
save_path = "./updated_mathbert"  # your save directory
encoder.model.save_pretrained(save_path)
encoder.tokenizer.save_pretrained(save_path)

# LOADING

# from transformers import BertTokenizer, BertModel

# model = BertModel.from_pretrained("./updated_mathbert")
# tokenizer = BertTokenizer.from_pretrained("./updated_mathbert")
