In [1]:
print("⏳ Setting up environment for HMTAS inference...")

# --- Step 1: Ensure clean base ---
!pip uninstall -y bleurt tensorflow tensorflow-text -q || true
!rm -rf /root/.cache/pip /root/.cache/huggingface /root/.cache/nltk

# --- Step 2: Install core stable packages (quietly) ---
!pip install -q -U numpy==1.26.4 scikit-learn==1.3.2
!pip install -q transformers==4.41.2 datasets==2.18.0 sentence-transformers==2.7.0
!pip install -q rouge-score==0.1.2 bert-score==0.3.13
!pip install -q networkx==3.2.1 hdbscan==0.8.33 umap-learn==0.5.5
!pip install -q matplotlib seaborn nltk

# --- Step 3: Torch (CPU) install (GPU uses default runtime CUDA) ---
!pip install -q torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121

# --- Step 4: Download NLTK resources silently ---
import nltk
nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)

print("✅ Environment ready for HMTAS inference and visualization.")


⏳ Setting up environment for HMTAS inference...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m86.6 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cesium 0.12.4 requires numpy<3.0,>=2.0, but you have numpy 1.26.4 which is incompatible.
umap-learn 0.5.9.post2 requires scikit-learn>=1.6, but you have scikit-learn 1.3.2 which is incompatible.[0m[31m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.8/43.8 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.1/9.1 MB[0m [31m60.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m25.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m171.5/171.5 kB[

In [2]:
"""
HMTAS (Hierarchical Multi-Document Text Abstractive Summarization) - Kaggle-ready inference pipeline
- Removes fine-tuning / training steps (inference-only)
- Keeps HTAS/HMTAS extractive guidance pipeline (SBERT + UMAP + HDBSCAN + PageRank)
- Uses a pretrained abstractive model (default: facebook/bart-large-cnn)
- Produces single-sentence global summary for 2-3 input news docs
- Keeps all plotting utilities and saves outputs to /kaggle/working/

Run in Kaggle (GPU) or similar environment. All outputs (plots, CSVs, JSON) are written to /kaggle/working/htas_output
"""

import warnings
warnings.filterwarnings("ignore")
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import time
import gc
import json
import random
from pathlib import Path
from collections import defaultdict
from typing import List, Dict, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset as TorchDataset

from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    logging as hf_logging
)
from rouge_score import rouge_scorer
from bert_score import score as bert_score_calculator

import nltk
from nltk.tokenize import sent_tokenize

# Optional metrics
try:
    import sacrebleu
    _HAS_SACREBLEU = True
except Exception:
    _HAS_SACREBLEU = False

try:
    from nltk.translate.meteor_score import meteor_score
    _HAS_METEOR = True
except Exception:
    _HAS_METEOR = False

# Clustering helpers
import hdbscan
import networkx as nx

# Optional UMAP
try:
    import umap.umap_ as umap
    _HAS_UMAP = True
except Exception:
    _HAS_UMAP = False

hf_logging.set_verbosity_error()
sns.set_style("whitegrid")

# --------------------------
# Configuration
# --------------------------
class Config:
    total_train_samples = 0
    total_test_samples = 2000

    ratios = {
        'cnn_dailymail::3.0.0': 0.35,
        'multi_news': 0.35,
        'xsum': 0.30
    }

    htas_model_name = 'facebook/bart-large-cnn'  # abstractive generator used for inference
    max_input_length = 1024
    max_target_length = 150

    eval_models = [
        'facebook/bart-large-cnn',
        't5-base',
        'google/pegasus-cnn_dailymail'
    ]

    # HTAS params
    token_budget = 400
    min_cluster_size = 3
    per_cluster_topk = 2
    fallback_topk = 8
    sim_threshold = 0.15

    pagerank_alpha = 0.85
    pagerank_max_iter = 100

    num_beams = 4
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    seed = 42

    output_dir = Path('/kaggle/working/htas_output')

config = Config()

# --------------------------
# Utilities
# --------------------------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def mkdir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

# ensure punkt
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

# --------------------------
# HTAS components
# --------------------------
class SentenceEncoderGPU:
    def __init__(self, device: str = config.device):
        self.device = device
        self.model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
        self.model.eval()

    def encode(self, sentences: List[str]) -> torch.Tensor:
        if not sentences:
            return torch.empty(0, 384, device=self.device)
        return self.model.encode(sentences, convert_to_tensor=True, device=self.device, show_progress_bar=False)

    def cosine_similarity(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
        a_norm = F.normalize(a, p=2, dim=-1)
        b_norm = F.normalize(b, p=2, dim=-1)
        return torch.mm(a_norm, b_norm.t())

class HTASClusterer:
    def __init__(self, cfg):
        self.cfg = cfg

    def cluster_sentences(self, embeddings) -> np.ndarray:
        if isinstance(embeddings, torch.Tensor):
            n = embeddings.size(0)
            emb_cpu = embeddings.cpu().numpy()
        else:
            emb_cpu = np.asarray(embeddings)
            n = emb_cpu.shape[0]

        if n == 0:
            return np.array([], dtype=int)
        if n < self.cfg.min_cluster_size:
            return np.zeros(n, dtype=int)

        clusterer = hdbscan.HDBSCAN(min_cluster_size=self.cfg.min_cluster_size, min_samples=1, metric='euclidean')
        labels = clusterer.fit_predict(emb_cpu)

        noise_mask = labels == -1
        if noise_mask.any() and (~noise_mask).any():
            non_noise_idx = np.where(~noise_mask)[0]
            for idx in np.where(noise_mask)[0]:
                vec = torch.tensor(emb_cpu[idx]).unsqueeze(0)
                non_noise_emb = torch.tensor(emb_cpu[non_noise_idx])
                sims = F.cosine_similarity(vec, non_noise_emb, dim=-1)
                labels[idx] = int(labels[non_noise_idx][sims.argmax().item()])

        return labels

    def select_from_cluster(self, cluster_indices: List[int], sim_matrix: torch.Tensor) -> List[int]:
        if not cluster_indices:
            return []
        if len(cluster_indices) <= self.cfg.per_cluster_topk:
            return cluster_indices

        sim_sub = sim_matrix[cluster_indices][:, cluster_indices].cpu().numpy()
        sim_sub[sim_sub < self.cfg.sim_threshold] = 0
        np.fill_diagonal(sim_sub, 0)
        try:
            G = nx.from_numpy_array(sim_sub)
            prs = nx.pagerank(G, alpha=self.cfg.pagerank_alpha, max_iter=self.cfg.pagerank_max_iter)
            deg = nx.degree_centrality(G)
            combined = {i: 0.6 * prs.get(i, 0) + 0.4 * deg.get(i, 0) for i in range(len(cluster_indices))}
            ranked = sorted(combined.items(), key=lambda x: x[1], reverse=True)
            return [cluster_indices[i] for i, _ in ranked[:self.cfg.per_cluster_topk]]
        except Exception:
            sums = sim_sub.sum(axis=1)
            idxs = np.argsort(-sums)[:self.cfg.per_cluster_topk]
            return [cluster_indices[int(i)] for i in idxs]

class TokenBudgetedSelector:
    def __init__(self, tokenizer, cfg):
        self.tokenizer = tokenizer
        self.cfg = cfg

    def select_with_budget(self, sentences: List[str], candidate_indices: List[int]) -> List[int]:
        selected = []
        total_tokens = 0
        for idx in candidate_indices:
            tokens = len(self.tokenizer.encode(sentences[idx], add_special_tokens=False))
            if total_tokens + tokens <= self.cfg.token_budget:
                selected.append(idx)
                total_tokens += tokens
            if total_tokens >= self.cfg.token_budget:
                break
        return selected

class HTASProcessor:
    def __init__(self, cfg, tokenizer):
        self.cfg = cfg
        self.tokenizer = tokenizer
        self.encoder = SentenceEncoderGPU(device=cfg.device)
        self.clusterer = HTASClusterer(cfg)
        self.selector = TokenBudgetedSelector(tokenizer, cfg)
        self._umap_reducer = None
        if _HAS_UMAP:
            try:
                self._umap_reducer = umap.UMAP(n_components=16, random_state=cfg.seed)
            except Exception:
                self._umap_reducer = None

    def preprocess_documents(self, documents: List[str]) -> List[str]:
        sentences = []
        for doc in documents:
            sents = sent_tokenize(doc)
            sentences.extend([s.strip() for s in sents if 8 < len(s.split()) < 150])
        return sentences

    def create_guided_input(self, documents: List[str]) -> str:
        sentences = self.preprocess_documents(documents)
        if not sentences:
            return " ".join(documents)

        embeddings = self.encoder.encode(sentences)
        if len(embeddings) == 0:
            return " ".join(documents)

        sim_matrix = self.encoder.cosine_similarity(embeddings, embeddings)
        emb_cpu = embeddings.cpu().numpy()
        cluster_input = emb_cpu
        if self._umap_reducer is not None:
            try:
                reduced = self._umap_reducer.fit_transform(emb_cpu)
                cluster_input = reduced
            except Exception:
                cluster_input = emb_cpu

        labels = self.clusterer.cluster_sentences(cluster_input)
        selected_indices = []
        unique_labels = np.unique(labels)
        if len(unique_labels) > 0:
            for label in unique_labels:
                cluster_indices = np.where(labels == label)[0].tolist()
                selected = self.clusterer.select_from_cluster(cluster_indices, sim_matrix)
                selected_indices.extend(selected)

        # fallback
        if len(selected_indices) < self.cfg.fallback_topk:
            sim_np = sim_matrix.cpu().numpy()
            sim_np[sim_np < self.cfg.sim_threshold] = 0
            np.fill_diagonal(sim_np, 0)
            try:
                G = nx.from_numpy_array(sim_np)
                pr = nx.pagerank(G, alpha=self.cfg.pagerank_alpha, max_iter=self.cfg.pagerank_max_iter)
                ranked = sorted(pr.items(), key=lambda x: x[1], reverse=True)
                for idx, _ in ranked:
                    if idx not in selected_indices:
                        selected_indices.append(idx)
                    if len(selected_indices) >= self.cfg.fallback_topk:
                        break
            except Exception:
                sums = sim_np.sum(axis=1)
                idxs = np.argsort(-sums)[:self.cfg.fallback_topk]
                for idx in idxs:
                    if idx not in selected_indices:
                        selected_indices.append(int(idx))

        final_indices = self.selector.select_with_budget(sentences, sorted(list(set(selected_indices))))
        guidance = " ".join([sentences[i] for i in sorted(final_indices)])
        full_text = " ".join(documents)
        return f"{guidance}</s><s>{full_text}"

# --------------------------
# Evaluation & generation helpers (no training)
# --------------------------
def compute_bleu(reference: str, candidate: str) -> float:
    if _HAS_SACREBLEU:
        try:
            return float(sacrebleu.sentence_bleu(candidate, [reference]).score)
        except Exception:
            return 0.0
    return 0.0


def compute_meteor(reference: str, candidate: str) -> float:
    if _HAS_METEOR:
        try:
            return meteor_score([reference], candidate) * 100.0
        except Exception:
            return 0.0
    return 0.0


def compute_coverage_compression(source: str, summary: str, tokenizer) -> Tuple[float, float]:
    src_tokens = len(tokenizer.encode(source, add_special_tokens=False))
    sum_tokens = len(tokenizer.encode(summary, add_special_tokens=False))
    if src_tokens == 0:
        return 0.0, 0.0
    coverage = (sum_tokens / src_tokens) * 100.0
    compression = src_tokens / max(sum_tokens, 1)
    return coverage, compression


def batch_generate(model, tokenizer, inputs: List[str], cfg, batch_size=8, single_sentence=False):
    model.to(cfg.device)
    model.eval()
    outputs = []
    for i in range(0, len(inputs), batch_size):
        batch_texts = inputs[i:i+batch_size]
        enc = tokenizer(batch_texts, max_length=cfg.max_input_length, truncation=True, padding=True, return_tensors='pt').to(cfg.device)
        with torch.no_grad():
            max_len = 50 if single_sentence else cfg.max_target_length
            summary_ids = model.generate(
                enc.input_ids,
                attention_mask=enc.attention_mask,
                num_beams=cfg.num_beams,
                max_length=max_len,
                min_length=8 if single_sentence else 30,
                length_penalty=1.0,
                early_stopping=True)
        for ids in summary_ids:
            text = tokenizer.decode(ids, skip_special_tokens=True).strip()
            if single_sentence:
                # keep only first sentence
                sents = sent_tokenize(text)
                text = sents[0] if sents else text
            outputs.append(text)
    return outputs


def evaluate_model_full(name: str, model, tokenizer, test_dataset, cfg, processor=None, single_sentence=False):
    print(f"\n📊 Evaluating {name} ({len(test_dataset)} samples)...")
    sources = []
    references = []
    inputs = []
    for item in test_dataset:
        src = " ".join(item['texts'])
        ref = item['summary']
        sources.append(src)
        references.append(ref)
        if processor:
            inputs.append(processor.create_guided_input(item['texts']))
        else:
            inputs.append(src)

    candidates = batch_generate(model, tokenizer, inputs, cfg, batch_size=4, single_sentence=single_sentence)

    rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    rouge_values = {'rouge1': [], 'rouge2': [], 'rougeL': []}
    bleu_values = []
    meteor_values = []
    coverage_values = []
    compression_values = []

    for src, ref, cand in zip(sources, references, candidates):
        if not cand:
            cand = ""
        r = rouge.score(ref, cand)
        rouge_values['rouge1'].append(r['rouge1'].fmeasure)
        rouge_values['rouge2'].append(r['rouge2'].fmeasure)
        rouge_values['rougeL'].append(r['rougeL'].fmeasure)
        bleu_values.append(compute_bleu(ref, cand))
        meteor_values.append(compute_meteor(ref, cand))
        cov, comp = compute_coverage_compression(src, cand, tokenizer)
        coverage_values.append(cov)
        compression_values.append(comp)

    try:
        P, R, F1 = bert_score_calculator(candidates, references, lang='en', model_type='distilbert-base-uncased', device=cfg.device, verbose=False)
        bert_values = (F1.cpu().numpy() * 100.0).tolist()
    except Exception:
        bert_values = [0.0] * len(candidates)

    return {
        'candidates': candidates,
        'references': references,
        'sources': sources,
        'rouge1': rouge_values['rouge1'],
        'rouge2': rouge_values['rouge2'],
        'rougeL': rouge_values['rougeL'],
        'bertscore_f1': bert_values,
        'bleu': bleu_values,
        'meteor': meteor_values,
        'coverage': coverage_values,
        'compression': compression_values
    }

# --------------------------
# Plotting utils
# --------------------------
def plot_metric_bars(results_summary: Dict[str, Dict], save_path: Path):
    metrics = ['rouge1_mean','rouge2_mean','rougeL_mean','bertscore_mean']
    df = pd.DataFrame(results_summary).T
    df_plot = df[[m for m in df.columns if m in metrics]]
    df_plot.columns = ['ROUGE-1','ROUGE-2','ROUGE-L','BERTScore']
    ax = df_plot.plot(kind='bar', figsize=(11,5), rot=0, edgecolor='black', alpha=0.9)
    ax.set_ylabel('Score (%)')
    plt.title('Model Metric Comparison')
    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close()


def plot_rouge_violin(all_scores: Dict[str, Dict], save_path: Path):
    df = pd.DataFrame({k: np.array(v['rougeL']) * 100 for k, v in all_scores.items()})
    plt.figure(figsize=(10,5))
    sns.violinplot(data=df, inner='quartile')
    plt.title('ROUGE-L Distribution Across Models')
    plt.ylabel('ROUGE-L (%)')
    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close()


def plot_correlation_heatmap(all_scores: Dict[str, Dict], save_path: Path):
    rows = []
    for m, v in all_scores.items():
        rows.append({
            'model': m,
            'rouge1': np.mean(v['rouge1']) * 100,
            'rouge2': np.mean(v['rouge2']) * 100,
            'rougeL': np.mean(v['rougeL']) * 100,
            'bert': np.mean(v['bertscore_f1']),
            'bleu': np.mean(v['bleu']),
            'meteor': np.mean(v['meteor'])
        })
    df = pd.DataFrame(rows).set_index('model')
    corr = df.corr()
    plt.figure(figsize=(8,6))
    sns.heatmap(corr, annot=True, fmt='.2f', cmap='vlag', center=0)
    plt.title('Correlation Between Metrics (model-mean)')
    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close()


def plot_radar(results_summary: Dict[str, Dict], save_path: Path):
    labels = ['ROUGE-1','ROUGE-2','ROUGE-L','BERTScore']
    n = len(labels)
    angles = np.linspace(0, 2*np.pi, n, endpoint=False).tolist()
    angles += angles[:1]
    plt.figure(figsize=(7,7))
    ax = plt.subplot(polar=True)
    for model_name, vals in results_summary.items():
        stats = [vals['rouge1_mean'], vals['rouge2_mean'], vals['rougeL_mean'], vals['bertscore_mean']]
        stats += stats[:1]
        ax.plot(angles, stats, label=model_name)
        ax.fill(angles, stats, alpha=0.1)
    ax.set_thetagrids(np.degrees(angles[:-1]), labels)
    ax.set_title('Radar: Model Comparison')
    ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close()

# --------------------------
# Data loader helper
# --------------------------
def load_and_mix_datasets(cfg: Config):
    print("🔎 Loading datasets according to ratios:", cfg.ratios)
    available_parts = {}
    raw_datasets = {}

    for ds_key, frac in cfg.ratios.items():
        try:
            name, ver = (ds_key.split('::', 1) + [None])[:2]
            ds = load_dataset(name, ver, split='train', streaming=False)
            raw_datasets[ds_key] = ds
            available_parts[ds_key] = frac
            print(f"  ✓ Loaded dataset: {ds_key} (size: {len(ds)})")
        except Exception as e:
            print(f"  ⚠️ Could not load {ds_key}: {e}. Skipping it.")
            continue

    if not raw_datasets: raise RuntimeError("No datasets could be loaded. Aborting.")

    total_frac = sum(available_parts.values())
    if total_frac < 1.0:
        for k in list(available_parts.keys()): available_parts[k] /= total_frac

    train_target, test_target = cfg.total_train_samples, cfg.total_test_samples
    per_ds_train = {k: int(round(v * train_target)) for k, v in available_parts.items()}
    per_ds_test = {k: int(round(v * test_target)) for k, v in available_parts.items()}
    train_samples, test_samples = [], []

    field_map = {'cnn_dailymail::3.0.0': ('article', 'highlights'), 'multi_news': ('document', 'summary'), 'xsum': ('document', 'summary')}

    for k, ds in raw_datasets.items():
        ds_list = list(ds)
        random.shuffle(ds_list)
        doc_field, sum_field = field_map.get(k, ('document', 'summary'))
        filtered = [{'texts': [ex[doc_field]], 'summary': ex[sum_field]} for ex in ds_list if ex.get(doc_field) and ex.get(sum_field)]
        if not filtered:
            print(f"  ⚠️ No usable examples from {k} after normalization.")
            continue
        want_train, want_test = per_ds_train[k], per_ds_test[k]
        split_idx = min(want_train, len(filtered) - want_test)
        train_samples.extend(filtered[:split_idx])
        test_samples.extend(filtered[split_idx:split_idx+want_test])
        print(f"  → {k}: train {len(filtered[:split_idx])}, test {len(filtered[split_idx:split_idx+want_test])}")

    random.shuffle(train_samples)
    random.shuffle(test_samples)
    train_samples = train_samples[:train_target]
    test_samples = test_samples[:test_target]

    print(f"✓ Final pools: train {len(train_samples)} | test {len(test_samples)}")
    return train_samples, test_samples

# --------------------------
# Main pipeline (inference-only)
# --------------------------
def main():
    set_seed(config.seed)
    mkdir(config.output_dir)
    print("="*80)
    print("🚀 HMTAS — Hierarchical Multi-Document Text Abstractive Summarization (Inference)")
    print("="*80)
    print(f"Device: {config.device}")
    print(f"Train pool target: {config.total_train_samples}, Test pool target: {config.total_test_samples}")
    print(f"Dataset ratios: {config.ratios}")
    print("="*80)

    # Load datasets (for evaluation) — train pool will be empty if config.total_train_samples==0
    train_pool, test_pool = load_and_mix_datasets(config)

    # init tokenizer and processor
    tokenizer = AutoTokenizer.from_pretrained(config.htas_model_name)
    processor = HTASProcessor(config, tokenizer)

    # Load abstractive model (inference-only)
    print("\n🔧 Loading abstractive model (inference):", config.htas_model_name)
    htas_model = AutoModelForSeq2SeqLM.from_pretrained(config.htas_model_name).to(config.device)

    # Demo: generate single-sentence fused summary for 2-3 sample news docs
    demo_docs = [
        "World leaders gathered in Geneva for the 2025 Climate Summit, marking a crucial moment in global climate policy.",
        "Major economies announced historic climate investments at the Geneva summit, including EU and China actions.",
        "The summit's final agreement includes binding emissions targets, transparency mechanisms, and finance commitments."
    ]
    print("\n🎯 Demo multi-document input (2-3 docs):\n", "\n---\n".join(demo_docs))
    guided_demo = processor.create_guided_input(demo_docs)
    single_sent = batch_generate(htas_model, tokenizer, [guided_demo], config, batch_size=1, single_sentence=True)[0]
    print("\n✅ Generated single-sentence global summary (demo):\n", single_sent)

    # Evaluate HMTAS (model used as-is, no fine-tuning)
    all_results = {}
    print("\n📊 Evaluating HMTAS (inference) on test pool...")
    all_results['HMTAS'] = evaluate_model_full("HMTAS (Inference)", htas_model, tokenizer, test_pool, config, processor, single_sentence=False)

    # Evaluate baselines (zero-shot)
    for mname in config.eval_models:
        try:
            print(f"\n📊 Loading baseline model {mname} (inference)...")
            tok = AutoTokenizer.from_pretrained(mname)
            mdl = AutoModelForSeq2SeqLM.from_pretrained(mname).to(config.device)
            all_results[mname] = evaluate_model_full(mname, mdl, tok, test_pool, config, None, single_sentence=False)
            del mdl, tok
            gc.collect(); torch.cuda.empty_cache()
        except Exception as e:
            print(f"⚠️ Could not evaluate {mname}: {e}")

    # Summarize numeric results
    results_summary = {}
    for name, res in all_results.items():
        results_summary[name] = {
            'rouge1_mean': float(np.mean(res['rouge1']) * 100), 'rouge1_std': float(np.std(res['rouge1']) * 100),
            'rouge2_mean': float(np.mean(res['rouge2']) * 100), 'rouge2_std': float(np.std(res['rouge2']) * 100),
            'rougeL_mean': float(np.mean(res['rougeL']) * 100), 'rougeL_std': float(np.std(res['rougeL']) * 100),
            'bertscore_mean': float(np.mean(res['bertscore_f1'])), 'bertscore_std': float(np.std(res['bertscore_f1'])),
            'bleu_mean': float(np.mean(res['bleu'])), 'bleu_std': float(np.std(res['bleu'])),
            'meteor_mean': float(np.mean(res['meteor'])), 'meteor_std': float(np.std(res['meteor'])),
            'coverage_mean': float(np.mean(res['coverage'])), 'compression_mean': float(np.mean(res['compression']))
        }

    # Save CSV of generated summaries
    rows = []
    for model_name, res in all_results.items():
        for i, cand in enumerate(res['candidates']):
            rows.append({'model': model_name, 'idx': i, 'candidate': cand, 'reference': res['references'][i]})
    df_out = pd.DataFrame(rows)
    df_out.to_csv(config.output_dir / 'generated_summaries.csv', index=False)

    # Save raw results (truncated)
    safe_results = {}
    for k, v in all_results.items():
        safe_results[k] = {kk: (vv if not isinstance(vv, list) else vv[:50]) for kk, vv in v.items()}
    with open(config.output_dir / 'results_raw_sample.json', 'w') as f:
        json.dump(safe_results, f, indent=2)
    with open(config.output_dir / 'results_summary.json', 'w') as f:
        json.dump(results_summary, f, indent=2)

    # Print final table
    print("\n" + "="*80)
    print("📊 FINAL RESULTS SUMMARY")
    print("="*80)
    print("| Model | ROUGE-1 | ROUGE-2 | ROUGE-L | BERTScore | BLEU | METEOR | COVERAGE | COMPRESSION |")
    print("|-------|---------|---------|---------|-----------|------|--------|----------|-------------|")
    for name, vals in results_summary.items():
        print(f"| {name:<20} | {vals['rouge1_mean']:.2f}±{vals['rouge1_std']:.2f} | {vals['rouge2_mean']:.2f}±{vals['rouge2_std']:.2f} | {vals['rougeL_mean']:.2f}±{vals['rougeL_std']:.2f} | {vals['bertscore_mean']:.2f}±{vals['bertscore_std']:.2f} | {vals['bleu_mean']:.2f} | {vals['meteor_mean']:.2f} | {vals['coverage_mean']:.2f} | {vals['compression_mean']:.2f} |")

    # Generate plots and save
    print("\n📈 Generating plots...")
    plot_dir = config.output_dir / 'plots'
    mkdir(plot_dir)
    try:
        plot_metric_bars(results_summary, plot_dir / 'metric_comparison.png')
    except Exception as e:
        print("⚠️ plot_metric_bars failed:", e)
    try:
        plot_rouge_violin(all_results, plot_dir / 'rouge_violin.png')
    except Exception as e:
        print("⚠️ plot_rouge_violin failed:", e)
    try:
        plot_correlation_heatmap(all_results, plot_dir / 'metric_correlation.png')
    except Exception as e:
        print("⚠️ plot_correlation_heatmap failed:", e)
    try:
        plot_radar(results_summary, plot_dir / 'radar_comparison.png')
    except Exception as e:
        print("⚠️ plot_radar failed:", e)

    print(f"\n✓ Saved plots to {plot_dir}")
    print(f"✓ Saved generated summaries to {config.output_dir / 'generated_summaries.csv'}")
    print(f"✓ Saved results JSONs to {config.output_dir}")

    # cleanup
    try:
        del htas_model
    except Exception:
        pass
    gc.collect(); torch.cuda.empty_cache()
    print("\n🎉 Pipeline complete. All outputs are in:", config.output_dir)

if __name__ == '__main__':
    main()


🚀 HMTAS — Hierarchical Multi-Document Text Abstractive Summarization (Inference)
Device: cuda
Train pool target: 0, Test pool target: 2000
Dataset ratios: {'cnn_dailymail::3.0.0': 0.35, 'multi_news': 0.35, 'xsum': 0.3}
🔎 Loading datasets according to ratios: {'cnn_dailymail::3.0.0': 0.35, 'multi_news': 0.35, 'xsum': 0.3}


Downloading readme: 0.00B [00:00, ?B/s]

Downloading data: 100%|██████████| 257M/257M [00:01<00:00, 164MB/s]
Downloading data: 100%|██████████| 257M/257M [00:01<00:00, 177MB/s]
Downloading data: 100%|██████████| 259M/259M [00:01<00:00, 180MB/s]
Downloading data: 100%|██████████| 34.7M/34.7M [00:01<00:00, 18.2MB/s]
Downloading data: 100%|██████████| 30.0M/30.0M [00:00<00:00, 47.0MB/s]


Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

  ✓ Loaded dataset: cnn_dailymail::3.0.0 (size: 287113)


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading readme: 0.00B [00:00, ?B/s]

Downloading data:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/58.8M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/66.9M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.30M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/69.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.31M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/44972 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5622 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5622 [00:00<?, ? examples/s]

  ✓ Loaded dataset: multi_news (size: 44972)


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading readme: 0.00B [00:00, ?B/s]

Downloading data:   0%|          | 0.00/255M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.00M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/204045 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11332 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11334 [00:00<?, ? examples/s]

  ✓ Loaded dataset: xsum (size: 204045)
  → cnn_dailymail::3.0.0: train 0, test 700
  → multi_news: train 0, test 700
  → xsum: train 0, test 600
✓ Final pools: train 0 | test 2000


config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]


🔧 Loading abstractive model (inference): facebook/bart-large-cnn


model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]


🎯 Demo multi-document input (2-3 docs):
 World leaders gathered in Geneva for the 2025 Climate Summit, marking a crucial moment in global climate policy.
---
Major economies announced historic climate investments at the Geneva summit, including EU and China actions.
---
The summit's final agreement includes binding emissions targets, transparency mechanisms, and finance commitments.

✅ Generated single-sentence global summary (demo):
 World leaders gathered in Geneva for the 2025 Climate Summit.

📊 Evaluating HMTAS (inference) on test pool...

📊 Evaluating HMTAS (Inference) (2000 samples)...


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]


📊 Loading baseline model facebook/bart-large-cnn (inference)...

📊 Evaluating facebook/bart-large-cnn (2000 samples)...

📊 Loading baseline model t5-base (inference)...


config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]


📊 Evaluating t5-base (2000 samples)...

📊 Loading baseline model google/pegasus-cnn_dailymail (inference)...


tokenizer_config.json:   0%|          | 0.00/88.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/1.91M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/65.0 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.28G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/280 [00:00<?, ?B/s]


📊 Evaluating google/pegasus-cnn_dailymail (2000 samples)...

📊 FINAL RESULTS SUMMARY
| Model | ROUGE-1 | ROUGE-2 | ROUGE-L | BERTScore | BLEU | METEOR | COVERAGE | COMPRESSION |
|-------|---------|---------|---------|-----------|------|--------|----------|-------------|
| HMTAS                | 27.79±13.02 | 9.55±10.47 | 18.57±10.75 | 78.00±4.27 | 0.00 | 0.00 | 10.44 | 22.26 |
| facebook/bart-large-cnn | 29.17±14.27 | 10.91±11.84 | 20.06±12.34 | 78.46±4.69 | 0.00 | 0.00 | 9.56 | 22.41 |
| t5-base              | 27.23±12.56 | 9.31±9.96 | 18.06±9.96 | 77.33±4.53 | 0.00 | 0.00 | 11.19 | 18.49 |
| google/pegasus-cnn_dailymail | 31.89±16.71 | 13.78±15.37 | 22.93±16.04 | 78.12±5.16 | 0.00 | 0.00 | 10.75 | 19.28 |

📈 Generating plots...

✓ Saved plots to /kaggle/working/htas_output/plots
✓ Saved generated summaries to /kaggle/working/htas_output/generated_summaries.csv
✓ Saved results JSONs to /kaggle/working/htas_output

🎉 Pipeline complete. All outputs are in: /kaggle/working/htas_output
