# Base Model

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_from_disk  # or load_dataset if remote
import re

# -----------------------------
# Configuration
# -----------------------------
MODEL_NAME = "t5-small"
MAX_INPUT_LENGTH = 512
MAX_TARGET_LENGTH = 128
BATCH_SIZE = 4
NUM_EPOCHS = 3
LR = 3e-5
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# -----------------------------
# Load Dataset
# -----------------------------
full_dataset = load_from_disk("masked_dataset")

# split train test
train_test = full_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = train_test["train"]
test_dataset = train_test["test"]

dataset = train_dataset

def remove_mask_content(example):
    text = example["input_text"]
    example["input_text"] = re.sub(r"\[MASK_START\].*?\[MASK_END\]", "[MASK_START][MASK_END]", text, flags=re.DOTALL)
    return example

dataset = dataset.map(remove_mask_content)
test_dataset = test_dataset.map(remove_mask_content)

# Option 2: or if it’s in memory already:
# from datasets import Dataset
# dataset = Dataset.from_dict({"input_text": [...], "target_text": [...]})

# -----------------------------
# Load tokenizer and model
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)

# -----------------------------
# Tokenize function for datasets.map()
# -----------------------------
def preprocess_function(batch):
    model_inputs = tokenizer(
        batch["input_text"],
        max_length=MAX_INPUT_LENGTH,
        padding="max_length",
        truncation=True,
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            batch["target_text"],
            max_length=MAX_TARGET_LENGTH,
            padding="max_length",
            truncation=True,
        )

    # Replace pad token IDs with -100 so they’re ignored in cross-entropy loss
    labels["input_ids"] = [
        [(token if token != tokenizer.pad_token_id else -100) for token in label]
        for label in labels["input_ids"]
    ]
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# -----------------------------
# Apply preprocessing
# -----------------------------
tokenized_dataset = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset.column_names,
)

tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

tokenized_test = test_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=test_dataset.column_names,
)

tokenized_test.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

# Create DataLoader
dataloader = DataLoader(tokenized_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(tokenized_test, batch_size=BATCH_SIZE)

# -----------------------------
# Optimizer
# -----------------------------
optimizer = AdamW(model.parameters(), lr=LR)

In [8]:
from tqdm import tqdm

# -----------------------------
# Training loop
# -----------------------------
model.train()
for epoch in range(NUM_EPOCHS):
    total_loss = 0

    # tqdm progress bar for batches
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=True)

    for batch in pbar:
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Avg Loss: {avg_loss:.4f}")

Epoch 1/3: 100%|██████████| 5169/5169 [18:23<00:00,  4.69it/s]


Epoch 1/3 | Avg Loss: 1.9718


Epoch 2/3: 100%|██████████| 5169/5169 [18:06<00:00,  4.76it/s]


Epoch 2/3 | Avg Loss: 1.7341


Epoch 3/3: 100%|██████████| 5169/5169 [18:03<00:00,  4.77it/s]

Epoch 3/3 | Avg Loss: 1.6260





In [9]:
# -----------------------------
# Save model & tokenizer
# -----------------------------
SAVE_PATH = "./model1"

model.save_pretrained(SAVE_PATH)
tokenizer.save_pretrained(SAVE_PATH)
print(f"Model and tokenizer saved to {SAVE_PATH}")

Model and tokenizer saved to ./model1


In [None]:
# Example of test prediction

def generate_masked_span(model, tokenizer, input_text, max_new_tokens=128):
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True).to(DEVICE)
    output_ids = model.generate(
        **inputs,
        max_new_tokens=250,
        min_new_tokens=30,
        num_beams=5,
        length_penalty=1.4,
        repetition_penalty=1.1,
        no_repeat_ngram_size=3,
        do_sample=True,
        temperature=0.8,
        top_p=0.9,
    )
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

'''
| Category                  | Parameters                                                                  | Purpose                            |
| ------------------------- | --------------------------------------------------------------------------- | ---------------------------------- |
| **Length control**        | `max_new_tokens`, `min_new_tokens`, `length_penalty`                        | Controls output size               |
| **Quality (beam search)** | `num_beams`, `length_penalty`, `no_repeat_ngram_size`, `repetition_penalty` | Improves coherence, avoids loops   |
| **Creativity (sampling)** | `do_sample`, `temperature`, `top_p`                                         | Adds randomness and variation      |
| **Structure**             | `**inputs`                                                                  | Provides prompt and attention mask |
'''

# -----------------------------
# Example inference
# -----------------------------

# Load the saved model and tokenizer
SAVE_PATH = "./model1"
model = AutoModelForSeq2SeqLM.from_pretrained(SAVE_PATH).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(SAVE_PATH)
test_example = test_dataset[0]["input_text"]
predicted = generate_masked_span(model, tokenizer, test_example)
print("\n--- Predicted masked section ---\n")
print(predicted)

In [13]:
# More test predictions

for i in range(5):
  test_example = test_dataset[i]["input_text"]
  predicted = generate_masked_span(model, tokenizer, test_example)
  print("\n--- Input Example ---\n")
  print(test_example)
  print("\n--- Predicted Masked Sentence ---\n")
  print(predicted)
  print("\n--- Target Sentence ---\n")
  print(test_dataset[i]["target_text"])


--- Input Example ---

<DELETE> <START_OUTLINE> <BOE> Amy discovers a big envelope on the table while preparing a gift for her friend Tom's birthday. <EOE> <BOE> Amy feels happy with the wrapped gift, but her mom is upset when she sees it. <EOE> <BOE> Amy's mom explains that the envelope contained important papers needed for her work. <EOE> <BOE> Feeling guilty, Amy helps her mom unwrap the gift to retrieve the important papers. <EOE> <BOE> They return the papers to the envelope and find another way to wrap Tom's gift. <EOE> <BOE> Through this experience, Amy learns the importance of listening to her mom and being obedient. <EOE> <END_OUTLINE> <START_STORY> One day, a little girl named Amy found a big envelope on the table. She wanted to wrap a gift for her friend Tom's birthday. [MASK_START][MASK_END] She was very happy with how it looked. But when her mom saw the wrapped gift, she was not happy. The envelope had important papers inside for her work. They needed to find the papers be

# Experiments

## Setup

In [None]:
import re
from datasets import load_from_disk
from transformers import AutoTokenizer
from torch.utils.data import DataLoader


def remove_mask_content(example):
    text = example["input_text"]
    example["input_text"] = re.sub(
        r"\[MASK_START\].*?\[MASK_END\]",
        "[MASK_START][MASK_END]",
        text,
        flags=re.DOTALL,
    )
    return example


def make_preprocess_function(tokenizer, max_input_length, max_target_length):
    """
    Factory that returns a preprocess_function configured with
    the chosen max_input_length and max_target_length.
    """
    def preprocess_function(batch):
        # Encoder inputs
        model_inputs = tokenizer(
            batch["input_text"],
            max_length=max_input_length,
            padding="max_length",
            truncation=True,
        )

        # Decoder targets
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                batch["target_text"],
                max_length=max_target_length,
                padding="max_length",
                truncation=True,
            )

        # Replace pad token IDs with -100 so they’re ignored in loss
        labels["input_ids"] = [
            [
                (token if token != tokenizer.pad_token_id else -100)
                for token in label
            ]
            for label in labels["input_ids"]
        ]
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    return preprocess_function


def prepare_dataloaders(
    data_path: str = "masked_dataset",
    model_name: str = "t5-small",
    subset_frac: float = 0.25,
    max_input_length: int = 512,
    max_target_length: int = 128,
    batch_size: int = 4,
    seed: int = 42,
):
    """
    Loads the dataset, applies cleaning, takes a subset of the train split,
    tokenizes, and returns DataLoaders (plus tokenizer & tokenized datasets).
    """

    # Load dataset and split
    full_dataset = load_from_disk(data_path)
    train_test = full_dataset.train_test_split(test_size=0.1, seed=seed)
    train_dataset = train_test["train"]
    test_dataset = train_test["test"]

    # Clean mask content
    train_dataset = train_dataset.map(remove_mask_content)
    test_dataset = test_dataset.map(remove_mask_content)

    # Subset 25% of the training data (after cleaning)
    if 0 < subset_frac < 1.0:
        train_dataset = train_dataset.shuffle(seed=seed)
        subset_size = int(len(train_dataset) * subset_frac)
        train_dataset = train_dataset.select(range(subset_size))

    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Preprocessing function with chosen lengths
    preprocess_fn = make_preprocess_function(
        tokenizer,
        max_input_length=max_input_length,
        max_target_length=max_target_length,
    )

    # Tokenize train and test
    tokenized_train = train_dataset.map(
        preprocess_fn,
        batched=True,
        remove_columns=train_dataset.column_names,
    )
    tokenized_test = test_dataset.map(
        preprocess_fn,
        batched=True,
        remove_columns=test_dataset.column_names,
    )

    # Set torch format
    tokenized_train.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "labels"],
    )
    tokenized_test.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "labels"],
    )

    # Dataloaders
    train_loader = DataLoader(tokenized_train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(tokenized_test, batch_size=batch_size)

    return {
        "tokenizer": tokenizer,
        "train_loader": train_loader,
        "test_loader": test_loader,
        "tokenized_train": tokenized_train,
        "tokenized_test": tokenized_test,
    }


## Models

### Stage 0 BaseLine Model

In [5]:
import torch
from transformers import AutoModelForSeq2SeqLM
from torch.optim import AdamW
from tqdm import tqdm

# Core configs for Stage 0
MODEL_NAME = "t5-small"
LR = 3e-5
BATCH_SIZE = 4
NUM_EPOCHS = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Prepare data (25% train subset, 512/128 lengths)
data = prepare_dataloaders(
    data_path="masked_dataset",
    model_name=MODEL_NAME,
    subset_frac=0.25,
    max_input_length=512,
    max_target_length=128,
    batch_size=BATCH_SIZE,
    seed=42,
)
tokenizer = data["tokenizer"]
train_loader = data["train_loader"]
test_loader = data["test_loader"]

# Model
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)

# Optimizer
optimizer = AdamW(model.parameters(), lr=LR)

# -----------------------------
# Stage 0: Training + Validation Loop
# -----------------------------
for epoch in range(NUM_EPOCHS):
    # ---- Training ----
    model.train()
    total_train_loss = 0.0

    train_pbar = tqdm(
        train_loader,
        desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]",
        leave=True,
    )

    for batch in train_pbar:
        optimizer.zero_grad()

        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
        train_pbar.set_postfix({"batch_loss": loss.item()})

    avg_train_loss = total_train_loss / len(train_loader)

    # ---- Validation ----
    model.eval()
    total_val_loss = 0.0

    with torch.no_grad():
        val_pbar = tqdm(
            test_loader,
            desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Val]",
            leave=False,
        )
        for batch in val_pbar:
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            loss = outputs.loss
            total_val_loss += loss.item()

            val_pbar.set_postfix({"batch_loss": loss.item()})

    avg_val_loss = total_val_loss / len(test_loader)

    print(
        f"Epoch {epoch+1}/{NUM_EPOCHS} | "
        f"Train Loss: {avg_train_loss:.4f} | "
        f"Val Loss: {avg_val_loss:.4f}"
    )

Epoch 1/2 [Train]: 100%|██████████| 1292/1292 [04:29<00:00,  4.79it/s, batch_loss=1.86]
                                                                                   

Epoch 1/2 | Train Loss: 2.2187 | Val Loss: 1.8843


Epoch 2/2 [Train]: 100%|██████████| 1292/1292 [04:30<00:00,  4.78it/s, batch_loss=1.51]
                                                                                    

Epoch 2/2 | Train Loss: 1.9613 | Val Loss: 1.7831




In [None]:
# -----------------------------
# Save final model checkpoint
# -----------------------------
# save_path = "stage0_baseline_model"
# model.save_pretrained(save_path)
# tokenizer.save_pretrained(save_path)

# print(f"Model saved to {save_path}")

Model saved to stage0_baseline_model


### Stage 1 LR Experiments

In [None]:
# -----------------------------
# Stage 1: Learning Rate Sweep
# -----------------------------

LR_LIST = [1e-5, 5e-5, 1e-4, 1e-3, 1e-2]
NUM_EPOCHS = 2                     
BATCH_SIZE = 4
MODEL_NAME = "t5-small"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Prepare dataloaders once (no need to redo for each LR)
data = prepare_dataloaders(
    data_path="masked_dataset",
    model_name=MODEL_NAME,
    subset_frac=0.25,
    max_input_length=512,
    max_target_length=128,
    batch_size=BATCH_SIZE,
    seed=42,
)
tokenizer = data["tokenizer"]
train_loader = data["train_loader"]
test_loader = data["test_loader"]

from transformers import AutoModelForSeq2SeqLM
from torch.optim import AdamW
import torch
from tqdm import tqdm

for lr in LR_LIST:
    print(f"\n============================")
    print(f"  Starting LR experiment: {lr}")
    print(f"============================\n")

    # --- Reinitialize the model for every LR ---
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
    optimizer = AdamW(model.parameters(), lr=lr)

    # --- Train + validate for NUM_EPOCHS ---
    for epoch in range(NUM_EPOCHS):

        # ----- Training -----
        model.train()
        total_train_loss = 0.0

        pbar = tqdm(train_loader, desc=f"LR {lr} | Epoch {epoch+1}/{NUM_EPOCHS} [Train]")
        for batch in pbar:
            optimizer.zero_grad()

            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            loss = outputs.loss
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            pbar.set_postfix({"loss": loss.item()})

        avg_train_loss = total_train_loss / len(train_loader)

        # ----- Validation -----
        model.eval()
        total_val_loss = 0.0

        with torch.no_grad():
            vbar = tqdm(test_loader, desc=f"LR {lr} | Epoch {epoch+1}/{NUM_EPOCHS} [Val]", leave=False)
            for batch in vbar:
                input_ids = batch["input_ids"].to(DEVICE)
                attention_mask = batch["attention_mask"].to(DEVICE)
                labels = batch["labels"].to(DEVICE)

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                )
                total_val_loss += outputs.loss.item()
                vbar.set_postfix({"loss": outputs.loss.item()})

        avg_val_loss = total_val_loss / len(test_loader)

        print(f"[LR {lr}] Epoch {epoch+1}/{NUM_EPOCHS} | "
              f"Train Loss: {avg_train_loss:.4f} | "
              f"Val Loss: {avg_val_loss:.4f}")

    # ----- Save checkpoint for this LR -----
    # save_dir = f"checkpoint_lr_{lr}"
    # model.save_pretrained(save_dir)
    # tokenizer.save_pretrained(save_dir)
    # print(f"Saved model for LR={lr} to '{save_dir}'\n")



  Starting LR experiment: 1e-05



LR 1e-05 | Epoch 1/2 [Train]: 100%|██████████| 1292/1292 [04:30<00:00,  4.78it/s, loss=2.25]
                                                                                        

[LR 1e-05] Epoch 1/2 | Train Loss: 2.4031 | Val Loss: 2.0461


LR 1e-05 | Epoch 2/2 [Train]: 100%|██████████| 1292/1292 [04:30<00:00,  4.77it/s, loss=2.32]
                                                                                        

[LR 1e-05] Epoch 2/2 | Train Loss: 2.1493 | Val Loss: 1.9317
Saved model for LR=1e-05 to 'checkpoint_lr_1e-05'


  Starting LR experiment: 5e-05



LR 5e-05 | Epoch 1/2 [Train]: 100%|██████████| 1292/1292 [04:30<00:00,  4.77it/s, loss=1.83]
                                                                                         

[LR 5e-05] Epoch 1/2 | Train Loss: 2.1369 | Val Loss: 1.8149


LR 5e-05 | Epoch 2/2 [Train]: 100%|██████████| 1292/1292 [04:31<00:00,  4.75it/s, loss=1.97] 
                                                                                         

[LR 5e-05] Epoch 2/2 | Train Loss: 1.8757 | Val Loss: 1.7110
Saved model for LR=5e-05 to 'checkpoint_lr_5e-05'


  Starting LR experiment: 0.0001



LR 0.0001 | Epoch 1/2 [Train]: 100%|██████████| 1292/1292 [04:31<00:00,  4.76it/s, loss=2.4]  
                                                                                          

[LR 0.0001] Epoch 1/2 | Train Loss: 2.0448 | Val Loss: 1.7309


LR 0.0001 | Epoch 2/2 [Train]: 100%|██████████| 1292/1292 [04:30<00:00,  4.77it/s, loss=1.28] 
                                                                                          

[LR 0.0001] Epoch 2/2 | Train Loss: 1.7640 | Val Loss: 1.6220
Saved model for LR=0.0001 to 'checkpoint_lr_0.0001'


  Starting LR experiment: 0.001



LR 0.001 | Epoch 1/2 [Train]: 100%|██████████| 1292/1292 [04:31<00:00,  4.76it/s, loss=2.2]  
                                                                                         

[LR 0.001] Epoch 1/2 | Train Loss: 1.9597 | Val Loss: 1.6397


LR 0.001 | Epoch 2/2 [Train]: 100%|██████████| 1292/1292 [04:30<00:00,  4.77it/s, loss=1.7]  
                                                                                         

[LR 0.001] Epoch 2/2 | Train Loss: 1.5781 | Val Loss: 1.5512
Saved model for LR=0.001 to 'checkpoint_lr_0.001'


  Starting LR experiment: 0.01



LR 0.01 | Epoch 1/2 [Train]: 100%|██████████| 1292/1292 [04:30<00:00,  4.77it/s, loss=3.43]
                                                                                       

[LR 0.01] Epoch 1/2 | Train Loss: 3.2018 | Val Loss: 2.9429


LR 0.01 | Epoch 2/2 [Train]: 100%|██████████| 1292/1292 [04:30<00:00,  4.78it/s, loss=4.93]
                                                                                       

[LR 0.01] Epoch 2/2 | Train Loss: 3.7458 | Val Loss: 4.5052
Saved model for LR=0.01 to 'checkpoint_lr_0.01'



### Stage 2 Batch Size Experiments

In [None]:
# -----------------------------
# Stage 2: Batch Size Sweep
# -----------------------------

import torch
from transformers import AutoModelForSeq2SeqLM
from torch.optim import AdamW
from tqdm import tqdm

MODEL_NAME = "t5-small"
LR = 1e-3                 # chosen from Stage 1
BATCH_SIZE_LIST = [4, 6, 8]   
NUM_EPOCHS = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

for batch_size in BATCH_SIZE_LIST:
    print("\n====================================")
    print(f"  Starting batch size experiment: {batch_size}")
    print("====================================\n")

    # Prepare data for this batch size (25% subset, same lengths)
    data = prepare_dataloaders(
        data_path="masked_dataset",
        model_name=MODEL_NAME,
        subset_frac=0.25,
        max_input_length=512,
        max_target_length=128,
        batch_size=batch_size,
        seed=42,
    )
    tokenizer = data["tokenizer"]
    train_loader = data["train_loader"]
    test_loader = data["test_loader"]

    # Model & optimizer (re-init for each batch size)
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
    optimizer = AdamW(model.parameters(), lr=LR)

    for epoch in range(NUM_EPOCHS):
        # ----- Training -----
        model.train()
        total_train_loss = 0.0

        train_pbar = tqdm(
            train_loader,
            desc=f"BS {batch_size} | Epoch {epoch+1}/{NUM_EPOCHS} [Train]",
            leave=True,
        )

        for batch in train_pbar:
            optimizer.zero_grad()

            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            loss = outputs.loss
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            train_pbar.set_postfix({"loss": loss.item()})

        avg_train_loss = total_train_loss / len(train_loader)

        # ----- Validation -----
        model.eval()
        total_val_loss = 0.0

        with torch.no_grad():
            val_pbar = tqdm(
                test_loader,
                desc=f"BS {batch_size} | Epoch {epoch+1}/{NUM_EPOCHS} [Val]",
                leave=False,
            )
            for batch in val_pbar:
                input_ids = batch["input_ids"].to(DEVICE)
                attention_mask = batch["attention_mask"].to(DEVICE)
                labels = batch["labels"].to(DEVICE)

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                )
                loss = outputs.loss
                total_val_loss += loss.item()
                val_pbar.set_postfix({"loss": loss.item()})

        avg_val_loss = total_val_loss / len(test_loader)

        print(
            f"[BS {batch_size}] Epoch {epoch+1}/{NUM_EPOCHS} | "
            f"Train Loss: {avg_train_loss:.4f} | "
            f"Val Loss: {avg_val_loss:.4f}"
        )

    # ----- Save checkpoint for this batch size -----
    # save_dir = f"checkpoint_bs_{batch_size}"
    # model.save_pretrained(save_dir)
    # tokenizer.save_pretrained(save_dir)
    # print(f"Saved model for batch_size={batch_size} to '{save_dir}'\n")



  Starting batch size experiment: 4



BS 4 | Epoch 1/2 [Train]: 100%|██████████| 1292/1292 [04:29<00:00,  4.80it/s, loss=2.04]
                                                                                     

[BS 4] Epoch 1/2 | Train Loss: 1.9635 | Val Loss: 1.6554


BS 4 | Epoch 2/2 [Train]: 100%|██████████| 1292/1292 [04:30<00:00,  4.77it/s, loss=1.44] 
                                                                                     

[BS 4] Epoch 2/2 | Train Loss: 1.5883 | Val Loss: 1.5584
Saved model for batch_size=4 to 'checkpoint_bs_4'


  Starting batch size experiment: 6



BS 6 | Epoch 1/2 [Train]: 100%|██████████| 862/862 [07:22<00:00,  1.95it/s, loss=1.44] 
                                                                                    

[BS 6] Epoch 1/2 | Train Loss: 1.9487 | Val Loss: 1.6574


BS 6 | Epoch 2/2 [Train]: 100%|██████████| 862/862 [07:15<00:00,  1.98it/s, loss=1.88] 
                                                                                     

[BS 6] Epoch 2/2 | Train Loss: 1.5741 | Val Loss: 1.5340
Saved model for batch_size=6 to 'checkpoint_bs_6'


  Starting batch size experiment: 8



BS 8 | Epoch 1/2 [Train]: 100%|██████████| 646/646 [18:04<00:00,  1.68s/it, loss=1.7] 
                                                                                    

[BS 8] Epoch 1/2 | Train Loss: 1.9518 | Val Loss: 1.6295


BS 8 | Epoch 2/2 [Train]: 100%|██████████| 646/646 [18:30<00:00,  1.72s/it, loss=1.18] 
                                                                                    

[BS 8] Epoch 2/2 | Train Loss: 1.5779 | Val Loss: 1.5356
Saved model for batch_size=8 to 'checkpoint_bs_8'



### Stage 3 Context Length Experiments

In [None]:
# -----------------------------
# Stage 3: Context Length Sweep
# -----------------------------

import torch
from transformers import AutoModelForSeq2SeqLM
from torch.optim import AdamW
from tqdm import tqdm

MODEL_NAME = "t5-small"
LR = 1e-3           # from Stage 1
BATCH_SIZE = 4      # from Stage 2
NUM_EPOCHS = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# (max_input_length, max_target_length) configs to test
CONTEXT_CONFIGS = [
    (256, 64),
    (384, 96),
    (512, 128),
]

for max_in, max_tgt in CONTEXT_CONFIGS:
    print("\n====================================")
    print(f"  Starting context experiment: input={max_in}, target={max_tgt}")
    print("====================================\n")

    # Prepare data for this context length
    data = prepare_dataloaders(
        data_path="masked_dataset",
        model_name=MODEL_NAME,
        subset_frac=0.25,
        max_input_length=max_in,
        max_target_length=max_tgt,
        batch_size=BATCH_SIZE,
        seed=42,
    )
    tokenizer = data["tokenizer"]
    train_loader = data["train_loader"]
    test_loader = data["test_loader"]

    # Model & optimizer (re-init for each config)
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
    optimizer = AdamW(model.parameters(), lr=LR)

    for epoch in range(NUM_EPOCHS):
        # ----- Training -----
        model.train()
        total_train_loss = 0.0

        train_pbar = tqdm(
            train_loader,
            desc=f"CTX in={max_in}, tgt={max_tgt} | Epoch {epoch+1}/{NUM_EPOCHS} [Train]",
            leave=True,
        )

        for batch in train_pbar:
            optimizer.zero_grad()

            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            loss = outputs.loss
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            train_pbar.set_postfix({"loss": loss.item()})

        avg_train_loss = total_train_loss / len(train_loader)

        # ----- Validation -----
        model.eval()
        total_val_loss = 0.0

        with torch.no_grad():
            val_pbar = tqdm(
                test_loader,
                desc=f"CTX in={max_in}, tgt={max_tgt} | Epoch {epoch+1}/{NUM_EPOCHS} [Val]",
                leave=False,
            )
            for batch in val_pbar:
                input_ids = batch["input_ids"].to(DEVICE)
                attention_mask = batch["attention_mask"].to(DEVICE)
                labels = batch["labels"].to(DEVICE)

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                )
                loss = outputs.loss
                total_val_loss += loss.item()
                val_pbar.set_postfix({"loss": loss.item()})

        avg_val_loss = total_val_loss / len(test_loader)

        print(
            f"[CTX in={max_in}, tgt={max_tgt}] Epoch {epoch+1}/{NUM_EPOCHS} | "
            f"Train Loss: {avg_train_loss:.4f} | "
            f"Val Loss: {avg_val_loss:.4f}"
        )

    # ----- Save checkpoint for this context config -----
    # save_dir = f"checkpoint_ctx_in{max_in}_tgt{max_tgt}"
    # model.save_pretrained(save_dir)
    # tokenizer.save_pretrained(save_dir)
    # print(f"Saved model for context (in={max_in}, tgt={max_tgt}) to '{save_dir}'\n")



  Starting context experiment: input=256, target=64





CTX in=256, tgt=64 | Epoch 1/2 [Train]: 100%|██████████| 1292/1292 [06:12<00:00,  3.47it/s, loss=1.29] 
                                                                                                   

[CTX in=256, tgt=64] Epoch 1/2 | Train Loss: 2.0172 | Val Loss: 1.7111


CTX in=256, tgt=64 | Epoch 2/2 [Train]: 100%|██████████| 1292/1292 [06:02<00:00,  3.56it/s, loss=1.08] 
                                                                                                   

[CTX in=256, tgt=64] Epoch 2/2 | Train Loss: 1.6402 | Val Loss: 1.6247
Saved model for context (in=256, tgt=64) to 'checkpoint_ctx_in256_tgt64'


  Starting context experiment: input=384, target=96



CTX in=384, tgt=96 | Epoch 1/2 [Train]: 100%|██████████| 1292/1292 [03:26<00:00,  6.26it/s, loss=1.43] 
                                                                                                   

[CTX in=384, tgt=96] Epoch 1/2 | Train Loss: 1.9733 | Val Loss: 1.6703


CTX in=384, tgt=96 | Epoch 2/2 [Train]: 100%|██████████| 1292/1292 [03:25<00:00,  6.28it/s, loss=1.27] 
                                                                                                   

[CTX in=384, tgt=96] Epoch 2/2 | Train Loss: 1.5872 | Val Loss: 1.5662
Saved model for context (in=384, tgt=96) to 'checkpoint_ctx_in384_tgt96'


  Starting context experiment: input=512, target=128



CTX in=512, tgt=128 | Epoch 1/2 [Train]: 100%|██████████| 1292/1292 [04:31<00:00,  4.76it/s, loss=1.37] 
                                                                                                    

[CTX in=512, tgt=128] Epoch 1/2 | Train Loss: 1.9669 | Val Loss: 1.6796


CTX in=512, tgt=128 | Epoch 2/2 [Train]: 100%|██████████| 1292/1292 [04:31<00:00,  4.76it/s, loss=2.08] 
                                                                                                    

[CTX in=512, tgt=128] Epoch 2/2 | Train Loss: 1.5891 | Val Loss: 1.5520
Saved model for context (in=512, tgt=128) to 'checkpoint_ctx_in512_tgt128'



### Stage 4 Epoch Count and Weight Decay Experiments

In [None]:
# -----------------------------
# Stage 4: Epoch & Weight Decay Sweep
# -----------------------------

import torch
from transformers import AutoModelForSeq2SeqLM
from torch.optim import AdamW
from tqdm import tqdm

MODEL_NAME = "t5-small"
LR = 1e-3
BATCH_SIZE = 4
MAX_INPUT_LENGTH = 384
MAX_TARGET_LENGTH = 96
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Epoch and weight decay configs to test
EPOCH_OPTIONS = [2, 3, 4]
WEIGHT_DECAY_OPTIONS = [0.0, 0.01]

# Prepare data once for this context & batch size
data = prepare_dataloaders(
    data_path="masked_dataset",
    model_name=MODEL_NAME,
    subset_frac=0.25,
    max_input_length=MAX_INPUT_LENGTH,
    max_target_length=MAX_TARGET_LENGTH,
    batch_size=BATCH_SIZE,
    seed=42,
)
tokenizer = data["tokenizer"]
train_loader = data["train_loader"]
test_loader = data["test_loader"]

for wd in WEIGHT_DECAY_OPTIONS:
    for num_epochs in EPOCH_OPTIONS:
        print("\n====================================")
        print(f"  Epoch/WD experiment: epochs={num_epochs}, weight_decay={wd}")
        print("====================================\n")

        # Re-init model & optimizer for this config
        model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
        optimizer = AdamW(model.parameters(), lr=LR, weight_decay=wd)

        for epoch in range(num_epochs):
            # ----- Training -----
            model.train()
            total_train_loss = 0.0

            train_pbar = tqdm(
                train_loader,
                desc=f"WD={wd} | Epoch {epoch+1}/{num_epochs} [Train]",
                leave=True,
            )

            for batch in train_pbar:
                optimizer.zero_grad()

                input_ids = batch["input_ids"].to(DEVICE)
                attention_mask = batch["attention_mask"].to(DEVICE)
                labels = batch["labels"].to(DEVICE)

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                )
                loss = outputs.loss
                loss.backward()
                optimizer.step()

                total_train_loss += loss.item()
                train_pbar.set_postfix({"loss": loss.item()})

            avg_train_loss = total_train_loss / len(train_loader)

            # ----- Validation -----
            model.eval()
            total_val_loss = 0.0

            with torch.no_grad():
                val_pbar = tqdm(
                    test_loader,
                    desc=f"WD={wd} | Epoch {epoch+1}/{num_epochs} [Val]",
                    leave=False,
                )
                for batch in val_pbar:
                    input_ids = batch["input_ids"].to(DEVICE)
                    attention_mask = batch["attention_mask"].to(DEVICE)
                    labels = batch["labels"].to(DEVICE)

                    outputs = model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        labels=labels,
                    )
                    loss = outputs.loss
                    total_val_loss += loss.item()
                    val_pbar.set_postfix({"loss": loss.item()})

            avg_val_loss = total_val_loss / len(test_loader)

            print(
                f"[WD={wd}, Epoch {epoch+1}/{num_epochs}] | "
                f"Train Loss: {avg_train_loss:.4f} | "
                f"Val Loss: {avg_val_loss:.4f}"
            )

        # Save checkpoint for this (epochs, weight decay) config
        # wd_str = str(wd).replace(".", "p")
        # save_dir = f"checkpoint_in{MAX_INPUT_LENGTH}_tgt{MAX_TARGET_LENGTH}_ep{num_epochs}_wd{wd_str}"
        # model.save_pretrained(save_dir)
        # tokenizer.save_pretrained(save_dir)
        # print(f"Saved model to '{save_dir}'\n")



  Epoch/WD experiment: epochs=2, weight_decay=0.0



WD=0.0 | Epoch 1/2 [Train]: 100%|██████████| 1292/1292 [03:20<00:00,  6.43it/s, loss=1.52] 
                                                                                       

[WD=0.0, Epoch 1/2] | Train Loss: 1.9738 | Val Loss: 1.6793


WD=0.0 | Epoch 2/2 [Train]: 100%|██████████| 1292/1292 [03:21<00:00,  6.41it/s, loss=1.76] 
                                                                                       

[WD=0.0, Epoch 2/2] | Train Loss: 1.5967 | Val Loss: 1.5823
Saved model to 'checkpoint_in384_tgt96_ep2_wd0p0'


  Epoch/WD experiment: epochs=3, weight_decay=0.0



WD=0.0 | Epoch 1/3 [Train]: 100%|██████████| 1292/1292 [03:22<00:00,  6.37it/s, loss=2.04] 
                                                                                       

[WD=0.0, Epoch 1/3] | Train Loss: 1.9828 | Val Loss: 1.6752


WD=0.0 | Epoch 2/3 [Train]: 100%|██████████| 1292/1292 [03:22<00:00,  6.38it/s, loss=1.21] 
                                                                                       

[WD=0.0, Epoch 2/3] | Train Loss: 1.5988 | Val Loss: 1.5591


WD=0.0 | Epoch 3/3 [Train]: 100%|██████████| 1292/1292 [03:22<00:00,  6.39it/s, loss=1.51] 
                                                                                       

[WD=0.0, Epoch 3/3] | Train Loss: 1.4002 | Val Loss: 1.5202
Saved model to 'checkpoint_in384_tgt96_ep3_wd0p0'


  Epoch/WD experiment: epochs=4, weight_decay=0.0



WD=0.0 | Epoch 1/4 [Train]: 100%|██████████| 1292/1292 [03:23<00:00,  6.36it/s, loss=2.18] 
                                                                                       

[WD=0.0, Epoch 1/4] | Train Loss: 1.9699 | Val Loss: 1.6684


WD=0.0 | Epoch 2/4 [Train]: 100%|██████████| 1292/1292 [03:23<00:00,  6.36it/s, loss=1.37] 
                                                                                       

[WD=0.0, Epoch 2/4] | Train Loss: 1.5968 | Val Loss: 1.5774


WD=0.0 | Epoch 3/4 [Train]: 100%|██████████| 1292/1292 [03:22<00:00,  6.38it/s, loss=2.11] 
                                                                                       

[WD=0.0, Epoch 3/4] | Train Loss: 1.3989 | Val Loss: 1.5294


WD=0.0 | Epoch 4/4 [Train]: 100%|██████████| 1292/1292 [03:22<00:00,  6.37it/s, loss=2.08] 
                                                                                       

[WD=0.0, Epoch 4/4] | Train Loss: 1.2537 | Val Loss: 1.5049
Saved model to 'checkpoint_in384_tgt96_ep4_wd0p0'


  Epoch/WD experiment: epochs=2, weight_decay=0.01



WD=0.01 | Epoch 1/2 [Train]: 100%|██████████| 1292/1292 [03:25<00:00,  6.27it/s, loss=2.2]  
                                                                                        

[WD=0.01, Epoch 1/2] | Train Loss: 1.9685 | Val Loss: 1.6617


WD=0.01 | Epoch 2/2 [Train]: 100%|██████████| 1292/1292 [03:26<00:00,  6.25it/s, loss=1.89] 
                                                                                        

[WD=0.01, Epoch 2/2] | Train Loss: 1.5978 | Val Loss: 1.5823
Saved model to 'checkpoint_in384_tgt96_ep2_wd0p01'


  Epoch/WD experiment: epochs=3, weight_decay=0.01



WD=0.01 | Epoch 1/3 [Train]: 100%|██████████| 1292/1292 [03:25<00:00,  6.28it/s, loss=1.94] 
                                                                                        

[WD=0.01, Epoch 1/3] | Train Loss: 1.9844 | Val Loss: 1.6791


WD=0.01 | Epoch 2/3 [Train]: 100%|██████████| 1292/1292 [03:25<00:00,  6.29it/s, loss=1.84] 
                                                                                        

[WD=0.01, Epoch 2/3] | Train Loss: 1.5974 | Val Loss: 1.5714


WD=0.01 | Epoch 3/3 [Train]: 100%|██████████| 1292/1292 [03:26<00:00,  6.25it/s, loss=1.17] 
                                                                                        

[WD=0.01, Epoch 3/3] | Train Loss: 1.3915 | Val Loss: 1.5400
Saved model to 'checkpoint_in384_tgt96_ep3_wd0p01'


  Epoch/WD experiment: epochs=4, weight_decay=0.01



WD=0.01 | Epoch 1/4 [Train]: 100%|██████████| 1292/1292 [03:27<00:00,  6.22it/s, loss=2.36] 
                                                                                        

[WD=0.01, Epoch 1/4] | Train Loss: 1.9744 | Val Loss: 1.6690


WD=0.01 | Epoch 2/4 [Train]: 100%|██████████| 1292/1292 [03:27<00:00,  6.23it/s, loss=1.8]  
                                                                                        

[WD=0.01, Epoch 2/4] | Train Loss: 1.5941 | Val Loss: 1.5814


WD=0.01 | Epoch 3/4 [Train]: 100%|██████████| 1292/1292 [03:27<00:00,  6.22it/s, loss=0.832]
                                                                                        

[WD=0.01, Epoch 3/4] | Train Loss: 1.4007 | Val Loss: 1.5249


WD=0.01 | Epoch 4/4 [Train]: 100%|██████████| 1292/1292 [03:26<00:00,  6.26it/s, loss=1.1]  
                                                                                        

[WD=0.01, Epoch 4/4] | Train Loss: 1.2494 | Val Loss: 1.5095
Saved model to 'checkpoint_in384_tgt96_ep4_wd0p01'



### Stage 5 Scheduler Experiments

In [None]:
# -----------------------------
# Stage 5: Scheduler Comparison
# -----------------------------

import torch
from transformers import AutoModelForSeq2SeqLM, get_scheduler
from torch.optim import AdamW
from tqdm import tqdm

MODEL_NAME = "t5-small"
LR = 1e-3
BATCH_SIZE = 4
MAX_INPUT_LENGTH = 384
MAX_TARGET_LENGTH = 96
NUM_EPOCHS = 2
WEIGHT_DECAY = 0.0
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Prepare data once with the chosen config
data = prepare_dataloaders(
    data_path="masked_dataset",
    model_name=MODEL_NAME,
    subset_frac=0.25,
    max_input_length=MAX_INPUT_LENGTH,
    max_target_length=MAX_TARGET_LENGTH,
    batch_size=BATCH_SIZE,
    seed=42,
)
tokenizer = data["tokenizer"]
train_loader = data["train_loader"]
test_loader = data["test_loader"]

# Schedulers to test (compared against constant LR baseline you already ran)
SCHEDULER_TYPES = ["linear", "cosine"]

for sched_name in SCHEDULER_TYPES:
    print("\n====================================")
    print(f"  Scheduler experiment: {sched_name}")
    print("====================================\n")

    # Re-init model & optimizer
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
    optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    # Total steps = epochs * steps per epoch
    num_update_steps_per_epoch = len(train_loader)
    num_training_steps = NUM_EPOCHS * num_update_steps_per_epoch
    warmup_steps = int(0.1 * num_training_steps)  # 10% warmup

    scheduler = get_scheduler(
        name=sched_name,
        optimizer=optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=num_training_steps,
    )

    global_step = 0

    for epoch in range(NUM_EPOCHS):
        # ----- Training -----
        model.train()
        total_train_loss = 0.0

        train_pbar = tqdm(
            train_loader,
            desc=f"{sched_name} | Epoch {epoch+1}/{NUM_EPOCHS} [Train]",
            leave=True,
        )

        for batch in train_pbar:
            optimizer.zero_grad()

            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            scheduler.step()     # update LR
            global_step += 1

            total_train_loss += loss.item()
            train_pbar.set_postfix({"loss": loss.item()})

        avg_train_loss = total_train_loss / len(train_loader)

        # ----- Validation -----
        model.eval()
        total_val_loss = 0.0

        with torch.no_grad():
            val_pbar = tqdm(
                test_loader,
                desc=f"{sched_name} | Epoch {epoch+1}/{NUM_EPOCHS} [Val]",
                leave=False,
            )
            for batch in val_pbar:
                input_ids = batch["input_ids"].to(DEVICE)
                attention_mask = batch["attention_mask"].to(DEVICE)
                labels = batch["labels"].to(DEVICE)

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                )
                loss = outputs.loss
                total_val_loss += loss.item()
                val_pbar.set_postfix({"loss": loss.item()})

        avg_val_loss = total_val_loss / len(test_loader)

        print(
            f"[{sched_name}] Epoch {epoch+1}/{NUM_EPOCHS} | "
            f"Train Loss: {avg_train_loss:.4f} | "
            f"Val Loss: {avg_val_loss:.4f}"
        )

    # Save checkpoint for this scheduler
    # save_dir = f"checkpoint_sched_{sched_name}_in{MAX_INPUT_LENGTH}_tgt{MAX_TARGET_LENGTH}"
    # model.save_pretrained(save_dir)
    # tokenizer.save_pretrained(save_dir)
    # print(f"Saved model with scheduler='{sched_name}' to '{save_dir}'\n")



  Scheduler experiment: linear



linear | Epoch 1/2 [Train]: 100%|██████████| 1292/1292 [03:22<00:00,  6.39it/s, loss=1.68] 
                                                                                       

[linear] Epoch 1/2 | Train Loss: 2.0003 | Val Loss: 1.6377


linear | Epoch 2/2 [Train]: 100%|██████████| 1292/1292 [03:22<00:00,  6.37it/s, loss=1.41] 
                                                                                       

[linear] Epoch 2/2 | Train Loss: 1.5317 | Val Loss: 1.5182
Saved model with scheduler='linear' to 'checkpoint_sched_linear_in384_tgt96'


  Scheduler experiment: cosine



cosine | Epoch 1/2 [Train]: 100%|██████████| 1292/1292 [03:21<00:00,  6.41it/s, loss=2.2]  
                                                                                       

[cosine] Epoch 1/2 | Train Loss: 1.9903 | Val Loss: 1.6264


cosine | Epoch 2/2 [Train]: 100%|██████████| 1292/1292 [03:21<00:00,  6.41it/s, loss=1.16] 
                                                                                       

[cosine] Epoch 2/2 | Train Loss: 1.5040 | Val Loss: 1.5240
Saved model with scheduler='cosine' to 'checkpoint_sched_cosine_in384_tgt96'



### Stage 6 Final Model

In [None]:
import torch
from transformers import AutoModelForSeq2SeqLM, get_scheduler
from torch.optim import AdamW
from tqdm import tqdm

# -----------------------------
# Full training config
# -----------------------------
MODEL_NAME = "t5-small"
LR = 1e-3
BATCH_SIZE = 4
MAX_INPUT_LENGTH = 512
MAX_TARGET_LENGTH = 128
NUM_EPOCHS = 10
WEIGHT_DECAY = 0.0
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

EARLY_STOPPING_PATIENCE = 3   # stop if no val improvement for N epochs
WARMUP_RATIO = 0.1            # 10% warmup for linear scheduler

# -----------------------------
# Data: full dataset (no subset)
# -----------------------------
data = prepare_dataloaders(
    data_path="masked_dataset_2", # larger dataset
    model_name=MODEL_NAME,
    subset_frac=1.0,  # use 100% of training data
    max_input_length=MAX_INPUT_LENGTH,
    max_target_length=MAX_TARGET_LENGTH,
    batch_size=BATCH_SIZE,
    seed=42,
)
tokenizer = data["tokenizer"]
train_loader = data["train_loader"]
val_loader = data["test_loader"]   # treat this as validation

# -----------------------------
# Model, optimizer, scheduler
# -----------------------------
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

num_update_steps_per_epoch = len(train_loader)
num_training_steps = NUM_EPOCHS * num_update_steps_per_epoch
num_warmup_steps = int(WARMUP_RATIO * num_training_steps)

scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)

# -----------------------------
# Training loop with early stopping
# -----------------------------
best_val_loss = float("inf")
epochs_without_improvement = 0

for epoch in range(NUM_EPOCHS):
    print(f"\n========== Epoch {epoch+1}/{NUM_EPOCHS} ==========\n")

    # ----- Training -----
    model.train()
    total_train_loss = 0.0

    train_pbar = tqdm(
        train_loader,
        desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]",
        leave=True,
    )

    for batch in train_pbar:
        optimizer.zero_grad()

        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_train_loss += loss.item()
        train_pbar.set_postfix({"loss": loss.item()})

    avg_train_loss = total_train_loss / len(train_loader)

    # ----- Validation -----
    model.eval()
    total_val_loss = 0.0

    with torch.no_grad():
        val_pbar = tqdm(
            val_loader,
            desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Val]",
            leave=False,
        )
        for batch in val_pbar:
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            loss = outputs.loss
            total_val_loss += loss.item()
            val_pbar.set_postfix({"loss": loss.item()})

    avg_val_loss = total_val_loss / len(val_loader)

    print(
        f"Epoch {epoch+1}/{NUM_EPOCHS} | "
        f"Train Loss: {avg_train_loss:.4f} | "
        f"Val Loss: {avg_val_loss:.4f}"
    )

    # ----- Early stopping check -----
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_without_improvement = 0

        # Save best model so far
        best_save_dir = "full_model_best"
        model.save_pretrained(best_save_dir)
        tokenizer.save_pretrained(best_save_dir)
        print(f"✅ New best model saved to '{best_save_dir}' (val_loss={best_val_loss:.4f})")
    else:
        epochs_without_improvement += 1
        print(f"No improvement for {epochs_without_improvement} epoch(s).")

        if epochs_without_improvement >= EARLY_STOPPING_PATIENCE:
            print("⏹ Early stopping triggered.")
            break

# -----------------------------
# Save final model (last epoch)
# -----------------------------
final_save_dir = "full_model_best"
model.save_pretrained(final_save_dir)
tokenizer.save_pretrained(final_save_dir)
print(f"\nFinal model (last epoch) saved to '{final_save_dir}'")
print(f"Best validation loss achieved: {best_val_loss:.4f}")






Epoch 1/10 [Train]: 100%|██████████| 7182/7182 [24:51<00:00,  4.82it/s, loss=1.81] 
                                                                               

Epoch 1/10 | Train Loss: 1.9227 | Val Loss: 1.5481
✅ New best model saved to 'full_model_best' (val_loss=1.5481)




Epoch 2/10 [Train]: 100%|██████████| 7182/7182 [24:52<00:00,  4.81it/s, loss=1.34] 
                                                                               

Epoch 2/10 | Train Loss: 1.5571 | Val Loss: 1.3589
✅ New best model saved to 'full_model_best' (val_loss=1.3589)




Epoch 3/10 [Train]: 100%|██████████| 7182/7182 [24:49<00:00,  4.82it/s, loss=1.82] 
                                                                               

Epoch 3/10 | Train Loss: 1.3364 | Val Loss: 1.2291
✅ New best model saved to 'full_model_best' (val_loss=1.2291)




Epoch 4/10 [Train]: 100%|██████████| 7182/7182 [24:49<00:00,  4.82it/s, loss=1.69] 
                                                                               

Epoch 4/10 | Train Loss: 1.1697 | Val Loss: 1.1426
✅ New best model saved to 'full_model_best' (val_loss=1.1426)




Epoch 5/10 [Train]: 100%|██████████| 7182/7182 [24:49<00:00,  4.82it/s, loss=0.931]
                                                                               

Epoch 5/10 | Train Loss: 1.0314 | Val Loss: 1.0593
✅ New best model saved to 'full_model_best' (val_loss=1.0593)




Epoch 6/10 [Train]: 100%|██████████| 7182/7182 [24:49<00:00,  4.82it/s, loss=0.861]
                                                                               

Epoch 6/10 | Train Loss: 0.9080 | Val Loss: 1.0041
✅ New best model saved to 'full_model_best' (val_loss=1.0041)




Epoch 7/10 [Train]: 100%|██████████| 7182/7182 [24:53<00:00,  4.81it/s, loss=1.03] 
                                                                               

Epoch 7/10 | Train Loss: 0.7992 | Val Loss: 0.9577
✅ New best model saved to 'full_model_best' (val_loss=0.9577)




Epoch 8/10 [Train]: 100%|██████████| 7182/7182 [24:56<00:00,  4.80it/s, loss=0.644]
                                                                               

Epoch 8/10 | Train Loss: 0.7034 | Val Loss: 0.9107
✅ New best model saved to 'full_model_best' (val_loss=0.9107)




Epoch 9/10 [Train]: 100%|██████████| 7182/7182 [24:50<00:00,  4.82it/s, loss=1.22] 
                                                                               

Epoch 9/10 | Train Loss: 0.6247 | Val Loss: 0.8757
✅ New best model saved to 'full_model_best' (val_loss=0.8757)




Epoch 10/10 [Train]: 100%|██████████| 7182/7182 [24:51<00:00,  4.82it/s, loss=0.394]
                                                                                

Epoch 10/10 | Train Loss: 0.5662 | Val Loss: 0.8687
✅ New best model saved to 'full_model_best' (val_loss=0.8687)

Final model (last epoch) saved to 'Generator Models/full_model_best'
Best validation loss achieved: 0.8687
