In [1]:
!pip install torch_geometric
!pip install --upgrade pandas

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1
Collecting pandas
  Downloading pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (89 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.9/89.9 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
Downloading pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [

In [2]:
import torch
from torch_geometric.data import HeteroData
from torch_geometric.transforms import ToUndirected
from sklearn.feature_extraction.text import TfidfVectorizer
import pandas as pd
from torch_geometric.data import HeteroData
from transformers import AutoTokenizer, AutoModel
import spacy


In [4]:
# === 1) Load models ===
nlp = spacy.load("en_core_web_sm")
tokenizer = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
model     = AutoModel.from_pretrained("nlpaueb/legal-bert-base-uncased")
model.eval()

def embed_texts(texts, max_length=512):
    """
    Mean‑pooled Legal‑BERT embeddings for a list of texts.
    Returns tensor [len(texts), hidden_dim].
    """
    encoded = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    with torch.no_grad():
        outputs = model(**encoded)
    last_hidden = outputs.last_hidden_state          # [batch, seq_len, hidden_dim]
    mask        = encoded.attention_mask.unsqueeze(-1)  # [batch, seq_len, 1]
    summed      = (last_hidden * mask).sum(dim=1)     # [batch, hidden_dim]
    lengths     = mask.sum(dim=1).clamp(min=1e-9)     # [batch, 1]
    return summed / lengths                          # [batch, hidden_dim]

def build_hetero_graph_from_row(doc_id: str, text: str, summary: str) -> HeteroData:
    """
    Build a heterogeneous graph per document:
      • Node types: 'document', 'sentence', 'entity'
      • Edge types:
          ('document','has_sentence','sentence'),
          ('sentence','to_sentence','sentence'),
          ('sentence','mentions','entity'),
          ('entity','rev_mentions','sentence')
    Stores summary on the document node as data['document'].y.
    """
    # Sentence split
    doc = nlp(text)
    sentences = [sent.text.strip() for sent in doc.sents if sent.text.strip()]

    # Unique entities
    unique_entities = []
    sent_docs = []
    for sent in sentences:
        sdoc = nlp(sent)
        sent_docs.append(sdoc)
        for ent in sdoc.ents:
            if ent.text not in unique_entities:
                unique_entities.append(ent.text)

    # Node features via Legal-BERT
    sent_feats = embed_texts(sentences)            # [num_sentences, hidden_dim]
    ent_feats  = embed_texts(unique_entities)      # [num_entities, hidden_dim]

    # Initialize HeteroData
    data = HeteroData()
    # Document node
    data["document"].x      = torch.ones((1,1))
    data["document"].doc_id = [doc_id]
    data["document"].y      = [summary]
    data["document"].text = text
    # Sentence nodes
    data["sentence"].x         = sent_feats
    data["sentence"].num_nodes = sent_feats.size(0)
    # Entity nodes
    data["entity"].x         = ent_feats
    data["entity"].num_nodes = ent_feats.size(0)

    # Edges: document -> sentences
    doc_sent = [(0, i) for i in range(len(sentences))]
    data["document", "has_sentence", "sentence"].edge_index = torch.tensor(doc_sent).t().contiguous()

    # Edges: sentence -> next sentence
    seq_edges = [(i, i+1) for i in range(len(sentences)-1)]
    data["sentence", "to_sentence", "sentence"].edge_index = torch.tensor(seq_edges).t().contiguous()

    # Edges: sentence ↔ entity (mentions)
    sent_ent = []
    ent_rev  = []
    for i, sdoc in enumerate(sent_docs):
        for ent in sdoc.ents:
            j = unique_entities.index(ent.text)
            sent_ent.append((i, j))
            ent_rev.append((j, i))
    if sent_ent:
        data["sentence", "mentions", "entity"].edge_index     = torch.tensor(sent_ent).t().contiguous()
        data["entity", "rev_mentions", "sentence"].edge_index = torch.tensor(ent_rev).t().contiguous()

    # Make graph undirected for bidirectional message passing
    return ToUndirected()(data)

def validate_hetero_graph(data: HeteroData):
    """
    Validates that the HeteroData graph is consumable by R-GCN/HeteroConv:
      - Each node type has 'x' of shape [num_nodes, feat_dim].
      - Each edge type has 'edge_index' of shape [2, num_edges].
      - Indices in edge_index are within valid bounds.
      - No duplicate relations.
    """
    # Node checks
    for ntype in data.node_types:
        if "x" not in data[ntype]:
            raise ValueError(f"Node '{ntype}' missing 'x' features.")
        x = data[ntype].x
        if x.ndim != 2:
            raise ValueError(f"Node '{ntype}' features must be 2D, got {x.ndim}D.")
        ncount = getattr(data[ntype], "num_nodes", x.size(0))
        if x.size(0) != ncount:
            raise ValueError(f"Node '{ntype}': num_nodes={ncount} != x.shape[0]={x.size(0)}")
    # Edge checks
    seen = set()
    for (src, rel, dst) in data.edge_types:
        if (src, rel, dst) in seen:
            raise ValueError(f"Duplicate relation {(src, rel, dst)}")
        seen.add((src, rel, dst))
        ei = data[(src, rel, dst)].edge_index
        if ei.ndim != 2 or ei.size(0) != 2:
            raise ValueError(f"Relation {(src,rel,dst)} edge_index must be [2, E], got {tuple(ei.shape)}")
        max_src = int(ei[0].max()) if ei.numel() > 0 else -1
        max_dst = int(ei[1].max()) if ei.numel() > 0 else -1
        if max_src >= data[src].num_nodes:
            raise ValueError(f"Relation {(src,rel,dst)} src index {max_src} out of bounds")
        if max_dst >= data[dst].num_nodes:
            raise ValueError(f"Relation {(src,rel,dst)} dst index {max_dst} out of bounds")
    print("✅ Graph validation passed")




The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/222k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [5]:
if __name__ == "__main__":
    # 1. pick cuda if available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # 2. read in
    df = pd.read_csv(
    "/content/sample_data/train.csv",
    engine="python",               # use the Python parser
    sep=",",                       # your delimiter
    quotechar='"',                 # standard CSV quoting
    escapechar="\\",               # allow backslash escapes
    on_bad_lines="skip",           # skip rows with parse errors
    dtype=str,                     # read everything as string to avoid type coercion issues
)
    df = df.head(1330)

    graphs = []
    for _, row in df.iterrows():
        # build on CPU, then move
        g = build_hetero_graph_from_row(row.doc_id, row.text, row.summary)

        # 3. move all the graph's tensors to GPU
        g = g.to(device)

        validate_hetero_graph(g)
        graphs.append(g)

    print(f"Built and validated {len(graphs)} heterogeneous graphs on {device}")

Using device: cuda
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation passed
✅ Graph validation 

In [6]:
print(len(graphs))

764


In [7]:
!pip install rouge_score

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=acbfb69f12761a160df716586dcb1f562b71c9cf3e2dfbf37e48cfd34f9e1100
  Stored in directory: /root/.cache/pip/wheels/1e/19/43/8a442dc83660ca25e163e1bd1f89919284ab0d0c1475475148
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


In [9]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch_geometric.loader import DataLoader
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, GATConv
from transformers import BartTokenizer, BartForConditionalGeneration
from rouge_score import rouge_scorer

# === 1) Load CSV and your pre-built 'graphs' list ===
# df = pd.read_csv("documents.csv")        # must have columns: doc_id, text, summary
summaries = [g["document"].y[0] for g in graphs]      # parallel list of gold summaries

# Assume you already have a list `graphs` of HeteroData, in the same order as df,
# where each graph has data["document"].y = [summary_str] set at build time.

# === 2) Split into train / validation ===
train_graphs, val_graphs, train_summaries, val_summaries = train_test_split(
    graphs, summaries, test_size=0.15, random_state=42
)

# === 3) Dataset for training (batch_size=1) ===
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

class GraphTrainDataset(torch.utils.data.Dataset):
    def __init__(self, graphs, summaries, tokenizer, max_len=150):
        self.graphs    = graphs
        self.summaries = summaries
        self.tokenizer = tokenizer
        self.max_len   = max_len

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        data    = self.graphs[idx]
        summary = self.summaries[idx]
        tok     = self.tokenizer(
            summary,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        labels = tok.input_ids.squeeze(0)                   # [max_len]
        labels[labels == self.tokenizer.pad_token_id] = -100 # mask PAD
        data.dec_input_ids      = labels
        data.dec_attention_mask = tok.attention_mask.squeeze(0)
        return data

train_ds     = GraphTrainDataset(train_graphs, train_summaries, tokenizer)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True)

# === 4) Model: 2‑layer GAT + BART projector/generator ===
class Graph2Seq(nn.Module):
    def __init__(self, hidden_dim=768, gat_layers=2):
        super().__init__()
        # 2‑layer heterogeneous GAT
        self.convs = nn.ModuleList([
            HeteroConv({
                ('sentence','to_sentence','sentence'): GATConv(hidden_dim, hidden_dim, add_self_loops=False),
                ('sentence','mentions','entity'):    GATConv(hidden_dim, hidden_dim, add_self_loops=False),
            }, aggr='sum')
            for _ in range(gat_layers)
        ])
        # BART head
        self.bart = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
        self.proj = nn.Linear(hidden_dim, self.bart.config.d_model)

    def forward(self, data: HeteroData):
        # 1) GAT encode
        x_dict, eidx = data.x_dict, data.edge_index_dict
        for conv in self.convs:
            x_dict = conv(x_dict, eidx)
            x_dict = {k: F.relu(v) for k, v in x_dict.items()}
        # 2) Build BART encoder inputs from 'sentence' nodes
        sent_emb   = x_dict['sentence']               # [N_sent, hidden_dim]
        enc_inputs = self.proj(sent_emb).unsqueeze(0) # [1, N_sent, d_model]
        enc_mask   = torch.ones(enc_inputs.shape[:2], device=enc_inputs.device)
        # 3) Compute loss, with PAD masked
        labels   = data.dec_input_ids.unsqueeze(0)    # [1, max_len]
        dec_mask = data.dec_attention_mask.unsqueeze(0)
        out = self.bart(
            inputs_embeds=enc_inputs,
            attention_mask=enc_mask,
            labels=labels,
            decoder_attention_mask=dec_mask,
            use_cache=False
        )
        return out.loss

    def generate(self, graph: HeteroData, max_length=150, num_beams=4):
        # identical encode
        x_dict, eidx = graph.x_dict, graph.edge_index_dict
        for conv in self.convs:
            x_dict = conv(x_dict, eidx)
            x_dict = {k: F.relu(v) for k, v in x_dict.items()}
        sent_emb   = x_dict['sentence']
        enc_inputs = self.proj(sent_emb).unsqueeze(0)
        enc_mask   = torch.ones(enc_inputs.shape[:2], device=enc_inputs.device)
        # graph‑aware generation
        return self.bart.generate(
            inputs_embeds=enc_inputs,
            attention_mask=enc_mask,
            max_length=max_length,
            num_beams=num_beams,
            early_stopping=True,
            no_repeat_ngram_size=3,
            length_penalty=1.2,
        )

# === 5) Train + Validation ===
device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model     = Graph2Seq().to(device)
optimizer = optim.Adam(model.parameters(), lr=3e-5)
scorer    = rouge_scorer.RougeScorer(["rouge1","rouge2","rougeL"], use_stemmer=True)

for epoch in range(1, 11):
    # — Training —
    model.train()
    train_loss = 0.0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        loss = model(batch)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    print(f"Epoch {epoch} | Train Loss: {train_loss/len(train_loader):.4f}")

    # — Validation (pure Python loop) —
    model.eval()
    preds, refs = [], []
    with torch.no_grad():
        for graph, gold in zip(val_graphs, val_summaries):
            graph = graph.to(device)
            gen_ids = model.generate(graph, max_length=150, num_beams=6)
            pred    = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
            preds.append(pred)
            refs.append(gold)

    # ROUGE‑1
    rouge1 = sum(scorer.score(r,p)["rouge1"].fmeasure for r,p in zip(refs, preds)) / len(refs)
    print(f"Epoch {epoch} | Val ROUGE‑1: {rouge1:.4f}")


Epoch 1 | Train Loss: 3.6121
Epoch 1 | Val ROUGE‑1: 0.2596
Epoch 2 | Train Loss: 3.0094
Epoch 2 | Val ROUGE‑1: 0.2412
Epoch 3 | Train Loss: 2.7337
Epoch 3 | Val ROUGE‑1: 0.2904
Epoch 4 | Train Loss: 2.4910
Epoch 4 | Val ROUGE‑1: 0.2682
Epoch 5 | Train Loss: 2.2957
Epoch 5 | Val ROUGE‑1: 0.2588
Epoch 6 | Train Loss: 2.1052
Epoch 6 | Val ROUGE‑1: 0.2515
Epoch 7 | Train Loss: 1.9216
Epoch 7 | Val ROUGE‑1: 0.2447
Epoch 8 | Train Loss: 1.7569
Epoch 8 | Val ROUGE‑1: 0.2650
Epoch 9 | Train Loss: 1.6031
Epoch 9 | Val ROUGE‑1: 0.1446
Epoch 10 | Train Loss: 1.4689
Epoch 10 | Val ROUGE‑1: 0.1988


In [None]:
torch.save(model.state_dict(), "graph2seq_state_dict.pth")