<a href="https://colab.research.google.com/github/RayAKaan/FUTURE_ML_01/blob/main/Suvidha-Foundation(HATS)-Research.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ==============================================================================
# CELL 1: ROBUST ENVIRONMENT SETUP & INSTALLATION (FIXED VERSION)
#
# This version addresses common installation failures for 'bert-score' and 'hdbscan'.
# RUN THIS CELL FIRST.
# ==============================================================================

import sys
import subprocess
import os

print("🔧 Starting robust environment setup...")

# 1. Upgrade pip and clear cache
print("\n1. Upgrading pip and clearing cache...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "pip"])
subprocess.check_call([sys.executable, "-m", "pip", "cache", "purge"])
print("✅ pip upgraded and cache cleared.")

# 2. Uninstall any conflicting versions for a clean slate
print("\n2. Uninstalling existing 'transformers', 'accelerate', and 'torch'...")
subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", "transformers", "accelerate", "torch"])
print("✅ Uninstallation complete.")

# 3. Install core libraries first. PyTorch is a fundamental dependency.
print("\n3. Installing core libraries (torch, transformers, accelerate)...")
try:
    print("   - Installing torch...")
    # Let pip decide the best torch version for the environment (CPU/GPU)
    subprocess.check_call([sys.executable, "-m", "pip", "install", "torch"])
    print("   ✅ torch installed.")
    print("   - Installing accelerate==0.34.2...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "accelerate==0.34.2"])
    print("   ✅ accelerate installed.")
    print("   - Installing transformers==4.44.2...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers==4.44.2"])
    print("   ✅ transformers installed.")
except Exception as e:
    print(f"❌ CRITICAL ERROR installing core libraries: {e}")
    print("   Please check the output above. This error must be resolved to continue.")
    # Stop execution if core libraries fail
    raise

# 4. Install essential build dependencies
print("\n4. Installing essential build dependencies...")
# hdbscan requires Cython to be compiled from source.
try:
    print("🔄 Installing cython...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "cython"])
    print("✅ Successfully installed cython")
except Exception as e:
    print(f"❌ FAILED to install cython: {e}")
    # This is a critical dependency for hdbscan
    raise


# 5. Install remaining dependencies one-by-one for maximum clarity
print("\n5. Installing remaining dependencies one by one...")
# Updated list of packages with fixes for known issues
packages = [
    "datasets==2.21.0",
    "sentence-transformers==3.0.1",
    "evaluate==0.4.3",
    "rouge-score==0.1.2",
    # FIX 1: Let pip resolve the version for bert-score to avoid dependency conflicts.
    "bert-score",
    "nltk==3.9.1",
    "scikit-learn==1.5.2",
    "umap-learn==0.5.6",
    # FIX 2: hdbscan is installed after cython, its build dependency.
    "pandas==2.2.2",
    "ipywidgets==8.1.5",
    "matplotlib==3.9.2",
    "seaborn==0.13.2",
    "tqdm==4.66.5"
]

failed_packages = []
for package in packages:
    try:
        print(f"🔄 Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])
        print(f"✅ Successfully installed {package}")
    except Exception as e:
        print(f"❌ FAILED to install {package}.")
        print(f"   Error: {e}")
        failed_packages.append(package)

# --- Final Report ---
print("\n" + "="*50)
print("🎉 INSTALLATION PROCESS FINISHED")
print("="*50)
if not failed_packages:
    print("✅ All packages installed successfully!")
    print("\n👉 IMPORTANT: You must now RESTART the runtime.")
    print("   Go to the menu: Runtime -> Restart session.")
else:
    print("⚠️ The following packages failed to install:")
    for pkg in failed_packages:
        print(f"   - {pkg}")
    print("\nPlease try to fix the errors for the failed packages before proceeding.")

🔧 Starting robust environment setup...

1. Upgrading pip and clearing cache...
✅ pip upgraded and cache cleared.

2. Uninstalling existing 'transformers', 'accelerate', and 'torch'...
✅ Uninstallation complete.

3. Installing core libraries (torch, transformers, accelerate)...
   - Installing torch...
   ✅ torch installed.
   - Installing accelerate==0.34.2...
   ✅ accelerate installed.
   - Installing transformers==4.44.2...
   ✅ transformers installed.

4. Installing essential build dependencies...
🔄 Installing cython...
✅ Successfully installed cython

5. Installing remaining dependencies one by one...
🔄 Installing datasets==2.21.0...
✅ Successfully installed datasets==2.21.0
🔄 Installing sentence-transformers==3.0.1...
✅ Successfully installed sentence-transformers==3.0.1
🔄 Installing evaluate==0.4.3...
✅ Successfully installed evaluate==0.4.3
🔄 Installing rouge-score==0.1.2...
✅ Successfully installed rouge-score==0.1.2
🔄 Installing bert-score...
✅ Successfully installed bert-scor

In [2]:
# ==============================================================================
# CELL 3: FINAL ATTEMPT - LETTING CHOOSE THE VERSION
#
# This cell uninstalls the problematic pinned version and lets pip find
# the latest version of hdbscan that can be successfully compiled.
# ==============================================================================

import sys
import subprocess

print("🔧 Starting final attempt to install hdbscan...")

# Step 1: Uninstall any failed installation attempts to ensure a clean slate
print("\n1. Uninstalling any existing 'hdbscan' versions...")
try:
    subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", "hdbscan"], check=False)
    print("✅ Cleanup complete.")
except Exception as e:
    print(f"   (Note: Could not uninstall, might not be installed: {e})")

# Step 2: Install hdbscan without a version pin
print("\n2. Installing hdbscan by letting pip resolve the version...")
try:
    print("🔄 Running: pip install hdbscan")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "hdbscan"])
    print("\n✅ SUCCESS! hdbscan was installed successfully.")
    print("\n🎉 ALL PACKAGES ARE NOW INSTALLED!")
    print("\n👉 IMPORTANT: You must now RESTART the runtime for the changes to take effect.")
    print("   Go to the menu: Runtime -> Restart session.")

except subprocess.CalledProcessError as e:
    print(f"\n❌ FAILED to install hdbscan automatically.")
    print("   This suggests a fundamental incompatibility with the environment.")
    print("\n--- ALTERNATIVE PLAN ---")
    print("If hdbscan is not strictly required, you can use 'DBSCAN' from scikit-learn.")
    print("It's already installed and provides similar functionality, though not identical.")
    print("   from sklearn.cluster import DBSCAN")
    print("\nIf you absolutely need hdbscan, we would need to debug the full compiler output.")

🔧 Starting final attempt to install hdbscan...

1. Uninstalling any existing 'hdbscan' versions...
✅ Cleanup complete.

2. Installing hdbscan by letting pip resolve the version...
🔄 Running: pip install hdbscan

✅ SUCCESS! hdbscan was installed successfully.

🎉 ALL PACKAGES ARE NOW INSTALLED!

👉 IMPORTANT: You must now RESTART the runtime for the changes to take effect.
   Go to the menu: Runtime -> Restart session.


In [17]:
# ==============================================================================
# CELL 2: IMPORTS, CONFIGURATION, AND MODEL LOADING (UPDATED)
# ==============================================================================

import os
import warnings
from pathlib import Path
from typing import List, Tuple, Dict
from dataclasses import dataclass
from collections import defaultdict

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn.functional as F
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.corpus import stopwords

from transformers import (AutoTokenizer, AutoModelForSeq2SeqLM,
                          GPT2LMHeadModel, GPT2Tokenizer,
                          BertTokenizer, BertModel,
                          Trainer, TrainingArguments)
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from rouge_score import rouge_scorer
from bert_score import score as bert_score

warnings.filterwarnings("ignore")
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['WANDB_DISABLED'] = 'true'
sns.set_style("whitegrid")

# --- Device Setup ---
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"✅ Using device: {device}")

# --- NLTK Data Download ---
print("📦 Downloading NLTK data...")
nltk.download('punkt', quiet=True)
### <-- MODIFICATION START
# Add this line for newer versions of NLTK
nltk.download('punkt_tab', quiet=True)
### <-- MODIFICATION END
nltk.download('stopwords', quiet=True)

# --- Configuration Class ---
@dataclass
class Config:
    # General
    output_dir: str = "./results"
    seed: int = 42

    # Dataset
    dataset_name: str = "cnn_dailymail"
    dataset_version: str = "3.0.0"
    split: str = "test"
    n_samples: int = 500 # Using 10 for very fast testing

    # Preprocessing
    min_sent_length: int = 5
    max_sent_length: int = 80

    # Encoding & Clustering
    encoder_model: str = "sentence-transformers/all-mpnet-base-v2"
    batch_size: int = 16
    use_hdbscan: bool = True # Will fallback to KMeans if hdbscan is not available
    min_cluster_size: int = 3
    max_clusters: int = 6

    # MMR Selection
    lambda_diversity: float = 0.2
    lambda_coverage: float = 0.6
    new_topic_bonus: float = 0.35
    sentences_to_select: int = 25

    # Generation
    gen_model: str = "facebook/bart-large-cnn"
    max_input_length: int = 1024
    max_output_length: int = 350
    min_output_length: int = 80
    num_beams: int = 8
    no_repeat_ngram: int = 3
    repetition_penalty: float = 1.5

    # Evaluation
    compute_bertscore: bool = True
    compute_perplexity: bool = True

    # --- FINE-TUNING SETTINGS ---
    # Set this to True to enable fine-tuning
    finetune_enabled: bool = True

    # Use a smaller model for much faster fine-tuning
    finetune_model_checkpoint: str = "sshleifer/distilbart-cnn-12-6"

    # Fine-tuning hyperparameters
    finetune_samples: int = 500
    finetune_epochs: int = 2
    finetune_batch_size: int = 1 # Use a small batch size to avoid OOM errors
    finetune_lr: float = 3e-5
    # --------------------------------

    # Add a device attribute to the config, using the global 'device' variable
    device: str = device

    def __post_init__(self):
        """This runs after the dataclass is initialized."""
        print(f"✅ Config object initialized. Using device: {self.device}")
        if self.finetune_enabled:
            print(f"🔥 Fine-tuning ENABLED. Will use model: {self.finetune_model_checkpoint}")
        else:
            print(f"ℹ️ Fine-tuning DISABLED. Will use model: {self.gen_model}")

# --- Preload Models for Efficiency ---
print("\n🚀 Preloading models to save time during evaluation...")
config = Config()
encoder_model = SentenceTransformer(config.encoder_model, device=device)
generator_tokenizer = AutoTokenizer.from_pretrained(config.gen_model)
generator_model = AutoModelForSeq2SeqLM.from_pretrained(config.gen_model).to(device)
generator_model.eval()
print("✅ Encoder and Generator models loaded and ready.")

# --- GPU Optimization ---
def optimize_gpu():
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        torch.cuda.empty_cache()
        print("✅ GPU optimization enabled.")

def log_gpu(stage: str):
    if torch.cuda.is_available():
        alloc = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        print(f"[GPU] {stage}: Allocated {alloc:.2f} GB, Reserved {reserved:.2f} GB")

✅ Using device: cuda
📦 Downloading NLTK data...

🚀 Preloading models to save time during evaluation...
✅ Config object initialized. Using device: cuda
🔥 Fine-tuning ENABLED. Will use model: sshleifer/distilbart-cnn-12-6
✅ Encoder and Generator models loaded and ready.


In [18]:
# ==============================================================================
# CELL 3: CORE LOGIC CLASSES
# ==============================================================================

# --- Preprocessor ---
class Preprocessor:
    def __init__(self, config: Config):
        self.config = config
        self.stopwords = set(stopwords.words('english'))

    def preprocess(self, text: str) -> List[Tuple[str, float]]:
        sentences = sent_tokenize(text)
        sentences = [s for s in sentences if self.config.min_sent_length <= len(s.split()) <= self.config.max_sent_length]
        results = []
        for i, sent in enumerate(sentences):
            words = word_tokenize(sent.lower())
            content_words = [w for w in words if w.isalpha() and w not in self.stopwords]
            content_ratio = len(content_words) / max(1, len(words))
            # Give a slight weight to earlier sentences
            weight = 1.2 if i < 3 else 1.0
            results.append((sent, content_ratio * weight))
        return results if results else [("", 0.0)]

# --- Topic Encoder & Clustering ---
class TopicEncoder:
    def __init__(self, config: Config, preloaded_model: SentenceTransformer):
        self.config = config
        self.model = preloaded_model

    def encode(self, sentences: List[str]) -> np.ndarray:
        clean_sentences = [s.strip() for s in sentences if s.strip()]
        if not clean_sentences:
            return np.zeros((0, self.model.get_sentence_embedding_dimension()))
        return self.model.encode(clean_sentences, batch_size=self.config.batch_size,
                                 convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False)

    def cluster(self, embeddings: np.ndarray) -> Tuple[np.ndarray, int]:
        n_samples = len(embeddings)
        if n_samples < 2:
            return np.zeros(n_samples, dtype=int), 1

        # Try HDBSCAN first
        if self.config.use_hdbscan:
            try:
                import hdbscan
                clusterer = hdbscan.HDBSCAN(min_cluster_size=self.config.min_cluster_size, metric='euclidean')
                labels = clusterer.fit_predict(embeddings)
                # Assign noise points (-1) to the nearest valid cluster
                if -1 in labels:
                    noise_indices = np.where(labels == -1)[0]
                    valid_indices = np.where(labels != -1)[0]
                    if len(valid_indices) > 0:
                        from sklearn.neighbors import NearestNeighbors
                        nbrs = NearestNeighbors(n_neighbors=1).fit(embeddings[valid_indices])
                        _, indices = nbrs.kneighbors(embeddings[noise_indices])
                        labels[noise_indices] = labels[valid_indices[indices.flatten()]]
                n_topics = len(set(labels))
                if n_topics > 1:
                    return labels, n_topics
            except ImportError:
                print("HDBSCAN not found, falling back to KMeans.")

        # Fallback to KMeans
        best_labels, best_score, best_k = None, -1, 2
        for k in range(2, min(self.config.max_clusters + 1, n_samples)):
            kmeans = KMeans(n_clusters=k, random_state=self.config.seed, n_init=10)
            current_labels = kmeans.fit_predict(embeddings)
            score = silhouette_score(embeddings, current_labels)
            if score > best_score:
                best_score, best_labels, best_k = score, current_labels, k

        return best_labels if best_labels is not None else np.zeros(n_samples, dtype=int), best_k

# --- MMR Selector ---
class MMRSelector:
    def __init__(self, config: Config):
        self.config = config

    def select(self, embeddings: np.ndarray, scores: List[float], topics: np.ndarray, n_topics: int) -> List[int]:
        n_sentences = len(embeddings)
        k = min(self.config.sentences_to_select, n_sentences)

        if n_sentences == 0:
            return []

        selected_indices, candidate_indices = [], list(range(n_sentences))
        topic_count = defaultdict(int)

        # Normalize importance scores
        importance = np.array(scores)
        if importance.max() > importance.min():
            importance = (importance - importance.min()) / (importance.max() - importance.min())
        else:
            importance = np.ones_like(importance)

        while len(selected_indices) < k and candidate_indices:
            best_idx, best_score = -1, -float('inf')

            for idx in candidate_indices:
                relevance = importance[idx]

                # Calculate diversity (max similarity to already selected sentences)
                diversity = 1.0
                if selected_indices:
                    # Cosine similarity
                    sim_scores = np.dot(embeddings[idx], embeddings[selected_indices].T)
                    diversity = 1.0 - np.max(sim_scores)

                # Calculate topic bonus/penalty
                topic_id = topics[idx]
                if topic_count[topic_id] == 0:
                    topic_bonus = self.config.new_topic_bonus
                else:
                    # Penalize over-representation of a topic
                    ideal_ratio = 1.0 / n_topics
                    current_ratio = topic_count[topic_id] / max(1, len(selected_indices))
                    topic_penalty = self.config.lambda_coverage * (current_ratio - ideal_ratio)
                    topic_bonus = -topic_penalty

                # Combined score
                combined_score = (self.config.lambda_diversity * diversity +
                                 0.1 * relevance +
                                 topic_bonus)

                if combined_score > best_score:
                    best_score = combined_score
                    best_idx = idx

            if best_idx != -1:
                selected_indices.append(best_idx)
                candidate_indices.remove(best_idx)
                topic_count[topics[best_idx]] += 1

        return selected_indices

# --- Generator ---
class Generator:
    def __init__(self, config: Config, preloaded_model, preloaded_tokenizer):
        self.config = config
        self.model = preloaded_model
        self.tokenizer = preloaded_tokenizer

    def generate(self, text: str) -> str:
        if not text or not text.strip():
            return ""

        inputs = self.tokenizer(text, max_length=self.config.max_input_length, truncation=True, return_tensors="pt").to(self.config.device)

        with torch.no_grad():
            summary_ids = self.model.generate(
                inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_length=self.config.max_output_length,
                min_length=self.config.min_output_length,
                num_beams=self.config.num_beams,
                no_repeat_ngram_size=self.config.no_repeat_ngram,
                repetition_penalty=self.config.repetition_penalty,
                early_stopping=True
            )

        return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)

# --- Baseline Summarizer ---
class BaselineBERTSummarizer:
    def __init__(self, config: Config):
        self.config = config
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.model = BertModel.from_pretrained("bert-base-uncased").to(config.device)
        self.model.eval()

    def summarize(self, document: str, top_k: int = 5) -> str:
        sentences = sent_tokenize(document)
        if not sentences:
            return ""

        inputs = self.tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors="pt").to(self.config.device)

        with torch.no_grad():
            outputs = self.model(**inputs)

        # Get sentence embeddings by averaging token embeddings
        sentence_embeddings = outputs.last_hidden_state.mean(dim=1)
        document_embedding = sentence_embeddings.mean(dim=0, keepdim=True)

        # Calculate cosine similarity
        similarities = F.cosine_similarity(sentence_embeddings, document_embedding)

        # Select top-k sentences
        top_indices = similarities.argsort(descending=True)[:top_k]
        selected_sentences = [sentences[i] for i in sorted(top_indices)]

        return " ".join(selected_sentences)

# --- Evaluator ---
class Evaluator:
    def __init__(self, config: Config):
        self.config = config
        self.rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

        if config.compute_perplexity:
            print("Loading GPT-2 for perplexity calculation...")
            self.ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
            self.ppl_model = GPT2LMHeadModel.from_pretrained("gpt2").to(config.device)
            self.ppl_model.eval()
        else:
            self.ppl_model = None

        print("Loading semantic similarity model for evaluation...")
        self.sem_model = SentenceTransformer(config.encoder_model, device=config.device)

    def _calculate_repetition(self, text: str) -> Dict[str, float]:
        words = text.lower().split()
        if not words:
            return {f'{n}gram_repetition': 0.0 for n in [2, 3, 4]}

        repetition_metrics = {}
        for n in [2, 3, 4]:
            if len(words) >= n:
                ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)]
                unique_ngrams = len(set(ngrams))
                repetition = (1 - unique_ngrams / len(ngrams)) * 100
                repetition_metrics[f'{n}gram_repetition'] = repetition
            else:
                repetition_metrics[f'{n}gram_repetition'] = 0.0

        repetition_metrics['unique_word_ratio'] = (len(set(words)) / len(words)) * 100
        return repetition_metrics

    def evaluate(self, preds: List[str], refs: List[str], sources: List[str]) -> Dict[str, float]:
        metrics = {}

        # ROUGE Scores
        rouge1, rouge2, rougeL = [], [], []
        for pred, ref in zip(preds, refs):
            score = self.rouge.score(ref, pred)
            rouge1.append(score['rouge1'].fmeasure)
            rouge2.append(score['rouge2'].fmeasure)
            rougeL.append(score['rougeL'].fmeasure)
        metrics['rouge1_f1'] = np.mean(rouge1) * 100
        metrics['rouge2_f1'] = np.mean(rouge2) * 100
        metrics['rougeL_f1'] = np.mean(rougeL) * 100

        # BERTScore
        if self.config.compute_bertscore:
            try:
                _, _, F1 = bert_score(preds, refs, lang='en', device=self.config.device, verbose=False)
                metrics['bertscore_f1'] = F1.mean().item() * 100
            except Exception as e:
                print(f"Could not compute BERTScore: {e}")
                metrics['bertscore_f1'] = 0.0

        # Repetition Metrics (calculated per-sample and then averaged)
        all_repetitions = [self._calculate_repetition(p) for p in preds]
        for key in all_repetitions[0].keys():
            metrics[key] = np.mean([r[key] for r in all_repetitions])

        # Perplexity
        if self.ppl_model:
            ppls = []
            with torch.no_grad():
                for text in preds:
                    if not text: continue
                    enc = self.ppl_tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
                    input_ids = enc.input_ids.to(self.ppl_model.device)
                    outputs = self.ppl_model(input_ids, labels=input_ids)
                    ppl = torch.exp(outputs.loss).item()
                    ppls.append(min(ppl, 1000)) # Cap perplexity to avoid outliers
            metrics['perplexity'] = np.mean(ppls) if ppls else 0

        # Semantic Consistency
        pred_emb = self.sem_model.encode(preds, convert_to_tensor=True, show_progress_bar=False)
        src_emb = self.sem_model.encode(sources, convert_to_tensor=True, show_progress_bar=False)
        sims = util.cos_sim(pred_emb, src_emb).diagonal()
        metrics['semantic_consistency'] = sims.mean().item() * 100

        return metrics

# --- Visualizer ---
class Visualizer:
    def __init__(self, out_dir: str):
        self.out_dir = Path(out_dir)
        self.out_dir.mkdir(parents=True, exist_ok=True)

    def plot_comparison(self, met_dict: Dict[str, Dict]):
        df = pd.DataFrame(met_dict).T
        cols = [m for m in ['rouge1_f1','rouge2_f1','rougeL_f1','bertscore_f1','3gram_repetition','semantic_consistency'] if m in df.columns]
        if not cols: return

        fig, ax = plt.subplots(figsize=(12,7))
        df[cols].plot(kind='bar', ax=ax, width=0.8)
        ax.set_title("Model Comparison on Key Metrics", fontsize=16)
        ax.set_xlabel("Model")
        ax.set_ylabel("Score")
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.xticks(rotation=45, ha="right")
        plt.tight_layout()
        plt.savefig(self.out_dir / "metrics_comparison.png", dpi=300, bbox_inches='tight')
        plt.close()
        print(f"✅ Saved: {self.out_dir / 'metrics_comparison.png'}")

In [19]:
# ==============================================================================
# CELL 4: MAIN EXECUTION PIPELINE (WITH FINE-TUNING)
# ==============================================================================

# --- Assume these classes are defined in previous cells ---
# Preprocessor, TopicEncoder, MMRSelector, Evaluator, BaselineBERTSummarizer, Visualizer

# --- UPDATE/REPLACE YOUR GENERATOR CLASS WITH THIS VERSION ---
# This version includes a finetune() method with speed optimizations.
from transformers import Trainer, TrainingArguments, Seq2SeqTrainingArguments

class Generator:
    def __init__(self, config: Config, preloaded_model=None, preloaded_tokenizer=None):
        self.config = config
        # Use a smaller, faster model for fine-tuning if specified
        model_checkpoint = config.finetune_model_checkpoint if config.finetune_enabled else config.gen_model

        if preloaded_model and model_checkpoint == config.gen_model:
            self.model = preloaded_model
        else:
            self.model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(config.device)

        if preloaded_tokenizer and model_checkpoint == config.gen_model:
            self.tokenizer = preloaded_tokenizer
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

        # For perplexity calculation
        if self.config.compute_perplexity:
            print("Loading GPT-2 for perplexity calculation...")
            self.ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
            self.ppl_model = GPT2LMHeadModel.from_pretrained("gpt2").to(config.device)
            self.ppl_model.eval()
            # Set pad token for GPT-2
            self.ppl_tokenizer.pad_token = self.ppl_tokenizer.eos_token

    def generate(self, text: str) -> str:
        # ... (your existing generate method) ...
        inputs = self.tokenizer(text, max_length=self.config.max_input_length, truncation=True, return_tensors="pt").to(self.config.device)
        summary_ids = self.model.generate(
            inputs['input_ids'],
            max_length=self.config.max_output_length,
            min_length=self.config.min_output_length,
            num_beams=self.config.num_beams,
            no_repeat_ngram_size=self.config.no_repeat_ngram,
            repetition_penalty=self.config.repetition_penalty,
            early_stopping=True
        )
        return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    def finetune(self, train_dataset):
        """Fine-tunes the model on the provided dataset with speed optimizations."""
        print(f"\n🔥 Starting fine-tuning on {len(train_dataset)} samples...")
        print(f"   Model: {self.config.finetune_model_checkpoint}")
        log_gpu("Before Fine-tuning")

        # --- Tokenization Function ---
        def preprocess_function(examples):
            # The model expects 'document' and 'summary' columns
            model_inputs = self.tokenizer(
                examples["document"], max_length=self.config.max_input_length, truncation=True
            )
            labels = self.tokenizer(
                examples["summary"], max_length=self.config.max_output_length, truncation=True
            )
            model_inputs["labels"] = labels["input_ids"]
            return model_inputs

        # --- Prepare Dataset ---
        tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True)
        tokenized_train_dataset = tokenized_train_dataset.remove_columns(["document", "summary"])

        # --- Training Arguments with Speed Optimizations ---
        training_args = Seq2SeqTrainingArguments(
            output_dir=self.config.output_dir + "/finetune",
            num_train_epochs=self.config.finetune_epochs,
            per_device_train_batch_size=self.config.finetune_batch_size,

            # --- SPEED OPTIMIZATIONS ---
            fp16=True,  # Use Mixed Precision (FP16) - HUGE speedup on T4/V100
            gradient_accumulation_steps=4, # Simulate larger batch size (e.g., 1 * 4 = 4)
            # --- END SPEED OPTIMIZATIONS ---

            save_strategy="no", # Don't save checkpoints to save time
            logging_steps=10,
            report_to="none", # Disable wandb/tensorboard
        )

        # --- Trainer ---
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=tokenized_train_dataset,
            tokenizer=self.tokenizer,
        )

        # --- Train ---
        trainer.train()

        # --- Update the model in the generator ---
        self.model = trainer.model
        print("✅ Fine-tuning complete. Model updated.")
        log_gpu("After Fine-tuning")


# --- SummarizationPipeline Class (Updated) ---
class SummarizationPipeline:
    def __init__(self, config: Config):
        self.config = config
        np.random.seed(config.seed)
        torch.manual_seed(config.seed)

        # Initialize components with preloaded models
        self.preproc = Preprocessor(config)
        self.encoder = TopicEncoder(config, preloaded_model=encoder_model)
        self.selector = MMRSelector(config)
        # Generator will load the fine-tuning model if enabled
        self.generator = Generator(config, preloaded_model=generator_model, preloaded_tokenizer=generator_tokenizer)
        self.evaluator = Evaluator(config)

    def finetune(self, dataset):
        """Public method to trigger fine-tuning."""
        if self.config.finetune_enabled:
            print("\n" + "="*60)
            print("FINE-TUNING GENERATOR MODEL")
            print("="*60)
            # Create a small training set
            train_dataset = dataset.select(range(min(self.config.finetune_samples, len(dataset))))
            self.generator.finetune(train_dataset)
        else:
            print("\nℹ️ Fine-tuning is disabled in config. Skipping.")

    # ... (rest of the SummarizationPipeline methods: summarize, evaluate_on_dataset, etc. remain the same) ...
    def summarize(self, document: str) -> Tuple[str, Dict]:
        sent_scores = self.preproc.preprocess(document)
        sentences = [s for s, _ in sent_scores]
        scores = [sc for _, sc in sent_scores]

        if not sentences:
            return "", {'n_topics': 0, 'coverage_ratio': 0, 'n_selected': 0}

        embeddings = self.encoder.encode(sentences)
        topic_labels, n_topics = self.encoder.cluster(embeddings)
        selected_indices = self.selector.select(embeddings, scores, topic_labels, n_topics)

        selected_sentences = [sentences[i] for i in selected_indices]
        coverage = len(set(topic_labels[i] for i in selected_indices)) / n_topics if n_topics > 0 else 0

        summary_text = " ".join(selected_sentences)
        final_summary = self.generator.generate(summary_text) if summary_text else ""

        info = {'n_topics': n_topics, 'coverage_ratio': coverage, 'n_selected': len(selected_indices)}
        return final_summary, info

    def evaluate_on_dataset(self, dataset, name: str) -> Tuple[pd.DataFrame, Dict]:
        print(f"\n{'='*60}\nEvaluating: {name}\n{'='*60}")
        log_gpu(f"Before {name} evaluation")

        predictions, references, sources = [], [], []
        topic_counts, coverage_ratios = [], []

        for example in tqdm(dataset, desc=f"Generating summaries for {name}"):
            doc, ref = example['document'], example['summary']
            summary, info = self.summarize(doc)

            predictions.append(summary)
            references.append(ref)
            sources.append(doc)
            topic_counts.append(info['n_topics'])
            coverage_ratios.append(info['coverage_ratio'])

        metrics = self.evaluator.evaluate(predictions, references, sources)

        results_df = pd.DataFrame({
            'source': sources,
            'reference': references,
            'prediction': predictions,
            'n_topics': topic_counts,
            'coverage_ratio': coverage_ratios
        })

        log_gpu(f"After {name} evaluation")
        self._print_metrics(metrics, name)
        return results_df, metrics

    def evaluate_baseline(self, dataset) -> Tuple[pd.DataFrame, Dict]:
        print(f"\n{'='*60}\nEvaluating: BERT Baseline\n{'='*60}")
        log_gpu("Before baseline evaluation")

        baseline = BaselineBERTSummarizer(self.config)
        predictions, references, sources = [], [], []

        for example in tqdm(dataset, desc="Generating baseline summaries"):
            doc, ref = example['document'], example['summary']
            summary = baseline.summarize(doc)
            predictions.append(summary)
            references.append(ref)
            sources.append(doc)

        metrics = self.evaluator.evaluate(predictions, references, sources)

        results_df = pd.DataFrame({
            'source': sources,
            'reference': references,
            'prediction': predictions,
            'n_topics': [1] * len(predictions), # Placeholder
            'coverage_ratio': [1.0] * len(predictions) # Placeholder
        })

        log_gpu("After baseline evaluation")
        self._print_metrics(metrics, "BERT Baseline")
        return results_df, metrics

    def _print_metrics(self, metrics: Dict, title: str):
        print(f"\n{title.upper()} METRICS")
        print("-" * 40)
        groups = {
            'ROUGE': ['rouge1_f1', 'rouge2_f1', 'rougeL_f1'],
            'Semantic Quality': ['bertscore_f1', 'semantic_consistency'],
            'Redundancy': ['2gram_repetition', '3gram_repetition', '4gram_repetition', 'unique_word_ratio'],
            'Fluency': ['perplexity'],
        }
        for group_name, keys in groups.items():
            print(f"\n{group_name}:")
            for key in keys:
                if key in metrics:
                    print(f"  {key:<25}: {metrics[key]:>8.2f}")
        print("-" * 40)


# --- Main Execution (Updated) ---
def main():
    print("="*80)
    print("HIERARCHICAL TOPIC-AWARE SUMMARIZATION - REMASTERED")
    print("="*80)

    optimize_gpu()
    log_gpu("Initial State")

    config = Config()
    out_dir = Path(config.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    # Load Dataset
    print(f"\nLoading {config.dataset_name} dataset...")
    dataset = load_dataset(config.dataset_name, config.dataset_version, split=config.split, trust_remote_code=True)
    dataset = dataset.shuffle(seed=config.seed).select(range(min(config.n_samples, len(dataset))))
    # Rename columns for consistency
    dataset = dataset.map(lambda x: {"document": x["article"], "summary": x["highlights"]}, remove_columns=["article", "highlights"])
    print(f"✅ Loaded {len(dataset)} samples.")

    # Initialize Pipeline
    pipeline = SummarizationPipeline(config)

    # --- NEW: FINE-TUNING STEP ---
    pipeline.finetune(dataset)

    # --- Run Evaluations ---
    # 1. Hierarchical Model (now potentially fine-tuned)
    hier_df, hier_metrics = pipeline.evaluate_on_dataset(dataset, "Hierarchical Model (Fine-tuned)" if config.finetune_enabled else "Hierarchical Model")
    hier_df.to_csv(out_dir / "hierarchical_results.csv", index=False)
    print(f"✅ Saved results to {out_dir / 'hierarchical_results.csv'}")

    # 2. BERT Baseline
    base_df, base_metrics = pipeline.evaluate_baseline(dataset)
    base_df.to_csv(out_dir / "baseline_results.csv", index=False)
    print(f"✅ Saved results to {out_dir / 'baseline_results.csv'}")

    # --- Save and Visualize Results ---
    all_metrics_df = pd.DataFrame({"Hierarchical": hier_metrics, "BERT_Baseline": base_metrics}).T
    all_metrics_df.to_csv(out_dir / "all_metrics_summary.csv")
    print(f"✅ Saved metric summary to {out_dir / 'all_metrics_summary.csv'}")

    print("\nGenerating visualizations...")
    visualizer = Visualizer(str(out_dir / "plots"))
    visualizer.plot_comparison({"Hierarchical": hier_metrics, "BERT_Baseline": base_metrics})

    print("\n🎉 EVALUATION COMPLETE!")
    print("Results and plots are saved in the './results' directory.")
    log_gpu("Final State")

if __name__ == "__main__":
    main()

HIERARCHICAL TOPIC-AWARE SUMMARIZATION - REMASTERED
✅ GPU optimization enabled.
[GPU] Initial State: Allocated 9.07 GB, Reserved 9.77 GB
✅ Config object initialized. Using device: cuda
🔥 Fine-tuning ENABLED. Will use model: sshleifer/distilbart-cnn-12-6

Loading cnn_dailymail dataset...
✅ Loaded 500 samples.
Loading GPT-2 for perplexity calculation...
Loading GPT-2 for perplexity calculation...
Loading semantic similarity model for evaluation...

FINE-TUNING GENERATOR MODEL

🔥 Starting fine-tuning on 500 samples...
   Model: sshleifer/distilbart-cnn-12-6
[GPU] Before Fine-tuning: Allocated 11.58 GB, Reserved 11.91 GB


Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Step,Training Loss
10,2.3487
20,2.1028
30,1.8217
40,1.8876
50,1.7983
60,1.8425
70,1.844
80,1.9862
90,1.718
100,1.6693


✅ Fine-tuning complete. Model updated.
[GPU] After Fine-tuning: Allocated 9.64 GB, Reserved 12.49 GB

Evaluating: Hierarchical Model (Fine-tuned)
[GPU] Before Hierarchical Model (Fine-tuned) evaluation: Allocated 9.64 GB, Reserved 12.49 GB


Generating summaries for Hierarchical Model (Fine-tuned):   0%|          | 0/500 [00:00<?, ?it/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[GPU] After Hierarchical Model (Fine-tuned) evaluation: Allocated 7.37 GB, Reserved 13.51 GB

HIERARCHICAL MODEL (FINE-TUNED) METRICS
----------------------------------------

ROUGE:
  rouge1_f1                :    43.13
  rouge2_f1                :    22.04
  rougeL_f1                :    30.35

Semantic Quality:
  bertscore_f1             :    88.02
  semantic_consistency     :    79.17

Redundancy:
  2gram_repetition         :     0.96
  3gram_repetition         :     0.02
  4gram_repetition         :     0.01
  unique_word_ratio        :    83.26

Fluency:
  perplexity               :    38.89
----------------------------------------
✅ Saved results to results/hierarchical_results.csv

Evaluating: BERT Baseline
[GPU] Before baseline evaluation: Allocated 7.37 GB, Reserved 13.51 GB


Generating baseline summaries:   0%|          | 0/500 [00:00<?, ?it/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[GPU] After baseline evaluation: Allocated 7.78 GB, Reserved 14.44 GB

BERT BASELINE METRICS
----------------------------------------

ROUGE:
  rouge1_f1                :    31.04
  rouge2_f1                :    10.61
  rougeL_f1                :    18.92

Semantic Quality:
  bertscore_f1             :    85.60
  semantic_consistency     :    82.36

Redundancy:
  2gram_repetition         :     7.18
  3gram_repetition         :     4.04
  4gram_repetition         :     3.11
  unique_word_ratio        :    70.15

Fluency:
  perplexity               :    32.80
----------------------------------------
✅ Saved results to results/baseline_results.csv
✅ Saved metric summary to results/all_metrics_summary.csv

Generating visualizations...
✅ Saved: results/plots/metrics_comparison.png

🎉 EVALUATION COMPLETE!
Results and plots are saved in the './results' directory.
[GPU] Final State: Allocated 7.37 GB, Reserved 14.44 GB
