In [None]:
!pip install spacy sentence_transformers

In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import re
import math
from collections import defaultdict, Counter
from typing import List, Dict, Tuple, Optional, Any
import numpy as np
import networkx as nx
from sklearn.metrics.pairwise import cosine_similarity
try:
    import spacy
except Exception as e:
    spacy = None
try:
    from sentence_transformers import SentenceTransformer
except Exception:
    SentenceTransformer = None
import torch
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.nn.functional as F
from transformers import PegasusTokenizer, PegasusForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
import json
from torch.utils.data import Dataset
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from torch.utils.data import Dataset
from torch.optim import AdamW

2025-10-31 16:04:12.699418: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-31 16:04:12.751947: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-10-31 16:04:13.676315: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
# -------------------------------
# 1. Normalization & Regex Utils
# -------------------------------
def normalize_text(text: str) -> str:
    """Lowercase, normalize whitespace, remove control chars."""
    if not text:
        return ""
    return re.sub(r"\s+", " ", text.strip().lower())


STATUTE_REGEX = re.compile(r"""
    (
        (?:s\.|sec(?:tion)?s?\.?)\s*\d+[A-Za-z]?(?:[\s,]*(?:and)?\s*\d+[A-Za-z]?)* |
        (S\.?\s*\d+[A-Za-z]?) |
        (Section\s+\d+[A-Za-z]?(?:\(\d+\))*) |
        (Sec\.?\s*\d+[A-Za-z]?) |
        (Article\s+\d+[A-Za-z]?(?:\(\d+\))*) |
        (Art\.?\s*\d+[A-Za-z]?) |
        (Order\s+[IVXLC]+\s+Rule\s+\d+[A-Za-z]?) |
        (Rule\s+\d+[A-Za-z]?(?:\(\d+\))*) |
        (Clause\s*\(?\d+[A-Za-z]?\)?) |
        (\b\d{1,4}\s+of\s+\d{4}\b) |
        (\b\d+\s+SCC\s+\d+\b) |
        (\bAIR\s*\[?\d{4}\]?\s+[A-Z]+\s+\d+\b)
    )
""", flags=re.I | re.X)

CASE_CITATION_REGEX = re.compile(
    r"\b([A-Z][\w\.\-& ]{1,120}?\s+v(?:\.|s|ersus)\.?\s+[A-Z][\w\.\-& ]{1,120}?)\b",
    flags=re.I | re.X
)


def extract_statute_references(sentence: str) -> List[str]:
    """Extract statute/section references from a sentence."""
    matches = STATUTE_REGEX.findall(sentence)
    return [" ".join([x for x in m if x]).strip() if isinstance(m, tuple) else m.strip()
            for m in matches if m]


def extract_case_citations(sentence: str) -> List[str]:
    """Extract case citations from a sentence."""
    return [m.strip() for m in CASE_CITATION_REGEX.findall(sentence) if m]


# -------------------------------
# 2. NLP & Embeddings
# -------------------------------
def load_spacy(model: str = "en_core_web_sm"):
    """Lazy-load spaCy model with helpful error if missing."""
    try:
        return spacy.load(model)
    except OSError:
        raise OSError(f"spaCy model '{model}' not found. Install with: python -m spacy download {model}")


def extract_entities_and_keyphrases(nlp, sentence: str, min_len: int = 2) -> Tuple[List[str], List[str]]:
    """Return (entities, keyphrases) for a sentence."""
    doc = nlp(sentence)
    entities = [ent.text.strip() for ent in doc.ents if len(ent.text.strip()) >= min_len]
    phrases = [chunk.text.strip() for chunk in doc.noun_chunks if len(chunk.text.strip()) >= min_len]

    # Deduplicate while preserving order
    def unique(seq):
        seen, out = set(), []
        for x in seq:
            k = x.lower()
            if k not in seen:
                seen.add(k)
                out.append(x)
        return out

    return unique(entities), unique(phrases)


class EmbeddingModel:
    """Wrapper for SentenceTransformer sentence embeddings."""
    def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
        self.model = SentenceTransformer(model_name)

    def encode(self, sentences: List[str], batch_size: int = 32) -> np.ndarray:
        sentences = [normalize_text(s) for s in sentences]
        return np.array(self.model.encode(sentences, batch_size=batch_size, show_progress_bar=False))


# -------------------------------
# 3. Graph Construction
# -------------------------------
def construct_document_graph(
    sentences: List[str],
    nlp=None,
    sent_embeddings: Optional[np.ndarray] = None,
    emb_model: Optional[EmbeddingModel] = None,
    sim_threshold: float = 0.7,
    max_sim_edges: int = 3,
    include_sequential_edges: bool = True,
    include_semantic_edges: bool = True,
    include_discourse_edges: bool = True,
) -> Tuple[nx.DiGraph, Dict[str, Any]]:

    # --- Normalize ---
    sentences_clean = [normalize_text(s) for s in sentences]
    num_sents = len(sentences_clean)

    # --- Load spaCy if needed ---
    if nlp is None:
        try:
            nlp = spacy.load("en_core_web_sm")
        except Exception:
            nlp = None

    # --- Embeddings ---
    if sent_embeddings is None and emb_model:
        sent_embeddings = emb_model.encode(sentences_clean)

    # --- Graph Init ---
    G = nx.DiGraph()

    # Sentence nodes
    for i, s in enumerate(sentences_clean):
        G.add_node(f"S_{i}", type="sentence", idx=i, text=s,
                   emb=sent_embeddings[i] if sent_embeddings is not None else None)

    # Virtual summary node
    G.add_node("V_SUMMARY", type="virtual", text="<virtual_summary_node>")

    # --- Extract entities/phrases/statutes/cases ---
    phrase_counter, sentence_entities, sentence_phrases = Counter(), [], []
    for s in sentences_clean:
        ents, phrs = ([], [])
        if nlp:
            try:
                ents, phrs = extract_entities_and_keyphrases(nlp, s)
            except Exception:
                pass
        if not ents:  # Fallback heuristic
            ents = re.findall(r"\b([A-Z][a-z]{2,}(?:\s+[A-Z][a-z]{2,})*)\b", s)

        sentence_entities.append(ents)
        sentence_phrases.append(phrs)
        for p in set(ents + phrs):
            phrase_counter[p.lower()] += 1

    phrase_to_idx = {p: i for i, p in enumerate(phrase_counter.keys())}
    for p, idx in phrase_to_idx.items():
        G.add_node(f"P_{idx}", type="phrase", text=p)

    # --- Add Edges ---
    # Sequential edges
    if include_sequential_edges:
        for i in range(num_sents - 1):
            G.add_edge(f"S_{i}", f"S_{i+1}", type="sequential", weight=1.0)

    # Discourse edges
    if include_discourse_edges:
        discourse_markers = {
            "however": "contrast", "therefore": "result", "thus": "result",
            "consequently": "result", "furthermore": "addition",
            "moreover": "addition", "but": "contrast", "although": "contrast"
        }
        for i, s in enumerate(sentences_clean):
            for marker, dtype in discourse_markers.items():
                if re.search(rf"\b{re.escape(marker)}\b", s.lower()) and i > 0:
                    G.add_edge(f"S_{i-1}", f"S_{i}", type=f"discourse_{dtype}", weight=1.0)

    # Phrase / citation edges
    for i, (ents, phrs) in enumerate(zip(sentence_entities, sentence_phrases)):
        sname = f"S_{i}"
        statutes = extract_statute_references(sentences_clean[i])
        cases = extract_case_citations(sentences_clean[i])

        all_items = ents + phrs + statutes + cases
        for item in all_items:
            key = item.lower()
            if key not in phrase_to_idx:
                idx = len(phrase_to_idx)
                phrase_to_idx[key] = idx
                G.add_node(f"P_{idx}", type="phrase", text=key)
            pname = f"P_{phrase_to_idx[key]}"
            G.add_edge(pname, sname, type="phrase_appears_in", weight=1.0)
            G.add_edge(sname, pname, type="sentence_mentions_phrase", weight=1.0)

    # Semantic similarity edges
    if include_semantic_edges and sent_embeddings is not None:
        sims = cosine_similarity(sent_embeddings)
        for i in range(num_sents):
            idxs = np.argsort(-sims[i])
            added = 0
            for j in idxs:
                if i == j or sims[i, j] < sim_threshold:
                    continue
                G.add_edge(f"S_{i}", f"S_{j}", type="semantic", weight=float(sims[i, j]))
                added += 1
                if added >= max_sim_edges:
                    break

    # Virtual summary connections
    for i in range(num_sents):
        sname = f"S_{i}"
        G.add_edge("V_SUMMARY", sname, type="virtual_to_sentence", weight=0.1)
        G.add_edge(sname, "V_SUMMARY", type="sentence_to_virtual", weight=0.1)

    meta = {
        "num_sentences": num_sents,
        "sentences": sentences_clean,
        "phrase_to_idx": phrase_to_idx,
        "phrases": list(phrase_to_idx.keys()),
    }
    return G, meta


# -------------------------------
# 4. Dataset Wrapper
# -------------------------------
def build_graphs_for_dataset(
    dataset: List[Dict[str, Any]],
    nlp=None,
    emb_model: Optional[EmbeddingModel] = None,
    sim_threshold: float = 0.7,
    max_sim_edges: int = 3,
) -> List[Tuple[nx.DiGraph, Dict[str, Any]]]:
    """Build graphs for all documents in dataset."""
    graphs = []
    for example in dataset:
        sents = example.get("judgement_sent") or example.get("judgement_sentences") or example.get("sentences", [])
        G, meta = construct_document_graph(
            sents, nlp=nlp, emb_model=emb_model,
            sim_threshold=sim_threshold, max_sim_edges=max_sim_edges
        )
        graphs.append((G, meta))
    return graphs


In [3]:
# ------------------ Graph Transformer Layer ------------------
class GraphTransformerLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads=4, dropout=0.1):
        super().__init__()
        self.fc_q = nn.Linear(input_dim, hidden_dim)
        self.fc_k = nn.Linear(input_dim, hidden_dim)
        self.fc_v = nn.Linear(input_dim, hidden_dim)
        self.fc_out = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        print(f"[GraphTransformerLayer] Initialized with input_dim={input_dim}, hidden_dim={hidden_dim}")

    def forward(self, node_features, adj_matrix):
        Q = self.fc_q(node_features)
        K = self.fc_k(node_features)
        V = self.fc_v(node_features)

        scores = torch.matmul(Q, K.transpose(0, 1)) / (K.size(-1) ** 0.5)
        scores = scores.masked_fill(adj_matrix == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        out = torch.matmul(attn, V)
        out = self.fc_out(out)
        return out
        # return F.layer_norm(out + node_features, out.shape[-1:])

# ------------------ Graph Encoder ------------------
class GraphEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=2):
        super().__init__()
        self.layers = nn.ModuleList([
            GraphTransformerLayer(input_dim if i == 0 else hidden_dim, hidden_dim)
            for i in range(num_layers)
        ])

    def forward(self, node_features, adj_matrix):
        x = node_features
        for layer in self.layers:
            x = layer(x, adj_matrix)
        return x

# ------------------ Graph to Sequence Attention ------------------
class GraphToSeqAttention(nn.Module):
    def __init__(self, hidden_dim, decoder_dim):
        super().__init__()
        self.key = nn.Linear(hidden_dim, decoder_dim)
        self.value = nn.Linear(hidden_dim, decoder_dim)
        self.query = nn.Linear(decoder_dim, decoder_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, graph_node_features, decoder_hidden):
        K = self.key(graph_node_features)       # [num_nodes, decoder_dim]
        V = self.value(graph_node_features)     # [num_nodes, decoder_dim]
        Q = self.query(decoder_hidden)          # [batch, seq_len, decoder_dim]

        attn_scores = torch.matmul(Q, K.transpose(0, 1)) / (K.size(-1) ** 0.5)
        attn_weights = self.softmax(attn_scores)
        context = torch.matmul(attn_weights, V)
        return context, attn_weights

# ------------------ Graph Enhanced Pegasus ------------------
#NEW CODE
import transformers
class GraphEnhancedPegasus(nn.Module):
    def __init__(self, pegasus_model_name="google/pegasus-large", graph_hidden_dim=256, num_graph_layers=2, seed=42):
        super().__init__()
        self.tokenizer = PegasusTokenizer.from_pretrained(pegasus_model_name)
        self.pegasus = PegasusForConditionalGeneration.from_pretrained(pegasus_model_name)
        self.graph_encoder = GraphEncoder(input_dim=768, hidden_dim=graph_hidden_dim, num_layers=num_graph_layers)
        self.graph_to_seq = GraphToSeqAttention(hidden_dim=graph_hidden_dim, decoder_dim=self.pegasus.config.d_model)
        
        # Project graph context to Pegasus hidden dimension
        self.proj = nn.Linear(graph_hidden_dim, self.pegasus.config.d_model)

        # deterministic generation
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        self.pegasus.config.do_sample = False

    def forward(self,input_texts=None,graph_node_features=None,adj_matrix=None,labels=None,input_ids=None,attention_mask=None,max_length=60):
        # Tokenize if raw text provided
        if input_ids is None:
            if input_texts is None:
                raise ValueError("Provide either input_texts or input_ids")
            inputs = self.tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True)
            input_ids = inputs.input_ids.to(graph_node_features.device)
            attention_mask = inputs.attention_mask.to(graph_node_features.device)
        else:
            input_ids = input_ids.to(graph_node_features.device)
            attention_mask = attention_mask.to(graph_node_features.device)
        
        # Pegasus encoder
        encoder_outputs = self.pegasus.model.encoder(input_ids, attention_mask=attention_mask)
        encoder_hidden_states = encoder_outputs.last_hidden_state  # [batch, seq_len, d_model]
        
        # Graph encoding
        graph_hidden = self.graph_encoder(graph_node_features, adj_matrix)
        
        # Attend from decoder to graph
        context, attn_weights = self.graph_to_seq(graph_hidden, encoder_hidden_states)
        
        # Add graph context
        enhanced_encoder_states = encoder_hidden_states + context
        
        encoder_outputs = transformers.modeling_outputs.BaseModelOutput(
            last_hidden_state=enhanced_encoder_states
        )
        seq_len = input_ids.size(1)
        summary_ratio = 0.5  # generate summary ~30% of input tokens
        max_length = max(10, int(seq_len * summary_ratio))
        min_length = max(5, int(max_length * 0.5))
        
        if labels is None:
            generated_ids = self.pegasus.generate(
                input_ids=input_ids,
                encoder_outputs=encoder_outputs,
                attention_mask=attention_mask,
                min_length=min_length,
                max_length=max_length,
                do_sample=False,
                num_beams=4,
                no_repeat_ngram_size=3,
                length_penalty=1.2,
                early_stopping=True
            )
            return self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        else:
            decoder_input_ids = self.pegasus.prepare_decoder_input_ids_from_labels(labels)
            outputs = self.pegasus(
                input_ids=None,
                attention_mask=attention_mask,
                encoder_outputs=encoder_outputs,
                decoder_input_ids=decoder_input_ids,
                labels=labels
            )
            return outputs

    def generate(self, input_ids=None, attention_mask=None, graph_node_features=None, adj_matrix=None, **kwargs):
        encoder_outputs = self.pegasus.model.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        encoder_hidden_states = encoder_outputs.last_hidden_state
    
        # Use same graph-to-seq attention as in forward()
        graph_hidden = self.graph_encoder(graph_node_features, adj_matrix)
        context, _ = self.graph_to_seq(graph_hidden, encoder_hidden_states)
        enhanced_encoder_states = encoder_hidden_states + context
    
        encoder_outputs = transformers.modeling_outputs.BaseModelOutput(
            last_hidden_state=enhanced_encoder_states
        )
    
        safe_kwargs = {
            k: v for k, v in kwargs.items()
            if k not in ["encoder_outputs", "inputs_embeds"]
        }
    
        # Step 5: Call Pegasus generate safely
        return self.pegasus.generate(
            encoder_outputs=encoder_outputs,
            attention_mask=attention_mask,
            **safe_kwargs  # user-specified params like max_length, num_beams, etc.
        )

In [None]:
!pip install rouge_score evaluate bert_score

In [4]:
import evaluate 
with open("data12.json", "r") as f:
    dataset = json.load(f)


In [5]:
class GraphLegalDataset(Dataset):
    def __init__(self, data, tokenizer, max_input_len=1024, max_target_len=256, emb_dim=768):
        self.data = data
        self.tokenizer = tokenizer
        self.max_input_len = max_input_len
        self.max_target_len = max_target_len
        self.emb_dim = emb_dim

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        judgement = item["judgement_sent"]
        summary = " ".join(item["headnote_sent"])

        # -------------------------------
        # 1. Tokenize input
        # -------------------------------
        enc = self.tokenizer(
            " ".join(judgement),
            truncation=True,
            padding="max_length",
            max_length=self.max_input_len,
            return_tensors="pt"
        )

        dec = self.tokenizer(
            summary,
            truncation=True,
            padding="max_length",
            max_length=self.max_target_len,
            return_tensors="pt"
        )

        # -------------------------------
        # 2. Build document graph
        # -------------------------------
        # Use your build_graphs_for_dataset
        graphs = build_graphs_for_dataset([item])
        G, meta = graphs[0]

        num_nodes = G.number_of_nodes()
        node_to_idx = {n: i for i, n in enumerate(G.nodes())}

        # adjacency matrix
        adj = np.zeros((num_nodes, num_nodes), dtype=np.float32)
        for u, v in G.edges():
            adj[node_to_idx[u], node_to_idx[v]] = 1.0

        adj_matrix = torch.tensor(adj)

        # node features (placeholder: random embeddings, replace with actual if available)
        node_features = torch.randn(num_nodes, self.emb_dim)

        return {
            "input_ids": enc["input_ids"].squeeze(),
            "attention_mask": enc["attention_mask"].squeeze(),
            "labels": dec["input_ids"].squeeze(),
            "node_features": node_features,
            "adj_matrix": adj_matrix,
            "original_summary_text": summary
        }


In [6]:
tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-large")

# Split dataset
train_dataset = GraphLegalDataset(dataset[:8000], tokenizer)
test_dataset  = GraphLegalDataset(dataset[8000:8500], tokenizer)

model = GraphEnhancedPegasus(
    pegasus_model_name="google/pegasus-large",
    graph_hidden_dim=256,
    num_graph_layers=2
)

Error during conversion: ChunkedEncodingError(ProtocolError('Response ended prematurely'))
Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at google/pegasus-large and are newly initialized: ['model.decoder.embed_positions.weight', 'model.encoder.embed_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[GraphTransformerLayer] Initialized with input_dim=768, hidden_dim=256
[GraphTransformerLayer] Initialized with input_dim=256, hidden_dim=256


In [None]:
#DONE ON FIRST 8000 DATA FOR TRAINING
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
optimizer = AdamW(model.parameters(), lr=5e-5)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.train()

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for epoch in range(4):
    print(f"\n=== Epoch {epoch+1} ===")
    total_loss = 0
    for step, batch in enumerate(train_loader):
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        node_features = batch["node_features"][0].to(device)
        adj_matrix = batch["adj_matrix"][0].to(device)

        with autocast():  # Mixed precision
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                graph_node_features=node_features,
                adj_matrix=adj_matrix,
                labels=labels
            )
            loss = outputs.loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        
        if step % 100 == 0:
            print(f"Step {step}, Loss: {loss.item():.4f}")
            torch.cuda.empty_cache()

    print(f"Epoch {epoch+1} Average Loss: {total_loss / len(train_loader):.4f}")


In [None]:
save_dir = "./result1_graph_pegasus_model"
os.makedirs(save_dir, exist_ok=True)
torch.save(model.state_dict(), f"{save_dir}/graph_pegasus_state_dict.pt")

model.pegasus.save_pretrained(save_dir)
model.tokenizer.save_pretrained(save_dir)

print(f" Model and tokenizer saved at: {save_dir}")

In [None]:
save_dir = "./result1_graph_pegasus_model"  # folder where you saved model
device = "cuda" if torch.cuda.is_available() else "cpu"


# Initialize model with saved Pegasus directory
model = GraphEnhancedPegasus(pegasus_model_name=save_dir)
model.load_state_dict(torch.load(os.path.join(save_dir, "graph_pegasus_state_dict.pt"), map_location=device))
model.to(device)
model.eval()
print(" Trained Graph-Enhanced Pegasus loaded.")

tokenizer = PegasusTokenizer.from_pretrained("result1_graph_pegasus_model")
test_dataset  = GraphLegalDataset(dataset[8000:8500], tokenizer)
# Create test DataLoader
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

model.eval()
predictions, references = [], []
total_loss, total_tokens = 0, 0  # for perplexity

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=model.tokenizer.pad_token_id)

with torch.no_grad():
    for step, batch in enumerate(test_loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        node_features = batch["node_features"][0].to(device)
        adj_matrix = batch["adj_matrix"][0].to(device)

        seq_len = input_ids.size(1)
        summary_ratio = 0.5  # generate summary of ~50% of input tokens
        max_length = max(10, int(seq_len * summary_ratio))
        min_length = max(5, int(max_length * 0.5))

        # ---- Generation ----
        generated_tokens = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            graph_node_features=node_features,
            adj_matrix=adj_matrix,
            max_length=max_length,
            min_length=min_length,
            num_beams=4,
            no_repeat_ngram_size=3,
            repetition_penalty=2.0,
            length_penalty=1.2,
            early_stopping=True
        )

        # ---- Decoding ----
        preds = model.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        refs = model.tokenizer.batch_decode(labels, skip_special_tokens=True)

        predictions.extend(preds)
        references.extend(refs)

        # Optional progress display
        if step % 500 == 0:
            print(f"Test step {step}")
            print("Pred:", preds[0][:200])
            print("Ref :", refs[0][:200])
            print("-" * 80)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            graph_node_features=node_features,
            adj_matrix=adj_matrix
        )

        loss = outputs.loss
        total_loss += loss.item() * labels.numel()
        total_tokens += (labels != model.tokenizer.pad_token_id).sum().item()

# -------------------------------
#  Evaluation Metrics
# -------------------------------
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")
meteor = evaluate.load("meteor")
bertscore = evaluate.load("bertscore")

# ROUGE
rouge_result = rouge.compute(predictions=predictions, references=references)

# BLEU
bleu_result = bleu.compute(
    predictions=[p for p in predictions],
    references=[r.split() for r in references]
)

# METEOR
meteor_result = meteor.compute(predictions=predictions, references=references)

# BERTScore
bertscore_result = bertscore.compute(
    predictions=predictions,
    references=references,
    lang="en"
)
# -------------------------------
#  Perplexity Calculation
# -------------------------------
avg_loss = total_loss / total_tokens
perplexity = math.exp(avg_loss) if avg_loss < 50 else float("inf")

# -------------------------------
#  Display Results
# -------------------------------
print("\n=== Evaluation Results ===")
print("ROUGE:", rouge_result)
print("BLEU :", bleu_result["bleu"])
print("METEOR:", meteor_result["meteor"])
print(f"BERTScore: P={sum(bertscore_result['precision'])/len(bertscore_result['precision']):.4f}, "
      f"R={sum(bertscore_result['recall'])/len(bertscore_result['recall']):.4f}, "
      f"F1={sum(bertscore_result['f1'])/len(bertscore_result['f1']):.4f}")
print(f" Perplexity: {perplexity:.4f}")

In [None]:
# === Evaluation Results ===
# ROUGE: {'rouge1': np.float64(0.5832954886912554), 'rouge2': np.float64(0.3478273416488371), 'rougeL': np.float64(0.3866087180337574), 'rougeLsum': np.float64(0.38616374920897983)}
# BLEU : 0.01892779342079829
# METEOR: 0.4163490888882138
# BERTScore: P=0.8725, R=0.8773, F1=0.8748
# 🧠 Perplexity: 3.6866