In [None]:
from tqdm import tqdm
import torch
from mathbert_encoder import MathBERTEncoder
import retriever_cosine as rc
# from import retrieve_top_k_cosine, retrieve_sample_k_cosine
from response_sampler import sample_responses_per_demo
from reward_aggregator import compute_demo_accuracy
from icl_model_wrapper import OpenAIICLModel
from grpo_optimizer import grpo_step
from datasets import load_dataset
from dotenv import load_dotenv
import os
from transformers import get_linear_schedule_with_warmup
from importlib import reload

reload(rc)

load_dotenv()

# === Settings ===
API_KEY = os.getenv("OPENAI_API_KEY")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
K = 16
NUM_SAMPLES_PER_DEMO = 5
LEARNING_RATE =  1e-5
MAX_STEPS = 5
TEMPERATURE = 0.7

# === Init ===
encoder = MathBERTEncoder(device=DEVICE, trainable=True)
encoder.train()

icl_model = OpenAIICLModel(api_key=API_KEY, model_name="gpt-4.1-nano", temperature=TEMPERATURE)
optimizer = torch.optim.Adam(encoder.parameters(), lr=LEARNING_RATE)

gsm8k_data = load_dataset('gsm8k', 'main')['train']
gsm8k_data = gsm8k_data.select(range(256))  # slice first 200 examples
gsm8k_data_to_infer = gsm8k_data.select(range(5))  # slice first 200 examples

# === Training Loop ===
for step in tqdm(range(MAX_STEPS), desc="Training Steps"):
    print(f"\n=== Training Step {step+1} ===")

    for inference_index in tqdm(range(len(gsm8k_data_to_infer)), desc="Examples"):
        inference_item = gsm8k_data[inference_index]
        demo_pool = [d for idx, d in enumerate(gsm8k_data) if idx != inference_index]

        Q_inf = inference_item["question"]
        A_gt = inference_item["answer"]
        demos = [(d["question"], d["answer"]) for d in demo_pool]

        q_emb = encoder.encode([Q_inf], detach=False).squeeze(0)
        demo_embs = encoder.encode([q for (q, a) in demos], detach=False)

        top_k_indices, similarities = rc.retrieve_sample_k_cosine(q_emb, demo_embs, k=min(K, len(demos)))
        selected_demos = [demos[i] for i in top_k_indices]

        print(f"\n🧠 Inference Index {inference_index}")
        print(f"🔍 Top-K Indices: {top_k_indices}")

        all_responses = sample_responses_per_demo(
            demo_tuples=selected_demos,
            Q_inf=Q_inf,
            icl_model=icl_model,
            num_samples=NUM_SAMPLES_PER_DEMO
        )

        rewards = []
        for i, responses in enumerate(all_responses):
            reward = compute_demo_accuracy(responses, A_gt)
            rewards.append(reward)
            # print(f"    Demo {i} | Reward: {reward:.2f}")

        rewards = torch.tensor(rewards, dtype=torch.float32).to(DEVICE)

        loss = grpo_step(
            rewards,
            similarities,
            q_emb,
            demo_embs,
            optimizer
        )

        print(f"✅ Loss: {loss:.4f}")



Training Steps:   0%|          | 0/5 [00:00<?, ?it/s]


=== Training Step 1 ===





🧠 Inference Index 0
🔍 Top-K Indices: tensor([194, 131, 242, 157,  77, 146, 175,   4, 160, 238, 197, 222, 249, 199,
        220, 168], device='cuda:0')
tensor([ 0.8456,  0.8456, -0.5074,  0.8456, -0.9583,  0.8456,  0.8456,  0.8456,
         0.8456,  0.8456, -1.4093, -0.0564, -1.4093, -1.4093,  0.3946, -1.4093],
       device='cuda:0')
tensor(0.4435, device='cuda:0')
tensor([-5.1454, -4.3108, -8.3210, -5.2740, -4.4696, -5.1823, -5.6701, -4.7556,
        -3.6997, -5.5623, -5.4184, -4.9934, -5.0418, -6.1091, -4.2354, -5.9010],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: -0.3311

🧠 Inference Index 1
🔍 Top-K Indices: tensor([218,  87, 154, 163,  17, 226,  82,  67, 143, 197,  76,  33, 104, 220,
         84, 196], device='cuda:0')
tensor([-2.1137, -0.6433,  0.8271,  0.8271,  0.8271,  0.8271,  0.8271,  0.8271,
        -1.3785,  0.0919,  0.0919, -0.6433,  0.8271, -1.3785, -0.6433,  0.8271],
       device='cuda:0')
tensor(0.2720, device='cuda:0')
tensor([-5.2589, -5.8771, -5.4172, -4.5478, -6.0249, -5.9719, -6.0917, -4.8652,
        -6.2500, -4.3484, -4.1903, -5.8066, -5.7541, -5.0994, -4.8990, -5.8860],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: 0.0132

🧠 Inference Index 2
🔍 Top-K Indices: tensor([ 23,  16, 159,  64, 125, 185, 104,  22,  86, 138,  59, 236, 231,  11,
        132,   8], device='cuda:0')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
tensor(1.0000e-04, device='cuda:0')
tensor([-5.1143, -4.9787, -4.1707, -3.9710, -4.7787, -5.7246, -4.6564, -4.7204,
        -5.8670, -4.9582, -5.4938, -4.6734, -5.4807, -4.6598, -5.4189, -4.7008],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: -0.0000

🧠 Inference Index 3
🔍 Top-K Indices: tensor([ 13,  81, 201, 156, 229, 110, 233,  46,  56, 159, 113, 140,  78, 192,
        124,  32], device='cuda:0')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
tensor(1.0000e-04, device='cuda:0')
tensor([-3.9856, -4.7440, -5.1136, -6.1375, -4.5493, -3.7004, -6.5981, -5.5969,
        -4.2024, -4.3169, -5.9472, -5.4135, -5.9562, -5.9105, -4.9632, -5.5239],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: -0.0000

🧠 Inference Index 4
🔍 Top-K Indices: tensor([ 29, 203, 202, 244, 192, 144, 134, 176, 173, 194, 157, 166, 217,  95,
        112, 163], device='cuda:0')
tensor([-0.4564, -0.4564, -0.4564, -0.4564, -0.4564, -0.4564, -0.4564, -0.4564,
         0.4564, -0.4564,  0.4564,  3.1950, -0.4564, -0.4564, -0.4564,  1.3693],
       device='cuda:0')
tensor(0.2191, device='cuda:0')
tensor([-5.4522, -5.5894, -3.6526, -5.8748, -5.1846, -4.7333, -6.1995, -5.6799,
        -4.4903, -4.2771, -4.9436, -6.0303, -3.8308, -5.6954, -5.8779, -5.7498],
       device='cuda:0', grad_fn=<IndexBackward0>)


Examples: 100%|██████████| 5/5 [03:18<00:00, 39.64s/it]
Training Steps:  20%|██        | 1/5 [03:18<13:12, 198.20s/it]

✅ Loss: 0.1953

=== Training Step 2 ===





🧠 Inference Index 0
🔍 Top-K Indices: tensor([ 61,  91,  98, 154, 217, 228,   4, 123, 153,   3, 224, 237,  37, 191,
        236, 220], device='cuda:0')
tensor([-1.5077,  0.5454,  1.0586, -1.5077, -0.4812, -1.5077,  1.0586,  0.0321,
         0.5454, -1.5077,  1.0586,  0.5454,  1.0586,  0.0321,  0.5454,  0.0321],
       device='cuda:0')
tensor(0.3897, device='cuda:0')
tensor([-5.3046, -4.6960, -6.5801, -4.9103, -6.1965, -5.6340, -3.7565, -4.9843,
        -5.0018, -5.6715, -6.8213, -5.4343, -5.0571, -6.1420, -5.3174, -3.8577],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: -0.0174

🧠 Inference Index 1
🔍 Top-K Indices: tensor([184,  37, 207, 251,  19,   8,  76,  62,  26, 137, 129, 160, 178, 201,
        141, 154], device='cuda:0')
tensor([ 0.7287,  0.0810,  0.7287,  0.7287,  0.0810, -0.5668, -1.8623, -0.5668,
        -2.5100,  0.7287, -0.5668,  0.7287,  0.7287,  0.7287,  0.0810,  0.7287],
       device='cuda:0')
tensor(0.3088, device='cuda:0')
tensor([-5.7786, -5.3247, -4.6060, -5.8521, -5.0467, -4.6162, -4.5253, -4.8948,
        -4.9147, -5.0262, -4.9026, -6.4407, -5.8219, -5.5181, -5.2575, -4.1002],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: 0.2358

🧠 Inference Index 2
🔍 Top-K Indices: tensor([183, 238,  80, 189, 154, 112, 159, 252,  37,  45, 168, 223, 220, 204,
         99,   8], device='cuda:0')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
tensor(1.0000e-04, device='cuda:0')
tensor([-6.1123, -4.1177, -4.4104, -5.7861, -6.4446, -5.3398, -3.7722, -5.7668,
        -4.1658, -7.3370, -7.7732, -5.9772, -5.2241, -5.9636, -5.3656, -4.4173],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: -0.0000

🧠 Inference Index 3
🔍 Top-K Indices: tensor([212,  24, 103,  93,  19, 223, 206, 237, 174, 149, 175,  30,  20, 240,
        229,   1], device='cuda:0')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
tensor(1.0000e-04, device='cuda:0')
tensor([-5.4534, -5.5201, -6.5064, -5.7796, -6.0349, -5.5206, -5.1551, -4.0949,
        -5.8499, -5.9529, -5.4827, -5.2768, -5.7629, -5.1516, -4.4070, -6.9638],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: -0.0000

🧠 Inference Index 4
🔍 Top-K Indices: tensor([116, 140, 149,  89, 150, 248, 110, 243,  68,  35, 246, 211, 199, 196,
         24,  81], device='cuda:0')
tensor([-0.2500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500,
        -0.2500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500,  3.7500],
       device='cuda:0')
tensor(0.0500, device='cuda:0')
tensor([-5.7195, -4.6046, -6.3385, -4.5890, -5.5211, -4.7957, -5.1232, -6.1087,
        -4.9333, -5.4732, -7.4061, -4.1805, -6.2001, -5.4392, -5.4651, -4.7046],
       device='cuda:0', grad_fn=<IndexBackward0>)


Examples: 100%|██████████| 5/5 [02:49<00:00, 33.93s/it]
Training Steps:  40%|████      | 2/5 [06:07<09:04, 181.41s/it]

✅ Loss: -0.1770

=== Training Step 3 ===





🧠 Inference Index 0
🔍 Top-K Indices: tensor([137, 186,  60, 177,  83, 161,   1, 233,  94,  59,   4, 138,  81, 230,
        185,  44], device='cuda:0')
tensor([-1.1107, -1.1107, -0.2221, -0.2221,  1.1107, -0.6664, -1.1107,  1.1107,
         1.1107,  1.1107,  1.1107, -1.1107, -0.6664, -1.1107,  0.6664,  1.1107],
       device='cuda:0')
tensor(0.4502, device='cuda:0')
tensor([-4.6149, -5.9698, -4.5673, -4.7276, -5.2845, -4.6536, -4.3347, -4.7922,
        -5.9147, -6.2993, -4.3382, -4.6063, -4.3064, -6.0478, -6.4308, -5.5178],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: 0.2219

🧠 Inference Index 1
🔍 Top-K Indices: tensor([216, 154, 152,  32, 187, 159, 110,  48, 157, 125, 164,   8, 204, 193,
        177, 124], device='cuda:0')
tensor([ 0.1936,  0.9682,  0.9682,  0.9682, -1.3555,  0.9682,  0.9682, -0.5809,
        -2.1301,  0.1936, -0.5809,  0.9682, -0.5809, -1.3555,  0.1936,  0.1936],
       device='cuda:0')
tensor(0.2582, device='cuda:0')
tensor([-5.0549, -4.7842, -5.6304, -4.9849, -6.2269, -4.9631, -5.4400, -6.8245,
        -4.3126, -6.7031, -4.3579, -5.4229, -5.3314, -6.9176, -4.3204, -5.7362],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: -0.1337

🧠 Inference Index 2
🔍 Top-K Indices: tensor([ 80, 220,  82,  84, 161,  28, 127, 251,  58, 119,  20, 236,  36,  41,
         30, 183], device='cuda:0')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
tensor(1.0000e-04, device='cuda:0')
tensor([-3.9159, -5.4330, -5.0236, -5.9459, -4.6576, -6.0128, -5.7390, -5.4344,
        -5.7743, -5.4262, -5.7456, -4.2494, -5.1139, -5.8046, -4.6115, -6.8162],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: -0.0000

🧠 Inference Index 3
🔍 Top-K Indices: tensor([ 26,  37, 104, 177,  17,  60, 140, 117,   9, 204, 149, 193, 222,   6,
        156, 143], device='cuda:0')
tensor([-0.2500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500,
        -0.2500, -0.2500,  3.7500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500],
       device='cuda:0')
tensor(0.0500, device='cuda:0')
tensor([-4.3985, -3.6619, -4.3918, -4.8656, -5.7011, -5.1925, -5.2015, -4.7866,
        -5.5185, -5.0953, -5.8171, -5.8671, -5.4881, -7.0398, -4.9963, -4.3981],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: 0.1665

🧠 Inference Index 4
🔍 Top-K Indices: tensor([143,  80, 233,   9,  64, 225, 167, 151, 204,  62, 211,  25, 223, 109,
         69,  17], device='cuda:0')
tensor([-0.2500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500,
        -0.2500,  3.7500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500],
       device='cuda:0')
tensor(0.0500, device='cuda:0')
tensor([-5.3958, -5.4613, -4.5443, -5.2438, -4.6022, -5.3108, -5.5604, -5.5012,
        -5.2376, -6.1026, -5.1004, -5.6553, -5.2865, -5.2916, -6.0371, -6.2014],
       device='cuda:0', grad_fn=<IndexBackward0>)


Examples: 100%|██████████| 5/5 [03:03<00:00, 36.62s/it]
Training Steps:  60%|██████    | 3/5 [09:10<06:04, 182.19s/it]

✅ Loss: 0.1736

=== Training Step 4 ===





🧠 Inference Index 0
🔍 Top-K Indices: tensor([150, 188, 102,  40,  11,  37, 237,   1, 103, 167, 130, 121, 105, 104,
          8,  23], device='cuda:0')
tensor([ 0.8280,  0.8280, -0.5915, -1.5378,  0.3549,  0.8280, -0.1183, -1.5378,
         0.8280,  0.3549,  0.3549, -1.5378, -1.5378,  0.8280,  0.8280,  0.8280],
       device='cuda:0')
tensor(0.4227, device='cuda:0')
tensor([-4.5484, -7.3329, -5.2084, -5.4376, -5.5828, -4.8591, -5.1967, -3.8067,
        -6.1593, -5.5976, -5.4290, -4.7948, -4.6353, -4.9659, -3.8999, -4.0540],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: 0.1964

🧠 Inference Index 1
🔍 Top-K Indices: tensor([211,  95,  49, 157, 233, 181, 136,  58, 201, 145,  47,  20,  88, 250,
        227, 215], device='cuda:0')
tensor([-1.1929, -0.1023,  0.9884,  0.4431,  0.9884,  0.4431, -1.7383,  0.9884,
         0.9884, -0.1023, -1.7383,  0.4431,  0.4431, -0.6476, -1.1929,  0.9884],
       device='cuda:0')
tensor(0.3667, device='cuda:0')
tensor([-5.4420, -4.3240, -5.2243, -4.9031, -4.8259, -5.2990, -5.9148, -5.3513,
        -4.2279, -5.1477, -6.6325, -5.3957, -5.5835, -4.2031, -5.5104, -4.7142],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: -0.3200

🧠 Inference Index 2
🔍 Top-K Indices: tensor([109, 233, 112,  67, 235,  75,  64,  37, 125, 239, 205,  65, 142, 194,
         36, 117], device='cuda:0')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
tensor(1.0000e-04, device='cuda:0')
tensor([-6.2929, -5.0003, -4.4130, -5.4473, -5.7478, -4.7631, -4.9842, -4.3250,
        -4.6143, -5.6872, -6.6292, -3.6572, -6.0327, -5.3463, -6.0972, -4.9444],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: -0.0000

🧠 Inference Index 3
🔍 Top-K Indices: tensor([129,  29,  49,  47, 159,  72,   8, 194, 145, 187, 232,  24, 117, 192,
        176, 224], device='cuda:0')
tensor([-0.2500, -0.2500, -0.2500,  3.7500, -0.2500, -0.2500, -0.2500, -0.2500,
        -0.2500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500],
       device='cuda:0')
tensor(0.0500, device='cuda:0')
tensor([-5.5028, -6.7819, -3.3453, -5.7307, -5.3695, -4.1733, -4.4171, -6.7028,
        -4.6699, -5.5706, -7.2216, -5.5014, -4.0198, -5.4283, -5.2018, -5.3066],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: 0.1054

🧠 Inference Index 4
🔍 Top-K Indices: tensor([168,  49,  20,  36, 182,   2, 120, 114,  35, 125, 207,  44,  77, 150,
        100,  68], device='cuda:0')
tensor([-0.3660,  2.5617, -0.3660, -0.3660, -0.3660, -0.3660, -0.3660, -0.3660,
        -0.3660,  2.5617, -0.3660, -0.3660, -0.3660, -0.3660, -0.3660, -0.3660],
       device='cuda:0')
tensor(0.0683, device='cuda:0')
tensor([-5.9560, -5.2797, -5.3349, -4.2228, -4.6860, -5.5160, -5.5321, -5.0060,
        -5.6139, -4.2365, -4.8170, -5.6048, -4.8593, -5.7493, -5.0928, -4.7979],
       device='cuda:0', grad_fn=<IndexBackward0>)


Examples: 100%|██████████| 5/5 [02:52<00:00, 34.41s/it]
Training Steps:  80%|████████  | 4/5 [12:03<02:58, 178.18s/it]

✅ Loss: -0.1412

=== Training Step 5 ===





🧠 Inference Index 0
🔍 Top-K Indices: tensor([244,  26,  10,  28, 215,  98,  65, 175,  69, 217, 140, 159,   9,  19,
        187, 130], device='cuda:0')
tensor([-1.4630, -1.4630,  0.1330,  0.6650,  1.1970,  0.1330,  0.6650,  1.1970,
        -1.4630, -0.3990,  1.1970,  0.6650,  0.6650, -1.4630,  0.1330, -0.3990],
       device='cuda:0')
tensor(0.3759, device='cuda:0')
tensor([-6.2727, -4.8072, -5.0861, -5.5377, -3.9404, -5.3990, -6.0288, -5.7879,
        -4.6695, -6.5060, -6.4245, -4.3273, -5.9697, -4.8239, -5.4092, -5.5173],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: 0.0682

🧠 Inference Index 1
🔍 Top-K Indices: tensor([ 98, 197, 178, 233, 209,  49, 131, 191, 177, 116, 184, 100, 124, 199,
        144,  26], device='cuda:0')
tensor([-0.1863, -2.4224,  0.5590,  0.5590,  0.5590,  0.5590,  0.5590, -0.1863,
         0.5590,  0.5590,  0.5590,  0.5590, -0.1863, -0.1863,  0.5590, -2.4224],
       device='cuda:0')
tensor(0.2683, device='cuda:0')
tensor([-4.6756, -3.5942, -4.7653, -5.1762, -4.7495, -4.2407, -4.2206, -6.2633,
        -4.7452, -4.5677, -4.7534, -5.8612, -5.6512, -5.3931, -5.1343, -5.8437],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: -0.0004

🧠 Inference Index 2
🔍 Top-K Indices: tensor([201,  37, 177, 231,  17,  52, 197, 210,  72, 107, 223,   8, 117,  98,
        157, 132], device='cuda:0')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
tensor(1.0000e-04, device='cuda:0')
tensor([-3.3002, -3.6954, -3.8424, -6.4328, -5.9578, -4.6884, -6.4111, -5.5998,
        -5.1886, -6.0345, -5.6689, -4.8982, -5.6503, -4.0359, -5.4935, -5.5356],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: -0.0000

🧠 Inference Index 3
🔍 Top-K Indices: tensor([ 10, 147,  88, 152, 157,   5, 194, 143,  61,  50,  11,  58,  60,  23,
        104, 129], device='cuda:0')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
tensor(1.0000e-04, device='cuda:0')
tensor([-3.4395, -4.6450, -6.1735, -6.0821, -4.3212, -3.9897, -5.7370, -5.1578,
        -6.7016, -5.5341, -5.4668, -5.4358, -5.4115, -5.0638, -4.6127, -4.2444],
       device='cuda:0', grad_fn=<IndexBackward0>)




✅ Loss: -0.0000

🧠 Inference Index 4
🔍 Top-K Indices: tensor([ 72, 198, 200, 124,  49,  45,  83,  11,  14, 222, 201,  34, 197,  92,
         31, 146], device='cuda:0')
tensor([-0.2500, -0.2500, -0.2500, -0.2500,  3.7500, -0.2500, -0.2500, -0.2500,
        -0.2500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500, -0.2500],
       device='cuda:0')
tensor(0.1000, device='cuda:0')
tensor([-4.6053, -5.9042, -5.1781, -6.1039, -4.4286, -4.4074, -5.3460, -3.9077,
        -5.0142, -5.6397, -5.7098, -6.6044, -5.3218, -5.6438, -6.2366, -4.2818],
       device='cuda:0', grad_fn=<IndexBackward0>)


Examples: 100%|██████████| 5/5 [03:01<00:00, 36.40s/it]
Training Steps: 100%|██████████| 5/5 [15:04<00:00, 181.00s/it]

✅ Loss: -0.2106





: 

In [None]:
# Save the updated MathBERT model
save_path = "./updated_mathbert"  # your save directory
encoder.model.save_pretrained(save_path)
encoder.tokenizer.save_pretrained(save_path)

# LOADING

# from transformers import BertTokenizer, BertModel

# model = BertModel.from_pretrained("./updated_mathbert")
# tokenizer = BertTokenizer.from_pretrained("./updated_mathbert")
