# Main Process Part

## Introduction of Dataset

In [1]:
import pandas as pd

df = pd.read_csv('labels.csv')

# All coarse_label
unique_coarse_labels = df['coarse_label'].unique()

# All fine_label
unique_fine_labels = df['fine_label'].unique()

# Print
print("Unique Coarse Labels:", unique_coarse_labels)
print("Unique Fine Labels:", unique_fine_labels)

Unique Coarse Labels: ['neutral' 'fear' 'surprise' 'love' 'anger' 'sadness' 'joy']
Unique Fine Labels: ['neutral' 'nervousness' 'curiosity' 'surprise' 'admiration' 'gratitude'
 'anger' 'confusion' 'caring' 'disappointment' 'disapproval' 'desire'
 'amusement' 'fear' 'sadness' 'pride' 'love' 'excitement' 'grief'
 'realization' 'annoyance' 'approval' 'relief' 'remorse' 'embarrassment'
 'optimism']


In [10]:
import gc
torch.cuda.empty_cache()
gc.collect()


147

In [3]:
# To keep the Rationality and Diversity of the experiment, we randomly choose 50 songs of the whole dataset,
# which including each fine_label without repitition(if the amount of the songs including that fine_label is lower than 50, we just choose all of them)
# This part is the code for our on-demand sampling, if you want to use it, then you can delete the markdown marks and run it.
"""
import pandas as pd
import numpy as np

# Load original label data
df = pd.read_csv("labels.csv")

# Fix random seed for reproducibility
np.random.seed(42)

# Define the order of fine labels, from smallest to largest
fine_label_order = [
    "optimism", "remorse", "excitement", "embarrassment", "sadness", "disappointment", 
    "grief", "fear", "love", "relief", "realization", "disapproval", "anger", "pride", 
    "caring", "confusion", "surprise", "admiration", "gratitude", "desire", 
    "approval", "annoyance", "amusement", "curiosity", "nervousness", "neutral"
]

selected_song_ids = set()
final_selected_df = pd.DataFrame()

for label in fine_label_order:
    # Find all rows that have this label
    label_rows = df[df["fine_label"] == label]

    # All candidate song_ids
    candidate_song_ids = label_rows["song_id"].unique()

    # Remove song_ids that have already been selected
    available_song_ids = [sid for sid in candidate_song_ids if sid not in selected_song_ids]

    # Randomly select up to 50
    if len(available_song_ids) <= 50:
        selected_ids = available_song_ids
    else:
        selected_ids = np.random.choice(available_song_ids, size=50, replace=False)

    # Mark as selected
    selected_song_ids.update(selected_ids)

    # Add the complete song content
    selected_rows = df[df["song_id"].isin(selected_ids)]
    final_selected_df = pd.concat([final_selected_df, selected_rows], ignore_index=True)

# Save the results
final_selected_df.to_csv("experiment.csv", index=False)

## Main Experiment

In [None]:
# For some Recommended Version of some key modules: numpy==1.24.4, matplotlib==3.7.1

In [5]:
#for main experiment models
import torch
import pandas as pd
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
# ✅ Training function
from structure_model import StructureAwareClassifier, HANClassifier, compute_total_loss, extended_eval
torch.cuda.empty_cache()

# ✅ Global config
NUM_SONGS = 1041 # 1041 is the amount of all the songs in experiment.csv, you can adjust it as you need
EPOCHS = 20
BATCH_SIZE = 16
TRAIN_FRAC = 0.2
VAL_FRAC = 0.1
SEED = 42
COARSE_LABEL_NAMES = ["joy", "sadness", "anger", "fear", "surprise", "love", "neutral"]

def seed_all(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_all()

# ✅ Coarse label mapping
coarse2id = {label: i for i, label in enumerate(COARSE_LABEL_NAMES)}

# ✅ Load + encode dataset
df = pd.read_csv("experiment.csv")
fine_labels = sorted(df["fine_label"].unique())
fine2id = {label: i for i, label in enumerate(fine_labels)}
df["coarse_label_id"] = df["coarse_label"].map(coarse2id)
df["fine_label_id"] = df["fine_label"].map(fine2id)

NUM_FINE_CLASSES = len(fine2id)

# ✅ Tokenizer
tokenizer = AutoTokenizer.from_pretrained("./bert-base-uncased", local_files_only=True)

# ✅ Dataset wrapper
class LyricsDataset(Dataset):
    def __init__(self, df):
        self.df = df
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        encoded = tokenizer(row["line_text"], padding="max_length", truncation=True, max_length=128, return_tensors="pt")
        return {
            "input_ids": encoded["input_ids"].squeeze(0),
            "attention_mask": encoded["attention_mask"].squeeze(0),
            "fine_label": torch.tensor(row["fine_label_id"]),
            "coarse_label": torch.tensor(row["coarse_label_id"], dtype=torch.long),
            "para_id": torch.tensor(row["para_id"]),
            "line_id": torch.tensor(row["line_id"]),
        }

# ✅ Dataset split
unique_ids = df["song_id"].unique()[:NUM_SONGS]
df = df[df["song_id"].isin(unique_ids)].reset_index(drop=True)
train_val = df.sample(frac=TRAIN_FRAC + VAL_FRAC, random_state=SEED)
train_df = train_val.sample(frac=TRAIN_FRAC / (TRAIN_FRAC + VAL_FRAC), random_state=SEED)
val_df = train_val.drop(train_df.index)
test_df = df.drop(train_val.index)

train_loader = DataLoader(LyricsDataset(train_df), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(LyricsDataset(val_df), batch_size=BATCH_SIZE)
test_loader = DataLoader(LyricsDataset(test_df), batch_size=BATCH_SIZE)

# ✅ Fine-to-coarse mapping
def get_fine_to_coarse_map(df):
    return dict(zip(df["fine_label_id"], df["coarse_label_id"]))
fine_to_coarse_map = get_fine_to_coarse_map(df)


# ✅ Train function with support for models
def train_model(model, train_loader, name, cfg, epochs=5):
    import geoopt
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.device = device

    # ✅ Use Riemannian optimizer for StructFormer-Hyper only
    if name == "StructFormer-Hyper":
        optimizer = geoopt.optim.RiemannianAdam(model.parameters(), lr=1e-5)
    elif isinstance(model, HANClassifier):
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    else:
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

    results = {"train_loss": []}
    for epoch in range(epochs):
        model.train()
        total = 0
        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            coarse_logits, fine_logits, struct_logits = model(
                batch["input_ids"], batch["attention_mask"],
                batch["para_id"], batch["line_id"]
            )
            z_euc, z_hyp = model.get_embeddings(
                batch["input_ids"], batch["attention_mask"],
                batch["para_id"], batch["line_id"]
            )
            losses = compute_total_loss(
                fine_logits, batch["fine_label"],
                coarse_logits, batch["coarse_label"],
                struct_logits=struct_logits,
                para_ids=batch["para_id"],
                z_euc=z_euc, z_hyp=z_hyp,
                lambda_fine=1.0, lambda_coarse=0.5,
                lambda_align=0.4, lambda_geom=0.05, margin=1.0,
                model=model
            )
            loss = losses["total"]
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            total += loss.item()
        results["train_loss"].append(total / len(train_loader))
        print(f"[{name}] Epoch {epoch+1}: Loss = {results['train_loss'][-1]:.4f}")
    return results


# ✅ Model configuration: five variants for ablation and full-model comparison
model_configs = {
    # StructFormer-Hyper: Full model with structure, projection, hierarchy, and hyperbolic geometry
    "StructFormer-Hyper": dict(
        use_structure=True, use_hyperbolic=True,
        project_structure=True, use_coarse_supervision=True,
        num_fine_labels=NUM_FINE_CLASSES
    ),
    # B1: Baseline BERT without structure, hyperbolic geometry, or coarse supervision
    "BERT": dict(
        use_structure=False, use_hyperbolic=False,
        project_structure=False, use_coarse_supervision=False,
        num_fine_labels=NUM_FINE_CLASSES
    ),
    # M2: Adds paragraph and line IDs without projection, no hierarchy or hyperbolic space
    "Multi-task BERT": dict(
        use_structure=False, use_hyperbolic=False,
        project_structure=False, use_coarse_supervision=True,
        num_fine_labels=NUM_FINE_CLASSES
    ),
    # H2: Uses only hyperbolic geometry and coarse label supervision, no structural input
    "HAN": dict(model_type="han", num_fine_labels=NUM_FINE_CLASSES)
}

# ✅ Train and evaluate
results = {}
trained_models = {}

# ✅ Updated model loading logic
for name, cfg in model_configs.items():
    print(f"\n🚀 Training: {name}")
    if cfg.get("model_type") == "han":
        model = HANClassifier(num_classes=NUM_FINE_CLASSES)
    else:
        model = StructureAwareClassifier(**cfg)
    train_result = train_model(model, train_loader, name, cfg, epochs=EPOCHS)
    results[name] = train_result
    trained_models[name] = model
    print(f"📊 Eval for {name}:", extended_eval(model, test_loader, fine_to_coarse_map))




🚀 Training: StructFormer-Hyper
[StructFormer-Hyper] Epoch 1: Loss = 2.7318
[StructFormer-Hyper] Epoch 2: Loss = 2.1837
[StructFormer-Hyper] Epoch 3: Loss = 1.7842
[StructFormer-Hyper] Epoch 4: Loss = 1.4584
[StructFormer-Hyper] Epoch 5: Loss = 1.1952
[StructFormer-Hyper] Epoch 6: Loss = 0.9945
[StructFormer-Hyper] Epoch 7: Loss = 0.8464
[StructFormer-Hyper] Epoch 8: Loss = 0.7241
[StructFormer-Hyper] Epoch 9: Loss = 0.6354
[StructFormer-Hyper] Epoch 10: Loss = 0.5686
[StructFormer-Hyper] Epoch 11: Loss = 0.5045
[StructFormer-Hyper] Epoch 12: Loss = 0.4578
[StructFormer-Hyper] Epoch 13: Loss = 0.4160
[StructFormer-Hyper] Epoch 14: Loss = 0.3809
[StructFormer-Hyper] Epoch 15: Loss = 0.3511
[StructFormer-Hyper] Epoch 16: Loss = 0.3271
[StructFormer-Hyper] Epoch 17: Loss = 0.3032
[StructFormer-Hyper] Epoch 18: Loss = 0.2792
[StructFormer-Hyper] Epoch 19: Loss = 0.2641
[StructFormer-Hyper] Epoch 20: Loss = 0.2463
📊 Eval for StructFormer-Hyper: {'macro_f1': 0.6530178202889233, 'weighted_f1'

## Visualization

In [4]:
# Visualization
from visualization import (
    plot_all_model_heatmaps,
    plot_all_model_umaps,
    plot_all_model_norm_trends
)

# ✅ define label
coarse_label_names = ["joy", "sadness", "anger", "fear", "surprise", "love", "neutral"]
fine_label_names = [
    "joy", "amusement", "pride", "excitement", "relief", "optimism",
    "sadness", "grief", "disappointment", "remorse",
    "anger", "annoyance", "disapproval",
    "fear", "embarrassment", "nervousness",
    "surprise", "realization", "confusion",
    "love", "gratitude", "desire",
    "neutral", "curiosity", "approval", "admiration"
]

# ✅ heatmap
plot_all_model_heatmaps(
    models_dict=trained_models,
    loader=test_loader,
    fine_label_names=fine_label_names,
    coarse_label_names=coarse_label_names,
    save_path="viz_heatmaps_all_models.png"
)

# ✅ UMAP
plot_all_model_umaps(
    models_dict=trained_models,
    loader=test_loader,
    fine_label_names=fine_label_names,
    save_path="viz_umap_all_models.png"
)

# ✅ Norm trends
plot_all_model_norm_trends(
    models_dict=trained_models,
    loader=test_loader,
    save_path="viz_norm_trends_all_models.png"
)


  warn(
  warn(
  warn(
  warn(


In [None]:
# Loss Plot
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from matplotlib.ticker import MaxNLocator

# Set the style to the style used by NeurIPS
plt.style.use('seaborn')
sns.set_context("paper", font_scale=1.4)
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.3

# Prepare data (loss values extracted from the document)
models = { }

# Create the plot
plt.figure(figsize=(10, 6), dpi=300)
colors = sns.color_palette("husl", 4)
linestyles = ['-', '--', '-.', ':']
markers = ['o', 's', 'D', '^']

for idx, (model_name, losses) in enumerate(models.items()):
    epochs = np.arange(1, len(losses)+1)
    plt.plot(epochs, losses, 
             label=model_name, 
             color=colors[idx],
             linestyle=linestyles[idx],
             marker=markers[idx],
             markersize=6,
             linewidth=2,
             markevery=2)

# Add final evaluation weighted-F1 values as annotations
final_f1 = {
    "StructFormer-Hyper": 0.96,
    "BERT": 0.96,
    "Multi-task BERT": 0.96,
    "HAN": 0.91
}

for idx, (model_name, f1) in enumerate(final_f1.items()):
    y_pos = models[model_name][-1]
    plt.annotate(f'Weighted-F1: {f1:.2f}',
                 xy=(20, y_pos),
                 xytext=(22, y_pos),
                 color=colors[idx],
                 fontsize=10,
                 arrowprops=dict(arrowstyle="->", color=colors[idx]))

# Decorate the plot
plt.title('Training Loss Curves (20 Epochs)', pad=20)
plt.xlabel('Epoch', labelpad=10)
plt.ylabel('Loss Value', labelpad=10)
plt.xticks(np.arange(0, 21, 2))
plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
plt.legend(loc='upper right', framealpha=1)
plt.yscale('log')  # Use logarithmic scale to better display loss values across different orders of magnitude

# Add NeurIPS-style grid and borders
sns.despine()
plt.grid(True, which="both", ls="--", alpha=0.2)

# Save the plot
plt.tight_layout()
#plt.savefig('training_loss_curves.pdf', bbox_inches='tight', dpi=300)
plt.show()

In [None]:
from sklearn.metrics import silhouette_score
from scipy.stats import ttest_rel
import numpy as np
import torch

# Step 1: extract embeddings and labels
def extract_embeddings(model, loader):
    model.eval()
    embeddings, labels = [], []
    with torch.no_grad():
        for batch in loader:
            input_ids = batch["input_ids"].to(model.device)
            attn_mask = batch["attention_mask"].to(model.device)
            para_id = batch["para_id"].to(model.device)
            line_id = batch["line_id"].to(model.device)
            fine_label = batch["fine_label"]
            z_euc, _ = model.get_embeddings(input_ids, attn_mask, para_id, line_id)
            embeddings.append(z_euc.cpu().numpy())
            labels.append(fine_label.cpu().numpy())
    return np.concatenate(embeddings), np.concatenate(labels)

# Step 2: paired bootstrap sampling
def bootstrap_silhouette_paired(X1, y1, X2, y2, n_iter=500, sample_size=512, seed=42):
    np.random.seed(seed)
    scores1, scores2 = [], []
    N = len(X1)
    sample_size = min(sample_size, N)
    for _ in range(n_iter):
        idx = np.random.choice(N, size=sample_size, replace=True)
        scores1.append(silhouette_score(X1[idx], y1[idx]))
        scores2.append(silhouette_score(X2[idx], y2[idx]))
    return np.array(scores1), np.array(scores2)

# Step 3: paired t-test comparison
def compare_models(model_main, model_base, loader, base_name, n_iter=500, sample_size=512):
    print(f"\n🔍 Testing StructFormer-Hyper vs {base_name}")
    X_main, y_main = extract_embeddings(model_main, loader)
    X_base, y_base = extract_embeddings(model_base, loader)
    scores_main, scores_base = bootstrap_silhouette_paired(
        X_main, y_main, X_base, y_base, n_iter=n_iter, sample_size=sample_size
    )
    t_stat, p_val = ttest_rel(scores_main, scores_base)
    print(f"Paired t-test: t = {t_stat:.4f}, p = {p_val:.4f}")
    return scores_main, scores_base, p_val

# Dictionary of trained models
main_model = trained_models["StructFormer-Hyper"]
baselines = ["BERT", "Multi-task BERT", "HAN"]

results = {}

for base_name in baselines:
    base_model = trained_models[base_name]
    scores_main, scores_base, p_val = compare_models(
        model_main=main_model,
        model_base=base_model,
        loader=test_loader,
        base_name=base_name,
        n_iter=500,
        sample_size=512
    )
    results[base_name] = {
        "StructFormer-Hyper": scores_main,
        base_name: scores_base,
        "p-value": p_val
    }



In [None]:
import numpy as np
from sklearn.metrics import silhouette_score
from scipy.stats import wilcoxon
import torch

# Function to extract embeddings and labels
def extract_embeddings(model, loader):
    model.eval()
    embeddings, labels = [], []
    with torch.no_grad():
        for batch in loader:
            input_ids = batch["input_ids"].to(model.device)
            attn_mask = batch["attention_mask"].to(model.device)
            para_id = batch["para_id"].to(model.device)
            line_id = batch["line_id"].to(model.device)
            fine_label = batch["fine_label"]
            z_euc, _ = model.get_embeddings(input_ids, attn_mask, para_id, line_id)
            embeddings.append(z_euc.cpu().numpy())
            labels.append(fine_label.cpu().numpy())
    return np.concatenate(embeddings), np.concatenate(labels)

# Bootstrap silhouette using same indices for both models
def bootstrap_silhouette_paired(X1, y1, X2, y2, n_iter=500, sample_size=None, seed=42):
    np.random.seed(seed)
    scores1, scores2 = [], []
    N = len(X1)
    if sample_size is None or sample_size > N:
        sample_size = N

    for _ in range(n_iter):
        idx = np.random.choice(N, size=sample_size, replace=True)
        scores1.append(silhouette_score(X1[idx], y1[idx]))
        scores2.append(silhouette_score(X2[idx], y2[idx]))
    return np.array(scores1), np.array(scores2)

# Compare models using Wilcoxon signed-rank test
def compare_models_wilcoxon(model_main, model_base, loader, base_name, n_iter=500, sample_size=None):
    print(f"\n🔍 Wilcoxon Test: StructFormer-Hyper vs {base_name}")
    X_main, y_main = extract_embeddings(model_main, loader)
    X_base, y_base = extract_embeddings(model_base, loader)

    scores_main, scores_base = bootstrap_silhouette_paired(
        X_main, y_main, X_base, y_base, n_iter=n_iter, sample_size=sample_size
    )

    stat, p_val = wilcoxon(scores_main, scores_base)
    print(f"Wilcoxon signed-rank test: statistic = {stat:.4f}, p = {p_val:.4f}")
    return scores_main, scores_base, p_val

# Running comparisons for all baselines
main_model = trained_models["StructFormer-Hyper"]
baselines = ["BERT", "Multi-task BERT", "HAN"]

results_wilcoxon = {}

for base_name in baselines:
    base_model = trained_models[base_name]
    scores_main, scores_base, p_val = compare_models_wilcoxon(
        main_model, base_model, test_loader, base_name=base_name, n_iter=500, sample_size=None
    )
    results_wilcoxon[base_name] = {
        "StructFormer-Hyper": scores_main,
        base_name: scores_base,
        "p-value": p_val
    }



🔍 Wilcoxon Test: StructFormer-Hyper vs BERT


## Ablation Studies

In [None]:
import torch
import pandas as pd
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
# ✅ Training function
from structure_model import StructureAwareClassifier, HANClassifier, compute_total_loss, extended_eval
torch.cuda.empty_cache()

# ✅ Global config
NUM_SONGS = 300 # 1041 is the amount of all the songs in experiment.csv, you can adjust it as you need
EPOCHS = 10
BATCH_SIZE = 16
TRAIN_FRAC = 0.2
VAL_FRAC = 0.1
SEED = 42
COARSE_LABEL_NAMES = ["joy", "sadness", "anger", "fear", "surprise", "love", "neutral"]

def seed_all(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_all()

# ✅ Coarse label mapping
coarse2id = {label: i for i, label in enumerate(COARSE_LABEL_NAMES)}

# ✅ Load + encode dataset
df = pd.read_csv("experiment.csv")
fine_labels = sorted(df["fine_label"].unique())
fine2id = {label: i for i, label in enumerate(fine_labels)}
df["coarse_label_id"] = df["coarse_label"].map(coarse2id)
df["fine_label_id"] = df["fine_label"].map(fine2id)

NUM_FINE_CLASSES = len(fine2id)

# ✅ Tokenizer
tokenizer = AutoTokenizer.from_pretrained("./bert-base-uncased", local_files_only=True)

# ✅ Dataset wrapper
class LyricsDataset(Dataset):
    def __init__(self, df):
        self.df = df
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        encoded = tokenizer(row["line_text"], padding="max_length", truncation=True, max_length=128, return_tensors="pt")
        return {
            "input_ids": encoded["input_ids"].squeeze(0),
            "attention_mask": encoded["attention_mask"].squeeze(0),
            "fine_label": torch.tensor(row["fine_label_id"]),
            "coarse_label": torch.tensor(row["coarse_label_id"], dtype=torch.long),
            "para_id": torch.tensor(row["para_id"]),
            "line_id": torch.tensor(row["line_id"]),
        }

# ✅ Dataset split
unique_ids = df["song_id"].unique()[:NUM_SONGS]
df = df[df["song_id"].isin(unique_ids)].reset_index(drop=True)
train_val = df.sample(frac=TRAIN_FRAC + VAL_FRAC, random_state=SEED)
train_df = train_val.sample(frac=TRAIN_FRAC / (TRAIN_FRAC + VAL_FRAC), random_state=SEED)
val_df = train_val.drop(train_df.index)
test_df = df.drop(train_val.index)

train_loader = DataLoader(LyricsDataset(train_df), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(LyricsDataset(val_df), batch_size=BATCH_SIZE)
test_loader = DataLoader(LyricsDataset(test_df), batch_size=BATCH_SIZE)

# ✅ Fine-to-coarse mapping
def get_fine_to_coarse_map(df):
    return dict(zip(df["fine_label_id"], df["coarse_label_id"]))
fine_to_coarse_map = get_fine_to_coarse_map(df)


# ✅ Train function with support for both models
def train_model(model, train_loader, name, cfg, epochs=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.device = device
    optimizer = (
        torch.optim.AdamW(model.parameters(), lr=1e-5)
        if not isinstance(model, HANClassifier)
        else torch.optim.Adam(model.parameters(), lr=1e-5)
    )
    results = {"train_loss": []}
    for epoch in range(epochs):
        model.train()
        total = 0
        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            coarse_logits, fine_logits, struct_logits = model(
                batch["input_ids"], batch["attention_mask"],
                batch["para_id"], batch["line_id"]
            )
            z_euc, z_hyp = model.get_embeddings(
                batch["input_ids"], batch["attention_mask"],
                batch["para_id"], batch["line_id"]
            )
            losses = compute_total_loss(
                fine_logits, batch["fine_label"],
                coarse_logits, batch["coarse_label"],
                struct_logits=struct_logits,
                para_ids=batch["para_id"],
                z_euc=z_euc, z_hyp=z_hyp,
                lambda_fine=1.0, lambda_coarse=0.5,
                lambda_align=0.2, lambda_geom=0.5, margin=1.0,
                model=model
            )
            loss = losses["total"]
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            total += loss.item()
        results["train_loss"].append(total / len(train_loader))
        print(f"[{name}] Epoch {epoch+1}: Loss = {results['train_loss'][-1]:.4f}")
    return results

# ✅ Model configuration: five variants for ablation and full-model comparison
# ✅ Ablation experiment configuration
model_configs = {
    # Full model with structure, projection, coarse supervision, and hyperbolic geometry
    "StructFormer-Hyper": dict(
        use_structure=True,
        use_hyperbolic=True,
        project_structure=True,
        use_coarse_supervision=True,
        num_fine_labels=NUM_FINE_CLASSES
    ),

    # Remove structural information: no paragraph/line embeddings
    "StructFormer-NoStruct": dict(
        use_structure=False,
        use_hyperbolic=True,
        project_structure=False,
        use_coarse_supervision=True,
        num_fine_labels=NUM_FINE_CLASSES
    ),

    # Remove hyperbolic geometry: model in Euclidean space only
    "StructFormer-NoHyper": dict(
        use_structure=True,
        use_hyperbolic=False,
        project_structure=True,
        use_coarse_supervision=True,
        num_fine_labels=NUM_FINE_CLASSES
    ),

    # Remove structure projection: structure is not mapped to BERT hidden space
    "StructFormer-NoProj": dict(
        use_structure=True,
        use_hyperbolic=True,
        project_structure=False,
        use_coarse_supervision=True,
        num_fine_labels=NUM_FINE_CLASSES
    ),

    # Remove coarse label supervision: no auxiliary supervision at coarse level
    "StructFormer-NoCoarse": dict(
        use_structure=True,
        use_hyperbolic=True,
        project_structure=True,
        use_coarse_supervision=False,
        num_fine_labels=NUM_FINE_CLASSES
    ),
}

# ✅ Train and evaluate
results = {}
trained_models = {}

# ✅ Updated model loading logic
for name, cfg in model_configs.items():
    print(f"\n🚀 Training: {name}")
    if cfg.get("model_type") == "han":
        model = HANClassifier(num_classes=NUM_FINE_CLASSES)
    else:
        model = StructureAwareClassifier(**cfg)
    train_result = train_model(model, train_loader, name, cfg, epochs=EPOCHS)
    results[name] = train_result
    trained_models[name] = model
    print(f"📊 Eval for {name}:", extended_eval(model, test_loader, fine_to_coarse_map))

