In [None]:
from huggingface_hub import notebook_login, create_repo, upload_folder

In [None]:
# Login to Hugging Face Hub
notebook_login()

# LegalBERTHSLN -Lora


In [None]:
# -*- coding: utf-8 -*-
"""eval_lora_model.py

Evaluates a pre-trained LoRA model on the test set and saves metrics to JSON.
"""

from huggingface_hub import hf_hub_download
import pandas as pd
from datasets import Dataset
import torch
import time
import os
import json
import numpy as np
from datetime import datetime
from sklearn.metrics import f1_score, accuracy_score, classification_report, confusion_matrix, precision_score, recall_score
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import torch.nn as nn
from transformers import AutoModel
from torchcrf import CRF
from peft import LoraConfig, get_peft_model
import random

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

class Config:
    # Update these for LoRA model
    bert_model_name = 'nlpaueb/legal-bert-base-uncased'
    context_hidden_size = 256
    max_num_sentences = 48
    max_length = 128
    dropout_rate = 0.4
    gamma = 2.0
    batch_size = 64

    # LoRA specific parameters
    lora_r = 128
    lora_alpha = 64
    lora_dropout = 0.15
    lora_target_modules = ["query", "value", "key"]

    # Paths and repo info
    output_dir = "./lora_evaluation_results"
    hf_repo_id = "Please enter your huggingface user id here/hierarchical-legal-model-lora-final"  # Your LoRA model repo


# Load datasets
splits = {
    'test': 'data/test-00000-of-00001-2526ab833e27e0ee.parquet'
}

test_df = pd.read_parquet("hf://datasets/opennyaiorg/InRhetoricalRoles/" + splits["test"])
test_ds = Dataset.from_pandas(test_df)

def get_spans_and_labels(example):
    """Extract spans and labels from example"""
    spans = []
    labels = []
    if example.get('annotations') and len(example['annotations']) > 0:
        if example['annotations'][0].get('result'):
            for ann in example['annotations'][0]['result']:
                if ann.get('value') and ann['value'].get('text') and ann['value'].get('labels'):
                    spans.append(ann['value']['text'])
                    labels.append(ann['value']['labels'][0])
    return {'spans': spans, 'labels': labels}

def preprocess_single_dataset(dataset, label2id):
    """Preprocess dataset for evaluation"""
    dataset = dataset.map(get_spans_and_labels)
    dataset = dataset.filter(lambda x: len(x['spans']) > 0)
    dataset = dataset.map(lambda x: {'text': x['spans'], 'label': x['labels']})
    return dataset

def tokenize_single_dataset(dataset, tokenizer, label2id):
    """Tokenize dataset for hierarchical input"""
    def tokenize_document(example):
        sentences = example['text']
        labels = example['label']
        sentences = sentences[:Config.max_num_sentences]
        labels = labels[:Config.max_num_sentences]
        pad_len = Config.max_num_sentences - len(sentences)
        sentences += [""] * pad_len
        labels += [list(label2id.keys())[0]] * pad_len

        input_ids = []
        attention_mask = []
        for sent in sentences:
            encoded = tokenizer(
                sent,
                padding="max_length",
                truncation=True,
                max_length=Config.max_length,
                return_tensors="pt"
            )
            input_ids.append(encoded["input_ids"].squeeze(0))
            attention_mask.append(encoded["attention_mask"].squeeze(0))

        input_ids = torch.stack(input_ids)
        attention_mask = torch.stack(attention_mask)
        label_ids = torch.tensor([label2id[l] for l in labels])

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": label_ids
        }

    tokenized_ds = dataset.map(tokenize_document)
    # Set format to PyTorch tensors
    tokenized_ds.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
    return tokenized_ds

class PositionalEncoding(nn.Module):
    """Positional embeddings for sentence order"""
    def __init__(self, d_model, max_len=Config.max_num_sentences):
        super().__init__()
        self.dropout = nn.Dropout(Config.dropout_rate)
        self.position_emb = nn.Embedding(max_len, d_model)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device).unsqueeze(0)
        return self.dropout(x + self.position_emb(positions))

class TransformerContextLayer(nn.Module):
    """Transformer-based context modeling"""
    def __init__(self, d_model, nhead=8, dim_feedforward=1024, dropout=0.2):
        super().__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
            activation='gelu'
        )
        self.transformer_encoder = nn.TransformerEncoder(
            self.encoder_layer,
            num_layers=2
        )

    def forward(self, x):
        return self.transformer_encoder(x)

class EmissionLayer(nn.Module):
    """Enhanced emission layer with residual connection"""
    def __init__(self, input_size, num_labels, dropout=0.3):
        super().__init__()
        self.linear1 = nn.Linear(input_size, input_size*2)
        self.linear2 = nn.Linear(input_size*2, num_labels)
        self.dropout = nn.Dropout(dropout)
        self.gelu = nn.GELU()
        self.layer_norm = nn.LayerNorm(input_size*2)
        self.residual_proj = nn.Linear(input_size, num_labels)

    def forward(self, x):
        residual = x
        x = self.linear1(x)
        x = self.layer_norm(x)
        x = self.gelu(x)
        x = self.dropout(x)
        return self.linear2(x) + self.residual_proj(residual)

class FocalCRF(nn.Module):
    """CRF with focal loss for class imbalance"""
    def __init__(self, num_tags, gamma=Config.gamma):
        super().__init__()
        self.crf = CRF(num_tags, batch_first=True)
        self.gamma = gamma

    def forward(self, emissions, tags, mask, class_weights=None):
        # Compute standard CRF loss
        log_likelihood = self.crf(emissions, tags, mask=mask, reduction='none')

        # Apply focal loss transformation
        pt = torch.exp(log_likelihood)
        focal_loss = -((1 - pt) ** self.gamma) * log_likelihood

        # Apply class weights if provided
        if class_weights is not None:
            weights_per_tag = class_weights[tags]  # (batch_size, seq_len)
            valid_counts = mask.sum(dim=1)  # (batch_size,)
            weights_per_sequence = weights_per_tag.sum(dim=1) / valid_counts
            focal_loss = focal_loss * weights_per_sequence

        return focal_loss.mean()

    def decode(self, emissions, mask):
        return self.crf.decode(emissions, mask=mask)

class LoRAHSLNModel(nn.Module):
    """Hierarchical model with LoRA integration"""
    def __init__(self, num_labels, model_config):
        super().__init__()
        # Update config from model_config
        for key, value in model_config.items():
            setattr(Config, key, value)

        # Sentence encoding with LoRA
        base_bert = AutoModel.from_pretrained(Config.bert_model_name)

        # Configure LoRA
        lora_config = LoraConfig(
            r=Config.lora_r,
            lora_alpha=Config.lora_alpha,
            target_modules=Config.lora_target_modules,
            lora_dropout=Config.lora_dropout,
            bias="none"
        )
        self.bert = get_peft_model(base_bert, lora_config)

        # Enhanced feature extraction
        self.sent_dropout = nn.Dropout(Config.dropout_rate)
        self.sent_layer_norm = nn.LayerNorm(self.bert.config.hidden_size)
        self.sent_projection = nn.Linear(
            self.bert.config.hidden_size,
            self.bert.config.hidden_size
        )

        # Context encoding
        self.position_enc = PositionalEncoding(self.bert.config.hidden_size)
        self.context_encoder = TransformerContextLayer(
            d_model=self.bert.config.hidden_size
        )

        # Emission layer
        self.emission = EmissionLayer(
            input_size=self.bert.config.hidden_size,
            num_labels=num_labels
        )

        # CRF layer with focal loss
        self.crf = FocalCRF(num_labels, gamma=Config.gamma)

    def forward(self, input_ids, attention_mask, labels=None):
        batch_size, num_sent, seq_len = input_ids.shape

        # Process each sentence
        flat_input_ids = input_ids.view(-1, seq_len)
        flat_mask = attention_mask.view(-1, seq_len)

        bert_out = self.bert(
            input_ids=flat_input_ids,
            attention_mask=flat_mask
        ).last_hidden_state

        # Sentence embeddings (CLS token)
        sent_emb = bert_out[:, 0, :]
        sent_emb = self.sent_projection(sent_emb)
        sent_emb = self.sent_layer_norm(sent_emb)
        sent_emb = self.sent_dropout(sent_emb)
        sent_emb = sent_emb.view(batch_size, num_sent, -1)

        # Context modeling
        sent_emb = self.position_enc(sent_emb)
        context_emb = self.context_encoder(sent_emb)

        # Emissions
        emissions = self.emission(context_emb)
        mask = attention_mask[:, :, 0] > 0  # Sentence-level mask

        if labels is not None:
            loss = self.crf(
                emissions,
                labels,
                mask=mask
            )
            return {"loss": loss, "emissions": emissions}
        return {"emissions": emissions}

def plot_confusion_matrix(y_true, y_pred, labels, label_names, output_dir):
    """Generate and save a confusion matrix with raw counts and improved visibility"""
    # Compute confusion matrix with raw counts
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    
    # Create a figure with optimized dimensions
    plt.figure(figsize=(10, 8))
    ax = plt.subplot()
    
    # Set font properties for better visibility
    label_font = {'size': 10, 'weight': 'bold'}
    value_font = {'size': 9, 'weight': 'bold'}
    
    # Plot heatmap with bold integer annotations
    sns.heatmap(
        cm, 
        annot=True, 
        fmt='d',  # Integer format
        cmap="Blues", 
        xticklabels=label_names,
        yticklabels=label_names,
        cbar=True,
        ax=ax,
        annot_kws=value_font  # Apply bold to cell values
    )
    
    # Set labels and title with bold font
    plt.title('Confusion Matrix (Raw Counts)', fontsize=12, fontweight='bold')
    plt.ylabel('True Label', fontdict=label_font)
    plt.xlabel('Predicted Label', fontdict=label_font)
    
    # Make axis labels bold and rotated for readability
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right', fontsize=9, fontweight='bold')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=9, fontweight='bold')
    
    # Adjust colorbar font
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=9)
    
    # Save the plot with tight layout
    plot_path = os.path.join(output_dir, "confusion_matrix.png")
    plt.tight_layout()
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    return plot_path

def evaluate_metrics(model, dataloader, device, label_list):
    """Comprehensive evaluation with padding masking"""
    model.eval()
    all_preds, all_labels = [], []
    total_time = 0
    n_docs = 0
    n_sentences = 0
    eval_start = time.time()

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            mask = attention_mask[:, :, 0] > 0

            start = time.time()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            end = time.time()

            emissions = outputs["emissions"]
            preds = model.crf.decode(emissions, mask=mask)

            for i in range(len(labels)):
                seq_len = mask[i].sum().item()
                all_preds.extend(preds[i][:seq_len])
                all_labels.extend(labels[i][:seq_len].cpu().numpy())

            total_time += (end - start)
            n_docs += input_ids.shape[0]
            n_sentences += mask.sum().item()

    eval_end = time.time()
    eval_time = eval_end - eval_start

    labels_for_report = list(range(len(label_list)))
    target_names = label_list

    # Generate classification report
    report = classification_report(
        all_labels, all_preds,
        labels=labels_for_report,
        target_names=target_names,
        output_dict=True,
        zero_division=0
    )

    # Calculate additional metrics
    macro_precision = precision_score(
        all_labels, all_preds, average='macro', zero_division=0
    )
    macro_recall = recall_score(
        all_labels, all_preds, average='macro', zero_division=0
    )
    weighted_precision = precision_score(
        all_labels, all_preds, average='weighted', zero_division=0
    )
    weighted_recall = recall_score(
        all_labels, all_preds, average='weighted', zero_division=0
    )
    per_label_precision = {
        label: report[label]['precision']
        for label in label_list
    }
    per_label_recall = {
        label: report[label]['recall']
        for label in label_list
    }

    macro_f1 = report['macro avg']['f1-score']
    weighted_f1 = report['weighted avg']['f1-score']
    accuracy = accuracy_score(all_labels, all_preds)
    per_label_f1 = {
        label: report[label]['f1-score']
        for label in label_list
    }

    latency_doc = (total_time / n_docs) * 1000 if n_docs else 0
    latency_sent = (total_time / n_sentences) * 1000 if n_sentences else 0

    return {
        "macro_f1": macro_f1,
        "weighted_f1": weighted_f1,
        "accuracy": accuracy,
        "per_label_f1": per_label_f1,
        "macro_precision": macro_precision,
        "macro_recall": macro_recall,
        "weighted_precision": weighted_precision,
        "weighted_recall": weighted_recall,
        "per_label_precision": per_label_precision,
        "per_label_recall": per_label_recall,
        "latency_ms_per_doc": latency_doc,
        "latency_ms_per_sentence": latency_sent,
        "eval_time_seconds": eval_time,
        "num_samples": n_docs,
        "all_labels": all_labels,
        "all_preds": all_preds
    }

def evaluate_lora_model():
    """Loads a pre-trained LoRA model and evaluates it on the test set"""
    try:
        start_time = time.time()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"\n{'='*50}")
        print(f"STARTING LoRA MODEL EVALUATION")
        print(f"Timestamp: {datetime.now().isoformat()}")
        print(f"Device: {device}")
        print(f"Model Repo: {Config.hf_repo_id}")
        print(f"Seed: {SEED}")
        print(f"{'='*50}\n")

        os.makedirs(Config.output_dir, exist_ok=True)

        # 1. Download model artifacts from Hugging Face Hub
        print("Downloading LoRA model artifacts from Hugging Face Hub...")
        config_path = hf_hub_download(
            repo_id=Config.hf_repo_id,
            filename="config.json"
        )
        model_path = hf_hub_download(
            repo_id=Config.hf_repo_id,
            filename="pytorch_model.bin"
        )

        # Load configuration
        with open(config_path, 'r') as f:
            saved_config = json.load(f)

        label2id = saved_config['label2id']
        id2label = saved_config['id2label']
        model_config = saved_config['model_config']

        # Convert id2label keys to integers
        id2label = {int(k): v for k, v in id2label.items()}

        # Create label list sorted by ID
        label_list = [id2label[i] for i in range(len(id2label))]

        # Update Config with model parameters
        for key, value in model_config.items():
            setattr(Config, key, value)

        print(f"Loaded configuration for LoRA model: {Config.bert_model_name}")
        print(f"LoRA r: {Config.lora_r}, alpha: {Config.lora_alpha}")
        print(f"Number of labels: {len(label_list)}")

        # 2. Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(Config.bert_model_name)

        # 3. Preprocess test dataset
        print("\nPreprocessing test dataset...")
        test_hier = preprocess_single_dataset(test_ds, label2id)
        print(f"Test examples after preprocessing: {len(test_hier)}")

        # 4. Tokenize test dataset
        print("Tokenizing test dataset...")
        test_tokenized = tokenize_single_dataset(test_hier, tokenizer, label2id)

        # 5. Create data loader
        test_loader = DataLoader(
            test_tokenized,
            batch_size=Config.batch_size,
            shuffle=False
        )
        print(f"Test batches: {len(test_loader)}")

        # 6. Initialize LoRA model
        print("\nInitializing LoRA model...")
        model = LoRAHSLNModel(
            num_labels=len(label2id),
            model_config=model_config
        ).to(device)

        # 7. Load model weights
        model.load_state_dict(torch.load(model_path, map_location=device))
        print("LoRA model weights loaded successfully")

        # 8. Evaluate on test set
        print("\nEvaluating LoRA model on test set...")
        test_metrics = evaluate_metrics(model, test_loader, device, label_list)

        # 9. Generate confusion matrix
        print("Generating confusion matrix with optimized visibility...")
        labels_idx = list(range(len(label_list)))
        cm_path = plot_confusion_matrix(
            test_metrics["all_labels"],
            test_metrics["all_preds"],
            labels=labels_idx,
            label_names=label_list,  # Pass label names for axes
            output_dir=Config.output_dir
        )
        print(f"Confusion matrix saved to: {cm_path}")

        # Remove large arrays to save memory
        del test_metrics["all_labels"]
        del test_metrics["all_preds"]

        # 10. Save test metrics
        metrics_path = os.path.join(Config.output_dir, "lora_test_metrics.json")
        with open(metrics_path, 'w') as f:
            json.dump(test_metrics, f, indent=2)

        print(f"\n{'='*30} LoRA TEST RESULTS {'='*30}")
        print(f"Weighted F1:      {test_metrics['weighted_f1']:.4f}")
        print(f"Macro F1:         {test_metrics['macro_f1']:.4f}")
        print(f"Accuracy:         {test_metrics['accuracy']:.4f}")
        print(f"Weighted Precision: {test_metrics['weighted_precision']:.4f}")
        print(f"Weighted Recall:    {test_metrics['weighted_recall']:.4f}")
        print(f"Macro Precision:  {test_metrics['macro_precision']:.4f}")
        print(f"Macro Recall:     {test_metrics['macro_recall']:.4f}")
        print(f"Latency:          {test_metrics['latency_ms_per_doc']:.2f} ms/doc")
        print(f"Evaluation time:  {test_metrics['eval_time_seconds']:.2f} seconds")
        print(f"Metrics saved to: {metrics_path}")

        print("\nPer-class Metrics:")
        for label in label_list:
            print(f"  {label}:")
            print(f"    F1:       {test_metrics['per_label_f1'][label]:.4f}")
            print(f"    Precision: {test_metrics['per_label_precision'][label]:.4f}")
            print(f"    Recall:    {test_metrics['per_label_recall'][label]:.4f}")

        total_time = time.time() - start_time
        print(f"\nEvaluation completed in {total_time:.2f} seconds")
        print(f"{'='*30} EVALUATION COMPLETE {'='*30}")

        return test_metrics

    except Exception as e:
        print(f"\n{'!'*50}")
        print("EVALUATION FAILED!")
        print(f"Error: {str(e)}")
        import traceback
        traceback.print_exc()
        with open(os.path.join(Config.output_dir, "lora_error_log.txt"), "w") as f:
            f.write(f"Evaluation error at {datetime.now()}\n")
            f.write(str(e))
            f.write(traceback.format_exc())
        return None

if __name__ == "__main__":
    evaluate_lora_model()

# LegalBERTHSLN-Qlora


In [None]:
# -*- coding: utf-8 -*-
"""eval_qlora_model.py

Evaluates a pre-trained QLoRA hierarchical model on the test set with optimized confusion matrix.
"""

from huggingface_hub import hf_hub_download
import pandas as pd
from datasets import Dataset
import torch
import time
import os
import json
import numpy as np
from datetime import datetime
from sklearn.metrics import f1_score, accuracy_score, classification_report, confusion_matrix, precision_score, recall_score
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, Dataset as TorchDataset
from transformers import AutoTokenizer, BitsAndBytesConfig
import torch.nn as nn
from transformers import AutoModel
from torchcrf import CRF
from peft import PeftModel, PeftConfig
import random

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

class Config:
    # Will be updated from model's config
    bert_model_name = 'nlpaueb/legal-bert-base-uncased'
    context_hidden_size = 256
    max_num_sentences = 48
    max_length = 128
    dropout_rate = 0.4
    gamma = 2.0
    batch_size = 4  # Reduced for QLoRA memory constraints

    # QLoRA specific parameters
    lora_r = 128
    lora_alpha = 64
    lora_dropout = 0.15
    lora_target_modules = ["query", "value", "key"]
    qlora_quant_type = "nf4"
    use_double_quant = True

    # Paths and repo info
    output_dir = "./qlora_evaluation_results"
    hf_repo_id = "Please enter your huggingface user id here/hierarchical-legal-model-qlora-final"  # Your QLoRA model repo


# Load datasets
splits = {
    'test': 'data/test-00000-of-00001-2526ab833e27e0ee.parquet'
}

test_df = pd.read_parquet("hf://datasets/opennyaiorg/InRhetoricalRoles/" + splits["test"])
test_ds = Dataset.from_pandas(test_df)

def get_spans_and_labels(example):
    """Extract spans and labels from example"""
    spans = []
    labels = []
    if example.get('annotations') and len(example['annotations']) > 0:
        if example['annotations'][0].get('result'):
            for ann in example['annotations'][0]['result']:
                if ann.get('value') and ann['value'].get('text') and ann['value'].get('labels'):
                    spans.append(ann['value']['text'])
                    labels.append(ann['value']['labels'][0])
    return {'spans': spans, 'labels': labels}

def preprocess_single_dataset(dataset, label2id):
    """Preprocess dataset for evaluation"""
    dataset = dataset.map(get_spans_and_labels)
    dataset = dataset.filter(lambda x: len(x['spans']) > 0)
    return dataset.map(lambda x: {'text': x['spans'], 'label': x['labels']})

def tokenize_document(example, tokenizer, label2id):
    """Tokenize document with batch processing"""
    sentences = example['text'][:Config.max_num_sentences]
    labels = example['label'][:Config.max_num_sentences]
    pad_len = Config.max_num_sentences - len(sentences)
    
    # Tokenize all sentences at once
    tokenized = tokenizer(
        sentences,
        padding="max_length",
        truncation=True,
        max_length=Config.max_length,
        return_tensors="pt",
        return_attention_mask=True
    )
    
    # Pad to max sentences
    pad_shape = (pad_len, Config.max_length)
    input_ids = torch.cat([
        tokenized["input_ids"],
        torch.full(pad_shape, tokenizer.pad_token_id, dtype=torch.long)
    ])
    attention_mask = torch.cat([
        tokenized["attention_mask"],
        torch.zeros(pad_shape, dtype=torch.long)
    ])
    labels += [list(label2id.keys())[0]] * pad_len

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": torch.tensor([label2id[l] for l in labels], dtype=torch.long)
    }

class PositionalEncoding(nn.Module):
    """Positional embeddings for sentence order"""
    def __init__(self, d_model, max_len=Config.max_num_sentences):
        super().__init__()
        self.dropout = nn.Dropout(Config.dropout_rate)
        self.position_emb = nn.Embedding(max_len, d_model)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device).unsqueeze(0)
        return self.dropout(x + self.position_emb(positions))

class TransformerContextLayer(nn.Module):
    """Transformer-based context modeling"""
    def __init__(self, d_model, nhead=8, dim_feedforward=1024, dropout=0.2):
        super().__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
            activation='gelu'
        )
        self.transformer_encoder = nn.TransformerEncoder(
            self.encoder_layer,
            num_layers=2
        )

    def forward(self, x):
        return self.transformer_encoder(x)

class EmissionLayer(nn.Module):
    """Enhanced emission layer with residual connection"""
    def __init__(self, input_size, num_labels, dropout=0.3):
        super().__init__()
        self.linear1 = nn.Linear(input_size, input_size*2)
        self.linear2 = nn.Linear(input_size*2, num_labels)
        self.dropout = nn.Dropout(dropout)
        self.gelu = nn.GELU()
        self.layer_norm = nn.LayerNorm(input_size*2)
        self.residual_proj = nn.Linear(input_size, num_labels)

    def forward(self, x):
        residual = x
        x = self.linear1(x)
        x = self.layer_norm(x)
        x = self.gelu(x)
        x = self.dropout(x)
        return self.linear2(x) + self.residual_proj(residual)

class FocalCRF(nn.Module):
    """CRF with focal loss for class imbalance"""
    def __init__(self, num_tags, gamma=Config.gamma):
        super().__init__()
        self.crf = CRF(num_tags, batch_first=True)
        self.gamma = gamma

    def forward(self, emissions, tags, mask, class_weights=None):
        # Compute standard CRF loss
        log_likelihood = self.crf(emissions, tags, mask=mask, reduction='none')

        # Apply focal loss transformation
        pt = torch.exp(log_likelihood)
        focal_loss = -((1 - pt) ** self.gamma) * log_likelihood

        # Apply class weights if provided
        if class_weights is not None:
            weights_per_tag = class_weights[tags]  # (batch_size, seq_len)
            valid_counts = mask.sum(dim=1)  # (batch_size,)
            weights_per_sequence = weights_per_tag.sum(dim=1) / valid_counts
            focal_loss = focal_loss * weights_per_sequence

        return focal_loss.mean()

    def decode(self, emissions, mask):
        return self.crf.decode(emissions, mask=mask)

class QLoRAHSLNModel(nn.Module):
    """Hierarchical model with QLoRA integration"""
    def __init__(self, num_labels, model_config):
        super().__init__()
        # Update config from model_config
        for key, value in model_config.items():
            setattr(Config, key, value)

        # Configure BitsAndBytes for 4-bit quantization
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type=Config.qlora_quant_type,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=Config.use_double_quant,
        )

        # Load base model with 4-bit quantization
        base_bert = AutoModel.from_pretrained(
            Config.bert_model_name,
            quantization_config=bnb_config,
            device_map="auto"
        )
        
        # Load QLoRA adapter
        self.bert = PeftModel.from_pretrained(
            base_bert,
            Config.hf_repo_id,
            device_map="auto"
        )

        # Enhanced feature extraction
        self.sent_dropout = nn.Dropout(Config.dropout_rate)
        self.sent_layer_norm = nn.LayerNorm(self.bert.config.hidden_size)
        self.sent_projection = nn.Linear(
            self.bert.config.hidden_size,
            self.bert.config.hidden_size
        )

        # Context encoding
        self.position_enc = PositionalEncoding(self.bert.config.hidden_size)
        self.context_encoder = TransformerContextLayer(
            d_model=self.bert.config.hidden_size
        )

        # Emission layer
        self.emission = EmissionLayer(
            input_size=self.bert.config.hidden_size,
            num_labels=num_labels
        )

        # CRF layer with focal loss
        self.crf = FocalCRF(num_labels, gamma=Config.gamma)
        
        # Get device from base model
        self.device = next(self.bert.parameters()).device
        print(f"Model initialized on device: {self.device}")
        
        # Move custom layers to same device as base model
        self.sent_dropout = self.sent_dropout.to(self.device)
        self.sent_layer_norm = self.sent_layer_norm.to(self.device)
        self.sent_projection = self.sent_projection.to(self.device)
        self.position_enc = self.position_enc.to(self.device)
        self.context_encoder = self.context_encoder.to(self.device)
        self.emission = self.emission.to(self.device)
        self.crf = self.crf.to(self.device)

    def forward(self, input_ids, attention_mask, labels=None):
        batch_size, num_sent, seq_len = input_ids.shape
        
        # Move inputs to same device as model
        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)
        if labels is not None:
            labels = labels.to(self.device)

        # Process each sentence
        flat_input_ids = input_ids.view(-1, seq_len)
        flat_mask = attention_mask.view(-1, seq_len)

        bert_out = self.bert(
            input_ids=flat_input_ids,
            attention_mask=flat_mask
        ).last_hidden_state

        # Convert to float32 for custom layers
        bert_out = bert_out.to(torch.float32)

        # Sentence embeddings (CLS token)
        sent_emb = bert_out[:, 0, :]
        sent_emb = self.sent_projection(sent_emb)
        sent_emb = self.sent_layer_norm(sent_emb)
        sent_emb = self.sent_dropout(sent_emb)
        sent_emb = sent_emb.view(batch_size, num_sent, -1)

        # Context modeling
        sent_emb = self.position_enc(sent_emb)
        context_emb = self.context_encoder(sent_emb)

        # Emissions
        emissions = self.emission(context_emb)
        mask = attention_mask[:, :, 0] > 0  # Sentence-level mask

        if labels is not None:
            loss = self.crf(
                emissions,
                labels,
                mask=mask
            )
            return {"loss": loss, "emissions": emissions}
        return {"emissions": emissions}

def plot_confusion_matrix(y_true, y_pred, labels, label_names, output_dir):
    """Generate publication-quality confusion matrix with raw counts"""
    # Compute confusion matrix with raw counts
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    
    # Create figure with optimized layout for multiple matrices
    plt.figure(figsize=(8, 6), dpi=300)  # Slightly smaller size for fitting multiple
    ax = plt.subplot()
    
    # Set font properties - make everything bold for better visibility
    title_font = {'size': 12, 'weight': 'bold'}  # Smaller title size
    label_font = {'size': 10, 'weight': 'bold'}  # Slightly smaller but bold
    tick_font = {'size': 8, 'weight': 'bold'}    # Smaller tick size but bold
    annot_font = {'size': 7, 'weight': 'bold'}   # Smaller annotation size but bold
    
    # Create heatmap with raw counts and bold annotations
    sns.heatmap(
        cm,
        annot=True,
        fmt='d',  # Integer format for raw counts
        cmap="Blues",
        cbar=True,
        cbar_kws={'label': 'Count', 'shrink': 0.75},  # Shrink colorbar for smaller plot
        ax=ax,
        annot_kws=annot_font  # Apply bold to cell values
    )
    
    # Set labels and title with bold font
    ax.set_title('Confusion Matrix (Raw Counts)', fontdict=title_font)
    ax.set_xlabel('Predicted Label', fontdict=label_font)
    ax.set_ylabel('True Label', fontdict=label_font)
    
    # Set tick labels with rotation
    ax.set_xticks(np.arange(len(label_names)))
    ax.set_xticklabels(label_names, rotation=45, ha="right", fontdict=tick_font)
    ax.set_yticks(np.arange(len(label_names)))
    ax.set_yticklabels(label_names, rotation=0, fontdict=tick_font)
    
    # Adjust colorbar
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=8)  # Smaller colorbar labels
    cbar.ax.set_ylabel('Count', fontsize=10, fontweight='bold')  # Bold colorbar label
    
    # Tight layout to maximize space
    plt.tight_layout()
    
    # Save plot
    plot_path = os.path.join(output_dir, "confusion_matrix.png")
    plt.savefig(plot_path, bbox_inches='tight', pad_inches=0.05)  # Reduced padding
    plt.close()
    
    return plot_path

class HierarchicalDataset(TorchDataset):
    """Custom dataset for hierarchical input"""
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Ensure all values are tensors
        input_ids = item["input_ids"]
        if not isinstance(input_ids, torch.Tensor):
            input_ids = torch.tensor(input_ids, dtype=torch.long)
        
        attention_mask = item["attention_mask"]
        if not isinstance(attention_mask, torch.Tensor):
            attention_mask = torch.tensor(attention_mask, dtype=torch.long)
        
        labels = item["labels"]
        if not isinstance(labels, torch.Tensor):
            labels = torch.tensor(labels, dtype=torch.long)
        
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

def collate_fn(batch):
    """Custom collate function for hierarchical input"""
    # Convert all items to tensors and stack them
    input_ids = torch.stack([item["input_ids"] for item in batch])
    attention_mask = torch.stack([item["attention_mask"] for item in batch])
    labels = torch.stack([item["labels"] for item in batch])
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

def convert_to_serializable(obj):
    """Recursively convert numpy types to native Python types for JSON serialization"""
    if isinstance(obj, (np.int_, np.intc, np.intp, np.int8, np.int16, np.int32, np.int64,
                       np.uint8, np.uint16, np.uint32, np.uint64)):
        return int(obj)
    elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {k: convert_to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_serializable(item) for item in obj]
    elif isinstance(obj, tuple):
        return tuple(convert_to_serializable(item) for item in obj)
    else:
        return obj

def evaluate_metrics(model, dataloader, device, label_list):
    """Comprehensive evaluation with padding masking"""
    model.eval()
    all_preds, all_labels = [], []
    total_time = 0
    n_docs = 0
    n_sentences = 0
    eval_start = time.time()

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            mask = attention_mask[:, :, 0] > 0

            start = time.time()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            end = time.time()

            emissions = outputs["emissions"]
            preds = model.crf.decode(emissions, mask=mask)

            for i in range(len(labels)):
                seq_len = mask[i].sum().item()
                all_preds.extend(preds[i][:seq_len])
                all_labels.extend(labels[i][:seq_len].cpu().numpy().tolist())  # Convert to list of native ints

            total_time += (end - start)
            n_docs += input_ids.shape[0]
            n_sentences += mask.sum().item()

    eval_end = time.time()
    eval_time = eval_end - eval_start

    # Handle case where no samples were processed
    if len(all_labels) == 0:
        print("WARNING: No samples processed during evaluation!")
        return {
            "macro_f1": 0,
            "weighted_f1": 0,
            "accuracy": 0,
            "macro_precision": 0,
            "macro_recall": 0,
            "weighted_precision": 0,
            "weighted_recall": 0,
            "latency_ms_per_doc": 0,
            "latency_ms_per_sentence": 0,
            "eval_time_seconds": 0,
            "num_samples": 0,
            "all_labels": [],
            "all_preds": []
        }

    labels_for_report = list(range(len(label_list)))
    target_names = label_list

    # Generate classification report
    report = classification_report(
        all_labels, all_preds,
        labels=labels_for_report,
        target_names=target_names,
        output_dict=True,
        zero_division=0
    )

    # Calculate additional metrics
    macro_precision = precision_score(
        all_labels, all_preds, average='macro', zero_division=0
    )
    macro_recall = recall_score(
        all_labels, all_preds, average='macro', zero_division=0
    )
    weighted_precision = precision_score(
        all_labels, all_preds, average='weighted', zero_division=0
    )
    weighted_recall = recall_score(
        all_labels, all_preds, average='weighted', zero_division=0
    )
    per_label_precision = {
        label: report[label]['precision']
        for label in label_list
    }
    per_label_recall = {
        label: report[label]['recall']
        for label in label_list
    }

    macro_f1 = report['macro avg']['f1-score']
    weighted_f1 = report['weighted avg']['f1-score']
    accuracy = accuracy_score(all_labels, all_preds)
    per_label_f1 = {
        label: report[label]['f1-score']
        for label in label_list
    }

    latency_doc = (total_time / n_docs) * 1000 if n_docs else 0
    latency_sent = (total_time / n_sentences) * 1000 if n_sentences else 0

    # Convert all metrics to serializable types
    return {
        "macro_f1": float(macro_f1),
        "weighted_f1": float(weighted_f1),
        "accuracy": float(accuracy),
        "per_label_f1": {label: float(score) for label, score in per_label_f1.items()},
        "macro_precision": float(macro_precision),
        "macro_recall": float(macro_recall),
        "weighted_precision": float(weighted_precision),
        "weighted_recall": float(weighted_recall),
        "per_label_precision": {label: float(score) for label, score in per_label_precision.items()},
        "per_label_recall": {label: float(score) for label, score in per_label_recall.items()},
        "latency_ms_per_doc": float(latency_doc),
        "latency_ms_per_sentence": float(latency_sent),
        "eval_time_seconds": float(eval_time),
        "num_samples": int(n_docs),
        "all_labels": all_labels,  # Already native ints
        "all_preds": all_preds     # Already native ints
    }

def evaluate_qlora_model():
    """Loads a pre-trained QLoRA model and evaluates it on the test set"""
    try:
        start_time = time.time()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"\n{'='*50}")
        print(f"STARTING QLoRA MODEL EVALUATION")
        print(f"Timestamp: {datetime.now().isoformat()}")
        print(f"Device: {device}")
        print(f"Model Repo: {Config.hf_repo_id}")
        print(f"Seed: {SEED}")
        print(f"{'='*50}\n")

        os.makedirs(Config.output_dir, exist_ok=True)

        # 1. Download model artifacts from Hugging Face Hub
        print("Downloading QLoRA model artifacts from Hugging Face Hub...")
        config_path = hf_hub_download(
            repo_id=Config.hf_repo_id,
            filename="config.json"
        )
        model_path = hf_hub_download(
            repo_id=Config.hf_repo_id,
            filename="pytorch_model.bin"
        )

        # 2. Load configuration
        with open(config_path, 'r') as f:
            saved_config = json.load(f)

        label2id = saved_config['label2id']
        id2label = saved_config['id2label']
        model_config = saved_config['model_config']

        # Convert id2label keys to integers
        id2label = {int(k): v for k, v in id2label.items()}

        # Create label list sorted by ID
        label_list = [id2label[i] for i in range(len(id2label))]

        # Update Config with model parameters
        for key, value in model_config.items():
            setattr(Config, key, value)

        print(f"Loaded configuration for QLoRA model: {Config.bert_model_name}")
        print(f"QLoRA quantization: {Config.qlora_quant_type}")
        print(f"Number of labels: {len(label_list)}")

        # 3. Preprocess test dataset
        print("\nPreprocessing test dataset...")
        test_hier = preprocess_single_dataset(test_ds, label2id)
        print(f"Test examples after preprocessing: {len(test_hier)}")

        # 4. Tokenize test dataset using batch tokenization
        print("Tokenizing test dataset with batch processing...")
        tokenizer = AutoTokenizer.from_pretrained(Config.bert_model_name)
        test_tokenized = test_hier.map(
            lambda x: tokenize_document(x, tokenizer, label2id),
            batched=False
        )

        # 5. Create data loader with custom dataset and collate function
        print("Creating data loader with custom collation...")
        test_dataset = HierarchicalDataset(test_tokenized)
        test_loader = DataLoader(
            test_dataset,
            batch_size=Config.batch_size,
            shuffle=False,
            collate_fn=collate_fn
        )
        print(f"Test batches: {len(test_loader)}")

        # 6. Initialize QLoRA model
        print("\nInitializing QLoRA model...")
        model = QLoRAHSLNModel(
            num_labels=len(label2id),
            model_config=model_config
        )

        # 7. Load custom layers weights
        print("Loading custom layer weights...")
        # Get model's device for weight loading
        model_device = next(model.parameters()).device
        model.load_state_dict(
            torch.load(model_path, map_location=model_device), 
            strict=False
        )
        print("QLoRA model weights loaded successfully")

        # 8. Evaluate on test set
        print("\nEvaluating QLoRA model on test set...")
        test_metrics = evaluate_metrics(model, test_loader, device, label_list)

        # 9. Generate confusion matrix with raw counts
        print("Generating publication-quality confusion matrix with raw counts...")
        labels_idx = list(range(len(label_list)))
        cm_path = plot_confusion_matrix(
            test_metrics["all_labels"],
            test_metrics["all_preds"],
            labels=labels_idx,
            label_names=label_list,
            output_dir=Config.output_dir
        )
        print(f"Confusion matrix saved to: {cm_path}")

        # 10. Save test metrics with proper serialization
        metrics_path = os.path.join(Config.output_dir, "qlora_test_metrics.json")
        with open(metrics_path, 'w') as f:
            # Convert all metrics to serializable types
            serializable_metrics = convert_to_serializable(test_metrics)
            json.dump(serializable_metrics, f, indent=2)

        print(f"\n{'='*30} QLoRA TEST RESULTS {'='*30}")
        print(f"Weighted F1:      {test_metrics['weighted_f1']:.4f}")
        print(f"Macro F1:         {test_metrics['macro_f1']:.4f}")
        print(f"Accuracy:         {test_metrics['accuracy']:.4f}")
        print(f"Weighted Precision: {test_metrics['weighted_precision']:.4f}")
        print(f"Weighted Recall:    {test_metrics['weighted_recall']:.4f}")
        print(f"Macro Precision:  {test_metrics['macro_precision']:.4f}")
        print(f"Macro Recall:     {test_metrics['macro_recall']:.4f}")
        print(f"Latency:          {test_metrics['latency_ms_per_doc']:.2f} ms/doc")
        print(f"Per-sentence:     {test_metrics['latency_ms_per_sentence']:.2f} ms/sent")
        print(f"Evaluation time:  {test_metrics['eval_time_seconds']:.2f} seconds")
        print(f"Metrics saved to: {metrics_path}")

        print("\nPer-class Metrics:")
        for label in label_list:
            print(f"  {label}:")
            print(f"    F1:       {test_metrics['per_label_f1'][label]:.4f}")
            print(f"    Precision: {test_metrics['per_label_precision'][label]:.4f}")
            print(f"    Recall:    {test_metrics['per_label_recall'][label]:.4f}")

        total_time = time.time() - start_time
        print(f"\nEvaluation completed in {total_time:.2f} seconds")
        print(f"{'='*30} EVALUATION COMPLETE {'='*30}")

        return test_metrics

    except Exception as e:
        print(f"\n{'!'*50}")
        print("EVALUATION FAILED!")
        print(f"Error: {str(e)}")
        import traceback
        traceback.print_exc()
        with open(os.path.join(Config.output_dir, "qlora_error_log.txt"), "w") as f:
            f.write(f"Evaluation error at {datetime.now()}\n")
            f.write(str(e))
            f.write(traceback.format_exc())
        return None

if __name__ == "__main__":
    evaluate_qlora_model()

# LegalBERTHSLN-Adalora

In [None]:
import pandas as pd
from datasets import load_dataset, Dataset
from huggingface_hub import notebook_login

In [None]:
# -*- coding: utf-8 -*-
"""eval_adalora.py

Evaluates AdaLoRA-enhanced hierarchical model on the test set with proper total_step handling.
"""

from huggingface_hub import hf_hub_download
import pandas as pd
from datasets import Dataset
import torch
import time
import os
import json
import numpy as np
from datetime import datetime
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
import torch.nn as nn
from transformers import AutoModel
from torchcrf import CRF
from peft import AdaLoraConfig, PeftModel  # AdaLoRA-specific imports
import matplotlib.pyplot as plt
import seaborn as sns
import random

# Set seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class Config:
    # Will be updated from model's config
    bert_model_name = 'nlpaueb/legal-bert-base-uncased'
    max_num_sentences = 48
    max_length = 128
    dropout_rate = 0.4
    gamma = 2.0
    batch_size = 4
    output_dir = "./adalora_evaluation_results"
    hf_repo_id = "Please enter your huggingface user id here/hierarchical-legal-model-adalora-final"  # Update with your model repo
    total_step = 10000  # Dummy value for evaluation

import pandas as pd
from datasets import load_dataset, Dataset
from huggingface_hub import notebook_login

# Load datasets
splits = {
    'test': 'data/test-00000-of-00001-2526ab833e27e0ee.parquet'
}

test_df = pd.read_parquet("hf://datasets/opennyaiorg/InRhetoricalRoles/" + splits["test"])
test_ds = Dataset.from_pandas(test_df)

def get_spans_and_labels(example):
    """Extract spans and labels from example"""
    spans = []
    labels = []
    if example.get('annotations') and len(example['annotations']) > 0:
        if example['annotations'][0].get('result'):
            for ann in example['annotations'][0]['result']:
                if ann.get('value') and ann['value'].get('text') and ann['value'].get('labels'):
                    spans.append(ann['value']['text'])
                    labels.append(ann['value']['labels'][0])
    return {'spans': spans, 'labels': labels}

def preprocess_single_dataset(dataset, label2id):
    """Preprocess dataset for evaluation"""
    dataset = dataset.map(get_spans_and_labels)
    dataset = dataset.filter(lambda x: len(x['spans']) > 0)
    dataset = dataset.map(lambda x: {'text': x['spans'], 'label': x['labels']})
    return dataset

def tokenize_single_dataset(dataset, tokenizer, label2id):
    """Tokenize dataset for hierarchical input"""
    def tokenize_document(example):
        sentences = example['text']
        labels = example['label']
        sentences = sentences[:Config.max_num_sentences]
        labels = labels[:Config.max_num_sentences]
        pad_len = Config.max_num_sentences - len(sentences)
        sentences += [""] * pad_len
        labels += [list(label2id.keys())[0]] * pad_len

        input_ids = []
        attention_mask = []
        for sent in sentences:
            encoded = tokenizer(
                sent,
                padding="max_length",
                truncation=True,
                max_length=Config.max_length,
                return_tensors="pt"
            )
            input_ids.append(encoded["input_ids"].squeeze(0))
            attention_mask.append(encoded["attention_mask"].squeeze(0))

        input_ids = torch.stack(input_ids)
        attention_mask = torch.stack(attention_mask)
        label_ids = torch.tensor([label2id[l] for l in labels])

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": label_ids
        }

    return dataset.map(tokenize_document)

class PositionalEncoding(nn.Module):
    """Positional embeddings for sentence order"""
    def __init__(self, d_model, max_len):
        super().__init__()
        self.position_emb = nn.Embedding(max_len, d_model)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device).unsqueeze(0)
        return x + self.position_emb(positions)

class TransformerContextLayer(nn.Module):
    """Transformer-based context modeling"""
    def __init__(self, d_model, nhead=8, dim_feedforward=1024, dropout=0.2):
        super().__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
            activation='gelu'
        )
        self.transformer_encoder = nn.TransformerEncoder(
            self.encoder_layer, 
            num_layers=2
        )

    def forward(self, x):
        return self.transformer_encoder(x)

class EmissionLayer(nn.Module):
    """Enhanced emission layer with residual connection"""
    def __init__(self, input_size, num_labels, dropout=0.3):
        super().__init__()
        self.linear1 = nn.Linear(input_size, input_size*2)
        self.linear2 = nn.Linear(input_size*2, num_labels)
        self.dropout = nn.Dropout(dropout)
        self.gelu = nn.GELU()
        self.layer_norm = nn.LayerNorm(input_size*2)
        self.residual_proj = nn.Linear(input_size, num_labels)

    def forward(self, x):
        residual = x
        x = self.linear1(x)
        x = self.layer_norm(x)
        x = self.gelu(x)
        x = self.dropout(x)
        return self.linear2(x) + self.residual_proj(residual)

class FocalCRF(nn.Module):
    """CRF with focal loss for class imbalance"""
    def __init__(self, num_tags, gamma):
        super().__init__()
        self.crf = CRF(num_tags, batch_first=True)
        self.gamma = gamma

    def forward(self, emissions, tags, mask, class_weights=None):
        log_likelihood = self.crf(emissions, tags, mask=mask, reduction='none')
        pt = torch.exp(log_likelihood)
        focal_loss = -((1 - pt) ** self.gamma) * log_likelihood
        if class_weights is not None:
            weights_per_tag = class_weights[tags]
            valid_counts = mask.sum(dim=1)
            weights_per_sequence = weights_per_tag.sum(dim=1) / valid_counts
            focal_loss = focal_loss * weights_per_sequence
        return focal_loss.mean()

    def decode(self, emissions, mask):
        return self.crf.decode(emissions, mask=mask)

class ImprovedHSLNModel(nn.Module):
    """AdaLoRA-enhanced hierarchical model with proper total_step handling"""
    def __init__(self, num_labels, model_config, class_weights=None):
        super().__init__()
        self.class_weights = class_weights

        # Update config from model_config
        for key, value in model_config.items():
            setattr(Config, key, value)

        # Base BERT model
        base_bert = AutoModel.from_pretrained(Config.bert_model_name)
        
        # AdaLoRA configuration with total_step
        adalora_config = AdaLoraConfig(
            init_r=model_config.get('adalora_init_r', 64),
            target_r=model_config.get('adalora_target_r', 4608),
            beta1=model_config.get('adalora_beta1', 0.85),
            beta2=model_config.get('adalora_beta2', 0.85),
            tinit=model_config.get('adalora_tinit', 200),
            tfinal=model_config.get('adalora_tfinal', 1000),
            deltaT=model_config.get('adalora_deltaT', 10),
            lora_alpha=model_config.get('lora_alpha', 64),
            lora_dropout=model_config.get('lora_dropout', 0.15),
            target_modules=model_config.get('lora_target_modules', ["query", "value", "key"]),
            bias="none",
            total_step=Config.total_step  # Use the dummy value
        )
        
        # Wrap BERT with AdaLoRA
        self.bert = PeftModel(
            base_bert, 
            adalora_config
        )
        
        # Feature extraction
        self.sent_dropout = nn.Dropout(Config.dropout_rate)
        self.sent_layer_norm = nn.LayerNorm(self.bert.config.hidden_size)
        self.sent_projection = nn.Linear(
            self.bert.config.hidden_size,
            self.bert.config.hidden_size
        )

        # Context modeling
        self.position_enc = PositionalEncoding(
            self.bert.config.hidden_size,
            max_len=Config.max_num_sentences
        )
        self.context_encoder = TransformerContextLayer(
            d_model=self.bert.config.hidden_size
        )

        # Emission layer
        self.emission = EmissionLayer(
            input_size=self.bert.config.hidden_size,
            num_labels=num_labels
        )

        # CRF layer with focal loss
        self.crf = FocalCRF(num_labels, gamma=Config.gamma)

    def forward(self, input_ids, attention_mask, labels=None):
        batch_size, num_sent, seq_len = input_ids.shape

        # Process each sentence
        flat_input_ids = input_ids.view(-1, seq_len)
        flat_mask = attention_mask.view(-1, seq_len)

        bert_out = self.bert(
            input_ids=flat_input_ids,
            attention_mask=flat_mask
        ).last_hidden_state

        # Sentence embeddings (CLS token)
        sent_emb = bert_out[:, 0, :]
        sent_emb = self.sent_projection(sent_emb)
        sent_emb = self.sent_layer_norm(sent_emb)
        sent_emb = self.sent_dropout(sent_emb)
        sent_emb = sent_emb.view(batch_size, num_sent, -1)

        # Context modeling
        sent_emb = self.position_enc(sent_emb)
        context_emb = self.context_encoder(sent_emb)

        # Emissions
        emissions = self.emission(context_emb)
        mask = attention_mask[:, :, 0] > 0  # Sentence-level mask

        if labels is not None:
            loss = self.crf(
                emissions,
                labels,
                mask=mask,
                class_weights=self.class_weights
            )
            return {"loss": loss, "emissions": emissions}
        return {"emissions": emissions}

class HierarchicalDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        return {
            "input_ids": item["input_ids"],
            "attention_mask": item["attention_mask"],
            "labels": item["labels"]
        }

def collate_fn(batch):
    def ensure_tensor(x):
        return torch.tensor(x) if not isinstance(x, torch.Tensor) else x

    input_ids = torch.stack([ensure_tensor(item["input_ids"]) for item in batch])
    attention_mask = torch.stack([ensure_tensor(item["attention_mask"]) for item in batch])
    labels = torch.stack([ensure_tensor(item["labels"]) for item in batch])

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

def create_single_loader(dataset):
    return DataLoader(
        HierarchicalDataset(dataset),
        batch_size=Config.batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )

def plot_confusion_matrix(cm, labels, output_path, title='Confusion Matrix'):
    """Plot confusion matrix with specified styling"""
    plt.figure(figsize=(10, 8))
    ax = sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True,
                     xticklabels=labels, yticklabels=labels)
    
    # Bold labels and adjust font sizes
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_xlabel('Predicted', fontsize=12, fontweight='bold')
    ax.set_ylabel('Actual', fontsize=12, fontweight='bold')
    
    # Bold and rotate tick labels
    plt.xticks(fontweight='bold', rotation=45, ha='right')
    plt.yticks(fontweight='bold')
    
    # Make values more visible
    for t in ax.texts:
        t.set_text(t.get_text() + " ")
        t.set_fontweight('bold')
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()

def convert_to_serializable(obj):
    """Convert numpy data types to Python-native types for JSON serialization"""
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {k: convert_to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_serializable(item) for item in obj]
    else:
        return obj

def evaluate_metrics(model, dataloader, device, label_list):
    """Comprehensive evaluation with padding masking"""
    try:
        model.eval()
        all_preds, all_labels = [], []
        total_time = 0
        n_docs = 0
        n_sentences = 0
        eval_start = time.time()

        with torch.no_grad():
            for batch in dataloader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                mask = attention_mask[:, :, 0] > 0

                start = time.time()
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                end = time.time()

                emissions = outputs["emissions"]
                preds = model.crf.decode(emissions, mask=mask)

                for i in range(len(labels)):
                    seq_len = mask[i].sum().item()
                    all_preds.extend(preds[i][:seq_len])
                    all_labels.extend(labels[i][:seq_len].cpu().numpy())

                total_time += (end - start)
                n_docs += input_ids.shape[0]
                n_sentences += mask.sum().item()

        eval_end = time.time()
        eval_time = eval_end - eval_start

        labels_for_report = list(range(len(label_list)))
        target_names = label_list

        report = classification_report(
            all_labels, all_preds,
            labels=labels_for_report,
            target_names=target_names,
            output_dict=True,
            zero_division=0
        )

        # Calculate confusion matrix
        cm = confusion_matrix(all_labels, all_preds, labels=labels_for_report)
        
        # Extract metrics
        macro_f1 = report['macro avg']['f1-score']
        weighted_f1 = report['weighted avg']['f1-score']
        accuracy = accuracy_score(all_labels, all_preds)
        
        # Extract precision and recall metrics
        macro_precision = report['macro avg']['precision']
        macro_recall = report['macro avg']['recall']
        weighted_precision = report['weighted avg']['precision']
        weighted_recall = report['weighted avg']['recall']
        
        per_label_metrics = {}
        for label in label_list:
            per_label_metrics[label] = {
                'f1': report[label]['f1-score'],
                'precision': report[label]['precision'],
                'recall': report[label]['recall']
            }

        latency_doc = (total_time / n_docs) * 1000 if n_docs else 0
        latency_sent = (total_time / n_sentences) * 1000 if n_sentences else 0

        return {
            "macro_f1": macro_f1,
            "macro_precision": macro_precision,
            "macro_recall": macro_recall,
            "weighted_f1": weighted_f1,
            "weighted_precision": weighted_precision,
            "weighted_recall": weighted_recall,
            "accuracy": accuracy,
            "per_label_metrics": per_label_metrics,
            "confusion_matrix": cm.tolist(),
            "latency_ms_per_doc": latency_doc,
            "latency_ms_per_sentence": latency_sent,
            "eval_time_seconds": eval_time,
            "num_samples": n_docs
        }

    except Exception as e:
        print(f"Evaluation failed: {str(e)}")
        raise

def evaluate_test_set():
    """Loads AdaLoRA-enhanced model and evaluates on test set with total_step fix"""
    try:
        start_time = time.time()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"\n{'='*50}")
        print(f"ADALORA TEST SET EVALUATION")
        print(f"Timestamp: {datetime.now().isoformat()}")
        print(f"Device: {device}")
        print(f"Model Repo: {Config.hf_repo_id}")
        print(f"Seed: {SEED}")
        print(f"{'='*50}\n")

        os.makedirs(Config.output_dir, exist_ok=True)

        # 1. Download model artifacts
        print("Downloading model artifacts...")
        config_path = hf_hub_download(
            repo_id=Config.hf_repo_id,
            filename="config.json"
        )
        model_path = hf_hub_download(
            repo_id=Config.hf_repo_id,
            filename="pytorch_model.bin"
        )

        # Load configuration
        with open(config_path, 'r') as f:
            saved_config = json.load(f)

        label2id = saved_config['label2id']
        id2label = saved_config['id2label']
        model_config = saved_config['model_config']

        # Convert id2label keys to integers
        id2label = {int(k): v for k, v in id2label.items()}

        # Create sorted label list
        label_list = [id2label[i] for i in range(len(id2label))]

        # Update Config with model parameters
        for key, value in model_config.items():
            setattr(Config, key, value)

        print(f"Loaded configuration for AdaLoRA model")
        print(f"Number of labels: {len(label_list)}")
        print(f"AdaLoRA target_r: {model_config.get('adalora_target_r', 4608)}")

        # 2. Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(Config.bert_model_name)

        # 3. Preprocess test dataset
        print("\nPreprocessing test dataset...")
        test_hier = preprocess_single_dataset(test_ds, label2id)
        print(f"Test examples after preprocessing: {len(test_hier)}")

        # 4. Tokenize test dataset
        print("Tokenizing test dataset...")
        test_tokenized = tokenize_single_dataset(test_hier, tokenizer, label2id)

        # 5. Create data loader
        test_loader = create_single_loader(test_tokenized)
        print(f"Test batches: {len(test_loader)}")

        # 6. Initialize AdaLoRA-enhanced model
        print("\nInitializing AdaLoRA model...")
        model = ImprovedHSLNModel(
            num_labels=len(label2id),
            model_config=model_config
        ).to(device)

        # 7. Load model weights
        model.load_state_dict(torch.load(model_path, map_location=device))
        print("AdaLoRA model weights loaded successfully")

        # 8. Evaluate on test set
        print("\nEvaluating on test set...")
        test_metrics = evaluate_metrics(model, test_loader, device, label_list)

        # 9. Save test metrics
        metrics_path = os.path.join(Config.output_dir, "test_metrics.json")
        with open(metrics_path, 'w') as f:
            # Convert metrics to serializable format
            serializable_metrics = convert_to_serializable(test_metrics)
            json.dump(serializable_metrics, f, indent=2)

        # 10. Plot confusion matrix
        cm_path = os.path.join(Config.output_dir, "confusion_matrix.png")
        cm = np.array(test_metrics["confusion_matrix"])
        plot_confusion_matrix(cm, label_list, cm_path, "AdaLoRA Model Confusion Matrix")
        
        print(f"\n{'='*30} ADALORA TEST RESULTS {'='*30}")
        print(f"Macro F1:     {test_metrics['macro_f1']:.4f}")
        print(f"Macro Prec:   {test_metrics['macro_precision']:.4f}")
        print(f"Macro Recall: {test_metrics['macro_recall']:.4f}")
        print(f"Weighted F1:  {test_metrics['weighted_f1']:.4f}")
        print(f"Accuracy:     {test_metrics['accuracy']:.4f}")
        print(f"Latency:      {test_metrics['latency_ms_per_doc']:.2f} ms/doc")
        print(f"Evaluation time: {test_metrics['eval_time_seconds']:.2f} seconds")
        print(f"Metrics saved to: {metrics_path}")
        print(f"Confusion matrix saved to: {cm_path}")

        print("\nPer-class Metrics:")
        for label, metrics in test_metrics['per_label_metrics'].items():
            print(f"  {label}:")
            print(f"    F1:       {metrics['f1']:.4f}")
            print(f"    Precision: {metrics['precision']:.4f}")
            print(f"    Recall:    {metrics['recall']:.4f}")

        total_time = time.time() - start_time
        print(f"\nEvaluation completed in {total_time:.2f} seconds")
        print(f"{'='*30} EVALUATION COMPLETE {'='*30}")

        return test_metrics

    except Exception as e:
        print(f"\n{'!'*50}")
        print("ADALORA EVALUATION FAILED!")
        print(f"Error: {str(e)}")
        with open(os.path.join(Config.output_dir, "error_log.txt"), "w") as f:
            f.write(f"Evaluation error at {datetime.now()}\n")
            f.write(str(e))
        return None

if __name__ == "__main__":
    evaluate_test_set()

# LegalBERTHSLN-FSSA

In [None]:
# -*- coding: utf-8 -*-
"""eval_fssa_model.py

Evaluation script for FSSA-based hierarchical legal model.
"""

from huggingface_hub import notebook_login, hf_hub_download
import torch.nn.functional as F
import pandas as pd
from datasets import Dataset
import torch
import time
import os
import json
import numpy as np
import math
from datetime import datetime
from sklearn.metrics import f1_score, accuracy_score, classification_report, confusion_matrix, precision_score, recall_score
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
import torch.nn as nn
from transformers import AutoModel
from torchcrf import CRF
import random
import matplotlib.pyplot as plt
import seaborn as sns

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class Config:
    # Defaults - will be overridden by saved config
    bert_model_name = 'nlpaueb/legal-bert-base-uncased'
    lstm_hidden_size = 200
    context_hidden_size = 200
    max_num_sentences = 32
    max_length = 64
    dropout_rate = 0.4
    gamma = 2.0
    
    # FSSA parameters
    fssa_linear_rank = 1
    fssa_emb_rank = 1
    fssa_linear_sparsity = 0.99
    fssa_emb_sparsity = 0.995
    fssa_block_size = 32
    
    # Size reduction parameters
    context_intermediate_size = 380
    emission_hidden_size = 64
    
    # Evaluation settings
    batch_size = 4
    output_dir = "./fssa_evaluation_results"
    hf_repo_id = "Please enter your huggingface user id here/hierarchical-legal-model-improved-fssa"

# =============================================
# Model Components (copied from training script)
# =============================================
class PositionalEncoding(nn.Module):
    """Positional embeddings for sentence order"""
    def __init__(self, d_model, max_len=Config.max_num_sentences):
        super().__init__()
        self.position_emb = nn.Embedding(max_len, d_model)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device).unsqueeze(0)
        return x + self.position_emb(positions)

class TransformerContextLayer(nn.Module):
    """Transformer-based context modeling with reduced FFN size"""
    def __init__(self, d_model, nhead=4, dim_feedforward=Config.context_intermediate_size, dropout=0.1):
        super().__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)

    def forward(self, x):
        return self.transformer_encoder(x)

class EmissionLayer(nn.Module):
    """Emission layer with reduced hidden size"""
    def __init__(self, input_size, num_labels, dropout=0.2):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_size, Config.emission_hidden_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(Config.emission_hidden_size, num_labels)
        )

    def forward(self, x):
        return self.mlp(x)

class FocalCRF(nn.Module):
    """CRF with focal loss for class imbalance"""
    def __init__(self, num_tags, gamma=Config.gamma):
        super().__init__()
        self.crf = CRF(num_tags, batch_first=True)
        self.gamma = gamma

    def forward(self, emissions, tags, mask, class_weights=None):
        log_likelihood = self.crf(emissions, tags, mask=mask, reduction='none')
        pt = torch.exp(log_likelihood)
        focal_loss = -((1 - pt) ** self.gamma) * log_likelihood
        
        if class_weights is not None:
            weights_per_tag = class_weights[tags]
            valid_counts = mask.sum(dim=1)
            weights_per_sequence = weights_per_tag.sum(dim=1) / valid_counts
            focal_loss = focal_loss * weights_per_sequence

        return focal_loss.mean()

    def decode(self, emissions, mask):
        return self.crf.decode(emissions, mask=mask)

class FSSALayer(nn.Module):
    """Factorized Structured Sparse Adaptation layer"""
    def __init__(self, original_layer, rank=Config.fssa_linear_rank,
                 sparsity=Config.fssa_linear_sparsity, block_size=Config.fssa_block_size):
        super().__init__()
        self.original_layer = original_layer
        self.rank = rank
        self.sparsity = sparsity
        self.block_size = block_size

        # Freeze original parameters
        for param in self.original_layer.parameters():
            param.requires_grad = False

        in_features = original_layer.in_features
        out_features = original_layer.out_features

        # Factorized adaptation parameters
        self.A = nn.Parameter(torch.zeros(rank, in_features))
        self.B = nn.Parameter(torch.zeros(out_features, rank))

        # Structured sparsity mask
        self.mask = self.create_sparsity_mask(out_features, in_features)
        self.init_parameters()

    def create_sparsity_mask(self, rows, cols):
        row_blocks = (rows + self.block_size - 1) // self.block_size
        col_blocks = (cols + self.block_size - 1) // self.block_size
        num_blocks = row_blocks * col_blocks
        num_active = int(num_blocks * (1 - self.sparsity))
        active_blocks = random.sample(range(num_blocks), num_active)

        mask = torch.zeros(rows, cols)
        for block_idx in active_blocks:
            i = block_idx // col_blocks
            j = block_idx % col_blocks
            row_start = i * self.block_size
            col_start = j * self.block_size
            row_end = min(row_start + self.block_size, rows)
            col_end = min(col_start + self.block_size, cols)
            mask[row_start:row_end, col_start:col_end] = 1

        return mask

    def init_parameters(self):
        nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.B, a=math.sqrt(5))

    def forward(self, x):
        base_output = self.original_layer(x)
        adapted = (self.B @ self.A) * self.mask.to(self.B.device)
        delta_output = F.linear(x, adapted)
        return base_output + delta_output

class FSSAEmbedding(nn.Module):
    """FSSA for embeddings"""
    def __init__(self, original_embedding, rank=Config.fssa_emb_rank,
                 sparsity=Config.fssa_emb_sparsity):
        super().__init__()
        self.original_embedding = original_embedding
        self.rank = rank
        self.sparsity = sparsity

        # Freeze original parameters
        for param in self.original_embedding.parameters():
            param.requires_grad = False

        num_embeddings = original_embedding.num_embeddings
        embedding_dim = original_embedding.embedding_dim

        # Factorized adaptation
        self.U = nn.Parameter(torch.zeros(num_embeddings, rank))
        self.V = nn.Parameter(torch.zeros(rank, embedding_dim))
        self.mask = (torch.rand(num_embeddings, rank) > sparsity).float()
        self.init_parameters()

    def init_parameters(self):
        nn.init.normal_(self.U, mean=0, std=0.02)
        nn.init.normal_(self.V, mean=0, std=0.02)

    def forward(self, input_ids):
        base_embeds = self.original_embedding(input_ids)
        adapted = (self.U * self.mask.to(self.U.device)) @ self.V
        delta_embeds = F.embedding(input_ids, adapted)
        return base_embeds + delta_embeds

def apply_fssa(model):
    """Apply FSSA to model layers"""
    # Apply to embeddings
    if hasattr(model, 'embeddings'):
        model.embeddings.word_embeddings = FSSAEmbedding(
            model.embeddings.word_embeddings
        )

    # Apply to transformer layers
    for layer in model.encoder.layer:
        layer.attention.self.query = FSSALayer(layer.attention.self.query)
        layer.attention.self.key = FSSALayer(layer.attention.self.key)
        layer.attention.self.value = FSSALayer(layer.attention.self.value)
        layer.attention.output.dense = FSSALayer(layer.attention.output.dense)
        layer.intermediate.dense = FSSALayer(layer.intermediate.dense)
        layer.output.dense = FSSALayer(layer.output.dense)

    return model

class ImprovedHSLNModel(nn.Module):
    """FSSA-enhanced hierarchical model"""
    def __init__(self, num_labels, model_config, class_weights=None):
        super().__init__()
        self.class_weights = class_weights

        # Load base BERT
        self.bert = AutoModel.from_pretrained(model_config['bert_model_name'])
        
        # Apply FSSA modifications
        self.bert = apply_fssa(self.bert)
        
        # Rest of the model
        self.sent_dropout = nn.Dropout(model_config['dropout_rate'])
        self.sent_layer_norm = nn.LayerNorm(self.bert.config.hidden_size)
        self.position_enc = PositionalEncoding(self.bert.config.hidden_size)
        self.context_encoder = TransformerContextLayer(
            d_model=self.bert.config.hidden_size
        )
        self.emission = EmissionLayer(
            input_size=self.bert.config.hidden_size,
            num_labels=num_labels
        )
        self.crf = FocalCRF(num_labels, gamma=model_config['gamma'])

    def forward(self, input_ids, attention_mask, labels=None):
        batch_size, num_sent, seq_len = input_ids.shape
        flat_input_ids = input_ids.view(-1, seq_len)
        flat_mask = attention_mask.view(-1, seq_len)

        bert_out = self.bert(
            input_ids=flat_input_ids,
            attention_mask=flat_mask
        ).last_hidden_state

        # Sentence embeddings (CLS token)
        sent_emb = bert_out[:, 0, :]
        sent_emb = self.sent_layer_norm(sent_emb)
        sent_emb = self.sent_dropout(sent_emb)
        sent_emb = sent_emb.view(batch_size, num_sent, -1)

        # Context modeling
        sent_emb = self.position_enc(sent_emb)
        context_emb = self.context_encoder(sent_emb)

        # Emissions
        emissions = self.emission(context_emb)
        mask = attention_mask[:, :, 0] > 0  # Sentence-level mask

        if labels is not None:
            loss = self.crf(
                emissions,
                labels,
                mask=mask,
                class_weights=self.class_weights
            )
            return {"loss": loss, "emissions": emissions}
        return {"emissions": emissions}

# ===================================
# Dataset Preparation and Evaluation
# ===================================
def get_spans_and_labels(example):
    """Extract spans and labels from example"""
    spans = []
    labels = []
    if example.get('annotations') and len(example['annotations']) > 0:
        if example['annotations'][0].get('result'):
            for ann in example['annotations'][0]['result']:
                if ann.get('value') and ann['value'].get('text') and ann['value'].get('labels'):
                    spans.append(ann['value']['text'])
                    labels.append(ann['value']['labels'][0])
    return {'spans': spans, 'labels': labels}

def preprocess_single_dataset(dataset, label2id):
    """Preprocess dataset for evaluation"""
    dataset = dataset.map(get_spans_and_labels)
    dataset = dataset.filter(lambda x: len(x['spans']) > 0)
    dataset = dataset.map(lambda x: {'text': x['spans'], 'label': x['labels']})
    return dataset

def tokenize_single_dataset(dataset, tokenizer, label2id):
    """Tokenize dataset for hierarchical input"""
    def tokenize_document(example):
        sentences = example['text']
        labels = example['label']
        sentences = sentences[:Config.max_num_sentences]
        labels = labels[:Config.max_num_sentences]
        pad_len = Config.max_num_sentences - len(sentences)
        sentences += [""] * pad_len
        labels += [list(label2id.keys())[0]] * pad_len

        input_ids = []
        attention_mask = []
        for sent in sentences:
            encoded = tokenizer(
                sent,
                padding="max_length",
                truncation=True,
                max_length=Config.max_length,
                return_tensors="pt"
            )
            input_ids.append(encoded["input_ids"].squeeze(0))
            attention_mask.append(encoded["attention_mask"].squeeze(0))

        input_ids = torch.stack(input_ids)
        attention_mask = torch.stack(attention_mask)
        label_ids = torch.tensor([label2id[l] for l in labels])

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": label_ids
        }

    return dataset.map(tokenize_document)

class HierarchicalDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        return {
            "input_ids": item["input_ids"],
            "attention_mask": item["attention_mask"],
            "labels": item["labels"]
        }

def collate_fn(batch):
    def ensure_tensor(x):
        return torch.tensor(x) if not isinstance(x, torch.Tensor) else x

    input_ids = torch.stack([ensure_tensor(item["input_ids"]) for item in batch])
    attention_mask = torch.stack([ensure_tensor(item["attention_mask"]) for item in batch])
    labels = torch.stack([ensure_tensor(item["labels"]) for item in batch])

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

def create_single_loader(dataset):
    return DataLoader(
        HierarchicalDataset(dataset),
        batch_size=Config.batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )

def plot_confusion_matrix(cm, labels, output_path, title="Confusion Matrix"):
    """Plot confusion matrix with specified formatting"""
    plt.figure(figsize=(10, 8))
    ax = sns.heatmap(
        cm, 
        annot=True, 
        fmt='d', 
        cmap='Blues', 
        cbar=True,
        annot_kws={"fontsize": 12, "fontweight": "bold"}
    )
    
    ax.set_xlabel('Predicted Labels', fontsize=14, fontweight='bold')
    ax.set_ylabel('True Labels', fontsize=14, fontweight='bold')
    ax.set_title(title, fontsize=16, fontweight='bold')
    
    # Set tick labels with bold font
    ax.set_xticklabels(
        labels, 
        rotation=45, 
        ha='right', 
        fontsize=12, 
        fontweight='bold'
    )
    ax.set_yticklabels(
        labels, 
        rotation=0, 
        fontsize=12, 
        fontweight='bold'
    )
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()

def evaluate_metrics(model, dataloader, device, label_list):
    """Comprehensive evaluation with padding masking"""
    model.eval()
    all_preds, all_labels = [], []
    total_time = 0
    n_docs = 0
    n_sentences = 0
    eval_start = time.time()

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            mask = attention_mask[:, :, 0] > 0

            start = time.time()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            end = time.time()

            emissions = outputs["emissions"]
            preds = model.crf.decode(emissions, mask=mask)

            for i in range(len(labels)):
                seq_len = mask[i].sum().item()
                all_preds.extend(preds[i][:seq_len])
                all_labels.extend(labels[i][:seq_len].cpu().numpy())

            total_time += (end - start)
            n_docs += input_ids.shape[0]
            n_sentences += mask.sum().item()

    eval_end = time.time()
    eval_time = eval_end - eval_start

    labels_for_report = list(range(len(label_list)))
    target_names = label_list

    # Compute classification metrics
    report = classification_report(
        all_labels, all_preds,
        labels=labels_for_report,
        target_names=target_names,
        output_dict=True,
        zero_division=0
    )
    
    # Compute precision and recall
    precision = precision_score(
        all_labels, all_preds, 
        labels=labels_for_report, 
        average='weighted', 
        zero_division=0
    )
    recall = recall_score(
        all_labels, all_preds, 
        labels=labels_for_report, 
        average='weighted', 
        zero_division=0
    )
    
    # Compute per-class precision and recall
    per_label_precision = {
        label: report[label]['precision'] 
        for label in label_list
    }
    per_label_recall = {
        label: report[label]['recall'] 
        for label in label_list
    }

    macro_f1 = report['macro avg']['f1-score']
    weighted_f1 = report['weighted avg']['f1-score']
    accuracy = accuracy_score(all_labels, all_preds)
    per_label_f1 = {
        label: report[label]['f1-score']
        for label in label_list
    }
    
    # Compute confusion matrix
    cm = confusion_matrix(all_labels, all_preds, labels=labels_for_report)

    latency_doc = (total_time / n_docs) * 1000 if n_docs else 0
    latency_sent = (total_time / n_sentences) * 1000 if n_sentences else 0

    return {
        "precision": precision,
        "recall": recall,
        "macro_f1": macro_f1,
        "weighted_f1": weighted_f1,
        "accuracy": accuracy,
        "per_label_f1": per_label_f1,
        "per_label_precision": per_label_precision,
        "per_label_recall": per_label_recall,
        "confusion_matrix": cm.tolist(),
        "latency_ms_per_doc": latency_doc,
        "latency_ms_per_sentence": latency_sent,
        "eval_time_seconds": eval_time,
        "num_samples": n_docs
    }

def evaluate_test_set():
    """Loads FSSA model and evaluates on test set"""
    try:
        # Set seeds again for double safety
        random.seed(SEED)
        np.random.seed(SEED)
        torch.manual_seed(SEED)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(SEED)
        
        start_time = time.time()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"\n{'='*50}")
        print(f"STARTING FSSA MODEL EVALUATION")
        print(f"Timestamp: {datetime.now().isoformat()}")
        print(f"Device: {device}")
        print(f"Model Repo: {Config.hf_repo_id}")
        print(f"Random Seed: {SEED}")
        print(f"{'='*50}\n")

        os.makedirs(Config.output_dir, exist_ok=True)

        # 1. Download model artifacts
        print("Downloading model artifacts from Hugging Face Hub...")
        config_path = hf_hub_download(
            repo_id=Config.hf_repo_id,
            filename="config.json"
        )
        model_path = hf_hub_download(
            repo_id=Config.hf_repo_id,
            filename="pytorch_model.bin"
        )

        # Load configuration
        with open(config_path, 'r') as f:
            saved_config = json.load(f)

        label2id = saved_config['label2id']
        id2label = {int(k): v for k, v in saved_config['id2label'].items()}
        model_config = saved_config['model_config']

        # Create label list sorted by ID
        label_list = [id2label[i] for i in range(len(id2label))]

        # Update Config with model parameters
        for key, value in model_config.items():
            setattr(Config, key, value)

        print(f"Loaded configuration for model: {Config.bert_model_name}")
        print(f"Number of labels: {len(label_list)}")

        # 2. Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(Config.bert_model_name)

        # 3. Preprocess test dataset
        import pandas as pd
        from datasets import load_dataset, Dataset
        
        print("\nPreprocessing test dataset...")
        splits = {'test': 'data/test-00000-of-00001-2526ab833e27e0ee.parquet'}
        test_df = pd.read_parquet("hf://datasets/opennyaiorg/InRhetoricalRoles/" + splits["test"])
        test_ds = Dataset.from_pandas(test_df)
        test_hier = preprocess_single_dataset(test_ds, label2id)
        print(f"Test examples after preprocessing: {len(test_hier)}")

        # 4. Tokenize test dataset
        print("Tokenizing test dataset...")
        test_tokenized = tokenize_single_dataset(test_hier, tokenizer, label2id)

        # 5. Create data loader
        test_loader = create_single_loader(test_tokenized)
        print(f"Test batches: {len(test_loader)}")

        # 6. Initialize FSSA model
        print("\nInitializing FSSA model...")
        model = ImprovedHSLNModel(
            num_labels=len(label2id),
            model_config=model_config,
            class_weights=None
        ).to(device)

        # 7. Load model weights
        model.load_state_dict(torch.load(model_path, map_location=device))
        print("FSSA model weights loaded successfully")

        # 8. Evaluate on test set
        print("\nEvaluating on test set...")
        test_metrics = evaluate_metrics(model, test_loader, device, label_list)

        # 9. Save test metrics and plot confusion matrix
        metrics_path = os.path.join(Config.output_dir, "fssa_test_metrics.json")
        with open(metrics_path, 'w') as f:
            json.dump(test_metrics, f, indent=2)
        
        # Plot and save confusion matrix
        cm_path = os.path.join(Config.output_dir, "fssa_confusion_matrix.png")
        cm = np.array(test_metrics["confusion_matrix"])
        plot_confusion_matrix(cm, label_list, cm_path)
        
        print(f"\n{'='*30} FSSA TEST RESULTS {'='*30}")
        print(f"Weighted Precision: {test_metrics['precision']:.4f}")
        print(f"Weighted Recall:    {test_metrics['recall']:.4f}")
        print(f"Weighted F1:        {test_metrics['weighted_f1']:.4f}")
        print(f"Macro F1:           {test_metrics['macro_f1']:.4f}")
        print(f"Accuracy:           {test_metrics['accuracy']:.4f}")
        print(f"Latency:            {test_metrics['latency_ms_per_doc']:.2f} ms/doc")
        print(f"Evaluation time:    {test_metrics['eval_time_seconds']:.2f} seconds")
        print(f"Confusion matrix saved to: {cm_path}")
        print(f"Metrics saved to:   {metrics_path}")

        print("\nPer-class Metrics:")
        for label in label_list:
            print(f"  {label}:")
            print(f"    Precision: {test_metrics['per_label_precision'][label]:.4f}")
            print(f"    Recall:    {test_metrics['per_label_recall'][label]:.4f}")
            print(f"    F1:        {test_metrics['per_label_f1'][label]:.4f}")

        total_time = time.time() - start_time
        print(f"\nEvaluation completed in {total_time:.2f} seconds")
        print(f"{'='*30} EVALUATION COMPLETE {'='*30}")

        return test_metrics

    except Exception as e:
        print(f"\n{'!'*50}")
        print("FSSA EVALUATION FAILED!")
        print(f"Error: {str(e)}")
        with open(os.path.join(Config.output_dir, "fssa_error_log.txt"), "w") as f:
            f.write(f"Evaluation error at {datetime.now()}\n")
            f.write(str(e))
        return None

if __name__ == "__main__":
    evaluate_test_set()

# LegalBERTHSLN-QloraFSSA


In [None]:
# -*- coding: utf-8 -*-
"""eval_qlora_fssa.py

Evaluates a pre-trained QLora-FSSA model on the test set and saves metrics to JSON.
"""

from huggingface_hub import hf_hub_download
import torch.nn.functional as F
import pandas as pd
from datasets import Dataset
import torch
import time
import os
import json
import numpy as np
import math
from datetime import datetime
from sklearn.metrics import f1_score, accuracy_score, classification_report, confusion_matrix
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, BitsAndBytesConfig
import torch.nn as nn
from transformers import AutoModel
from torchcrf import CRF
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import matplotlib.pyplot as plt
import seaborn as sns
import random

# Set seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class Config:
    # Default values - will be overwritten by saved config
    bert_model_name = 'nlpaueb/legal-bert-base-uncased'
    max_num_sentences = 32
    max_length = 64
    dropout_rate = 0.4
    gamma = 2.0
    batch_size = 4
    output_dir = "./qlora_fssa_evaluation_results"
    hf_repo_id = "Please enter your huggingface user id here/hierarchical-legal-model-improved-fssa-qlora"
    
    # FSSA parameters
    fssa_linear_rank = 1
    fssa_emb_rank = 1
    fssa_linear_sparsity = 0.99
    fssa_emb_sparsity = 0.995
    fssa_block_size = 32
    
    # QLoRA parameters
    qlora_r = 8
    qlora_alpha = 32
    qlora_dropout = 0.05
    qlora_target_modules = ["query", "key", "value", "dense"]
    qlora_compute_dtype = torch.bfloat16
    
    # Size reduction parameters
    context_intermediate_size = 380
    emission_hidden_size = 64

# Load datasets
splits = {
    'test': 'data/test-00000-of-00001-2526ab833e27e0ee.parquet'
}

test_df = pd.read_parquet("hf://datasets/opennyaiorg/InRhetoricalRoles/" + splits["test"])
test_ds = Dataset.from_pandas(test_df)

def get_spans_and_labels(example):
    """Extract spans and labels from example"""
    spans = []
    labels = []
    if example.get('annotations') and len(example['annotations']) > 0:
        if example['annotations'][0].get('result'):
            for ann in example['annotations'][0]['result']:
                if ann.get('value') and ann['value'].get('text') and ann['value'].get('labels'):
                    spans.append(ann['value']['text'])
                    labels.append(ann['value']['labels'][0])
    return {'spans': spans, 'labels': labels}

def preprocess_single_dataset(dataset, label2id):
    """Preprocess dataset for evaluation"""
    dataset = dataset.map(get_spans_and_labels)
    dataset = dataset.filter(lambda x: len(x['spans']) > 0)
    dataset = dataset.map(lambda x: {'text': x['spans'], 'label': x['labels']})
    return dataset

def tokenize_single_dataset(dataset, tokenizer, label2id):
    """Tokenize dataset for hierarchical input"""
    def tokenize_document(example):
        sentences = example['text']
        labels = example['label']
        sentences = sentences[:Config.max_num_sentences]
        labels = labels[:Config.max_num_sentences]
        pad_len = Config.max_num_sentences - len(sentences)
        sentences += [""] * pad_len
        labels += [list(label2id.keys())[0]] * pad_len

        input_ids = []
        attention_mask = []
        for sent in sentences:
            encoded = tokenizer(
                sent,
                padding="max_length",
                truncation=True,
                max_length=Config.max_length,
                return_tensors="pt"
            )
            input_ids.append(encoded["input_ids"].squeeze(0))
            attention_mask.append(encoded["attention_mask"].squeeze(0))

        input_ids = torch.stack(input_ids)
        attention_mask = torch.stack(attention_mask)
        label_ids = torch.tensor([label2id[l] for l in labels])

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": label_ids
        }

    return dataset.map(tokenize_document)

# ====================== MODEL ARCHITECTURE ======================
class PositionalEncoding(nn.Module):
    """Positional embeddings for sentence order"""
    def __init__(self, d_model, max_len=Config.max_num_sentences):
        super().__init__()
        self.position_emb = nn.Embedding(max_len, d_model)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device).unsqueeze(0)
        return x + self.position_emb(positions)

class TransformerContextLayer(nn.Module):
    """Transformer-based context modeling with reduced FFN size"""
    def __init__(self, d_model, nhead=4, dim_feedforward=Config.context_intermediate_size, dropout=0.1):
        super().__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)

    def forward(self, x):
        return self.transformer_encoder(x)

class EmissionLayer(nn.Module):
    """Emission layer with reduced hidden size"""
    def __init__(self, input_size, num_labels, dropout=0.2):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_size, Config.emission_hidden_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(Config.emission_hidden_size, num_labels)
        )

    def forward(self, x):
        return self.mlp(x)

class FocalCRF(nn.Module):
    """CRF with focal loss for class imbalance"""
    def __init__(self, num_tags, gamma=Config.gamma):
        super().__init__()
        self.crf = CRF(num_tags, batch_first=True)
        self.gamma = gamma

    def forward(self, emissions, tags, mask, class_weights=None):
        log_likelihood = self.crf(emissions, tags, mask=mask, reduction='none')
        pt = torch.exp(log_likelihood)
        focal_loss = -((1 - pt) ** self.gamma) * log_likelihood
        
        if class_weights is not None:
            weights_per_tag = class_weights[tags]
            valid_counts = mask.sum(dim=1)
            weights_per_sequence = weights_per_tag.sum(dim=1) / valid_counts
            focal_loss = focal_loss * weights_per_sequence

        return focal_loss.mean()

    def decode(self, emissions, mask):
        return self.crf.decode(emissions, mask=mask)

class FSSALayer(nn.Module):
    """Factorized Structured Sparse Adaptation layer"""
    def __init__(self, original_layer, rank=Config.fssa_linear_rank,
                 sparsity=Config.fssa_linear_sparsity, block_size=Config.fssa_block_size):
        super().__init__()
        self.original_layer = original_layer
        self.rank = rank
        self.sparsity = sparsity
        self.block_size = block_size

        for param in self.original_layer.parameters():
            param.requires_grad = False

        in_features = original_layer.in_features
        out_features = original_layer.out_features

        self.A = nn.Parameter(torch.zeros(rank, in_features))
        self.B = nn.Parameter(torch.zeros(out_features, rank))
        self.mask = self.create_sparsity_mask(out_features, in_features)
        self.init_parameters()

    def create_sparsity_mask(self, rows, cols):
        row_blocks = (rows + self.block_size - 1) // self.block_size
        col_blocks = (cols + self.block_size - 1) // self.block_size
        num_blocks = row_blocks * col_blocks
        num_active = int(num_blocks * (1 - self.sparsity))
        active_blocks = random.sample(range(num_blocks), num_active)

        mask = torch.zeros(rows, cols)
        for block_idx in active_blocks:
            i = block_idx // col_blocks
            j = block_idx % col_blocks
            row_start = i * self.block_size
            col_start = j * self.block_size
            row_end = min(row_start + self.block_size, rows)
            col_end = min(col_start + self.block_size, cols)
            mask[row_start:row_end, col_start:col_end] = 1
        return mask

    def init_parameters(self):
        nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.B, a=math.sqrt(5))

    def forward(self, x):
        base_output = self.original_layer(x)
        adapted = (self.B @ self.A) * self.mask.to(self.B.device)
        delta_output = F.linear(x, adapted)
        return base_output + delta_output

class FSSAEmbedding(nn.Module):
    """Factorized Structured Sparse Adaptation for embeddings"""
    def __init__(self, original_embedding, rank=Config.fssa_emb_rank,
                 sparsity=Config.fssa_emb_sparsity):
        super().__init__()
        self.original_embedding = original_embedding
        self.rank = rank
        self.sparsity = sparsity

        for param in self.original_embedding.parameters():
            param.requires_grad = False

        num_embeddings = original_embedding.num_embeddings
        embedding_dim = original_embedding.embedding_dim

        self.U = nn.Parameter(torch.zeros(num_embeddings, rank))
        self.V = nn.Parameter(torch.zeros(rank, embedding_dim))
        self.mask = (torch.rand(num_embeddings, rank) > sparsity).float()
        self.init_parameters()

    def init_parameters(self):
        nn.init.normal_(self.U, mean=0, std=0.02)
        nn.init.normal_(self.V, mean=0, std=0.02)

    def forward(self, input_ids):
        base_embeds = self.original_embedding(input_ids)
        adapted = (self.U * self.mask.to(self.U.device)) @ self.V
        delta_embeds = F.embedding(input_ids, adapted)
        return base_embeds + delta_embeds

def apply_fssa_to_hierarchical(model):
    """Apply FSSA only to hierarchical components"""
    model.position_enc.position_emb = FSSAEmbedding(model.position_enc.position_emb)
    
    for name, module in model.context_encoder.named_children():
        if isinstance(module, nn.Linear):
            setattr(model.context_encoder, name, FSSALayer(module))
        else:
            for sub_name, sub_module in module.named_children():
                if isinstance(sub_module, nn.Linear):
                    setattr(module, sub_name, FSSALayer(sub_module))
    
    for i, layer in enumerate(model.emission.mlp):
        if isinstance(layer, nn.Linear):
            model.emission.mlp[i] = FSSALayer(layer)
            
    return model

class ImprovedHSLNModel(nn.Module):
    """Hybrid QLoRA (BERT) + FSSA (Hierarchical) Model"""
    def __init__(self, num_labels, model_config, class_weights=None):
        super().__init__()
        self.class_weights = class_weights
        
        # Load configuration overrides
        for key, value in model_config.items():
            setattr(Config, key, value)

        # Configure 4-bit quantization
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=Config.qlora_compute_dtype
        )

        # Load BERT with QLoRA
        self.bert = AutoModel.from_pretrained(
            Config.bert_model_name,
            quantization_config=bnb_config
        )
        self.bert = prepare_model_for_kbit_training(self.bert)
        
        # Apply QLoRA adapters
        lora_config = LoraConfig(
            r=Config.qlora_r,
            lora_alpha=Config.qlora_alpha,
            target_modules=Config.qlora_target_modules,
            lora_dropout=Config.qlora_dropout,
            bias="none",
            task_type="FEATURE_EXTRACTION"
        )
        self.bert = get_peft_model(self.bert, lora_config)

        # Hierarchical components
        self.sent_dropout = nn.Dropout(Config.dropout_rate)
        self.sent_layer_norm = nn.LayerNorm(self.bert.config.hidden_size)
        self.position_enc = PositionalEncoding(self.bert.config.hidden_size)
        self.context_encoder = TransformerContextLayer(
            d_model=self.bert.config.hidden_size
        )
        self.emission = EmissionLayer(
            input_size=self.bert.config.hidden_size,
            num_labels=num_labels
        )

        # Apply FSSA to hierarchical components
        self = apply_fssa_to_hierarchical(self)

        # CRF layer
        self.crf = FocalCRF(num_labels, gamma=Config.gamma)

    def forward(self, input_ids, attention_mask, labels=None):
        batch_size, num_sent, seq_len = input_ids.shape

        # Process each sentence
        flat_input_ids = input_ids.view(-1, seq_len)
        flat_mask = attention_mask.view(-1, seq_len)

        bert_out = self.bert(
            input_ids=flat_input_ids,
            attention_mask=flat_mask
        ).last_hidden_state

        # Sentence embeddings (CLS token)
        sent_emb = bert_out[:, 0, :]
        sent_emb = self.sent_layer_norm(sent_emb)
        sent_emb = self.sent_dropout(sent_emb)
        sent_emb = sent_emb.view(batch_size, num_sent, -1)

        # Context modeling
        sent_emb = self.position_enc(sent_emb)
        context_emb = self.context_encoder(sent_emb)

        # Emissions
        emissions = self.emission(context_emb)
        mask = attention_mask[:, :, 0] > 0  # Sentence-level mask

        if labels is not None:
            loss = self.crf(
                emissions,
                labels,
                mask=mask,
                class_weights=self.class_weights
            )
            return {"loss": loss, "emissions": emissions}
        return {"emissions": emissions}
# ====================== END MODEL ARCHITECTURE ======================

class HierarchicalDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        return {
            "input_ids": item["input_ids"],
            "attention_mask": item["attention_mask"],
            "labels": item["labels"]
        }

def collate_fn(batch):
    def ensure_tensor(x):
        return torch.tensor(x) if not isinstance(x, torch.Tensor) else x

    input_ids = torch.stack([ensure_tensor(item["input_ids"]) for item in batch])
    attention_mask = torch.stack([ensure_tensor(item["attention_mask"]) for item in batch])
    labels = torch.stack([ensure_tensor(item["labels"]) for item in batch])

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

def create_single_loader(dataset):
    return DataLoader(
        HierarchicalDataset(dataset),
        batch_size=Config.batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )

def evaluate_metrics(model, dataloader, device, label_list):
    """Comprehensive evaluation with padding masking"""
    try:
        model.eval()
        all_preds, all_labels = [], []
        total_time = 0
        n_docs = 0
        n_sentences = 0
        eval_start = time.time()

        with torch.no_grad():
            for batch in dataloader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                mask = attention_mask[:, :, 0] > 0

                start = time.time()
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                end = time.time()

                emissions = outputs["emissions"]
                preds = model.crf.decode(emissions, mask=mask)

                for i in range(len(labels)):
                    seq_len = mask[i].sum().item()
                    all_preds.extend(preds[i][:seq_len])
                    all_labels.extend(labels[i][:seq_len].cpu().numpy())

                total_time += (end - start)
                n_docs += input_ids.shape[0]
                n_sentences += mask.sum().item()

        eval_end = time.time()
        eval_time = eval_end - eval_start

        labels_for_report = list(range(len(label_list)))
        target_names = label_list

        report = classification_report(
            all_labels, all_preds,
            labels=labels_for_report,
            target_names=target_names,
            output_dict=True,
            zero_division=0
        )

        # Calculate additional metrics
        macro_f1 = report['macro avg']['f1-score']
        weighted_f1 = report['weighted avg']['f1-score']
        accuracy = accuracy_score(all_labels, all_preds)
        
        # Per-label metrics
        per_label_metrics = {}
        for label in label_list:
            per_label_metrics[label] = {
                'f1': report[label]['f1-score'],
                'precision': report[label]['precision'],
                'recall': report[label]['recall']
            }
        
        # Confusion matrix (non-normalized)
        cm = confusion_matrix(all_labels, all_preds, labels=labels_for_report)

        latency_doc = (total_time / n_docs) * 1000 if n_docs else 0
        latency_sent = (total_time / n_sentences) * 1000 if n_sentences else 0

        return {
            "macro_f1": macro_f1,
            "weighted_f1": weighted_f1,
            "accuracy": accuracy,
            "per_label_metrics": per_label_metrics,
            "confusion_matrix": cm,
            "latency_ms_per_doc": latency_doc,
            "latency_ms_per_sentence": latency_sent,
            "eval_time_seconds": eval_time,
            "num_samples": n_docs
        }

    except Exception as e:
        print(f"Evaluation failed: {str(e)}")
        raise

def plot_confusion_matrix(cm, classes, output_path):
    """Create and save confusion matrix plot"""
    plt.figure(figsize=(10, 8))
    ax = sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                     cbar=True, annot_kws={"size": 12, "weight": "bold"})
    
    # Set labels and titles with bold font
    ax.set_xlabel('Predicted Labels', fontsize=14, fontweight='bold')
    ax.set_ylabel('True Labels', fontsize=14, fontweight='bold')
    ax.set_title('Confusion Matrix', fontsize=16, fontweight='bold')
    
    # Set tick labels with bold font and smaller size for visibility
    ax.set_xticklabels(classes, rotation=45, ha='right', 
                       fontsize=10, fontweight='bold')
    ax.set_yticklabels(classes, rotation=0, 
                       fontsize=10, fontweight='bold')
    
    # Adjust colorbar font
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=10)
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()

def evaluate_test_set():
    """Loads a pre-trained model and evaluates it on the test set"""
    try:
        start_time = time.time()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"\n{'='*50}")
        print(f"STARTING TEST SET EVALUATION (QLora-FSSA)")
        print(f"Timestamp: {datetime.now().isoformat()}")
        print(f"Device: {device}")
        print(f"Model Repo: {Config.hf_repo_id}")
        print(f"{'='*50}\n")

        os.makedirs(Config.output_dir, exist_ok=True)

        # 1. Download model artifacts from Hugging Face Hub
        print("Downloading model artifacts from Hugging Face Hub...")
        config_path = hf_hub_download(
            repo_id=Config.hf_repo_id,
            filename="config.json"
        )
        model_path = hf_hub_download(
            repo_id=Config.hf_repo_id,
            filename="pytorch_model.bin"
        )

        # Load configuration
        with open(config_path, 'r') as f:
            saved_config = json.load(f)

        label2id = saved_config['label2id']
        id2label = saved_config['id2label']
        model_config = saved_config['model_config']

        # Convert id2label keys to integers
        id2label = {int(k): v for k, v in id2label.items()}

        # Create label list sorted by ID
        label_list = [id2label[i] for i in range(len(id2label))]

        # Update Config with model parameters
        for key, value in model_config.items():
            setattr(Config, key, value)

        print(f"Loaded configuration for model: {Config.bert_model_name}")
        print(f"Number of labels: {len(label_list)}")

        # 2. Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(Config.bert_model_name)

        # 3. Preprocess test dataset
        print("\nPreprocessing test dataset...")
        test_hier = preprocess_single_dataset(test_ds, label2id)
        print(f"Test examples after preprocessing: {len(test_hier)}")

        # 4. Tokenize test dataset
        print("Tokenizing test dataset...")
        test_tokenized = tokenize_single_dataset(test_hier, tokenizer, label2id)

        # 5. Create data loader
        test_loader = create_single_loader(test_tokenized)
        print(f"Test batches: {len(test_loader)}")

        # 6. Initialize model
        print("\nInitializing model...")
        model = ImprovedHSLNModel(
            num_labels=len(label2id),
            model_config=model_config,
            class_weights=None
        ).to(device)

        # 7. Load model weights
        model.load_state_dict(torch.load(model_path, map_location=device))
        print("Model weights loaded successfully")

        # 8. Evaluate on test set
        print("\nEvaluating on test set...")
        test_metrics = evaluate_metrics(model, test_loader, device, label_list)

        # 9. Save confusion matrix plot
        cm_path = os.path.join(Config.output_dir, "confusion_matrix.png")
        plot_confusion_matrix(test_metrics["confusion_matrix"], label_list, cm_path)
        
        # Convert confusion matrix to list for JSON serialization
        test_metrics["confusion_matrix"] = test_metrics["confusion_matrix"].tolist()

        # 10. Save test metrics
        metrics_path = os.path.join(Config.output_dir, "test_metrics.json")
        with open(metrics_path, 'w') as f:
            json.dump(test_metrics, f, indent=2)

        print(f"\n{'='*30} TEST RESULTS {'='*30}")
        print(f"Weighted F1: {test_metrics['weighted_f1']:.4f}")
        print(f"Macro F1:    {test_metrics['macro_f1']:.4f}")
        print(f"Accuracy:    {test_metrics['accuracy']:.4f}")
        print(f"Latency:     {test_metrics['latency_ms_per_doc']:.2f} ms/doc")
        print(f"Evaluation time: {test_metrics['eval_time_seconds']:.2f} seconds")
        print(f"Metrics saved to: {metrics_path}")
        print(f"Confusion matrix saved to: {cm_path}")

        print("\nPer-class Metrics:")
        for label, metrics in test_metrics['per_label_metrics'].items():
            print(f"  {label}:")
            print(f"    F1:       {metrics['f1']:.4f}")
            print(f"    Precision: {metrics['precision']:.4f}")
            print(f"    Recall:    {metrics['recall']:.4f}")

        total_time = time.time() - start_time
        print(f"\nEvaluation completed in {total_time:.2f} seconds")
        print(f"{'='*30} EVALUATION COMPLETE {'='*30}")

        return test_metrics

    except Exception as e:
        print(f"\n{'!'*50}")
        print("EVALUATION FAILED!")
        print(f"Error: {str(e)}")
        import traceback
        traceback.print_exc()
        with open(os.path.join(Config.output_dir, "error_log.txt"), "w") as f:
            f.write(f"Evaluation error at {datetime.now()}\n")
            f.write(str(e))
            f.write("\n\nTraceback:\n")
            f.write(traceback.format_exc())
        return None

if __name__ == "__main__":
    evaluate_test_set()

# LegalBERTHSLN-Role_Routed

In [None]:
# -*- coding: utf-8 -*-
"""eval_role_routed.py

Evaluates a pre-trained Role-Routed Adapter model on the test set and saves metrics to JSON.
"""

from huggingface_hub import hf_hub_download
import pandas as pd
from datasets import Dataset
import torch
import time
import os
import json
import numpy as np
from datetime import datetime
from sklearn.metrics import f1_score, accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoConfig, AutoModel
import torch.nn as nn
from torchcrf import CRF
import random

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True

class Config:
    # Role-Routed specific parameters
    num_roles = 13
    adapter_intermediate_size = 256
    context_hidden_size = 256
    max_num_sentences = 48
    max_length = 96
    dropout_rate = 0.4
    gamma = 2.0
    batch_size = 4
    output_dir = "./role_routed_evaluation_results"
    hf_repo_id = "Please enter your huggingface user id here/hierarchical-legal-model-role-routed"
    bert_model_name = 'nlpaueb/legal-bert-base-uncased'

# Load datasets
splits = {
    'test': 'data/test-00000-of-00001-2526ab833e27e0ee.parquet'
}

test_df = pd.read_parquet("hf://datasets/opennyaiorg/InRhetoricalRoles/" + splits["test"])
test_ds = Dataset.from_pandas(test_df)

def get_spans_and_labels(example):
    """Extract spans and labels from example"""
    spans = []
    labels = []
    if example.get('annotations') and len(example['annotations']) > 0:
        if example['annotations'][0].get('result'):
            for ann in example['annotations'][0]['result']:
                if ann.get('value') and ann['value'].get('text') and ann['value'].get('labels'):
                    spans.append(ann['value']['text'])
                    labels.append(ann['value']['labels'][0])
    return {'spans': spans, 'labels': labels}

def preprocess_single_dataset(dataset, label2id):
    """Preprocess dataset for evaluation"""
    dataset = dataset.map(get_spans_and_labels)
    dataset = dataset.filter(lambda x: len(x['spans']) > 0)
    dataset = dataset.map(lambda x: {'text': x['spans'], 'label': x['labels']})
    return dataset

def tokenize_single_dataset(dataset, tokenizer, label2id):
    """Tokenize dataset for hierarchical input"""
    def tokenize_document(example):
        sentences = example['text']
        labels = example['label']
        sentences = sentences[:Config.max_num_sentences]
        labels = labels[:Config.max_num_sentences]
        pad_len = Config.max_num_sentences - len(sentences)
        sentences += [""] * pad_len
        labels += [list(label2id.keys())[0]] * pad_len

        input_ids = []
        attention_mask = []
        for sent in sentences:
            encoded = tokenizer(
                sent,
                padding="max_length",
                truncation=True,
                max_length=Config.max_length,
                return_tensors="pt"
            )
            input_ids.append(encoded["input_ids"].squeeze(0))
            attention_mask.append(encoded["attention_mask"].squeeze(0))

        input_ids = torch.stack(input_ids)
        attention_mask = torch.stack(attention_mask)
        label_ids = torch.tensor([label2id[l] for l in labels])

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": label_ids
        }

    return dataset.map(tokenize_document)

class PositionalEncoding(nn.Module):
    """Positional embeddings with dropout"""
    def __init__(self, d_model, max_len=Config.max_num_sentences):
        super().__init__()
        self.dropout = nn.Dropout(Config.dropout_rate)
        self.position_emb = nn.Embedding(max_len, d_model)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device).unsqueeze(0)
        return self.dropout(x + self.position_emb(positions))

class RoleRoutedAdapter(nn.Module):
    """Role-Routed Adapter with 13 parallel adapters for rhetorical roles"""
    def __init__(self, config, role_count, intermediate_size=64):
        super().__init__()
        self.config = config
        self.role_count = role_count
        self.intermediate_size = intermediate_size

        # 13 parallel adapters for each role
        self.adapters = nn.ModuleList([
            nn.Sequential(
                nn.Linear(config.hidden_size, intermediate_size),
                nn.GELU(),
                nn.Linear(intermediate_size, config.hidden_size)
            ) for _ in range(role_count)
        ])

        # Router to coordinate adapters
        self.router = nn.Sequential(
            nn.Linear(config.hidden_size, 256),
            nn.Tanh(),
            nn.Linear(256, role_count)
        )

        # Initialize weights
        for adapter in self.adapters:
            for layer in adapter:
                if isinstance(layer, nn.Linear):
                    layer.weight.data.normal_(mean=0.0, std=0.02)
                    if layer.bias is not None:
                        layer.bias.data.zero_()

        for layer in self.router:
            if isinstance(layer, nn.Linear):
                layer.weight.data.normal_(mean=0.0, std=0.02)
                if layer.bias is not None:
                    layer.bias.data.zero_()

    def forward(self, x):
        batch_size, seq_len, hidden_size = x.shape
        avg_emb = x.mean(dim=1)
        router_logits = self.router(avg_emb)
        routing_weights = torch.softmax(router_logits, dim=-1)

        adapter_outputs = [adapter(x) for adapter in self.adapters]
        adapter_outputs = torch.stack(adapter_outputs, dim=1)
        weighted_output = torch.einsum('br,brsh->bsh', routing_weights, adapter_outputs)
        return x + weighted_output

class TransformerContextLayer(nn.Module):
    """Custom transformer layer with role-routed adapters"""
    def __init__(self, d_model, nhead=8, dim_feedforward=1024, dropout=0.2, num_roles=Config.num_roles):
        super().__init__()
        self.d_model = d_model

        # Multi-Head Attention
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=nhead,
            dropout=dropout,
            batch_first=True
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.adapter1 = RoleRoutedAdapter(
            config=AutoConfig.from_pretrained(Config.bert_model_name),
            role_count=num_roles,
            intermediate_size=Config.adapter_intermediate_size
        )

        # Feed Forward
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.adapter2 = RoleRoutedAdapter(
            config=AutoConfig.from_pretrained(Config.bert_model_name),
            role_count=num_roles,
            intermediate_size=Config.adapter_intermediate_size
        )

    def forward(self, x):
        attn_output, _ = self.multihead_attn(x, x, x)
        x = self.norm1(x + attn_output)
        x = self.adapter1(x)

        ff_output = self.feed_forward(x)
        x = self.norm2(x + ff_output)
        x = self.adapter2(x)
        return x

class EmissionLayer(nn.Module):
    """Enhanced emission layer with residual connection"""
    def __init__(self, input_size, num_labels, dropout=0.3):
        super().__init__()
        self.linear1 = nn.Linear(input_size, input_size*2)
        self.linear2 = nn.Linear(input_size*2, num_labels)
        self.dropout = nn.Dropout(dropout)
        self.gelu = nn.GELU()
        self.layer_norm = nn.LayerNorm(input_size*2)
        self.residual_proj = nn.Linear(input_size, num_labels)

    def forward(self, x):
        residual = x
        x = self.linear1(x)
        x = self.layer_norm(x)
        x = self.gelu(x)
        x = self.dropout(x)
        return self.linear2(x) + self.residual_proj(residual)

class FocalCRF(nn.Module):
    """Fixed CRF with focal loss"""
    def __init__(self, num_tags, gamma=Config.gamma):
        super().__init__()
        self.crf = CRF(num_tags, batch_first=True)
        self.gamma = gamma

    def forward(self, emissions, tags, mask, class_weights=None):
        log_likelihood = self.crf(emissions, tags, mask=mask, reduction='none')
        pt = torch.exp(log_likelihood)
        focal_loss = -((1 - pt) ** self.gamma) * log_likelihood

        if class_weights is not None:
            weights_per_tag = class_weights[tags]
            valid_counts = mask.sum(dim=1)
            weights_per_sequence = weights_per_tag.sum(dim=1) / valid_counts
            focal_loss = focal_loss * weights_per_sequence

        return focal_loss.mean()

    def decode(self, emissions, mask):
        return self.crf.decode(emissions, mask=mask)

class ImprovedHSLNModel(nn.Module):
    """HSLN model with Role-Routed Adapters"""
    def __init__(self, num_labels, model_config, class_weights=None):
        super().__init__()
        self.class_weights = class_weights

        # Update config from model_config
        for key, value in model_config.items():
            setattr(Config, key, value)

        # Load base Legal-BERT
        self.bert = AutoModel.from_pretrained(Config.bert_model_name)

        # Feature extraction layers
        self.sent_dropout = nn.Dropout(Config.dropout_rate)
        self.sent_layer_norm = nn.LayerNorm(self.bert.config.hidden_size)
        self.sent_projection = nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size)

        # Context encoding
        self.position_enc = PositionalEncoding(self.bert.config.hidden_size)
        self.context_encoder = TransformerContextLayer(d_model=self.bert.config.hidden_size)

        # Emission and CRF
        self.emission = EmissionLayer(input_size=self.bert.config.hidden_size, num_labels=num_labels)
        self.crf = FocalCRF(num_labels, gamma=Config.gamma)

    def forward(self, input_ids, attention_mask, labels=None):
        batch_size, num_sent, seq_len = input_ids.shape
        flat_input_ids = input_ids.view(-1, seq_len)
        flat_mask = attention_mask.view(-1, seq_len)

        bert_out = self.bert(input_ids=flat_input_ids, attention_mask=flat_mask).last_hidden_state
        sent_emb = bert_out[:, 0, :]
        sent_emb = self.sent_projection(sent_emb)
        sent_emb = self.sent_layer_norm(sent_emb)
        sent_emb = self.sent_dropout(sent_emb)
        sent_emb = sent_emb.view(batch_size, num_sent, -1)

        sent_emb = self.position_enc(sent_emb)
        context_emb = self.context_encoder(sent_emb)

        emissions = self.emission(context_emb)
        mask = attention_mask[:, :, 0] > 0

        if labels is not None:
            loss = self.crf(emissions, labels, mask=mask, class_weights=self.class_weights)
            return {"loss": loss, "emissions": emissions}
        return {"emissions": emissions}

def collate_fn(batch):
    """Fixed collate function to handle tensor conversion"""
    input_ids = []
    attention_mask = []
    labels = []
    
    for item in batch:
        # Convert to tensors if they're not already
        if not isinstance(item["input_ids"], torch.Tensor):
            input_ids.append(torch.tensor(item["input_ids"], dtype=torch.long))
        else:
            input_ids.append(item["input_ids"])
            
        if not isinstance(item["attention_mask"], torch.Tensor):
            attention_mask.append(torch.tensor(item["attention_mask"], dtype=torch.long))
        else:
            attention_mask.append(item["attention_mask"])
            
        if not isinstance(item["labels"], torch.Tensor):
            labels.append(torch.tensor(item["labels"], dtype=torch.long))
        else:
            labels.append(item["labels"])
    
    return {
        "input_ids": torch.stack(input_ids),
        "attention_mask": torch.stack(attention_mask),
        "labels": torch.stack(labels)
    }

def create_single_loader(dataset):
    return DataLoader(
        dataset,
        batch_size=Config.batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )

def plot_confusion_matrix(y_true, y_pred, labels, output_path, title):
    """Plot non-normalized confusion matrix with bold labels"""
    cm = confusion_matrix(y_true, y_pred, labels=np.arange(len(labels)))
    
    plt.figure(figsize=(10, 8))
    ax = sns.heatmap(
        cm, 
        annot=True, 
        fmt="d", 
        cmap="Blues", 
        xticklabels=labels, 
        yticklabels=labels,
        cbar_kws={'label': 'Count'}
    )
    
    # Make labels bold and rotate
    plt.xticks(fontsize=10, fontweight='bold', rotation=45, ha='right')
    plt.yticks(fontsize=10, fontweight='bold')
    
    # Bold axis labels
    plt.xlabel('Predicted Labels', fontsize=12, fontweight='bold')
    plt.ylabel('True Labels', fontsize=12, fontweight='bold')
    plt.title(title, fontsize=14, fontweight='bold')
    
    # Adjust layout to fit labels
    plt.tight_layout()
    plt.savefig(output_path, dpi=300)
    plt.close()

def evaluate_metrics(model, dataloader, device, label_list, output_dir):
    """Comprehensive evaluation with padding masking"""
    try:
        model.eval()
        all_preds, all_labels = [], []
        total_time = 0
        n_docs = 0
        n_sentences = 0
        eval_start = time.time()

        with torch.no_grad():
            for batch in dataloader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                mask = attention_mask[:, :, 0] > 0

                start = time.time()
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                end = time.time()

                emissions = outputs["emissions"]
                preds = model.crf.decode(emissions, mask=mask)

                for i in range(len(labels)):
                    seq_len = mask[i].sum().item()
                    all_preds.extend(preds[i][:seq_len])
                    all_labels.extend(labels[i][:seq_len].cpu().numpy())

                total_time += (end - start)
                n_docs += input_ids.shape[0]
                n_sentences += mask.sum().item()

        eval_end = time.time()
        eval_time = eval_end - eval_start

        labels_for_report = list(range(len(label_list)))
        target_names = label_list

        report = classification_report(
            all_labels, all_preds,
            labels=labels_for_report,
            target_names=target_names,
            output_dict=True,
            zero_division=0
        )

        # Generate confusion matrix
        cm_path = os.path.join(output_dir, "confusion_matrix.png")
        plot_confusion_matrix(
            all_labels, 
            all_preds, 
            label_list, 
            cm_path,
            "Role-Routed Model Confusion Matrix"
        )

        # Extract metrics
        macro_f1 = report['macro avg']['f1-score']
        weighted_f1 = report['weighted avg']['f1-score']
        accuracy = accuracy_score(all_labels, all_preds)
        
        per_label_metrics = {}
        for label in label_list:
            per_label_metrics[label] = {
                'f1': report[label]['f1-score'],
                'precision': report[label]['precision'],
                'recall': report[label]['recall'],
                'support': report[label]['support']
            }

        latency_doc = (total_time / n_docs) * 1000 if n_docs else 0
        latency_sent = (total_time / n_sentences) * 1000 if n_sentences else 0

        return {
            "macro_f1": macro_f1,
            "weighted_f1": weighted_f1,
            "accuracy": accuracy,
            "per_label_metrics": per_label_metrics,
            "latency_ms_per_doc": latency_doc,
            "latency_ms_per_sentence": latency_sent,
            "eval_time_seconds": eval_time,
            "num_samples": n_docs,
            "confusion_matrix_path": cm_path
        }

    except Exception as e:
        print(f"Evaluation failed: {str(e)}")
        raise

def evaluate_test_set():
    """Loads a pre-trained model and evaluates it on the test set"""
    try:
        start_time = time.time()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"\n{'='*50}")
        print(f"STARTING ROLE-ROUTED MODEL EVALUATION")
        print(f"Timestamp: {datetime.now().isoformat()}")
        print(f"Device: {device}")
        print(f"Model Repo: {Config.hf_repo_id}")
        print(f"Seed: {SEED}")
        print(f"{'='*50}\n")

        os.makedirs(Config.output_dir, exist_ok=True)

        # 1. Download model artifacts from Hugging Face Hub
        print("Downloading model artifacts from Hugging Face Hub...")
        config_path = hf_hub_download(
            repo_id=Config.hf_repo_id,
            filename="config.json"
        )
        model_path = hf_hub_download(
            repo_id=Config.hf_repo_id,
            filename="pytorch_model.bin"
        )

        # Load configuration
        with open(config_path, 'r') as f:
            saved_config = json.load(f)

        label2id = saved_config['label2id']
        id2label = saved_config['id2label']
        model_config = saved_config['model_config']

        # Convert id2label keys to integers
        id2label = {int(k): v for k, v in id2label.items()}

        # Create label list sorted by ID
        label_list = [id2label[i] for i in range(len(id2label))]

        # Update Config with model parameters
        for key, value in model_config.items():
            setattr(Config, key, value)

        print(f"Loaded configuration for model: {Config.bert_model_name}")
        print(f"Number of labels: {len(label_list)}")
        print(f"Role-Routed Parameters: num_roles={Config.num_roles}, adapter_size={Config.adapter_intermediate_size}")

        # 2. Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(Config.bert_model_name)

        # 3. Preprocess test dataset
        print("\nPreprocessing test dataset...")
        test_hier = preprocess_single_dataset(test_ds, label2id)
        print(f"Test examples after preprocessing: {len(test_hier)}")

        # 4. Tokenize test dataset
        print("Tokenizing test dataset...")
        test_tokenized = tokenize_single_dataset(test_hier, tokenizer, label2id)

        # 5. Create data loader
        test_loader = create_single_loader(test_tokenized)
        print(f"Test batches: {len(test_loader)}")

        # 6. Initialize model
        print("\nInitializing Role-Routed model...")
        model = ImprovedHSLNModel(
            num_labels=len(label2id),
            model_config=model_config,
            class_weights=None
        ).to(device)

        # 7. Load model weights
        model.load_state_dict(torch.load(model_path, map_location=device))
        print("Model weights loaded successfully")

        # 8. Evaluate on test set
        print("\nEvaluating on test set...")
        test_metrics = evaluate_metrics(model, test_loader, device, label_list, Config.output_dir)

        # 9. Save test metrics
        metrics_path = os.path.join(Config.output_dir, "test_metrics.json")
        with open(metrics_path, 'w') as f:
            json.dump(test_metrics, f, indent=2)

        print(f"\n{'='*30} TEST RESULTS {'='*30}")
        print(f"Weighted F1: {test_metrics['weighted_f1']:.4f}")
        print(f"Macro F1:    {test_metrics['macro_f1']:.4f}")
        print(f"Accuracy:    {test_metrics['accuracy']:.4f}")
        print(f"Latency:     {test_metrics['latency_ms_per_doc']:.2f} ms/doc")
        print(f"Evaluation time: {test_metrics['eval_time_seconds']:.2f} seconds")
        print(f"Metrics saved to: {metrics_path}")
        print(f"Confusion matrix saved to: {test_metrics['confusion_matrix_path']}")

        print("\nPer-class Metrics:")
        for label, metrics in test_metrics['per_label_metrics'].items():
            print(f"  {label}:")
            print(f"    F1:       {metrics['f1']:.4f}")
            print(f"    Precision:{metrics['precision']:.4f}")
            print(f"    Recall:   {metrics['recall']:.4f}")
            print(f"    Support:  {metrics['support']}")

        total_time = time.time() - start_time
        print(f"\nEvaluation completed in {total_time:.2f} seconds")
        print(f"{'='*30} EVALUATION COMPLETE {'='*30}")

        return test_metrics

    except Exception as e:
        print(f"\n{'!'*50}")
        print("EVALUATION FAILED!")
        print(f"Error: {str(e)}")
        with open(os.path.join(Config.output_dir, "error_log.txt"), "w") as f:
            f.write(f"Evaluation error at {datetime.now()}\n")
            f.write(str(e))
        return None

if __name__ == "__main__":
    evaluate_test_set()

# LegalBERTHSLN-Baseline-Full_Finetuning

In [None]:
# -*- coding: utf-8 -*-
"""test4_evaluate_test_set

Evaluates a pre-trained model on the test set and saves metrics to JSON.
"""

from huggingface_hub import hf_hub_download, snapshot_download
import pandas as pd
from datasets import Dataset
import torch
import time
import os
import json
import psutil
import numpy as np
from datetime import datetime
from sklearn.metrics import f1_score, accuracy_score, classification_report, confusion_matrix
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
import torch.nn as nn
from transformers import AutoModel
from torchcrf import CRF
import matplotlib.pyplot as plt
import seaborn as sns
import random

# Set seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class Config:
    # These will be updated from the model's config
    bert_model_name = 'nlpaueb/legal-bert-base-uncased'
    lstm_hidden_size = 200
    context_hidden_size = 200
    max_num_sentences = 32
    max_length = 128
    dropout_rate = 0.4
    gamma = 2.0
    batch_size = 4
    output_dir = "./test_evaluation_results"
    hf_repo_id = "Please enter your huggingface user id here/hierarchical-legal-model-improved-augmentation"

import pandas as pd
from datasets import load_dataset, Dataset
from huggingface_hub import notebook_login

# Load datasets
splits = {
    'test': 'data/test-00000-of-00001-2526ab833e27e0ee.parquet'
}

test_df = pd.read_parquet("hf://datasets/opennyaiorg/InRhetoricalRoles/" + splits["test"])
test_ds = Dataset.from_pandas(test_df)

def get_spans_and_labels(example):
    """Extract spans and labels from example"""
    spans = []
    labels = []
    if example.get('annotations') and len(example['annotations']) > 0:
        if example['annotations'][0].get('result'):
            for ann in example['annotations'][0]['result']:
                if ann.get('value') and ann['value'].get('text') and ann['value'].get('labels'):
                    spans.append(ann['value']['text'])
                    labels.append(ann['value']['labels'][0])
    return {'spans': spans, 'labels': labels}

def preprocess_single_dataset(dataset, label2id):
    """Preprocess dataset for evaluation"""
    dataset = dataset.map(get_spans_and_labels)
    dataset = dataset.filter(lambda x: len(x['spans']) > 0)
    dataset = dataset.map(lambda x: {'text': x['spans'], 'label': x['labels']})
    return dataset

def tokenize_single_dataset(dataset, tokenizer, label2id):
    """Tokenize dataset for hierarchical input"""
    def tokenize_document(example):
        sentences = example['text']
        labels = example['label']
        sentences = sentences[:Config.max_num_sentences]
        labels = labels[:Config.max_num_sentences]
        pad_len = Config.max_num_sentences - len(sentences)
        sentences += [""] * pad_len
        labels += [list(label2id.keys())[0]] * pad_len

        input_ids = []
        attention_mask = []
        for sent in sentences:
            encoded = tokenizer(
                sent,
                padding="max_length",
                truncation=True,
                max_length=Config.max_length,
                return_tensors="pt"
            )
            input_ids.append(encoded["input_ids"].squeeze(0))
            attention_mask.append(encoded["attention_mask"].squeeze(0))

        input_ids = torch.stack(input_ids)
        attention_mask = torch.stack(attention_mask)
        label_ids = torch.tensor([label2id[l] for l in labels])

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": label_ids
        }
    
    return dataset.map(tokenize_document)

class PositionalEncoding(nn.Module):
    """Positional embeddings for sentence order"""
    def __init__(self, d_model, max_len):
        super().__init__()
        self.position_emb = nn.Embedding(max_len, d_model)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device).unsqueeze(0)
        return x + self.position_emb(positions)

class TransformerContextLayer(nn.Module):
    """Transformer-based context modeling"""
    def __init__(self, d_model, nhead=4, dim_feedforward=512, dropout=0.1):
        super().__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)

    def forward(self, x):
        return self.transformer_encoder(x)

class EmissionLayer(nn.Module):
    """Enhanced emission layer with MLP"""
    def __init__(self, input_size, num_labels, dropout=0.2):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_size, input_size*2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(input_size*2, num_labels)
        )

    def forward(self, x):
        return self.mlp(x)

class FocalCRF(nn.Module):
    """CRF with focal loss for class imbalance"""
    def __init__(self, num_tags, gamma):
        super().__init__()
        self.crf = CRF(num_tags, batch_first=True)
        self.gamma = gamma

    def forward(self, emissions, tags, mask, class_weights=None):
        # Compute standard CRF loss
        log_likelihood = self.crf(emissions, tags, mask=mask, reduction='none')

        # Apply focal loss transformation
        pt = torch.exp(log_likelihood)
        focal_loss = -((1 - pt) ** self.gamma) * log_likelihood

        # Apply class weights if provided
        if class_weights is not None:
            weights_per_tag = class_weights[tags]  # (batch_size, seq_len)
            valid_counts = mask.sum(dim=1)  # (batch_size,)
            weights_per_sequence = weights_per_tag.sum(dim=1) / valid_counts
            focal_loss = focal_loss * weights_per_sequence

        return focal_loss.mean()

    def decode(self, emissions, mask):
        return self.crf.decode(emissions, mask=mask)

class ImprovedHSLNModel(nn.Module):
    """Enhanced hierarchical model with class imbalance handling"""
    def __init__(self, num_labels, model_config, class_weights=None):
        super().__init__()
        self.class_weights = class_weights

        # Update config from model_config
        for key, value in model_config.items():
            setattr(Config, key, value)
        
        # Sentence encoding
        self.bert = AutoModel.from_pretrained(Config.bert_model_name)
        self.sent_dropout = nn.Dropout(Config.dropout_rate)
        self.sent_layer_norm = nn.LayerNorm(self.bert.config.hidden_size)

        # Context encoding
        self.position_enc = PositionalEncoding(
            self.bert.config.hidden_size, 
            max_len=Config.max_num_sentences
        )
        self.context_encoder = TransformerContextLayer(
            d_model=self.bert.config.hidden_size
        )

        # Emission layer
        self.emission = EmissionLayer(
            input_size=self.bert.config.hidden_size,
            num_labels=num_labels
        )

        # CRF layer with focal loss
        self.crf = FocalCRF(num_labels, gamma=Config.gamma)

    def forward(self, input_ids, attention_mask, labels=None):
        batch_size, num_sent, seq_len = input_ids.shape

        # Process each sentence
        flat_input_ids = input_ids.view(-1, seq_len)
        flat_mask = attention_mask.view(-1, seq_len)

        bert_out = self.bert(
            input_ids=flat_input_ids,
            attention_mask=flat_mask
        ).last_hidden_state

        # Sentence embeddings (CLS token)
        sent_emb = bert_out[:, 0, :]
        sent_emb = self.sent_layer_norm(sent_emb)
        sent_emb = self.sent_dropout(sent_emb)
        sent_emb = sent_emb.view(batch_size, num_sent, -1)

        # Context modeling
        sent_emb = self.position_enc(sent_emb)
        context_emb = self.context_encoder(sent_emb)

        # Emissions
        emissions = self.emission(context_emb)
        mask = attention_mask[:, :, 0] > 0  # Sentence-level mask

        if labels is not None:
            loss = self.crf(
                emissions,
                labels,
                mask=mask,
                class_weights=self.class_weights
            )
            return {"loss": loss, "emissions": emissions}
        return {"emissions": emissions}

class HierarchicalDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        return {
            "input_ids": item["input_ids"],
            "attention_mask": item["attention_mask"],
            "labels": item["labels"]
        }

def collate_fn(batch):
    def ensure_tensor(x):
        return torch.tensor(x) if not isinstance(x, torch.Tensor) else x

    input_ids = torch.stack([ensure_tensor(item["input_ids"]) for item in batch])
    attention_mask = torch.stack([ensure_tensor(item["attention_mask"]) for item in batch])
    labels = torch.stack([ensure_tensor(item["labels"]) for item in batch])

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

def create_single_loader(dataset):
    return DataLoader(
        HierarchicalDataset(dataset),
        batch_size=Config.batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )

def plot_confusion_matrix(cm, labels, output_dir):
    """Create and save a confusion matrix with specified formatting"""
    plt.figure(figsize=(10, 8))
    ax = sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True,
                     annot_kws={"size": 12, "weight": "bold"})
    
    # Set axis labels with bold font
    ax.set_xlabel('Predicted labels', fontsize=14, fontweight='bold')
    ax.set_ylabel('True labels', fontsize=14, fontweight='bold')
    
    # Set tick labels with bold font
    ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=12, fontweight='bold')
    ax.set_yticklabels(labels, rotation=0, fontsize=12, fontweight='bold')
    
    plt.title('Confusion Matrix', fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    # Save the figure
    plt.savefig(os.path.join(output_dir, "confusion_matrix.png"), dpi=300, bbox_inches='tight')
    plt.close()

def evaluate_metrics(model, dataloader, device, label_list, output_dir):
    """Comprehensive evaluation with padding masking"""
    try:
        model.eval()
        all_preds, all_labels = [], []
        total_time = 0
        n_docs = 0
        n_sentences = 0
        eval_start = time.time()

        with torch.no_grad():
            for batch in dataloader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                mask = attention_mask[:, :, 0] > 0

                start = time.time()
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                end = time.time()

                emissions = outputs["emissions"]
                preds = model.crf.decode(emissions, mask=mask)

                for i in range(len(labels)):
                    seq_len = mask[i].sum().item()
                    all_preds.extend(preds[i][:seq_len])
                    all_labels.extend(labels[i][:seq_len].cpu().numpy())

                total_time += (end - start)
                n_docs += input_ids.shape[0]
                n_sentences += mask.sum().item()

        eval_end = time.time()
        eval_time = eval_end - eval_start

        labels_for_report = list(range(len(label_list)))
        target_names = label_list

        report = classification_report(
            all_labels, all_preds,
            labels=labels_for_report,
            target_names=target_names,
            output_dict=True,
            zero_division=0
        )

        # Generate confusion matrix
        cm = confusion_matrix(all_labels, all_preds, labels=labels_for_report)
        plot_confusion_matrix(cm, label_list, output_dir)

        # Extract metrics
        macro_f1 = report['macro avg']['f1-score']
        weighted_f1 = report['weighted avg']['f1-score']
        accuracy = accuracy_score(all_labels, all_preds)
        macro_precision = report['macro avg']['precision']
        macro_recall = report['macro avg']['recall']
        weighted_precision = report['weighted avg']['precision']
        weighted_recall = report['weighted avg']['recall']
        
        per_label_f1 = {}
        per_label_precision = {}
        per_label_recall = {}
        for label in label_list:
            per_label_f1[label] = report[label]['f1-score']
            per_label_precision[label] = report[label]['precision']
            per_label_recall[label] = report[label]['recall']

        latency_doc = (total_time / n_docs) * 1000 if n_docs else 0
        latency_sent = (total_time / n_sentences) * 1000 if n_sentences else 0

        return {
            "macro_f1": macro_f1,
            "weighted_f1": weighted_f1,
            "accuracy": accuracy,
            "macro_precision": macro_precision,
            "macro_recall": macro_recall,
            "weighted_precision": weighted_precision,
            "weighted_recall": weighted_recall,
            "per_label_f1": per_label_f1,
            "per_label_precision": per_label_precision,
            "per_label_recall": per_label_recall,
            "latency_ms_per_doc": latency_doc,
            "latency_ms_per_sentence": latency_sent,
            "eval_time_seconds": eval_time,
            "num_samples": n_docs
        }

    except Exception as e:
        print(f"Evaluation failed: {str(e)}")
        raise

def evaluate_test_set():
    """Loads a pre-trained model and evaluates it on the test set"""
    try:
        start_time = time.time()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"\n{'='*50}")
        print(f"STARTING TEST SET EVALUATION")
        print(f"Timestamp: {datetime.now().isoformat()}")
        print(f"Device: {device}")
        print(f"Model Repo: {Config.hf_repo_id}")
        print(f"{'='*50}\n")

        os.makedirs(Config.output_dir, exist_ok=True)

        # 1. Download model artifacts from Hugging Face Hub
        print("Downloading model artifacts from Hugging Face Hub...")
        config_path = hf_hub_download(
            repo_id=Config.hf_repo_id,
            filename="config.json"
        )
        model_path = hf_hub_download(
            repo_id=Config.hf_repo_id,
            filename="pytorch_model.bin"
        )
        
        # Load configuration
        with open(config_path, 'r') as f:
            saved_config = json.load(f)
        
        label2id = saved_config['label2id']
        id2label = saved_config['id2label']
        model_config = saved_config['model_config']
        
        # Convert id2label keys to integers
        id2label = {int(k): v for k, v in id2label.items()}
        
        # Create label list sorted by ID
        label_list = [id2label[i] for i in range(len(id2label))]
        
        # Update Config with model parameters
        for key, value in model_config.items():
            setattr(Config, key, value)
        
        print(f"Loaded configuration for model: {Config.bert_model_name}")
        print(f"Number of labels: {len(label_list)}")
        
        # 2. Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(Config.bert_model_name)
        
        # 3. Preprocess test dataset
        print("\nPreprocessing test dataset...")
        test_hier = preprocess_single_dataset(test_ds, label2id)
        print(f"Test examples after preprocessing: {len(test_hier)}")
        
        # 4. Tokenize test dataset
        print("Tokenizing test dataset...")
        test_tokenized = tokenize_single_dataset(test_hier, tokenizer, label2id)
        
        # 5. Create data loader
        test_loader = create_single_loader(test_tokenized)
        print(f"Test batches: {len(test_loader)}")
        
        # 6. Initialize model
        print("\nInitializing model...")
        model = ImprovedHSLNModel(
            num_labels=len(label2id),
            model_config=model_config,
            class_weights=None
        ).to(device)
        
        # 7. Load model weights
        model.load_state_dict(torch.load(model_path, map_location=device))
        print("Model weights loaded successfully")
        
        # 8. Evaluate on test set
        print("\nEvaluating on test set...")
        test_metrics = evaluate_metrics(model, test_loader, device, label_list, Config.output_dir)
        
        # 9. Save test metrics
        metrics_path = os.path.join(Config.output_dir, "test_metrics.json")
        with open(metrics_path, 'w') as f:
            json.dump(test_metrics, f, indent=2)
            
        print(f"\n{'='*30} TEST RESULTS {'='*30}")
        print(f"Weighted F1: {test_metrics['weighted_f1']:.4f}")
        print(f"Macro F1:    {test_metrics['macro_f1']:.4f}")
        print(f"Accuracy:    {test_metrics['accuracy']:.4f}")
        print(f"Macro Precision: {test_metrics['macro_precision']:.4f}")
        print(f"Macro Recall:    {test_metrics['macro_recall']:.4f}")
        print(f"Latency:     {test_metrics['latency_ms_per_doc']:.2f} ms/doc")
        print(f"Evaluation time: {test_metrics['eval_time_seconds']:.2f} seconds")
        print(f"Metrics saved to: {metrics_path}")
        
        print("\nPer-class Metrics:")
        for label in label_list:
            print(f"  {label}:")
            print(f"    F1:       {test_metrics['per_label_f1'][label]:.4f}")
            print(f"    Precision: {test_metrics['per_label_precision'][label]:.4f}")
            print(f"    Recall:    {test_metrics['per_label_recall'][label]:.4f}")
            
        total_time = time.time() - start_time
        print(f"\nEvaluation completed in {total_time:.2f} seconds")
        print(f"{'='*30} EVALUATION COMPLETE {'='*30}")
        
        return test_metrics

    except Exception as e:
        print(f"\n{'!'*50}")
        print("EVALUATION FAILED!")
        print(f"Error: {str(e)}")
        with open(os.path.join(Config.output_dir, "error_log.txt"), "w") as f:
            f.write(f"Evaluation error at {datetime.now()}\n")
            f.write(str(e))
        return None

if __name__ == "__main__":
    evaluate_test_set()