In [1]:
import accelerate
print("Accelerate version:", accelerate.__version__)

Accelerate version: 0.26.0


In [7]:
import torch
print("Torch version:", torch.__version__)
print("MPS available:", torch.backends.mps.is_available())

Torch version: 2.5.1
MPS available: True


In [2]:
# Standard library imports
import re
import json
from pathlib import Path

# Third-party imports
import torch
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# Datasets and ML frameworks
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

# Transformers
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments
)


In [3]:
with open("gsm8k_train_flawed_plus1_final_answer.jsonl", "r") as f:
    lines = f.readlines()
    print(f"Total lines: {len(lines)}")
    print("First line:", lines[0] if lines else "No data")


Total lines: 7473
First line: {"id": 0, "question": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?", "flawed_answer": "Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 73", "label": {"verdict": "Flawed", "error_details": {"error_type": "computational_error", "erroneous_line_number": "L3", "explanation": "The final answer is too high by 1. It should be 72, not 73.", "error_in_text": "#### 72", "correction_in_text": "#### 72"}}}



In [5]:
original_dataset = load_dataset("openai/gsm8k", "main")

def load_jsonl(path):
    with open(path, 'r') as f:
        return [json.loads(line) for line in f]

flawed_train_final_answer = load_jsonl("gsm8k_train_flawed_plus1_final_answer.jsonl")
flawed_test = load_jsonl("gsm8k_test_flawed_plus1_final_answer.jsonl")

combined_train_final_answer = []

for i, ex in enumerate(original_dataset["train"]):
    combined_train_final_answer.append({
        "question": ex["question"],
        "solution": ex["answer"],
        "label": "Correct"
    })

for i, ex in enumerate(flawed_train_final_answer):
    combined_train_final_answer.append({
        "question": ex["question"],
        "solution": ex["flawed_answer"],
        "label": "Flawed"
    })

print(f"Combined training set size: {len(combined_train_final_answer)}")

combined_test_final_answer = []

for i, ex in enumerate(original_dataset["test"]):
    combined_test_final_answer.append({
        "question": ex["question"],
        "solution": ex["answer"],
        "label": "Correct"
    })

for i, ex in enumerate(flawed_test):
    combined_test_final_answer.append({
        "question": ex["question"],
        "solution": ex["flawed_answer"],
        "label": "Flawed"
    })

print(f"Combined test set size: {len(combined_test_final_answer)}")

Combined training set size: 14946
Combined test set size: 2638


In [6]:
original_dataset = load_dataset("openai/gsm8k", "main")

def load_jsonl(path):
    with open(path, 'r') as f:
        return [json.loads(line) for line in f]

flawed_train_2nd_last = load_jsonl("gsm8k_train_flawed_plus1_2nd_last.jsonl")
flawed_test = load_jsonl("gsm8k_test_flawed_plus1_2nd_last.jsonl")

combined_train_2nd_last = []

for i, ex in enumerate(original_dataset["train"]):
    combined_train_2nd_last.append({
        "question": ex["question"],
        "solution": ex["answer"],
        "label": "Correct"
    })

for i, ex in enumerate(flawed_train_2nd_last):
    combined_train_2nd_last.append({
        "question": ex["question"],
        "solution": ex["flawed_answer"],
        "label": "Flawed"
    })

print(f"Combined training set size: {len(combined_train_2nd_last)}")

combined_test_2nd_last = []

for i, ex in enumerate(original_dataset["test"]):
    combined_test_2nd_last.append({
        "question": ex["question"],
        "solution": ex["answer"],
        "label": "Correct"
    })

for i, ex in enumerate(flawed_test):
    combined_test_2nd_last.append({
        "question": ex["question"],
        "solution": ex["flawed_answer"],
        "label": "Flawed"
    })

print(f"Combined test set size: {len(combined_test_2nd_last)}")

Combined training set size: 14946
Combined test set size: 2638


In [21]:
# 📏 ANALYZE RAW TEXT LENGTHS BEFORE TOKENIZATION
print("=== RAW TEXT LENGTH ANALYSIS ===")

def analyze_text_lengths(dataset, name):
    print(f"\n{name} Dataset Analysis:")
    
    # Separate by label
    correct_examples = [ex for ex in dataset if ex['label'] == 'Correct']
    flawed_examples = [ex for ex in dataset if ex['label'] == 'Flawed']
    
    def get_text_stats(examples, label_name):
        question_lengths = [len(ex['question']) for ex in examples]
        solution_lengths = [len(ex['solution']) for ex in examples]
        # 🔥 FIX: Create combined text outside f-string
        combined_lengths = []
        for ex in examples:
            combined_text = f"Question:\n{ex['question']}\n\nSolution:\n{ex['solution']}"
            combined_lengths.append(len(combined_text))
        
        question_words = [len(ex['question'].split()) for ex in examples]
        solution_words = [len(ex['solution'].split()) for ex in examples]
        # 🔥 FIX: Create combined text outside f-string
        combined_words = []
        for ex in examples:
            combined_text = f"Question:\n{ex['question']}\n\nSolution:\n{ex['solution']}"
            combined_words.append(len(combined_text.split()))
        
        print(f"  {label_name} Examples ({len(examples)} total):")
        print(f"    Question chars: min={min(question_lengths)}, max={max(question_lengths)}, avg={sum(question_lengths)/len(question_lengths):.1f}")
        print(f"    Solution chars: min={min(solution_lengths)}, max={max(solution_lengths)}, avg={sum(solution_lengths)/len(solution_lengths):.1f}")
        print(f"    Combined chars: min={min(combined_lengths)}, max={max(combined_lengths)}, avg={sum(combined_lengths)/len(combined_lengths):.1f}")
        print(f"    Question words: min={min(question_words)}, max={max(question_words)}, avg={sum(question_words)/len(question_words):.1f}")
        print(f"    Solution words: min={min(solution_words)}, max={max(solution_words)}, avg={sum(solution_words)/len(solution_words):.1f}")
        print(f"    Combined words: min={min(combined_words)}, max={max(combined_words)}, avg={sum(combined_words)/len(combined_words):.1f}")
        
        return combined_lengths, combined_words
    
    correct_char_lengths, correct_word_lengths = get_text_stats(correct_examples, "CORRECT")
    flawed_char_lengths, flawed_word_lengths = get_text_stats(flawed_examples, "FLAWED")
    
    # Compare lengths between correct and flawed
    print(f"  COMPARISON:")
    print(f"    Avg chars - Correct: {sum(correct_char_lengths)/len(correct_char_lengths):.1f}, Flawed: {sum(flawed_char_lengths)/len(flawed_char_lengths):.1f}")
    print(f"    Avg words - Correct: {sum(correct_word_lengths)/len(correct_word_lengths):.1f}, Flawed: {sum(flawed_word_lengths)/len(flawed_word_lengths):.1f}")
    
    return correct_char_lengths + flawed_char_lengths, correct_word_lengths + flawed_word_lengths

# Analyze training data
train_char_lengths, train_word_lengths = analyze_text_lengths(combined_train_final_answer, "TRAINING")

# Analyze test data  
test_char_lengths, test_word_lengths = analyze_text_lengths(combined_test_final_answer, "TEST")

# Overall statistics
print(f"\n=== OVERALL STATISTICS ===")
print(f"Training set:")
print(f"  Character lengths: min={min(train_char_lengths)}, max={max(train_char_lengths)}, avg={sum(train_char_lengths)/len(train_char_lengths):.1f}")
print(f"  Word lengths: min={min(train_word_lengths)}, max={max(train_word_lengths)}, avg={sum(train_word_lengths)/len(train_word_lengths):.1f}")

print(f"Test set:")
print(f"  Character lengths: min={min(test_char_lengths)}, max={max(test_char_lengths)}, avg={sum(test_char_lengths)/len(test_char_lengths):.1f}")
print(f"  Word lengths: min={min(test_word_lengths)}, max={max(test_word_lengths)}, avg={sum(test_word_lengths)/len(test_word_lengths):.1f}")

# Rough token estimation (1 word ≈ 1.3 tokens for English)
print(f"\n=== ROUGH TOKEN ESTIMATION (words × 1.3) ===")
train_estimated_tokens = [w * 1.3 for w in train_word_lengths]
test_estimated_tokens = [w * 1.3 for w in test_word_lengths]

print(f"Training estimated tokens: min={min(train_estimated_tokens):.0f}, max={max(train_estimated_tokens):.0f}, avg={sum(train_estimated_tokens)/len(train_estimated_tokens):.0f}")
print(f"Test estimated tokens: min={min(test_estimated_tokens):.0f}, max={max(test_estimated_tokens):.0f}, avg={sum(test_estimated_tokens)/len(test_estimated_tokens):.0f}")

# Check how many would exceed 512 tokens
train_over_512 = sum(1 for t in train_estimated_tokens if t > 512)
test_over_512 = sum(1 for t in test_estimated_tokens if t > 512)
print(f"Estimated examples over 512 tokens: Train={train_over_512}/{len(train_estimated_tokens)} ({100*train_over_512/len(train_estimated_tokens):.1f}%), Test={test_over_512}/{len(test_estimated_tokens)} ({100*test_over_512/len(test_estimated_tokens):.1f}%)")

# Show a few examples
print(f"\n=== SAMPLE EXAMPLES ===")
print("Shortest example:")
shortest_idx = train_char_lengths.index(min(train_char_lengths))
shortest_example = combined_train_final_answer[shortest_idx]
# 🔥 FIX: Create combined text outside f-string
shortest_text = f"Question:\n{shortest_example['question']}\n\nSolution:\n{shortest_example['solution']}"
print(f"  Length: {len(shortest_text)} chars")
print(f"  Question: {shortest_example['question'][:100]}...")
print(f"  Solution: {shortest_example['solution'][:100]}...")

print("\nLongest example:")
longest_idx = train_char_lengths.index(max(train_char_lengths))
longest_example = combined_train_final_answer[longest_idx]
# 🔥 FIX: Create combined text outside f-string
longest_text = f"Question:\n{longest_example['question']}\n\nSolution:\n{longest_example['solution']}"
print(f"  Length: {len(longest_text)} chars")
print(f"  Question: {longest_example['question'][:100]}...")
print(f"  Solution: {longest_example['solution'][:100]}...")

print("=" * 50)

=== RAW TEXT LENGTH ANALYSIS ===

TRAINING Dataset Analysis:
  CORRECT Examples (7473 total):
    Question chars: min=42, max=985, avg=234.5
    Solution chars: min=50, max=1228, avg=287.5
    Combined chars: min=146, max=1711, avg=544.0
    Question words: min=9, max=183, avg=45.1
    Solution words: min=4, max=216, avg=51.7
    Combined words: min=23, max=336, avg=98.8
  FLAWED Examples (7473 total):
    Question chars: min=42, max=985, avg=234.5
    Solution chars: min=50, max=1228, avg=287.5
    Combined chars: min=146, max=1711, avg=544.0
    Question words: min=9, max=183, avg=45.1
    Solution words: min=4, max=216, avg=51.7
    Combined words: min=23, max=336, avg=98.8
  COMPARISON:
    Avg chars - Correct: 544.0, Flawed: 544.0
    Avg words - Correct: 98.8, Flawed: 98.8

TEST Dataset Analysis:
  CORRECT Examples (1319 total):
    Question chars: min=73, max=848, avg=239.9
    Solution chars: min=48, max=1070, avg=292.9
    Combined chars: min=182, max=1640, avg=554.8
    Quest

# Data

In [8]:
train_dataset = Dataset.from_list(combined_train_final_answer)
test_dataset = Dataset.from_list(combined_test_final_answer)

# train_dataset = Dataset.from_list(combined_train_2nd_last)
# test_dataset = Dataset.from_list(combined_test_2nd_last)

# AutoTokenizer

Converts text into numbers that neural networks can understand.

### Special Tokens
- `[CLS]` (101): Start of sequence
- `[SEP]` (102): Separator between segments  
- `[PAD]` (0): Padding token
- `[UNK]`: Unknown/out-of-vocabulary words

| Parameter | Purpose |
|-----------|---------|
| `truncation=True` | Cut text if > 512 tokens |
| `padding="max_length"` | Add padding to reach exactly 512 tokens |
| `max_length=512` | Set maximum sequence length |

In [None]:
# Preprocessing with tokenizer
from transformers import AutoTokenizer

model_name = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 🔥 FIX: Add padding token for GPT-2 style models
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def tokenize_fn(example):
    # Combine the question and solution into a single input string
    full_input = f"Question:\n{example['question']}\n\nSolution:\n{example['solution']}"
    
    # Tokenize the combined input with truncation, padding, and a max length of 512 tokens
    return tokenizer(full_input, truncation=True, padding="max_length", max_length=512)

train_dataset = train_dataset.map(tokenize_fn)
test_dataset = test_dataset.map(tokenize_fn)

# Label encoding
label_map = {"Correct": 0, "Flawed": 1}
train_dataset = train_dataset.map(lambda e: {"labels": label_map[e["label"]]})
test_dataset = test_dataset.map(lambda e: {"labels": label_map[e["label"]]})

Map:   0%|          | 0/14946 [00:00<?, ? examples/s]

Map:   0%|          | 0/2638 [00:00<?, ? examples/s]

Map:   0%|          | 0/14946 [00:00<?, ? examples/s]

Map:   0%|          | 0/2638 [00:00<?, ? examples/s]

In [22]:
# 🔍 DEBUG: Check what's in your datasets
print("=== DEBUGGING DATASET ===")
print("Train dataset columns:", train_dataset.column_names)
print("Test dataset columns:", test_dataset.column_names)

# Check a sample
sample = train_dataset[0]
print("Sample keys:", sample.keys())
print("Input IDs shape:", sample['input_ids'].shape if hasattr(sample['input_ids'], 'shape') else len(sample['input_ids']))
print("Labels:", sample['labels'])

# Decode the input to see what text is actually fed to model
decoded_input = tokenizer.decode(sample['input_ids'], skip_special_tokens=True)
print("Decoded input text:")
print(decoded_input)
print("=" * 50)

=== DEBUGGING DATASET ===
Train dataset columns: ['question', 'solution', 'label', 'input_ids', 'attention_mask', 'labels']
Test dataset columns: ['question', 'solution', 'label', 'input_ids', 'attention_mask', 'labels']
Sample keys: dict_keys(['question', 'solution', 'label', 'input_ids', 'attention_mask', 'labels'])
Input IDs shape: 512
Labels: 0
Decoded input text:
Question:
Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?

Solution:
Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
#### 72


# Model Setup & Training

## Fine-tuning

#### Epoch

An epoch refers to one complete pass through the entire training dataset. During an epoch, the model processes all the training samples once, updating its weights based on the computed loss. Training for multiple epochs allows the model to learn and refine its parameters iteratively.

Increasing the number of epochs allows the model to learn more but risks overfitting if too high. A balance between batch size and epochs is crucial for optimal performance.

#### Batch size

Batch size determines the number of samples processed before the model updates its weights. For example, a batch size of 8 means 8 samples (e.g., 8 question-answer pairs) are processed together in one forward and backward pass during training.

A smaller batch size uses less memory but may take longer to converge, while a larger batch size can speed up training but requires more memory. 

#### Optimizer - AdamW

AdamW is an optimizer that implements the Adam algorithm with weight decay regularization. It helps prevent overfitting by penalizing large weights, which is particularly useful in deep learning models. AdamW is widely used because it combines the benefits of Adam (adaptive learning rates) with weight decay for better generalization.

Other optimizer choices include:
- **SGD (Stochastic Gradient Descent)**: A simple optimizer with momentum and learning rate decay options.
- **RMSprop**: Designed for non-stationary objectives, often used in RNNs.
- **Adagrad**: Adapts learning rates based on parameter updates, suitable for sparse data.
- **Adadelta**: An extension of Adagrad that reduces aggressive learning rate decay.
- **Adam**: Similar to AdamW but without weight decay.
- **Nadam**: Adam with Nesterov momentum.

#### tqdm

`tqdm` is a Python library used to display progress bars for loops. It provides a visual representation of the progress of an iterable, such as a training loop or data processing, making it easier to monitor the execution time and completion percentage. It is especially useful in long-running tasks.

For example:
```python
from tqdm import tqdm
for i in tqdm(range(100)):
    # Simulate some work
    pass
```

This will display a progress bar in the console, showing the percentage completed, elapsed time, and estimated time remaining.

In [28]:
# 🔍 LEARNING RATE FINDER (FastAI Style) - FIXED FOR MPS
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm

def find_learning_rate(model, train_dataset, tokenizer, start_lr=1e-7, end_lr=1e-1, num_iter=100):
    """
    FastAI-style learning rate finder - Fixed for MPS device
    """
    # 🔥 FIX: Ensure model is on the correct device
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model = model.to(device)
    model.train()
    
    optimizer = AdamW(model.parameters(), lr=start_lr)
    
    # Create dataloader
    train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    
    # Learning rate schedule
    lr_mult = (end_lr / start_lr) ** (1.0 / num_iter)
    
    losses = []
    lrs = []
    best_loss = float('inf')
    
    for i, batch in enumerate(tqdm(train_dataloader, desc="Finding LR")):
        if i >= num_iter:
            break
        
        # 🔥 FIX: Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Forward pass
        outputs = model(input_ids=batch['input_ids'], 
                       attention_mask=batch['attention_mask'],
                       labels=batch['labels'])
        loss = outputs.loss
        
        # Track loss and learning rate
        current_lr = optimizer.param_groups[0]['lr']
        losses.append(loss.item())
        lrs.append(current_lr)
        
        # Stop if loss explodes
        if loss.item() > best_loss * 4:
            break
        if loss.item() < best_loss:
            best_loss = loss.item()
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update learning rate
        for param_group in optimizer.param_groups:
            param_group['lr'] *= lr_mult
    
    return lrs, losses

# 🔥 INITIALIZE MODEL FIRST
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

# 🔥 FIX: Set pad_token_id in model config to match tokenizer
model.config.pad_token_id = tokenizer.pad_token_id

# Run learning rate finder
print("🔍 Running Learning Rate Finder...")
lrs, losses = find_learning_rate(model, train_dataset, tokenizer)

# Plot results
plt.figure(figsize=(10, 6))
plt.plot(lrs, losses)
plt.xscale('log')
plt.xlabel('Learning Rate')
plt.ylabel('Loss')
plt.title('Learning Rate Finder')
plt.grid(True)
plt.show()

# Find optimal learning rate (usually where loss decreases fastest)
min_loss_idx = losses.index(min(losses))
optimal_lr = lrs[min_loss_idx]
print(f"📊 Suggested Learning Rate: {optimal_lr:.2e}")
print(f"📊 Consider using: {optimal_lr/10:.2e} to {optimal_lr:.2e}")

Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at distilgpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


🔍 Running Learning Rate Finder...


Finding LR:   0%|          | 0/935 [00:03<?, ?it/s]


RuntimeError: MPS backend out of memory (MPS allocated: 18.10 GB, other allocations: 26.70 MB, max allowed: 18.13 GB). Tried to allocate 24.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

### Notes

Learning rate scheduler

torch.optim.lr_scheduler

https://www.datacamp.com/tutorial/fine-tuning-large-language-models

Increase batch_size 

total_norm_util = clip_grad_norm_(model.parameters(), max_norm=float('inf')) 

weight-decay (decided based on if it is overfitting)

cross validation

In [23]:
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

training_args = TrainingArguments(
    output_dir="./fine_tuned_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    
    # 🔥 LEARNING RATE - Most Important Change
    learning_rate=5e-4,  # Increased from 5e-5 - try 1e-4 to 3e-4 range
    
    # 🔥 MORE TRAINING
    num_train_epochs=5,  # Increased from 3 - more learning time
    
    # 🔥 BATCH SIZE - Better gradient estimates
    per_device_train_batch_size=16,   # Good starting point
    per_device_eval_batch_size=64,
    
    # 🔥 LEARNING RATE SCHEDULING
    lr_scheduler_type="cosine",  # Add learning rate decay
    warmup_steps=500,           # Gradual warmup
    
    # 🔥 REGULARIZATION ADJUSTMENTS
    weight_decay=0.1,           # Increased from 0.01
    
    # 🔥 EVALUATION & EARLY STOPPING
    eval_steps=250,             # Evaluate more frequently
    load_best_model_at_end=True,  # Load best checkpoint
    metric_for_best_model="eval_f1",  # Use F1 score for model selection
    greater_is_better=True,
    
    # 🔥 IMPROVED LOGGING
    logging_steps=50,           # More frequent logging
    save_total_limit=3,         # Keep more checkpoints
    
    # Keep these settings
    report_to="none",
    remove_unused_columns=False,
    dataloader_num_workers=0,
    dataloader_drop_last=False,
    dataloader_pin_memory=False,
    skip_memory_metrics=True,
)

# Complete the compute_metrics function
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = predictions.argmax(axis=-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average="binary")
    accuracy = accuracy_score(labels, predictions)
    return {
        "accuracy": accuracy,
        "f1": f1,
        "precision": precision,
        "recall": recall,
    }

# Ensure datasets are properly formatted
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,  # Use tokenizer instead of processing_class
    compute_metrics=compute_metrics,
)

# Start training
print("Starting training with Trainer...")
trainer.train()

# Save the model
model.save_pretrained("./fine_tuned_model")
tokenizer.save_pretrained("./fine_tuned_model")
print("Training completed and model saved!")

model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]

Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at distilgpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Starting training with Trainer...


ValueError: Cannot handle batch sizes > 1 if no padding token is defined.

In [35]:
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

training_args = TrainingArguments(
    output_dir="./fine_tuned_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    
    # 🔥 LEARNING RATE - Most Important Change
    learning_rate=5e-4,  # Increased from 5e-5 - try 1e-4 to 3e-4 range
    
    # 🔥 MORE TRAINING
    num_train_epochs=10,  # Increased from 3 - more learning time
    
    # 🔥 BATCH SIZE - Better gradient estimates
    per_device_train_batch_size=32,  # Increased from 8
    per_device_eval_batch_size=32,   # Increased from 8
    
    # 🔥 LEARNING RATE SCHEDULING
    lr_scheduler_type="cosine",  # Add learning rate decay
    warmup_steps=500,           # Gradual warmup
    
    # 🔥 REGULARIZATION ADJUSTMENTS
    weight_decay=0.1,           # Increased from 0.01
    
    # 🔥 EVALUATION & EARLY STOPPING
    eval_steps=250,             # Evaluate more frequently
    load_best_model_at_end=True,  # Load best checkpoint
    metric_for_best_model="eval_f1",  # Use F1 score for model selection
    greater_is_better=True,
    
    # 🔥 IMPROVED LOGGING
    logging_steps=50,           # More frequent logging
    save_total_limit=3,         # Keep more checkpoints
    
    # Keep these settings
    report_to="none",
    remove_unused_columns=False,
    dataloader_num_workers=0,
    dataloader_drop_last=False,
    dataloader_pin_memory=False,
    skip_memory_metrics=True,
)

# Complete the compute_metrics function
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = predictions.argmax(axis=-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average="binary")
    accuracy = accuracy_score(labels, predictions)
    return {
        "accuracy": accuracy,
        "f1": f1,
        "precision": precision,
        "recall": recall,
    }

# Ensure datasets are properly formatted
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,  # Use tokenizer instead of processing_class
    compute_metrics=compute_metrics,
)

# Start training
print("Starting training with Trainer...")
trainer.train()

# Save the model
model.save_pretrained("./fine_tuned_model")
tokenizer.save_pretrained("./fine_tuned_model")
print("Training completed and model saved!")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Starting training with Trainer...


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.6965,0.693432,0.5,0.006033,0.5,0.003035
2,0.6655,0.61262,0.682094,0.705965,0.656658,0.763278
3,0.5352,0.512835,0.688923,0.5629,0.946237,0.400607
4,0.494,0.51343,0.739757,0.755175,0.712938,0.802731
5,0.4646,0.438555,0.745827,0.686036,0.897059,0.555387
6,0.4385,0.410214,0.748103,0.779841,0.69258,0.892261
7,0.3655,0.395925,0.774659,0.727273,0.92093,0.60091
8,0.3466,0.418233,0.763278,0.783333,0.722151,0.855842
9,0.3406,0.406329,0.773141,0.746825,0.844828,0.669196
10,0.3236,0.412946,0.76176,0.764264,0.756315,0.772382


Training completed and model saved!
