In [None]:
from datasets import load_dataset
from huggingface_hub import notebook_login, create_repo, upload_folder
import pandas as pd
from datasets import Dataset
import torch
import time
import os
import json
import psutil
import numpy as np
from datetime import datetime
from sklearn.metrics import f1_score, accuracy_score, classification_report
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch.nn as nn
from torchcrf import CRF
import random
from collections import Counter
from datasets import concatenate_datasets
import nltk
from nltk.corpus import wordnet
from transformers import get_linear_schedule_with_warmup

# Download required NLTK data
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)
nltk.download('punkt', quiet=True)

In [None]:
class Config:
    # Optimized hyperparameters
    bert_model_name = 'nlpaueb/legal-bert-base-uncased'
    context_hidden_size = 256
    max_num_sentences = 48
    max_length = 96
    dropout_rate = 0.4
    gamma = 2.0
    weight_decay = 1e-5

    # Training parameters
    epochs = 8
    batch_size = 4
    learning_rate = 5e-5
    warmup_ratio = 0.1
    patience = 3
    gradient_accumulation_steps = 2

    # Role-Routed Adapter parameters
    adapter_intermediate_size = 256
    num_roles = 13  # 13 rhetorical roles (e.g., FAC, ARG_RESPONDENT, etc.)

    # Paths and repo info
    hf_repo_id = "Please enter your huggingface user id here/hierarchical-legal-model-role-routed"
    output_dir = "./role_routed_model"
    save_checkpoint = "best_model"


In [None]:
# Login to Hugging Face Hub
notebook_login()

In [None]:
import pandas as pd
from datasets import load_dataset, Dataset
from huggingface_hub import notebook_login

In [None]:
# Load datasets
splits = {
    'train': 'data/train-00000-of-00001-bb0092e0d8549337.parquet',
    'dev': 'data/dev-00000-of-00001-af55705c75623915.parquet',
    'test': 'data/test-00000-of-00001-2526ab833e27e0ee.parquet'
}

train_df = pd.read_parquet("hf://datasets/opennyaiorg/InRhetoricalRoles/" + splits["train"])
dev_df = pd.read_parquet("hf://datasets/opennyaiorg/InRhetoricalRoles/" + splits["dev"])
test_df = pd.read_parquet("hf://datasets/opennyaiorg/InRhetoricalRoles/" + splits["test"])

# Convert to Hugging Face Datasets
train_ds = Dataset.from_pandas(train_df)
dev_ds = Dataset.from_pandas(dev_df)
test_ds = Dataset.from_pandas(test_df)

In [None]:
def get_synonyms(word):
    """Get synonyms for data augmentation"""
    synonyms = set()
    for syn in wordnet.synsets(word):
        for lemma in syn.lemmas():
            synonym = lemma.name().replace("_", " ").lower()
            if synonym != word and len(synonym) > 1:
                synonyms.add(synonym)
    return list(synonyms) if synonyms else [word]

def augment_sentence(sent):
    """Enhanced data augmentation with synonym replacement"""
    if len(sent.strip()) == 0:
        return sent

    words = nltk.word_tokenize(sent)
    if len(words) < 2:
        return sent

    choice = random.choices([1, 2, 3, 4], weights=[0.4, 0.2, 0.3, 0.1])[0]

    if choice == 1:  # Synonym replacement
        idx = random.randint(0, len(words)-1)
        synonyms = get_synonyms(words[idx])
        if synonyms and len(synonyms) > 0:
            words[idx] = random.choice(synonyms)

    elif choice == 2 and len(words) >= 4:  # Random deletion
        del_idx = random.randint(0, len(words)-1)
        del words[del_idx]

    elif choice == 3 and len(words) >= 3:  # Word swap
        i, j = random.sample(range(len(words)), 2)
        words[i], words[j] = words[j], words[i]

    elif choice == 4:  # Random insertion
        idx = random.randint(0, len(words)-1)
        synonyms = get_synonyms(words[idx])
        if synonyms and len(synonyms) > 0:
            words.insert(idx, random.choice(synonyms))

    return " ".join(words)

def prepare_hierarchical_datasets(train_ds, dev_ds, test_ds):
    """Optimized dataset preparation with efficient sampling"""
    print("Preprocessing datasets with efficient sampling...")

    def get_spans_and_labels(example):
        spans = []
        labels = []
        if example.get('annotations') and len(example['annotations']) > 0:
            if example['annotations'][0].get('result'):
                for ann in example['annotations'][0]['result']:
                    if ann.get('value') and ann['value'].get('text') and ann['value'].get('labels'):
                        spans.append(ann['value']['text'])
                        labels.append(ann['value']['labels'][0])
        return {'spans': spans, 'labels': labels}

    train_ds = train_ds.map(get_spans_and_labels)
    dev_ds = dev_ds.map(get_spans_and_labels)
    test_ds = test_ds.map(get_spans_and_labels)

    train_ds = train_ds.filter(lambda x: len(x['spans']) > 0)
    dev_ds = dev_ds.filter(lambda x: len(x['spans']) > 0)
    test_ds = test_ds.filter(lambda x: len(x['spans']) > 0)

    def prepare_for_hierarchical(example):
        return {'text': example['spans'], 'label': example['labels']}

    train_hier = train_ds.map(prepare_for_hierarchical)
    dev_hier = dev_ds.map(prepare_for_hierarchical)
    test_hier = test_ds.map(prepare_for_hierarchical)

    all_labels = set()
    for example in train_hier:
        all_labels.update(example['label'])
    label_list = sorted(list(all_labels))
    label2id = {l: i for i, l in enumerate(label_list)}
    id2label = {i: l for i, l in enumerate(label_list)}

    print(f"Identified {len(label_list)} labels: {label_list}")

    def augment_dataset(dataset, label2id):
        label_counts = Counter()
        for example in dataset:
            label_counts.update(example['label'])

        rare_classes = [label for label, count in label_counts.items() if count < 10]
        print(f"Rare classes (<10 samples): {rare_classes}")

        augmented_examples = []
        for example in dataset:
            labels = example['label']
            copies = 1
            if any(label in rare_classes for label in labels):
                copies = 3
            for _ in range(copies):
                augmented_text = [
                    augment_sentence(sent) if random.random() < 0.5 and sent.strip() else sent
                    for sent in example['text']
                ]
                augmented_examples.append({
                    'text': augmented_text,
                    'label': labels.copy()
                })

        print(f"Added {len(augmented_examples)} augmented examples")
        return concatenate_datasets([dataset, Dataset.from_list(augmented_examples)])

    train_hier = augment_dataset(train_hier, label2id)

    label_counts = Counter()
    for example in train_hier:
        label_counts.update(example['label'])

    max_count = min(1000, max(label_counts.values()))
    balanced_examples = []
    for label in label_list:
        class_examples = [ex for ex in train_hier if label in ex['label']]
        current_count = label_counts[label]
        needed = max(0, max_count - current_count)
        if needed > 0 and class_examples:
            balanced_examples.extend(class_examples)
            duplicates = min(needed, len(class_examples))
            balanced_examples.extend(random.choices(class_examples, k=duplicates))
        else:
            balanced_examples.extend(class_examples)

    train_hier = Dataset.from_list(balanced_examples)
    print(f"Created balanced dataset with {len(train_hier)} examples")

    return train_hier, dev_hier, test_hier, label2id, id2label, label_list

class RoleRoutedAdapter(nn.Module):
    """Role-Routed Adapter with 13 parallel adapters for rhetorical roles"""
    def __init__(self, config, role_count, intermediate_size=64):
        super().__init__()
        self.config = config
        self.role_count = role_count
        self.intermediate_size = intermediate_size

        # 13 parallel adapters for each role
        self.adapters = nn.ModuleList([
            nn.Sequential(
                nn.Linear(config.hidden_size, intermediate_size),
                nn.GELU(),
                nn.Linear(intermediate_size, config.hidden_size)
            ) for _ in range(role_count)
        ])

        # Router to coordinate adapters
        self.router = nn.Sequential(
            nn.Linear(config.hidden_size, 256),
            nn.Tanh(),
            nn.Linear(256, role_count)
        )

        for adapter in self.adapters:
            for layer in adapter:
                if isinstance(layer, nn.Linear):
                    layer.weight.data.normal_(mean=0.0, std=0.02)
                    if layer.bias is not None:
                        layer.bias.data.zero_()

        for layer in self.router:
            if isinstance(layer, nn.Linear):
                layer.weight.data.normal_(mean=0.0, std=0.02)
                if layer.bias is not None:
                    layer.bias.data.zero_()

    def forward(self, x):
        # x shape: (batch_size, num_sentences, hidden_size)
        batch_size, seq_len, hidden_size = x.shape

        # Use average embedding for routing (since this is sentence-level)
        avg_emb = x.mean(dim=1)  # (batch_size, hidden_size)
        router_logits = self.router(avg_emb)  # (batch_size, role_count)
        routing_weights = torch.softmax(router_logits, dim=-1)  # (batch_size, role_count)

        # Compute outputs for all 13 adapters
        adapter_outputs = [adapter(x) for adapter in self.adapters]  # List of (batch_size, num_sentences, hidden_size)
        adapter_outputs = torch.stack(adapter_outputs, dim=1)  # (batch_size, role_count, num_sentences, hidden_size)

        # Apply routing weights
        weighted_output = torch.einsum('br,brsh->bsh', routing_weights, adapter_outputs)

        return x + weighted_output

class PositionalEncoding(nn.Module):
    """Positional embeddings with dropout"""
    def __init__(self, d_model, max_len=Config.max_num_sentences):
        super().__init__()
        self.dropout = nn.Dropout(Config.dropout_rate)
        self.position_emb = nn.Embedding(max_len, d_model)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device).unsqueeze(0)
        return self.dropout(x + self.position_emb(positions))

class TransformerContextLayer(nn.Module):
    """Custom transformer layer with 13 parallel adapters per the image schematic"""
    def __init__(self, d_model, nhead=8, dim_feedforward=1024, dropout=0.2, num_roles=Config.num_roles):
        super().__init__()
        self.d_model = d_model

        # Multi-Head Attention
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=nhead,
            dropout=dropout,
            batch_first=True
        )
        self.norm1 = nn.LayerNorm(d_model)
        # First set of 13 parallel adapters after attention
        self.adapter1 = RoleRoutedAdapter(
            config=AutoConfig.from_pretrained(Config.bert_model_name),
            role_count=num_roles,
            intermediate_size=Config.adapter_intermediate_size
        )

        # Feed Forward
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)
        # Second set of 13 parallel adapters after feed forward
        self.adapter2 = RoleRoutedAdapter(
            config=AutoConfig.from_pretrained(Config.bert_model_name),
            role_count=num_roles,
            intermediate_size=Config.adapter_intermediate_size
        )

    def forward(self, x):
        # Multi-Head Attention with residual connection
        attn_output, _ = self.multihead_attn(x, x, x)
        x = self.norm1(x + attn_output)
        # First set of 13 parallel adapters
        x = self.adapter1(x)

        # Feed Forward with residual connection
        ff_output = self.feed_forward(x)
        x = self.norm2(x + ff_output)
        # Second set of 13 parallel adapters
        x = self.adapter2(x)

        return x

class EmissionLayer(nn.Module):
    """Enhanced emission layer with residual connection"""
    def __init__(self, input_size, num_labels, dropout=0.3):
        super().__init__()
        self.linear1 = nn.Linear(input_size, input_size*2)
        self.linear2 = nn.Linear(input_size*2, num_labels)
        self.dropout = nn.Dropout(dropout)
        self.gelu = nn.GELU()
        self.layer_norm = nn.LayerNorm(input_size*2)
        self.residual_proj = nn.Linear(input_size, num_labels)

    def forward(self, x):
        residual = x
        x = self.linear1(x)
        x = self.layer_norm(x)
        x = self.gelu(x)
        x = self.dropout(x)
        return self.linear2(x) + self.residual_proj(residual)

class FocalCRF(nn.Module):
    """Fixed CRF with focal loss"""
    def __init__(self, num_tags, gamma=Config.gamma):
        super().__init__()
        self.crf = CRF(num_tags, batch_first=True)
        self.gamma = gamma

    def forward(self, emissions, tags, mask, class_weights=None):
        log_likelihood = self.crf(emissions, tags, mask=mask, reduction='none')
        pt = torch.exp(log_likelihood)
        focal_loss = -((1 - pt) ** self.gamma) * log_likelihood

        if class_weights is not None:
            weights_per_tag = class_weights[tags]
            valid_counts = mask.sum(dim=1)
            weights_per_sequence = weights_per_tag.sum(dim=1) / valid_counts
            focal_loss = focal_loss * weights_per_sequence

        return focal_loss.mean()

    def decode(self, emissions, mask):
        return self.crf.decode(emissions, mask=mask)

class ImprovedHSLNModel(nn.Module):
    """HSLN model with Role-Routed Adapters in transformer layer"""
    def __init__(self, num_labels, class_weights=None):
        super().__init__()
        self.class_weights = class_weights

        # Load base Legal-BERT without QLoRA
        self.bert = AutoModel.from_pretrained(
            Config.bert_model_name,
            device_map="auto"
        )

        # Feature extraction layers
        self.sent_dropout = nn.Dropout(Config.dropout_rate)
        self.sent_layer_norm = nn.LayerNorm(self.bert.config.hidden_size)
        self.sent_projection = nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size)

        # Context encoding with custom transformer layer
        self.position_enc = PositionalEncoding(self.bert.config.hidden_size)
        self.context_encoder = TransformerContextLayer(d_model=self.bert.config.hidden_size)

        # Emission and CRF
        self.emission = EmissionLayer(input_size=self.bert.config.hidden_size, num_labels=num_labels)
        self.crf = FocalCRF(num_labels, gamma=Config.gamma)

    def forward(self, input_ids, attention_mask, labels=None):
        batch_size, num_sent, seq_len = input_ids.shape
        flat_input_ids = input_ids.view(-1, seq_len)
        flat_mask = attention_mask.view(-1, seq_len)

        bert_out = self.bert(input_ids=flat_input_ids, attention_mask=flat_mask).last_hidden_state
        sent_emb = bert_out[:, 0, :]
        sent_emb = self.sent_projection(sent_emb)
        sent_emb = self.sent_layer_norm(sent_emb)
        sent_emb = self.sent_dropout(sent_emb)
        sent_emb = sent_emb.view(batch_size, num_sent, -1)

        sent_emb = self.position_enc(sent_emb)
        context_emb = self.context_encoder(sent_emb)

        emissions = self.emission(context_emb)
        mask = attention_mask[:, :, 0] > 0

        if labels is not None:
            loss = self.crf(emissions, labels, mask=mask, class_weights=self.class_weights)
            return {"loss": loss, "emissions": emissions}
        return {"emissions": emissions}

def tokenize_datasets(train_hier, dev_hier, test_hier, label2id):
    """Tokenize datasets"""
    print("Tokenizing datasets...")
    tokenizer = AutoTokenizer.from_pretrained(Config.bert_model_name)

    def tokenize_document(example):
        sentences = example['text'][:Config.max_num_sentences]
        labels = example['label'][:Config.max_num_sentences]

        tokenized = tokenizer(
            sentences,
            padding="max_length",
            truncation=True,
            max_length=Config.max_length,
            return_tensors="pt",
            return_attention_mask=True
        )

        pad_len = Config.max_num_sentences - len(sentences)
        if pad_len > 0:
            pad_shape = (pad_len, Config.max_length)
            tokenized["input_ids"] = torch.cat([
                tokenized["input_ids"],
                torch.full(pad_shape, tokenizer.pad_token_id, dtype=torch.long)
            ])
            tokenized["attention_mask"] = torch.cat([
                tokenized["attention_mask"],
                torch.zeros(pad_shape, dtype=torch.long)
            ])
            labels += [list(label2id.keys())[0]] * pad_len

        label_ids = torch.tensor([label2id[l] for l in labels], dtype=torch.long)
        return {"input_ids": tokenized["input_ids"], "attention_mask": tokenized["attention_mask"], "labels": label_ids}

    train_tokenized = train_hier.map(tokenize_document, batched=False)
    dev_tokenized = dev_hier.map(tokenize_document, batched=False)
    test_tokenized = test_hier.map(tokenize_document, batched=False)
    return train_tokenized, dev_tokenized, test_tokenized, tokenizer

class HierarchicalDataset(Dataset):
    """Dataset class"""
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        return {
            "input_ids": torch.tensor(item["input_ids"], dtype=torch.long) if not isinstance(item["input_ids"], torch.Tensor) else item["input_ids"],
            "attention_mask": torch.tensor(item["attention_mask"], dtype=torch.long) if not isinstance(item["attention_mask"], torch.Tensor) else item["attention_mask"],
            "labels": torch.tensor(item["labels"], dtype=torch.long) if not isinstance(item["labels"], torch.Tensor) else item["labels"]
        }

def collate_fn(batch):
    input_ids = torch.stack([item["input_ids"] for item in batch])
    attention_mask = torch.stack([item["attention_mask"] for item in batch])
    labels = torch.stack([item["labels"] for item in batch])
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

def create_data_loaders(train_tokenized, dev_tokenized, test_tokenized):
    train_loader = DataLoader(HierarchicalDataset(train_tokenized), batch_size=Config.batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True)
    dev_loader = DataLoader(HierarchicalDataset(dev_tokenized), batch_size=Config.batch_size, shuffle=False, collate_fn=collate_fn, pin_memory=True)
    test_loader = DataLoader(HierarchicalDataset(test_tokenized), batch_size=Config.batch_size, shuffle=False, collate_fn=collate_fn, pin_memory=True)
    return train_loader, dev_loader, test_loader

def compute_class_weights(train_hier, label2id):
    label_counts = {label: 0 for label in label2id}
    for example in train_hier:
        for label in example['label']:
            label_counts[label] += 1
    total_samples = sum(label_counts.values())
    weights = [(total_samples / (label_counts[label] + 1e-5)) ** 0.5 for label in label2id]
    return torch.tensor(weights, dtype=torch.float32) / min(weights)

def train_model(model, train_loader, dev_loader, optimizer, device, epochs, label_list):
    """Training loop"""
    print(f"\n{'='*30} TRAINING STARTED {'='*30}")
    total_steps = len(train_loader) * epochs // Config.gradient_accumulation_steps
    warmup_steps = int(total_steps * Config.warmup_ratio)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

    model.train()
    best_dev_f1 = 0
    best_macro_f1 = 0
    no_improve = 0
    training_start = time.time()
    history = []

    for epoch in range(epochs):
        epoch_start = time.time()
        total_loss = 0
        all_preds, all_labels = [], []
        optimizer.zero_grad()

        for batch_idx, batch in enumerate(train_loader):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs["loss"] / Config.gradient_accumulation_steps
            loss.backward()

            total_loss += loss.item()
            emissions = outputs["emissions"]
            preds = model.crf.decode(emissions, mask=attention_mask[:, :, 0] > 0)

            flat_labels = labels.cpu().numpy().flatten()
            flat_preds = []
            for i, seq in enumerate(preds):
                seq_len = (attention_mask[i, :, 0] > 0).sum().item()
                flat_preds.extend(seq[:seq_len])
            flat_preds += [0] * (len(flat_labels) - len(flat_preds))
            flat_preds = np.array(flat_preds)

            all_preds.extend(flat_preds)
            all_labels.extend(flat_labels)

            if (batch_idx + 1) % Config.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            if (batch_idx + 1) % 10 == 0:
                print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} | Loss: {loss.item() * Config.gradient_accumulation_steps:.4f}")

        if (batch_idx + 1) % Config.gradient_accumulation_steps != 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        epoch_time = time.time() - epoch_start
        train_f1 = f1_score(all_labels, all_preds, average="weighted", zero_division=0)

        model.eval()
        dev_metrics = evaluate_metrics(model, dev_loader, device, label_list)
        dev_f1 = dev_metrics["weighted_f1"]
        dev_macro = dev_metrics["macro_f1"]

        history.append({'epoch': epoch+1, 'train_loss': total_loss/len(train_loader), 'train_f1': train_f1, 'dev_f1': dev_f1, 'dev_macro_f1': dev_macro})

        print(f"Epoch {epoch+1} completed in {epoch_time:.2f}s | Train F1: {train_f1:.4f} | Dev F1: {dev_f1:.4f} | Macro F1: {dev_macro:.4f}")

        if dev_f1 > best_dev_f1 or dev_macro > best_macro_f1:
            if dev_f1 > best_dev_f1: best_dev_f1 = dev_f1
            if dev_macro > best_macro_f1: best_macro_f1 = dev_macro
            no_improve = 0
            torch.save(model.state_dict(), os.path.join(Config.output_dir, "best_model.pt"))
            print(f"New best model saved with F1: {dev_f1:.4f}, Macro F1: {dev_macro:.4f}")
        else:
            no_improve += 1
            if no_improve >= Config.patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    training_time = time.time() - training_start
    print(f"Training completed in {training_time:.2f} seconds")
    model.load_state_dict(torch.load(os.path.join(Config.output_dir, "best_model.pt")))
    return model, history

def evaluate_metrics(model, dataloader, device, label_list):
    """Evaluation"""
    model.eval()
    all_preds, all_labels = [], []
    total_time = 0
    n_docs = 0
    n_sentences = 0
    eval_start = time.time()

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            mask = attention_mask[:, :, 0] > 0

            start = time.time()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            end = time.time()

            emissions = outputs["emissions"]
            preds = model.crf.decode(emissions, mask=mask)

            for i in range(len(labels)):
                seq_len = mask[i].sum().item()
                if seq_len > 0:
                    all_preds.extend(preds[i][:seq_len])
                    all_labels.extend(labels[i][:seq_len].cpu().numpy())

            total_time += (end - start)
            n_docs += input_ids.shape[0]
            n_sentences += mask.sum().item()

    eval_time = time.time() - eval_start
    report = classification_report(all_labels, all_preds, labels=list(range(len(label_list))), target_names=label_list, output_dict=True, zero_division=0)
    return {
        "macro_f1": report['macro avg']['f1-score'],
        "weighted_f1": report['weighted avg']['f1-score'],
        "accuracy": report['accuracy'],
        "per_label_f1": {label: report[label]['f1-score'] for label in label_list},
        "latency_ms_per_doc": (total_time / n_docs) * 1000 if n_docs else 0,
        "latency_ms_per_sentence": (total_time / n_sentences) * 1000 if n_sentences else 0,
        "eval_time_seconds": eval_time,
        "num_samples": n_docs
    }

def get_model_size_mb(model):
    param_size = sum(param.nelement() * param.element_size() for param in model.parameters())
    buffer_size = sum(buffer.nelement() * buffer.element_size() for buffer in model.buffers())
    return (param_size + buffer_size) / (1024 ** 2)

def get_memory_footprint():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / (1024 ** 3)

def save_checkpoint(model, tokenizer, metrics, label2id, config, save_dir, checkpoint_name):
    """Save model checkpoint"""
    checkpoint_path = os.path.join(save_dir, checkpoint_name)
    os.makedirs(checkpoint_path, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(checkpoint_path, "pytorch_model.bin"))
    tokenizer.save_pretrained(checkpoint_path)

    with open(os.path.join(checkpoint_path, "config.json"), "w") as f:
        json.dump({
            "label2id": label2id,
            "id2label": {i: l for l, i in label2id.items()},
            "model_config": {
                "bert_model_name": config.bert_model_name,
                "max_num_sentences": config.max_num_sentences,
                "max_length": config.max_length,
                "dropout_rate": config.dropout_rate,
                "gamma": config.gamma,
                "num_roles": config.num_roles,
                "adapter_intermediate_size": config.adapter_intermediate_size
            }
        }, f, indent=2)

    with open(os.path.join(checkpoint_path, "metrics.json"), "w") as f:
        json.dump(metrics, f, indent=2)

    print(f"Checkpoint saved to {checkpoint_path}")
    return checkpoint_path

def upload_to_huggingface(save_path, repo_id):
    """Upload to Hugging Face"""
    create_repo(repo_id, exist_ok=True, token=True)
    upload_folder(repo_id=repo_id, folder_path=save_path, commit_message="Hierarchical Legal Model with Role-Routed Adapters", repo_type="model", token=True)
    print(f"Model uploaded to https://huggingface.co/{repo_id}")

def main():
    """Main pipeline"""
    try:
        start_time = time.time()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"\n{'='*50}\nHIERARCHICAL LEGAL MODEL WITH ROLE-ROUTED ADAPTERS\nTimestamp: {datetime.now().isoformat()}\nDevice: {device}\n{'='*50}\n")
        os.makedirs(Config.output_dir, exist_ok=True)

        train_hier, dev_hier, test_hier, label2id, id2label, label_list = prepare_hierarchical_datasets(train_ds, dev_ds, test_ds)
        class_weights = compute_class_weights(train_hier, label2id).to(device)
        train_tokenized, dev_tokenized, test_tokenized, tokenizer = tokenize_datasets(train_hier, dev_hier, test_hier, label2id)
        train_loader, dev_loader, test_loader = create_data_loaders(train_tokenized, dev_tokenized, test_tokenized)

        model = ImprovedHSLNModel(num_labels=len(label2id), class_weights=class_weights).to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=Config.learning_rate, weight_decay=Config.weight_decay)

        model, history = train_model(model, train_loader, dev_loader, optimizer, device, Config.epochs, label_list)
        dev_metrics = evaluate_metrics(model, dev_loader, device, label_list)
        train_metrics = evaluate_metrics(model, train_loader, device, label_list)

        metrics = {
            "train": train_metrics,
            "dev": dev_metrics,
            "model_size_mb": get_model_size_mb(model),
            "training_memory_footprint_gb": get_memory_footprint(),
            "label2id": label2id,
            "id2label": id2label,
            "training_time": time.time() - start_time,
            "training_history": history
        }

        checkpoint_path = save_checkpoint(model, tokenizer, metrics, label2id, Config, Config.output_dir, Config.save_checkpoint)
        upload_to_huggingface(checkpoint_path, Config.hf_repo_id)

        print("\n==== FINAL METRICS ====")
        print(f"Train Weighted F1: {train_metrics['weighted_f1']:.4f} | Macro F1: {train_metrics['macro_f1']:.4f} | Accuracy: {train_metrics['accuracy']:.4f}")
        print(f"Dev Weighted F1: {dev_metrics['weighted_f1']:.4f} | Macro F1: {dev_metrics['macro_f1']:.4f} | Accuracy: {dev_metrics['accuracy']:.4f}")
        print(f"Model Size: {metrics['model_size_mb']:.2f} MB | Training Time: {metrics['training_time']:.2f} seconds")

    except Exception as e:
        print(f"\n{'!'*50}\nPIPELINE FAILED!\nError: {str(e)}")
        import traceback
        traceback.print_exc()
        with open(os.path.join(Config.output_dir, "error_log.txt"), "w") as f:
            f.write(f"Pipeline error at {datetime.now()}\n{str(e)}\n{traceback.format_exc()}")

if __name__ == "__main__":
    main()