# Graph-based Extractive Summarizer with GAT (PyTorch Geometric)


In [5]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
!pip install transformers torch accelerate datasets rouge_score  --quiet

In [None]:
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
!pip install -q torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.1.0+cpu.html
!pip install -q torch-geometric sentence-transformers

### 🔹 Step 1: Input & Sentence Embeddings


In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from sentence_transformers import SentenceTransformer
import numpy as np

2025-06-08 12:14:32.615636: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749384872.794434      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749384872.845665      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
# Sample multi-document input
documents = [
    "The stock market crashed due to inflation concerns.",
    "Experts blame inflation for the recent market downturn.",
    "New policies are expected to ease the impact of inflation.",
    "Some investors are optimistic about the long-term growth."
]

# Sentence Embeddings
embedder = SentenceTransformer("all-MiniLM-L6-v2")
sentence_embeddings = embedder.encode(documents, convert_to_tensor=True)

In [15]:
sentence_embeddings.shape

torch.Size([4, 384])

### 🔹 Step 2: Graph Construction


In [25]:
from torch_geometric.utils import dense_to_sparse

def build_graph(x):
    N = x.size(0)
    # Create a fully connected graph without self-loops
    adj = torch.ones((N, N)) - torch.eye(N)
    edge_index, _ = dense_to_sparse(adj)
    return edge_index.long().to(x.device)

edge_index = build_graph(sentence_embeddings)

### 🔹 Step 3: GAT-based Graph Model

In [None]:
class GraphSummarizer(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.gat1 = GATConv(in_dim, hidden_dim, heads=4, concat=True)
        self.gat2 = GATConv(hidden_dim * 4, out_dim, heads=1)
        self.output = nn.Linear(out_dim, 1)

    def forward(self, x, edge_index):
        x = F.relu(self.gat1(x, edge_index))
        x = self.gat2(x, edge_index)
        return torch.sigmoid(self.output(x)).squeeze()

In [None]:
# Dataset Preparation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move embeddings to device
x = sentence_embeddings.to(device)
y = torch.tensor([1, 1, 0, 0], dtype=torch.float32).to(device)

# Graph edges
edge_index = build_graph(x).to(device)

# Final graph data
data = Data(x=x, edge_index=edge_index, y=y).to(device)

In [None]:
model = GraphSummarizer(in_dim=x.size(1), hidden_dim=64, out_dim=32).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
loss_fn = nn.BCELoss()

In [35]:
print(edge_index.shape)
print(edge_index)

torch.Size([2, 12])
tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],
        [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]], device='cuda:0')


### 🔹 Step 4: Train Model

In [36]:
model.train()
for epoch in range(50):
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = loss_fn(out, data.y)
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

Epoch 0: Loss = 0.6942
Epoch 10: Loss = 0.6932
Epoch 20: Loss = 0.6932
Epoch 30: Loss = 0.6932
Epoch 40: Loss = 0.6931


### 🔹 Inference: Select Top Sentences

In [None]:
model.eval()
with torch.no_grad():
    scores = model(data.x, data.edge_index)
    topk_idx = scores.topk(2).indices
    print("\n📝 Summary:")
    for i in topk_idx:
        print("-", documents[i])


📝 Summary:
- Experts blame inflation for the recent market downturn.
- The stock market crashed due to inflation concerns.
