<a href="https://colab.research.google.com/github/RayAKaan/Personal-Research/blob/main/HTAS-V4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
# ======================================================================================
# COLAB ENVIRONMENT SETUP (Fixed for Oct 2025)
# ======================================================================================

print("⏳ Installing all required Python packages... This may take a few minutes.")

# --- Full environment setup with binary compatibility fix ---
!pip install -U numpy==1.26.4 scikit-learn==1.3.2 --force-reinstall --no-cache-dir
!pip install transformers==4.38.2 datasets==2.18.0 sentence-transformers==2.7.0 \
rouge-score==0.1.2 bert-score==0.3.13 networkx==3.2.1 hdbscan==0.8.33 \
nltk==3.8.1 torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 \
matplotlib seaborn --extra-index-url https://download.pytorch.org/whl/cu118

print("✅ Python packages installed successfully and compatible with NumPy ABI.")

# --- Download NLTK resources ---
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')  # Some newer NLTK versions also need this
print("✅ Punkt tokenizer data reinstalled successfully!")

print("\n" + "="*80)
print("  ✅ SETUP COMPLETE. THE ENVIRONMENT IS NOW PREPARED.")
print("  ⚠️ IMPORTANT: Please restart the runtime before running your experiments.")
print("  Go to the menu: 'Runtime' -> 'Restart Session' (or 'Factory reset runtime').")
print("="*80)


⏳ Installing all required Python packages... This may take a few minutes.
Collecting numpy==1.26.4
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m100.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting scikit-learn==1.3.2
  Downloading scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting scipy>=1.5.0 (from scikit-learn==1.3.2)
  Downloading scipy-1.16.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (62 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31m197.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting joblib>=1.1.1 (from scikit-learn==1.3.2)
  Downloading joblib-1.5.2-py3-none-any.whl.metadata (5.6 kB)
Collecting threadpoolctl>=2.0.0 (from scikit-learn==1.3.2)
  Downloading threadpoolctl-3.6.0-py3-none-any.whl.metadata (13 kB)
Downloadi

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu118
✅ Python packages installed successfully and compatible with NumPy ABI.


ValueError: numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject

In [1]:
# --------------------------
# Begin script
# --------------------------
import warnings
warnings.filterwarnings("ignore")
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import time
import gc
import json
import random
from pathlib import Path
from collections import defaultdict
from typing import List, Dict, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset as TorchDataset

from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    get_linear_schedule_with_warmup,
    DataCollatorForSeq2Seq,
    logging as hf_logging
)
from rouge_score import rouge_scorer
from bert_score import score as bert_score_calculator

import nltk
from nltk.tokenize import sent_tokenize

# try to import sacrebleu (preferred BLEU)
try:
    import sacrebleu
    _HAS_SACREBLEU = True
except Exception:
    _HAS_SACREBLEU = False

# meteor fallback (nltk)
try:
    from nltk.translate.meteor_score import meteor_score
    _HAS_METEOR = True
except Exception:
    _HAS_METEOR = False

import hdbscan
import networkx as nx

hf_logging.set_verbosity_error()
sns.set_style("whitegrid")

# --------------------------
# Configuration
# --------------------------
class Config:
    # Data
    num_train_samples = 2500   # increased (2x)
    num_test_samples  = 400   # increased (2x)

    # Model (HTAS fine-tune)
    htas_model_name = 'facebook/bart-base'
    max_input_length = 1024
    max_target_length = 150

    # Baselines (zero-shot evaluation)
    eval_models = [
        'facebook/bart-large-cnn',
        't5-base',
        'google/pegasus-cnn_dailymail'
    ]

    # Training
    batch_size = 4
    gradient_accumulation_steps = 2
    learning_rate = 3e-5
    max_steps = 1200
    epochs = 3
    warmup_ratio = 0.1
    fp16 = True

    # HTAS
    token_budget = 400
    min_cluster_size = 3
    per_cluster_topk = 2
    fallback_topk = 8
    sim_threshold = 0.15

    # PageRank
    pagerank_alpha = 0.85
    pagerank_max_iter = 100

    # Evaluation
    num_beams = 4
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    seed = 42

    # Output
    output_dir = Path('/content/htas_output')

config = Config()

# --------------------------
# Utilities
# --------------------------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def mkdir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

# ensure NLTK punkt
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

# --------------------------
# HTAS components (encoder, clusterer, selector)
# --------------------------
class SentenceEncoderGPU:
    def __init__(self, device: str = config.device):
        self.device = device
        self.model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
        self.model.eval()

    def encode(self, sentences: List[str]) -> torch.Tensor:
        if not sentences:
            return torch.empty(0, 384, device=self.device)
        return self.model.encode(sentences, convert_to_tensor=True, device=self.device, show_progress_bar=False)

    def cosine_similarity(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
        a_norm = F.normalize(a, p=2, dim=-1)
        b_norm = F.normalize(b, p=2, dim=-1)
        return torch.mm(a_norm, b_norm.t())

class HTASClusterer:
    def __init__(self, cfg):
        self.cfg = cfg

    def cluster_sentences(self, embeddings: torch.Tensor) -> np.ndarray:
        n = len(embeddings)
        if n == 0:
            return np.array([], dtype=int)
        if n < self.cfg.min_cluster_size:
            return np.zeros(n, dtype=int)
        emb_cpu = embeddings.cpu().numpy()
        clusterer = hdbscan.HDBSCAN(min_cluster_size=self.cfg.min_cluster_size, min_samples=1, metric='euclidean')
        labels = clusterer.fit_predict(emb_cpu)
        # assign noise to nearest non-noise cluster
        noise_mask = labels == -1
        if noise_mask.any() and (~noise_mask).any():
            non_noise_idx = np.where(~noise_mask)[0]
            for idx in np.where(noise_mask)[0]:
                sims = F.cosine_similarity(embeddings[idx].unsqueeze(0), embeddings[non_noise_idx])
                labels[idx] = int(labels[non_noise_idx][sims.argmax().item()])
        return labels

    def select_from_cluster(self, cluster_indices: List[int], sim_matrix: torch.Tensor) -> List[int]:
        if not cluster_indices:
            return []
        if len(cluster_indices) <= self.cfg.per_cluster_topk:
            return cluster_indices
        sim_sub = sim_matrix[cluster_indices][:, cluster_indices].cpu().numpy()
        sim_sub[sim_sub < self.cfg.sim_threshold] = 0
        np.fill_diagonal(sim_sub, 0)
        try:
            G = nx.from_numpy_array(sim_sub)
            prs = nx.pagerank(G, alpha=self.cfg.pagerank_alpha, max_iter=self.cfg.pagerank_max_iter)
            deg = nx.degree_centrality(G)
            combined = {i: 0.6 * prs[i] + 0.4 * deg[i] for i in range(len(cluster_indices))}
            ranked = sorted(combined.items(), key=lambda x: x[1], reverse=True)
            return [cluster_indices[i] for i, _ in ranked[:self.cfg.per_cluster_topk]]
        except Exception:
            sums = sim_sub.sum(axis=1)
            idxs = np.argsort(-sums)[:self.cfg.per_cluster_topk]
            return [cluster_indices[int(i)] for i in idxs]

class TokenBudgetedSelector:
    def __init__(self, tokenizer, cfg):
        self.tokenizer = tokenizer
        self.cfg = cfg

    def select_with_budget(self, sentences: List[str], candidate_indices: List[int]) -> List[int]:
        selected = []
        total_tokens = 0
        for idx in candidate_indices:
            tokens = len(self.tokenizer.encode(sentences[idx], add_special_tokens=False))
            if total_tokens + tokens <= self.cfg.token_budget:
                selected.append(idx)
                total_tokens += tokens
            if total_tokens >= self.cfg.token_budget:
                break
        return selected

class HTASProcessor:
    def __init__(self, cfg, tokenizer):
        self.cfg = cfg
        self.tokenizer = tokenizer
        self.encoder = SentenceEncoderGPU(device=cfg.device)
        self.clusterer = HTASClusterer(cfg)
        self.selector = TokenBudgetedSelector(tokenizer, cfg)

    def preprocess_documents(self, documents: List[str]) -> List[str]:
        sentences = []
        for doc in documents:
            sents = sent_tokenize(doc)
            sentences.extend([s.strip() for s in sents if 8 < len(s.split()) < 150])
        return sentences

    def create_guided_input(self, documents: List[str]) -> str:
        sentences = self.preprocess_documents(documents)
        if not sentences:
            return " ".join(documents)
        embeddings = self.encoder.encode(sentences)
        if len(embeddings) == 0:
            return " ".join(documents)
        sim_matrix = self.encoder.cosine_similarity(embeddings, embeddings)
        labels = self.clusterer.cluster_sentences(embeddings)
        selected_indices = []
        for label in np.unique(labels):
            cluster_indices = np.where(labels == label)[0].tolist()
            selected = self.clusterer.select_from_cluster(cluster_indices, sim_matrix)
            selected_indices.extend(selected)
        if len(selected_indices) < self.cfg.fallback_topk:
            sim_np = sim_matrix.cpu().numpy()
            sim_np[sim_np < self.cfg.sim_threshold] = 0
            np.fill_diagonal(sim_np, 0)
            G = nx.from_numpy_array(sim_np)
            pr = nx.pagerank(G, alpha=self.cfg.pagerank_alpha, max_iter=self.cfg.pagerank_max_iter)
            ranked = sorted(pr.items(), key=lambda x: x[1], reverse=True)
            for idx, _ in ranked:
                if idx not in selected_indices:
                    selected_indices.append(idx)
                if len(selected_indices) >= self.cfg.fallback_topk:
                    break
        final_indices = self.selector.select_with_budget(sentences, selected_indices)
        final_indices = sorted(set(final_indices))
        guidance = " ".join([sentences[i] for i in final_indices])
        full_text = " ".join(documents)
        return f"{guidance}</s><s>{full_text}"

# --------------------------
# Dataset and DataLoader
# --------------------------
class SummarizationDataset(TorchDataset):
    def __init__(self, texts: List[str], summaries: List[str], tokenizer, cfg):
        self.texts = texts
        self.summaries = summaries
        self.tokenizer = tokenizer
        self.cfg = cfg

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        model_inputs = self.tokenizer(str(self.texts[idx]), max_length=self.cfg.max_input_length, truncation=True, padding=False)
        with self.tokenizer.as_target_tokenizer():
            labels = self.tokenizer(str(self.summaries[idx]), max_length=self.cfg.max_target_length, truncation=True, padding=False)
        model_inputs['labels'] = labels['input_ids']
        return model_inputs

# --------------------------
# Trainer (manual loop supporting fp16)
# --------------------------
def train_model(model, tokenizer, train_dataset, cfg):
    model.to(cfg.device)
    model.train()
    collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)
    dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collator)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, eps=1e-8)
    num_steps = min(cfg.max_steps, (len(dataloader) // cfg.gradient_accumulation_steps) * cfg.epochs)
    num_warmup = int(cfg.warmup_ratio * num_steps) if num_steps > 0 else 0
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup, num_steps) if num_steps > 0 else None
    scaler = torch.cuda.amp.GradScaler(enabled=cfg.fp16 and cfg.device.startswith('cuda'))
    global_step = 0
    losses = []
    t0 = time.time()
    print(f"\n🚂 Training for {num_steps} steps...")
    for epoch in range(cfg.epochs):
        if global_step >= num_steps:
            break
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{cfg.epochs}")
        for step, batch in enumerate(pbar):
            if global_step >= num_steps:
                break
            batch = {k: v.to(cfg.device) for k, v in batch.items()}
            with torch.cuda.amp.autocast(enabled=cfg.fp16 and cfg.device.startswith('cuda')):
                outputs = model(**batch)
                loss = outputs.loss / cfg.gradient_accumulation_steps
            scaler.scale(loss).backward()
            if (step + 1) % cfg.gradient_accumulation_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                if scheduler:
                    scheduler.step()
                optimizer.zero_grad()
                global_step += 1
                val = float(loss.item() * cfg.gradient_accumulation_steps)
                losses.append(val)
                pbar.set_postfix({'loss': f'{val:.4f}'})
    elapsed = time.time() - t0
    print(f"✓ Fine-tuning finished in {elapsed/60:.2f} min")
    return model, losses

# --------------------------
# Metrics helpers
# --------------------------
def compute_bleu(reference: str, candidate: str) -> float:
    if _HAS_SACREBLEU:
        try:
            # sacrebleu sentence-level (returns score between 0-100)
            return float(sacrebleu.sentence_bleu(candidate, [reference]).score)
        except Exception:
            return 0.0
    else:
        # fallback: quick token overlap percent (not true BLEU)
        try:
            ref_tokens = reference.split()
            cand_tokens = candidate.split()
            if not cand_tokens or not ref_tokens:
                return 0.0
            # n-gram overlap heuristic
            overlap = len(set(ref_tokens) & set(cand_tokens)) / max(1, len(set(ref_tokens)))
            return overlap * 100.0
        except Exception:
            return 0.0

def compute_meteor(reference: str, candidate: str) -> float:
    if _HAS_METEOR:
        try:
            return meteor_score([reference], candidate) * 100.0
        except Exception:
            return 0.0
    return 0.0

def compute_coverage_compression(source: str, summary: str, tokenizer) -> Tuple[float, float]:
    src_tokens = len(tokenizer.encode(source, add_special_tokens=False))
    sum_tokens = len(tokenizer.encode(summary, add_special_tokens=False))
    if src_tokens == 0:
        return 0.0, 0.0
    coverage = (sum_tokens / src_tokens) * 100.0
    compression = src_tokens / max(sum_tokens, 1)
    return coverage, compression

# --------------------------
# Batched generation and evaluation
# --------------------------
def batch_generate(model, tokenizer, inputs: List[str], cfg, batch_size=8):
    model.to(cfg.device)
    model.eval()
    outputs = []
    for i in range(0, len(inputs), batch_size):
        batch_texts = inputs[i:i+batch_size]
        enc = tokenizer(batch_texts, max_length=cfg.max_input_length, truncation=True, padding=True, return_tensors='pt').to(cfg.device)
        with torch.no_grad():
            summary_ids = model.generate(enc.input_ids, attention_mask=enc.attention_mask, num_beams=cfg.num_beams, max_length=cfg.max_target_length, min_length=30, length_penalty=1.0, early_stopping=True)
        for ids in summary_ids:
            outputs.append(tokenizer.decode(ids, skip_special_tokens=True).strip())
    return outputs

def evaluate_model_full(name: str, model, tokenizer, test_dataset, cfg, processor=None):
    print(f"\n📊 Evaluating {name} ({len(test_dataset)} samples)...")
    sources = []
    references = []
    # build input list first
    inputs = []
    for item in test_dataset:
        src = " ".join(item['texts'])
        ref = item['summary']
        sources.append(src)
        references.append(ref)
        if processor:
            inputs.append(processor.create_guided_input(item['texts']))
        else:
            inputs.append(src)
    candidates = batch_generate(model, tokenizer, inputs, cfg, batch_size=4)
    # per-sample metrics
    rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    rouge_values = {'rouge1': [], 'rouge2': [], 'rougeL': []}
    bert_values = []
    bleu_values = []
    meteor_values = []
    coverage_values = []
    compression_values = []
    for src, ref, cand in zip(sources, references, candidates):
        if not cand:
            cand = ""
        r = rouge.score(ref, cand)
        rouge_values['rouge1'].append(r['rouge1'].fmeasure)
        rouge_values['rouge2'].append(r['rouge2'].fmeasure)
        rouge_values['rougeL'].append(r['rougeL'].fmeasure)
        bleu_values.append(compute_bleu(ref, cand))
        meteor_values.append(compute_meteor(ref, cand))
        cov, comp = compute_coverage_compression(src, cand, tokenizer)
        coverage_values.append(cov)
        compression_values.append(comp)
    # BERTScore in batch
    try:
        P, R, F1 = bert_score_calculator(candidates, references, lang='en', model_type='distilbert-base-uncased', device=cfg.device, verbose=False)
        bert_values = (F1.cpu().numpy() * 100.0).tolist()
    except Exception:
        bert_values = [0.0] * len(candidates)
    return {
        'candidates': candidates,
        'references': references,
        'sources': sources,
        'rouge1': rouge_values['rouge1'],
        'rouge2': rouge_values['rouge2'],
        'rougeL': rouge_values['rougeL'],
        'bertscore_f1': bert_values,
        'bleu': bleu_values,
        'meteor': meteor_values,
        'coverage': coverage_values,
        'compression': compression_values
    }

# --------------------------
# Plotting (5 plots)
# --------------------------
def plot_training_loss(losses: List[float], save_path: Path):
    plt.figure(figsize=(8,4))
    plt.plot(losses, marker='o', linewidth=1)
    plt.title("Training Loss Curve")
    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close()

def plot_metric_bars(results_summary: Dict[str, Dict], save_path: Path):
    metrics = ['rouge1_mean','rouge2_mean','rougeL_mean','bertscore_mean']
    df = pd.DataFrame(results_summary).T
    df_plot = df[[m for m in df.columns if m in metrics]]
    df_plot.columns = ['ROUGE-1','ROUGE-2','ROUGE-L','BERTScore']
    ax = df_plot.plot(kind='bar', figsize=(11,5), rot=0, edgecolor='black', alpha=0.9)
    ax.set_ylabel('Score (%)')
    plt.title('Model Metric Comparison')
    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close()

def plot_rouge_violin(all_scores: Dict[str, Dict], save_path: Path):
    df = pd.DataFrame({k: np.array(v['rougeL']) * 100 for k, v in all_scores.items()})
    plt.figure(figsize=(10,5))
    sns.violinplot(data=df, inner='quartile')
    plt.title('ROUGE-L Distribution Across Models')
    plt.ylabel('ROUGE-L (%)')
    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close()

def plot_correlation_heatmap(all_scores: Dict[str, Dict], save_path: Path):
    rows = []
    for m, v in all_scores.items():
        rows.append({
            'model': m,
            'rouge1': np.mean(v['rouge1']) * 100,
            'rouge2': np.mean(v['rouge2']) * 100,
            'rougeL': np.mean(v['rougeL']) * 100,
            'bert': np.mean(v['bertscore_f1']),
            'bleu': np.mean(v['bleu']),
            'meteor': np.mean(v['meteor'])
        })
    df = pd.DataFrame(rows).set_index('model')
    corr = df.corr()
    plt.figure(figsize=(8,6))
    sns.heatmap(corr, annot=True, fmt='.2f', cmap='vlag', center=0)
    plt.title('Correlation Between Metrics (model-mean)')
    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close()

def plot_radar(results_summary: Dict[str, Dict], save_path: Path):
    labels = ['ROUGE-1','ROUGE-2','ROUGE-L','BERTScore']
    n = len(labels)
    angles = np.linspace(0, 2*np.pi, n, endpoint=False).tolist()
    angles += angles[:1]
    plt.figure(figsize=(7,7))
    ax = plt.subplot(polar=True)
    for model_name, vals in results_summary.items():
        stats = [vals['rouge1_mean'], vals['rouge2_mean'], vals['rougeL_mean'], vals['bertscore_mean']]
        stats += stats[:1]
        ax.plot(angles, stats, label=model_name)
        ax.fill(angles, stats, alpha=0.1)
    ax.set_thetagrids(np.degrees(angles[:-1]), labels)
    ax.set_title('Radar: Model Comparison')
    ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close()

# --------------------------
# Main pipeline
# --------------------------
def main():
    set_seed(config.seed)
    mkdir(config.output_dir)
    print("="*80)
    print("🚀 HTAS-V3 (Upgraded Multi-Model Pipeline) — Full Run")
    print("="*80)
    print(f"Device: {config.device}")
    print(f"Train samples: {config.num_train_samples}, Test samples: {config.num_test_samples}")
    print(f"HTAS model (to fine-tune): {config.htas_model_name}")
    print(f"Baselines (zero-shot): {config.eval_models}")
    print("="*80)

    # Load and prepare dataset
    def normalize(ex):
        if 'article' in ex and ex.get('article'):
            return {'texts':[ex['article']], 'summary': ex['highlights']}
        return {'texts': [], 'summary': ''}

    total = config.num_train_samples + config.num_test_samples
    print("\n📚 Loading dataset...")
    dataset = load_dataset('cnn_dailymail', '3.0.0', split=f'train[:{total}]', trust_remote_code=True)
    dataset = dataset.map(normalize, remove_columns=dataset.column_names)
    dataset = dataset.filter(lambda x: x['texts'] and x['summary'])
    split = dataset.train_test_split(test_size=config.num_test_samples, seed=config.seed)
    train_data = split['train']
    test_data  = split['test']
    print(f"✓ Loaded {len(train_data)} train, {len(test_data)} test samples")

    # init tokenizer and processor
    tokenizer = AutoTokenizer.from_pretrained(config.htas_model_name)
    processor = HTASProcessor(config, tokenizer)

    # Prepare HTAS-guided training data
    print("\n🔧 Preparing HTAS-guided training data...")
    guided_texts = []
    for item in tqdm(train_data, desc="Processing"):
        guided_texts.append(processor.create_guided_input(item['texts']))
    train_summaries = [item['summary'] for item in train_data]
    train_dataset = SummarizationDataset(guided_texts, train_summaries, tokenizer, config)

    # Fine-tune HTAS model
    print("\n🚂 Fine-tuning HTAS model...")
    htas_model = AutoModelForSeq2SeqLM.from_pretrained(config.htas_model_name)
    htas_model, losses = train_model(htas_model, tokenizer, train_dataset, config)

    # Save HTAS model
    model_dir = config.output_dir / 'htas_model'
    mkdir(model_dir)
    htas_model.save_pretrained(model_dir)
    tokenizer.save_pretrained(model_dir)
    print(f"✓ HTAS model saved to {model_dir}")

    # Demo
    print("\n🎯 Multi-doc demo (short):")
    demo_docs = [
        "World leaders gathered in Geneva for the 2025 Climate Summit, marking a crucial moment in global climate policy.",
        "Major economies announced historic climate investments at the Geneva summit, including EU and China actions.",
        "The summit's final agreement includes binding emissions targets, transparency mechanisms, and finance commitments."
    ]
    guided = processor.create_guided_input(demo_docs)
    out = batch_generate(htas_model, tokenizer, [guided], config, batch_size=1)[0]
    print("Generated (demo):", out[:600], "...")

    # Evaluate HTAS
    all_results = {}
    print("\n📊 Evaluating HTAS (fine-tuned)...")
    htas_eval = evaluate_model_full("HTAS-V3 (Fine-Tuned)", htas_model, tokenizer, test_data, config, processor)
    all_results['HTAS-V3'] = htas_eval

    # Evaluate baselines
    for mname in config.eval_models:
        try:
            print(f"\n📊 Loading baseline model {mname} (this may download large weights)...")
            tok = AutoTokenizer.from_pretrained(mname)
            mdl = AutoModelForSeq2SeqLM.from_pretrained(mname)
            baseline_eval = evaluate_model_full(mname, mdl, tok, test_data, config, None)
            all_results[mname] = baseline_eval
            del mdl
            gc.collect()
            torch.cuda.empty_cache()
        except Exception as e:
            print(f"⚠️ Could not evaluate {mname}: {e}")

    # Summarize numeric results
    results_summary = {}
    for name, res in all_results.items():
        results_summary[name] = {
            'rouge1_mean': np.mean(res['rouge1']) * 100, 'rouge1_std': np.std(res['rouge1']) * 100,
            'rouge2_mean': np.mean(res['rouge2']) * 100, 'rouge2_std': np.std(res['rouge2']) * 100,
            'rougeL_mean': np.mean(res['rougeL']) * 100, 'rougeL_std': np.std(res['rougeL']) * 100,
            'bertscore_mean': np.mean(res['bertscore_f1']), 'bertscore_std': np.std(res['bertscore_f1']),
            'bleu_mean': np.mean(res['bleu']), 'bleu_std': np.std(res['bleu']),
            'meteor_mean': np.mean(res['meteor']), 'meteor_std': np.std(res['meteor']),
            'coverage_mean': np.mean(res['coverage']), 'compression_mean': np.mean(res['compression'])
        }

    # Save generated summaries CSV
    rows = []
    for model_name, res in all_results.items():
        for i, cand in enumerate(res['candidates']):
            rows.append({'model': model_name, 'idx': i, 'candidate': cand, 'reference': res['references'][i]})
    df_out = pd.DataFrame(rows)
    df_out.to_csv(config.output_dir / 'generated_summaries.csv', index=False)

    # Save raw results (safe JSON)
    safe_results = {}
    for k, v in all_results.items():
        safe_results[k] = {kk: (vv if not isinstance(vv, list) else vv[:20]) for kk, vv in v.items()}  # truncate list fields
    with open(config.output_dir / 'results_raw_sample.json', 'w') as f:
        json.dump(safe_results, f, indent=2)

    with open(config.output_dir / 'results_summary.json', 'w') as f:
        json.dump(results_summary, f, indent=2)

    # Print summary table
    print("\n" + "="*80)
    print("📊 FINAL RESULTS SUMMARY")
    print("="*80)
    print("| Model | ROUGE-1 | ROUGE-2 | ROUGE-L | BERTScore | BLEU | METEOR | COVERAGE | COMPRESSION |")
    print("|-------|---------|---------|---------|-----------|------|--------|----------|-------------|")
    for name, vals in results_summary.items():
        print(f"| {name:<20} | {vals['rouge1_mean']:.2f}±{vals['rouge1_std']:.2f} | {vals['rouge2_mean']:.2f}±{vals['rouge2_std']:.2f} | {vals['rougeL_mean']:.2f}±{vals['rougeL_std']:.2f} | {vals['bertscore_mean']:.2f}±{vals['bertscore_std']:.2f} | {vals['bleu_mean']:.2f} | {vals['meteor_mean']:.2f} | {vals['coverage_mean']:.2f} | {vals['compression_mean']:.2f} |")

    # Generate plots (5)
    print("\n📈 Generating plots...")
    mkdir(config.output_dir / 'plots')
    plot_training_loss(losses, config.output_dir / 'plots' / 'training_loss.png')
    plot_metric_bars(results_summary, config.output_dir / 'plots' / 'metric_comparison.png')
    plot_rouge_violin(all_results, config.output_dir / 'plots' / 'rouge_violin.png')
    plot_correlation_heatmap(all_results, config.output_dir / 'plots' / 'metric_correlation.png')
    plot_radar(results_summary, config.output_dir / 'plots' / 'radar_comparison.png')

    print(f"\n✓ Saved plots to {config.output_dir / 'plots'}")
    print(f"✓ Saved generated summaries to {config.output_dir / 'generated_summaries.csv'}")
    print(f"✓ Saved results JSONs to {config.output_dir}")

    # Cleanup
    try:
        del htas_model
    except Exception:
        pass
    gc.collect()
    torch.cuda.empty_cache()
    print("\n🎉 Pipeline complete. All outputs are in:", config.output_dir)

if __name__ == "__main__":
    main()


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


🚀 HTAS-V3 (Upgraded Multi-Model Pipeline) — Full Run
Device: cuda
Train samples: 2500, Test samples: 400
HTAS model (to fine-tune): facebook/bart-base
Baselines (zero-shot): ['facebook/bart-large-cnn', 't5-base', 'google/pegasus-cnn_dailymail']

📚 Loading dataset...


Downloading readme: 0.00B [00:00, ?B/s]

Downloading data: 100%|██████████| 257M/257M [00:01<00:00, 145MB/s]
Downloading data: 100%|██████████| 257M/257M [00:04<00:00, 53.4MB/s]
Downloading data: 100%|██████████| 259M/259M [00:05<00:00, 46.1MB/s]
Downloading data: 100%|██████████| 34.7M/34.7M [00:00<00:00, 109MB/s] 
Downloading data: 100%|██████████| 30.0M/30.0M [00:00<00:00, 96.8MB/s]


Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

Map:   0%|          | 0/2900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2900 [00:00<?, ? examples/s]

✓ Loaded 2500 train, 400 test samples


config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]


🔧 Preparing HTAS-guided training data...


Processing:   0%|          | 0/2500 [00:00<?, ?it/s]


🚂 Fine-tuning HTAS model...


model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]


🚂 Training for 936 steps...


Epoch 1/3:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 2/3:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 3/3:   0%|          | 0/625 [00:00<?, ?it/s]

✓ Fine-tuning finished in 6.66 min
✓ HTAS model saved to /content/htas_output/htas_model

🎯 Multi-doc demo (short):
Generated (demo): World leaders gathered in Geneva for the 2025 Climate Summit.
Major economies announced historic climate investments at the summit.
The summit's final agreement includes binding emissions targets, transparency mechanisms, and finance commitments. ...

📊 Evaluating HTAS (fine-tuned)...

📊 Evaluating HTAS-V3 (Fine-Tuned) (400 samples)...


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]


📊 Loading baseline model facebook/bart-large-cnn (this may download large weights)...


config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]


📊 Evaluating facebook/bart-large-cnn (400 samples)...

📊 Loading baseline model t5-base (this may download large weights)...


config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]


📊 Evaluating t5-base (400 samples)...

📊 Loading baseline model google/pegasus-cnn_dailymail (this may download large weights)...


tokenizer_config.json:   0%|          | 0.00/88.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/1.91M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/65.0 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.28G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/280 [00:00<?, ?B/s]


📊 Evaluating google/pegasus-cnn_dailymail (400 samples)...

📊 FINAL RESULTS SUMMARY
| Model | ROUGE-1 | ROUGE-2 | ROUGE-L | BERTScore | BLEU | METEOR | COVERAGE | COMPRESSION |
|-------|---------|---------|---------|-----------|------|--------|----------|-------------|
| HTAS-V3              | 34.41±11.19 | 13.63±10.52 | 24.37±10.41 | 80.18±3.81 | 26.72 | 0.00 | 9.51 | 14.09 |
| facebook/bart-large-cnn | 40.66±13.01 | 18.94±13.44 | 29.83±13.14 | 82.03±4.24 | 32.97 | 0.00 | 9.74 | 13.73 |
| t5-base              | 33.90±12.72 | 13.63±11.69 | 23.81±11.88 | 79.28±4.62 | 28.63 | 0.00 | 11.83 | 11.87 |
| google/pegasus-cnn_dailymail | 55.07±19.87 | 36.63±23.86 | 47.37±23.51 | 84.94±5.72 | 44.91 | 0.00 | 9.97 | 13.57 |

📈 Generating plots...

✓ Saved plots to /content/htas_output/plots
✓ Saved generated summaries to /content/htas_output/generated_summaries.csv
✓ Saved results JSONs to /content/htas_output

🎉 Pipeline complete. All outputs are in: /content/htas_output
