In [None]:
import torch.nn.functional as F
from datasets import load_dataset
from huggingface_hub import notebook_login, create_repo, upload_folder
import pandas as pd
from datasets import Dataset
import torch
import time
import os
import json
import psutil
import numpy as np
import math
from datetime import datetime
from sklearn.metrics import f1_score, accuracy_score, classification_report
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
import torch.nn as nn
from transformers import AutoModel
from torchcrf import CRF
import random
from collections import Counter
from datasets import concatenate_datasets
import nltk
from nltk.corpus import wordnet
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig
from accelerate import Accelerator
import traceback

# Download required NLTK data
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)
nltk.download('punkt_tab')

class Config:
    # Model hyperparameters
    bert_model_name = 'nlpaueb/legal-bert-base-uncased'
    lstm_hidden_size = 200
    context_hidden_size = 200
    max_num_sentences = 32
    max_length = 64
    dropout_rate = 0.4
    gamma = 2.0
    weight_decay = 1e-5

    # FSSA parameters - reduced ranks and increased sparsity
    fssa_linear_rank = 1   # Reduced from 4
    fssa_emb_rank = 1      # Reduced from 2
    fssa_linear_sparsity = 0.99  # Increased sparsity from 0.92
    fssa_emb_sparsity = 0.995
    fssa_block_size = 32

    # New parameters for size reduction
    context_intermediate_size = 380  # Reduced from 2000
    emission_hidden_size = 64        # Reduced from 384

    # QLoRA parameters
    qlora_r = 8
    qlora_alpha = 32
    qlora_dropout = 0.05
    qlora_target_modules = ["query", "key", "value", "dense"]
    qlora_compute_dtype = torch.bfloat16

    # Training parameters
    epochs = 8
    batch_size = 4
    learning_rate = 5e-5
    warmup_ratio = 0.1

    # Paths and repo info
    hf_repo_id = "Please enter your huggingface user id here/hierarchical-legal-model-improved-fssa-qlora"
    output_dir = "./improved_hierarchical_model_fssa_qlora"
    save_checkpoint = "best_model"

In [None]:
# Login to Hugging Face Hub
notebook_login()

In [None]:
import pandas as pd
from datasets import load_dataset, Dataset
from huggingface_hub import notebook_login

# Load datasets
splits = {
    'train': 'data/train-00000-of-00001-bb0092e0d8549337.parquet',
    'dev': 'data/dev-00000-of-00001-af55705c75623915.parquet',
    'test': 'data/test-00000-of-00001-2526ab833e27e0ee.parquet'
}

train_df = pd.read_parquet("hf://datasets/opennyaiorg/InRhetoricalRoles/" + splits["train"])
dev_df = pd.read_parquet("hf://datasets/opennyaiorg/InRhetoricalRoles/" + splits["dev"])
test_df = pd.read_parquet("hf://datasets/opennyaiorg/InRhetoricalRoles/" + splits["test"])

# Convert to Hugging Face Datasets
train_ds = Dataset.from_pandas(train_df)
dev_ds = Dataset.from_pandas(dev_df)
test_ds = Dataset.from_pandas(test_df)

def get_synonyms(word):
    """Get synonyms for data augmentation"""
    synonyms = set()
    for syn in wordnet.synsets(word):
        for lemma in syn.lemmas():
            synonym = lemma.name().replace("_", " ").lower()
            if synonym != word and len(synonym) > 1:
                synonyms.add(synonym)
    return list(synonyms) if synonyms else [word]

def augment_sentence(sent):
    """Enhanced data augmentation with synonym replacement"""
    if len(sent.strip()) == 0:
        return sent

    words = nltk.word_tokenize(sent)
    if len(words) < 2:
        return sent

    # Choose an augmentation technique
    choice = random.choices([1, 2, 3, 4], weights=[0.4, 0.2, 0.3, 0.1])[0]

    if choice == 1:  # Synonym replacement
        idx = random.randint(0, len(words)-1)
        synonyms = get_synonyms(words[idx])
        if synonyms and len(synonyms) > 0:
            words[idx] = random.choice(synonyms)
    elif choice == 2 and len(words) >= 4:  # Random deletion
        del_idx = random.randint(0, len(words)-1)
        del words[del_idx]
    elif choice == 3 and len(words) >= 3:  # Word swap
        i, j = random.sample(range(len(words)), 2)
        words[i], words[j] = words[j], words[i]
    elif choice == 4:  # Random insertion
        idx = random.randint(0, len(words)-1)
        synonyms = get_synonyms(words[idx])
        if synonyms and len(synonyms) > 0:
            words.insert(idx, random.choice(synonyms))

    return " ".join(words)

def prepare_hierarchical_datasets(train_ds, dev_ds, test_ds):
    """Optimized dataset preparation with efficient sampling"""
    print("Preprocessing datasets with efficient sampling...")

    def get_spans_and_labels(example):
        spans = []
        labels = []
        if example.get('annotations') and len(example['annotations']) > 0:
            if example['annotations'][0].get('result'):
                for ann in example['annotations'][0]['result']:
                    if ann.get('value') and ann['value'].get('text') and ann['value'].get('labels'):
                        spans.append(ann['value']['text'])
                        labels.append(ann['value']['labels'][0])
        return {'spans': spans, 'labels': labels}

    # Apply to all splits
    train_ds = train_ds.map(get_spans_and_labels)
    dev_ds = dev_ds.map(get_spans_and_labels)
    test_ds = test_ds.map(get_spans_and_labels)

    # Filter out empty examples
    train_ds = train_ds.filter(lambda x: len(x['spans']) > 0)
    dev_ds = dev_ds.filter(lambda x: len(x['spans']) > 0)
    test_ds = test_ds.filter(lambda x: len(x['spans']) > 0)

    def prepare_for_hierarchical(example):
        return {'text': example['spans'], 'label': example['labels']}

    train_hier = train_ds.map(prepare_for_hierarchical)
    dev_hier = dev_ds.map(prepare_for_hierarchical)
    test_hier = test_ds.map(prepare_for_hierarchical)

    # Build label mapping
    all_labels = set()
    for example in train_hier:
        all_labels.update(example['label'])
    label_list = sorted(list(all_labels))
    label2id = {l: i for i, l in enumerate(label_list)}
    id2label = {i: l for i, l in enumerate(label_list)}

    print(f"Identified {len(label_list)} labels: {label_list}")

    # Efficient data augmentation (only for rare classes)
    def augment_dataset(dataset, label2id):
        label_counts = Counter()
        for example in dataset:
            label_counts.update(example['label'])

        # Identify rare classes
        rare_classes = [label for label, count in label_counts.items() if count < 10]
        print(f"Rare classes (<10 samples): {rare_classes}")

        augmented_examples = []
        for example in dataset:
            labels = example['label']
            copies = 1

            if any(label in rare_classes for label in labels):
                copies = 3  # Moderate augmentation for rare classes

            # Create augmented copies
            for _ in range(copies):
                augmented_text = [
                    augment_sentence(sent) if random.random() < 0.5 and sent.strip() else sent
                    for sent in example['text']
                ]
                augmented_examples.append({
                    'text': augmented_text,
                    'label': labels.copy()
                })

        print(f"Added {len(augmented_examples)} augmented examples")
        return concatenate_datasets([dataset, Dataset.from_list(augmented_examples)])

    # Apply augmentation to training set
    train_hier = augment_dataset(train_hier, label2id)

    # Efficient class balancing
    label_counts = Counter()
    for example in train_hier:
        label_counts.update(example['label'])

    # Calculate target counts - cap at 1000 samples per class
    max_count = min(1000, max(label_counts.values()))
    balanced_examples = []
    for label in label_list:
        # Collect examples containing this label
        class_examples = [ex for ex in train_hier if label in ex['label']]
        current_count = label_counts[label]

        # Calculate how many to add
        needed = max(0, max_count - current_count)

        # If we need to add examples, duplicate existing ones
        if needed > 0 and class_examples:
            # Add existing examples
            balanced_examples.extend(class_examples)
            # Add duplicated examples
            duplicates = min(needed, len(class_examples))
            balanced_examples.extend(random.choices(class_examples, k=duplicates))
        else:
            balanced_examples.extend(class_examples)

    # Create balanced dataset
    train_hier = Dataset.from_list(balanced_examples)
    print(f"Created balanced dataset with {len(train_hier)} examples")

    return train_hier, dev_hier, test_hier, label2id, id2label, label_list

class PositionalEncoding(nn.Module):
    """Positional embeddings for sentence order"""
    def __init__(self, d_model, max_len=Config.max_num_sentences):
        super().__init__()
        self.position_emb = nn.Embedding(max_len, d_model)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device).unsqueeze(0)
        return x + self.position_emb(positions)

class TransformerContextLayer(nn.Module):
    """Transformer-based context modeling with reduced FFN size"""
    def __init__(self, d_model, nhead=4, dim_feedforward=Config.context_intermediate_size, dropout=0.1):
        super().__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)

    def forward(self, x):
        return self.transformer_encoder(x)

class EmissionLayer(nn.Module):
    """Emission layer with reduced hidden size"""
    def __init__(self, input_size, num_labels, dropout=0.2):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_size, Config.emission_hidden_size),  # Reduced hidden size
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(Config.emission_hidden_size, num_labels)   # Final layer
        )

    def forward(self, x):
        return self.mlp(x)

class FocalCRF(nn.Module):
    """CRF with focal loss for class imbalance"""
    def __init__(self, num_tags, gamma=Config.gamma):
        super().__init__()
        self.crf = CRF(num_tags, batch_first=True)
        self.gamma = gamma

    def forward(self, emissions, tags, mask, class_weights=None):
        # Compute standard CRF loss
        log_likelihood = self.crf(emissions, tags, mask=mask, reduction='none')

        # Apply focal loss transformation
        pt = torch.exp(log_likelihood)
        focal_loss = -((1 - pt) ** self.gamma) * log_likelihood

        # Apply class weights if provided
        if class_weights is not None:
            weights_per_tag = class_weights[tags]  # (batch_size, seq_len)
            valid_counts = mask.sum(dim=1)  # (batch_size,)
            weights_per_sequence = weights_per_tag.sum(dim=1) / valid_counts
            focal_loss = focal_loss * weights_per_sequence

        return focal_loss.mean()

    def decode(self, emissions, mask):
        return self.crf.decode(emissions, mask=mask)

# ====================================================
# FSSA: Factorized Structured Sparse Adaptation Layers
# ====================================================
class FSSALayer(nn.Module):
    """Factorized Structured Sparse Adaptation layer with higher sparsity"""
    def __init__(self, original_layer, rank=Config.fssa_linear_rank,
                 sparsity=Config.fssa_linear_sparsity, block_size=Config.fssa_block_size):
        super().__init__()
        self.original_layer = original_layer
        self.rank = rank
        self.sparsity = sparsity
        self.block_size = block_size

        # Freeze original parameters
        for param in self.original_layer.parameters():
            param.requires_grad = False

        in_features = original_layer.in_features
        out_features = original_layer.out_features

        # Factorized sparse adaptation parameters
        self.A = nn.Parameter(torch.zeros(rank, in_features))
        self.B = nn.Parameter(torch.zeros(out_features, rank))

        # Structured sparsity mask
        self.mask = self.create_sparsity_mask(out_features, in_features)

        # Initialize parameters
        self.init_parameters()

    def create_sparsity_mask(self, rows, cols):
        # Create block-sparse mask
        row_blocks = (rows + self.block_size - 1) // self.block_size
        col_blocks = (cols + self.block_size - 1) // self.block_size

        # Generate sparse block pattern
        num_blocks = row_blocks * col_blocks
        num_active = int(num_blocks * (1 - self.sparsity))
        active_blocks = random.sample(range(num_blocks), num_active)

        # Create full mask
        mask = torch.zeros(rows, cols)
        for block_idx in active_blocks:
            i = block_idx // col_blocks
            j = block_idx % col_blocks
            row_start = i * self.block_size
            col_start = j * self.block_size
            row_end = min(row_start + self.block_size, rows)
            col_end = min(col_start + self.block_size, cols)
            mask[row_start:row_end, col_start:col_end] = 1

        return mask

    def init_parameters(self):
        # Initialize with small random values
        nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.B, a=math.sqrt(5))

    def forward(self, x):
        # Original layer output
        base_output = self.original_layer(x)

        # Factorized sparse adaptation
        adapted = (self.B @ self.A) * self.mask.to(self.B.device)

        # FIX: Remove the .t() transpose - use adapted directly
        delta_output = F.linear(x, adapted)  # CHANGED: removed .t()

        return base_output + delta_output

class FSSAEmbedding(nn.Module):
    """Factorized Structured Sparse Adaptation for embeddings"""
    def __init__(self, original_embedding, rank=Config.fssa_emb_rank,
                 sparsity=Config.fssa_emb_sparsity):
        super().__init__()
        self.original_embedding = original_embedding
        self.rank = rank
        self.sparsity = sparsity

        # Freeze original parameters
        for param in self.original_embedding.parameters():
            param.requires_grad = False

        num_embeddings = original_embedding.num_embeddings
        embedding_dim = original_embedding.embedding_dim

        # Factorized adaptation parameters
        self.U = nn.Parameter(torch.zeros(num_embeddings, rank))
        self.V = nn.Parameter(torch.zeros(rank, embedding_dim))

        # Sparsity mask
        self.mask = (torch.rand(num_embeddings, rank) > sparsity).float()

        # Initialize parameters
        self.init_parameters()

    def init_parameters(self):
        nn.init.normal_(self.U, mean=0, std=0.02)
        nn.init.normal_(self.V, mean=0, std=0.02)

    def forward(self, input_ids):
        base_embeds = self.original_embedding(input_ids)

        # Compute adaptation (sparse factorized)
        adapted = (self.U * self.mask.to(self.U.device)) @ self.V
        # Lookup adaptation vectors
        delta_embeds = F.embedding(input_ids, adapted)

        return base_embeds + delta_embeds

def apply_fssa_to_hierarchical(model):
    """Apply FSSA only to hierarchical components (skip BERT)"""
    # Apply FSSA to positional embeddings
    model.position_enc.position_emb = FSSAEmbedding(
        model.position_enc.position_emb
    )
    
    # Apply FSSA to transformer context encoder
    for name, module in model.context_encoder.named_children():
        if isinstance(module, nn.Linear):
            setattr(model.context_encoder, name, FSSALayer(module))
        else:
            # Recursively apply to submodules
            for sub_name, sub_module in module.named_children():
                if isinstance(sub_module, nn.Linear):
                    setattr(module, sub_name, FSSALayer(sub_module))
    
    # Apply FSSA to emission layer
    for i, layer in enumerate(model.emission.mlp):
        if isinstance(layer, nn.Linear):
            model.emission.mlp[i] = FSSALayer(layer)
    
    return model

def count_parameters(model):
    """Count trainable parameters"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def print_parameter_counts(model):
    """Print parameter statistics"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = count_parameters(model)

    print("="*50)
    print(f"TOTAL PARAMETERS: {total_params/1e6:.2f}M")
    print(f"TRAINABLE PARAMETERS: {trainable_params/1e6:.2f}M")
    print("="*50)

    return trainable_params

class ImprovedHSLNModel(nn.Module):
    """Hybrid QLoRA (BERT) + FSSA (Hierarchical) Model"""
    def __init__(self, num_labels, class_weights=None):
        super().__init__()
        self.class_weights = class_weights

        # Configure 4-bit quantization for BERT
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=Config.qlora_compute_dtype
        )
        
        # Load BERT with QLoRA
        self.bert = AutoModel.from_pretrained(
            Config.bert_model_name,
            quantization_config=bnb_config
        )
        
        # Prepare model for k-bit training
        self.bert = prepare_model_for_kbit_training(self.bert)
        
        # Apply QLoRA adapters
        lora_config = LoraConfig(
            r=Config.qlora_r,
            lora_alpha=Config.qlora_alpha,
            target_modules=Config.qlora_target_modules,
            lora_dropout=Config.qlora_dropout,
            bias="none",
            task_type="FEATURE_EXTRACTION"
        )
        self.bert = get_peft_model(self.bert, lora_config)

        # Hierarchical components
        self.sent_dropout = nn.Dropout(Config.dropout_rate)
        self.sent_layer_norm = nn.LayerNorm(self.bert.config.hidden_size)
        self.position_enc = PositionalEncoding(self.bert.config.hidden_size)
        self.context_encoder = TransformerContextLayer(
            d_model=self.bert.config.hidden_size
        )
        self.emission = EmissionLayer(
            input_size=self.bert.config.hidden_size,
            num_labels=num_labels
        )
        
        # Apply FSSA to hierarchical components (skip BERT)
        self = apply_fssa_to_hierarchical(self)
        
        # CRF layer with focal loss
        self.crf = FocalCRF(num_labels, gamma=Config.gamma)
        
        # Print entire model parameter counts
        print_parameter_counts(self)

    def forward(self, input_ids, attention_mask, labels=None):
        batch_size, num_sent, seq_len = input_ids.shape

        # Process each sentence
        flat_input_ids = input_ids.view(-1, seq_len)
        flat_mask = attention_mask.view(-1, seq_len)

        bert_out = self.bert(
            input_ids=flat_input_ids,
            attention_mask=flat_mask
        ).last_hidden_state

        # Sentence embeddings (CLS token)
        sent_emb = bert_out[:, 0, :]
        sent_emb = self.sent_layer_norm(sent_emb)
        sent_emb = self.sent_dropout(sent_emb)
        sent_emb = sent_emb.view(batch_size, num_sent, -1)

        # Context modeling
        sent_emb = self.position_enc(sent_emb)
        context_emb = self.context_encoder(sent_emb)

        # Emissions
        emissions = self.emission(context_emb)
        mask = attention_mask[:, :, 0] > 0  # Sentence-level mask

        if labels is not None:
            loss = self.crf(
                emissions,
                labels,
                mask=mask,
                class_weights=self.class_weights
            )
            return {"loss": loss, "emissions": emissions}
        return {"emissions": emissions}

def tokenize_datasets(train_hier, dev_hier, test_hier, label2id):
    """Tokenize datasets for hierarchical input"""
    print("Tokenizing datasets...")
    tokenizer = AutoTokenizer.from_pretrained(Config.bert_model_name)

    def tokenize_document(example):
        sentences = example['text']
        labels = example['label']
        sentences = sentences[:Config.max_num_sentences]
        labels = labels[:Config.max_num_sentences]
        pad_len = Config.max_num_sentences - len(sentences)
        sentences += [""] * pad_len
        labels += [list(label2id.keys())[0]] * pad_len

        input_ids = []
        attention_mask = []
        for sent in sentences:
            encoded = tokenizer(
                sent,
                padding="max_length",
                truncation=True,
                max_length=Config.max_length,
                return_tensors="pt"
            )
            input_ids.append(encoded["input_ids"].squeeze(0))
            attention_mask.append(encoded["attention_mask"].squeeze(0))

        input_ids = torch.stack(input_ids)
        attention_mask = torch.stack(attention_mask)
        label_ids = torch.tensor([label2id[l] for l in labels])

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": label_ids
        }

    train_tokenized = train_hier.map(tokenize_document)
    dev_tokenized = dev_hier.map(tokenize_document)
    test_tokenized = test_hier.map(tokenize_document)

    return train_tokenized, dev_tokenized, test_tokenized, tokenizer

class HierarchicalDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        return {
            "input_ids": item["input_ids"],
            "attention_mask": item["attention_mask"],
            "labels": item["labels"]
        }

def collate_fn(batch):
    def ensure_tensor(x):
        return torch.tensor(x) if not isinstance(x, torch.Tensor) else x

    input_ids = torch.stack([ensure_tensor(item["input_ids"]) for item in batch])
    attention_mask = torch.stack([ensure_tensor(item["attention_mask"]) for item in batch])
    labels = torch.stack([ensure_tensor(item["labels"]) for item in batch])

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

def create_data_loaders(train_tokenized, dev_tokenized, test_tokenized):
    train_loader = DataLoader(
        HierarchicalDataset(train_tokenized),
        batch_size=Config.batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )
    dev_loader = DataLoader(
        HierarchicalDataset(dev_tokenized),
        batch_size=Config.batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )
    test_loader = DataLoader(
        HierarchicalDataset(test_tokenized),
        batch_size=Config.batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )
    return train_loader, dev_loader, test_loader

def compute_class_weights(train_hier, label2id):
    """Compute linear class weights without squaring"""
    label_counts = {label: 0 for label in label2id}
    for example in train_hier:
        for label in example['label']:
            label_counts[label] += 1

    total_samples = sum(label_counts.values())
    weights = [
        total_samples / (label_counts[label] + 1e-5)  # Linear weighting
        for label in label2id
    ]
    weights = torch.tensor(weights, dtype=torch.float32)
    return weights / weights.min()  # Normalize

def train_model(model, train_loader, dev_loader, optimizer, device, epochs, label_list):
    """Training loop with learning rate warmup"""
    print(f"\n{'='*30} TRAINING STARTED {'='*30}")
    print(f"Training on: {device}")
    print(f"Number of epochs: {epochs}")
    print(f"Batch size: {Config.batch_size}")
    print(f"Learning rate: {Config.learning_rate}")
    print(f"Total batches: {len(train_loader)}")

    model.train()
    best_dev_f1 = 0
    training_start = time.time()
    history = []

    # Warmup scheduling
    total_steps = epochs * len(train_loader)
    warmup_steps = int(total_steps * Config.warmup_ratio)

    for epoch in range(epochs):
        epoch_start = time.time()
        total_loss = 0
        all_preds, all_labels = [], []
        step = 0

        # Training
        model.train()
        for batch_idx, batch in enumerate(train_loader):
            current_step = epoch * len(train_loader) + batch_idx

            # Learning rate warmup
            if current_step < warmup_steps:
                lr_scale = min(1.0, float(current_step + 1) / warmup_steps)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = Config.learning_rate * lr_scale

            optimizer.zero_grad()

            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs["loss"]
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()

            total_loss += loss.item()
            emissions = outputs["emissions"]
            preds = model.crf.decode(emissions, mask=attention_mask[:, :, 0] > 0)

            # Flatten predictions and labels
            flat_labels = labels.cpu().numpy().flatten()
            flat_preds = np.array([p for seq in preds for p in seq] +
                                  [0]*(len(flat_labels) - sum(len(seq) for seq in preds)))

            all_preds.extend(flat_preds)
            all_labels.extend(flat_labels)

            if (batch_idx + 1) % 10 == 0:
                avg_loss = total_loss/(batch_idx+1)
                print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} | "
                      f"Loss: {loss.item():.4f} | Avg Loss: {avg_loss:.4f} | LR: {optimizer.param_groups[0]['lr']:.2e}")

        epoch_time = time.time() - epoch_start
        train_f1 = f1_score(all_labels, all_preds, average="weighted")
        train_acc = accuracy_score(all_labels, all_preds)

        # Validation
        model.eval()
        dev_metrics = evaluate_metrics(model, dev_loader, device, label_list)
        dev_f1 = dev_metrics["weighted_f1"]

        history.append({
            'epoch': epoch+1,
            'train_loss': total_loss/len(train_loader),
            'train_f1': train_f1,
            'dev_f1': dev_f1
        })

        print(f"\nEpoch {epoch+1} completed in {epoch_time:.2f}s")
        print(f"Train Loss: {total_loss/len(train_loader):.4f} | F1: {train_f1:.4f}")
        print(f"Dev Weighted F1: {dev_f1:.4f} | Macro F1: {dev_metrics['macro_f1']:.4f}")

        # Save best model without quantization parameters
        if dev_f1 > best_dev_f1:
            best_dev_f1 = dev_f1
            
            # Create filtered state_dict
            filtered_state_dict = {}
            for name, param in model.state_dict().items():
                if not any(q in name for q in ['quant_state', 'absmax', 'quant_map', 'nested_absmax', 'nested_quant_map']):
                    filtered_state_dict[name] = param
            
            torch.save(filtered_state_dict, os.path.join(Config.output_dir, "best_model.pt"))
            print(f"New best model saved with F1: {dev_f1:.4f}")

    training_time = time.time() - training_start
    print(f"Training completed in {training_time:.2f} seconds")
    print(f"{'='*30} TRAINING COMPLETED {'='*30}\n")

    # Load best model with safe loading
    try:
        best_model_path = os.path.join(Config.output_dir, "best_model.pt")
        if os.path.exists(best_model_path):
            checkpoint = torch.load(best_model_path, map_location=device)
            
            # Filter out any quantization parameters that might have been saved
            filtered_checkpoint = {}
            for name, param in checkpoint.items():
                if not any(q in name for q in ['quant_state', 'absmax', 'quant_map', 'nested_absmax', 'nested_quant_map']):
                    filtered_checkpoint[name] = param
            
            model.load_state_dict(filtered_checkpoint, strict=False)
            print("Loaded best model checkpoint")
        else:
            print("Best model checkpoint not found. Using final model state.")
    except Exception as e:
        print(f"Error loading best model: {str(e)}")
        print("Using final model state for evaluation")

    return model, history

def evaluate_metrics(model, dataloader, device, label_list):
    """Comprehensive evaluation with padding masking"""
    try:
        model.eval()
        all_preds, all_labels = [], []
        total_time = 0
        n_docs = 0
        n_sentences = 0
        eval_start = time.time()

        with torch.no_grad():
            for batch in dataloader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                mask = attention_mask[:, :, 0] > 0

                start = time.time()
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                end = time.time()

                emissions = outputs["emissions"]
                preds = model.crf.decode(emissions, mask=mask)

                for i in range(len(labels)):
                    seq_len = mask[i].sum().item()
                    all_preds.extend(preds[i][:seq_len])
                    all_labels.extend(labels[i][:seq_len].cpu().numpy())

                total_time += (end - start)
                n_docs += input_ids.shape[0]
                n_sentences += mask.sum().item()

        eval_end = time.time()
        eval_time = eval_end - eval_start

        labels_for_report = list(range(len(label_list)))
        target_names = label_list

        report = classification_report(
            all_labels, all_preds,
            labels=labels_for_report,
            target_names=target_names,
            output_dict=True,
            zero_division=0
        )

        macro_f1 = report['macro avg']['f1-score']
        weighted_f1 = report['weighted avg']['f1-score']
        accuracy = accuracy_score(all_labels, all_preds)
        per_label_f1 = {
            label: report[label]['f1-score']
            for label in label_list
        }

        latency_doc = (total_time / n_docs) * 1000 if n_docs else 0
        latency_sent = (total_time / n_sentences) * 1000 if n_sentences else 0

        return {
            "macro_f1": macro_f1,
            "weighted_f1": weighted_f1,
            "accuracy": accuracy,
            "per_label_f1": per_label_f1,
            "latency_ms_per_doc": latency_doc,
            "latency_ms_per_sentence": latency_sent,
            "eval_time_seconds": eval_time,
            "num_samples": n_docs
        }

    except Exception as e:
        print(f"Evaluation failed: {str(e)}")
        raise

def get_model_size_mb(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    return (param_size + buffer_size) / (1024 ** 2)

def get_memory_footprint():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / (1024 ** 3)

def save_checkpoint(model, tokenizer, metrics, label2id, config, save_dir, checkpoint_name):
    """Save model checkpoint without quantization state parameters"""
    try:
        checkpoint_path = os.path.join(save_dir, checkpoint_name)
        os.makedirs(checkpoint_path, exist_ok=True)
        
        # Create filtered state_dict without quantization parameters
        filtered_state_dict = {}
        for name, param in model.state_dict().items():
            if not any(q in name for q in ['quant_state', 'absmax', 'quant_map', 'nested_absmax', 'nested_quant_map']):
                filtered_state_dict[name] = param
                
        torch.save(filtered_state_dict, os.path.join(checkpoint_path, "pytorch_model.bin"))
        tokenizer.save_pretrained(checkpoint_path)

        with open(os.path.join(checkpoint_path, "config.json"), "w") as f:
            json.dump({
                "label2id": label2id,
                "id2label": {i: l for l, i in label2id.items()},
                "model_config": {
                    "bert_model_name": config.bert_model_name,
                    "lstm_hidden_size": config.lstm_hidden_size,
                    "context_hidden_size": config.context_hidden_size,
                    "max_num_sentences": config.max_num_sentences,
                    "max_length": config.max_length,
                    "dropout_rate": config.dropout_rate,
                    "gamma": config.gamma,
                    "fssa_params": {
                        "linear_rank": config.fssa_linear_rank,
                        "emb_rank": config.fssa_emb_rank,
                        "linear_sparsity": config.fssa_linear_sparsity,
                        "emb_sparsity": config.fssa_emb_sparsity,
                        "block_size": config.fssa_block_size
                    },
                    "qlora_params": {
                        "r": config.qlora_r,
                        "alpha": config.qlora_alpha,
                        "dropout": config.qlora_dropout
                    }
                }
            }, f, indent=2)

        with open(os.path.join(checkpoint_path, "metrics.json"), "w") as f:
            json.dump(metrics, f, indent=2)

        print(f"Checkpoint saved to {checkpoint_path}")
        return checkpoint_path

    except Exception as e:
        print(f"Error saving checkpoint: {str(e)}")
        raise

def upload_to_huggingface(save_path, repo_id):
    """Upload model to Hugging Face Hub"""
    try:
        create_repo(repo_id, exist_ok=True, token=True)
        upload_folder(
            repo_id=repo_id,
            folder_path=save_path,
            commit_message="Hybrid QLoRA + FSSA Hierarchical Legal Model",
            repo_type="model",
            token=True
        )
        print(f"Model uploaded to https://huggingface.co/{repo_id}")
    except Exception as e:
        print(f"Upload failed: {str(e)}")

def main():
    """End-to-end training pipeline"""
    try:
        start_time = time.time()
        accelerator = Accelerator()
        device = accelerator.device
        print(f"\n{'='*50}")
        print(f"STARTING HYBRID QLORA + FSSA HIERARCHICAL LEGAL MODEL TRAINING")
        print(f"Timestamp: {datetime.now().isoformat()}")
        print(f"Device: {device}")
        print(f"Accelerator: {accelerator.state}")
        print(f"{'='*50}\n")

        os.makedirs(Config.output_dir, exist_ok=True)

        train_hier, dev_hier, test_hier, label2id, id2label, label_list = prepare_hierarchical_datasets(
            train_ds, dev_ds, test_ds
        )

        class_weights = compute_class_weights(train_hier, label2id).to(device)
        print(f"Class weights: {class_weights.cpu().numpy()}")

        train_tokenized, dev_tokenized, test_tokenized, tokenizer = tokenize_datasets(
            train_hier, dev_hier, test_hier, label2id
        )
        train_loader, dev_loader, test_loader = create_data_loaders(
            train_tokenized, dev_tokenized, test_tokenized
        )

        model = ImprovedHSLNModel(
            num_labels=len(label2id),
            class_weights=class_weights
        ).to(device)

        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=Config.learning_rate,
            weight_decay=Config.weight_decay
        )

        # Enable gradient checkpointing for memory efficiency
        model.bert.gradient_checkpointing_enable()
        model.bert.enable_input_require_grads()

        # Prepare with accelerator
        model, optimizer, train_loader, dev_loader = accelerator.prepare(
            model, optimizer, train_loader, dev_loader
        )

        model, history = train_model(
            model, train_loader, dev_loader,
            optimizer, device, Config.epochs, label_list
        )

        print("\nEvaluating on training set...")
        train_metrics = evaluate_metrics(model, train_loader, device, label_list)

        print("\nEvaluating on dev set...")
        dev_metrics = evaluate_metrics(model, dev_loader, device, label_list)

        metrics = {
            "train": train_metrics,
            "dev": dev_metrics,
            "overfitting_gap": train_metrics["weighted_f1"] - dev_metrics["weighted_f1"],
            "model_size_mb": get_model_size_mb(model),
            "training_memory_footprint_gb": get_memory_footprint(),
            "label2id": label2id,
            "id2label": id2label,
            "training_time": time.time() - start_time,
            "training_history": history
        }

        checkpoint_path = save_checkpoint(
            model, tokenizer, metrics, label2id, Config,
            Config.output_dir, Config.save_checkpoint
        )
        upload_to_huggingface(checkpoint_path, Config.hf_repo_id)

        print("\n==== FINAL METRICS ====")
        print(f"Train Weighted F1: {train_metrics['weighted_f1']:.4f}")
        print(f"Dev Weighted F1:   {dev_metrics['weighted_f1']:.4f}")
        print(f"Overfitting Gap:   {metrics['overfitting_gap']:.4f}")
        print(f"Model Size:        {metrics['model_size_mb']:.2f} MB")
        print(f"Training Time:     {metrics['training_time']:.2f} seconds")
        print(f"Saved to:          {checkpoint_path}")

        print("\nPer-class F1 Scores (Dev Set):")
        for label, score in dev_metrics["per_label_f1"].items():
            print(f"{label}: {score:.4f}")

        print(f"\n{'='*50}")
        print("TRAINING PIPELINE COMPLETED SUCCESSFULLY")
        print(f"{'='*50}")

        return metrics

    except Exception as e:
        print(f"\n{'!'*50}")
        print("PIPELINE FAILED!")
        print(f"Error: {str(e)}")
        import traceback
        traceback.print_exc()
        with open(os.path.join(Config.output_dir, "error_log.txt"), "w") as f:
            f.write(f"Pipeline error at {datetime.now()}\n")
            f.write(str(e))
            f.write("\n\nTraceback:\n")
            f.write(traceback.format_exc())
        return None

if __name__ == "__main__":
    main()