# Gamma 3 Model Training with LoRA

This notebook implements the training pipeline for Google's Gamma 3 model using LoRA (Low-Rank Adaptation) for efficient fine-tuning.

Features:
1. LoRA implementation
2. Multi-metric early stopping
3. Evaluation metrics tracking

In [None]:
import os
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import (
    get_peft_model,
    LoraConfig,
    TaskType,
    prepare_model_for_kbit_training
)
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    cohen_kappa_score,
    matthews_corrcoef,
    roc_auc_score
)
from sklearn.model_selection import train_test_split

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

In [None]:
class FinancialTweetDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [None]:
class EarlyStoppingCallback:
    def __init__(self, patience=3, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_metrics = None
        self.early_stop = False
        
    def __call__(self, metrics):
        if self.best_metrics is None:
            self.best_metrics = metrics
            return False
        
        # Check if any metric improved by min_delta
        improved = False
        for metric, value in metrics.items():
            if value > self.best_metrics[metric] + self.min_delta:
                improved = True
                self.best_metrics = metrics
                break
        
        if not improved:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.counter = 0
        
        return self.early_stop

In [None]:
def calculate_metrics(predictions, labels):
    """Calculate multiple evaluation metrics"""
    pred_labels = np.argmax(predictions, axis=1)
    
    # Basic metrics
    accuracy = accuracy_score(labels, pred_labels)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, pred_labels, average='weighted'
    )
    
    # Additional metrics
    kappa = cohen_kappa_score(labels, pred_labels)
    mcc = matthews_corrcoef(labels, pred_labels)
    
    # ROC-AUC (multi-class)
    try:
        roc_auc = roc_auc_score(labels, predictions, multi_class='ovr')
    except:
        roc_auc = 0.0
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'kappa': kappa,
        'mcc': mcc,
        'roc_auc': roc_auc
    }

In [None]:
# Load labeled data
df = pd.read_csv('../data/all_labeled_tweets.csv')

# Convert labels to numeric
label_map = {
    'STRONGLY_POSITIVE': 0,
    'POSITIVE': 1,
    'NEUTRAL': 2,
    'NEGATIVE': 3,
    'STRONGLY_NEGATIVE': 4,
    'NOT_RELATED': 5,
    'UNCERTAIN': 6
}
df['label'] = df['sentiment'].map(label_map)

# Split data
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df['text'].values, df['label'].values,
    test_size=0.2, random_state=42
)

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained('google/gamma-3')

# Create datasets
train_dataset = FinancialTweetDataset(train_texts, train_labels, tokenizer)
val_dataset = FinancialTweetDataset(val_texts, val_labels, tokenizer)

In [None]:
# Initialize model with LoRA config
model = AutoModelForSequenceClassification.from_pretrained(
    'google/gamma-3',
    num_labels=7
)

# LoRA configuration
lora_config = LoraConfig(
    r=8,  # rank
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.SEQ_CLS
)

# Prepare model for LoRA
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

# Training parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
criterion = torch.nn.CrossEntropyLoss()

# Initialize early stopping
early_stopping = EarlyStoppingCallback(patience=3)

In [None]:
# Training loop
num_epochs = 10
batch_size = 16

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

best_metrics = None
best_model_state = None

for epoch in range(num_epochs):
    # Training
    model.train()
    total_loss = 0
    
    for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        optimizer.zero_grad()
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = criterion(outputs.logits, labels)
        
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    print(f"\nAverage training loss: {avg_loss:.4f}")
    
    # Validation
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            all_preds.append(outputs.logits.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    
    predictions = np.vstack(all_preds)
    true_labels = np.concatenate(all_labels)
    
    # Calculate metrics
    metrics = calculate_metrics(predictions, true_labels)
    
    print("\nValidation Metrics:")
    for metric, value in metrics.items():
        print(f"{metric}: {value:.4f}")
    
    # Early stopping check
    if early_stopping(metrics):
        print("\nEarly stopping triggered!")
        break
    
    scheduler.step()
    
    # Save best model
    if best_metrics is None or metrics['f1'] > best_metrics['f1']:
        best_metrics = metrics
        best_model_state = model.state_dict().copy()

# Save final model and metrics
output_dir = '../models/gamma3'
os.makedirs(output_dir, exist_ok=True)

# Save best model state
torch.save(best_model_state, f'{output_dir}/gamma3_lora_model.pt')

# Save metrics
metrics_df = pd.DataFrame([best_metrics])
metrics_df.to_csv(f'{output_dir}/metrics.csv', index=False)