<a href="https://colab.research.google.com/github/Hitika-Jain/LegalTalk/blob/main/notebooks/graph_based_legalbert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# Install once (uncomment if needed)
#!pip install -U sentence-transformers faiss-cpu networkx pyvis tqdm

import os
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm
import networkx as nx
import pickle

# Optional FAISS (faster NN for larger corpora)
try:
    import faiss
    _HAS_FAISS = True
except Exception:
    _HAS_FAISS = False

# ----------------------
# Config / paths
# ----------------------
DATA_CSV = "/content/drive/MyDrive/legal_dataset/statutes.csv"
EMBEDDINGS_NPY = "/content/ipc_embeddings.npy"
SECTIONS_PKL = "/content/ipc_sections.pkl"
GRAPH_PKL = "/content/ipc_graph.pkl"
PYVIS_HTML = "/content/ipc_graph.html"

MODEL_NAME = "law-ai/InLegalBERT"   # your legal model; replace if unavailable
EMBED_BATCH = 32
SIM_THRESHOLD = 0.75    # initial threshold; you can tune this
USE_FAISS = _HAS_FAISS  # set False to avoid faiss usage

# ----------------------
# Load CSV and prepare texts
# ----------------------
df = pd.read_csv(DATA_CSV)
# adjust column names if different
df['Section'] = df.get('id', df.columns[0]).astype(str).str.strip().str.upper()
# combine fields for text (adjust fields to your CSV)
df['text'] = df.get('Description', '').fillna('') + ' ' + df.get('Offense', '').fillna('') + ' ' + df.get('Punishment', '').fillna('')

sections = df['Section'].tolist()
texts = df['text'].tolist()

print(f"Loaded {len(sections)} sections")

# ----------------------
# Build or load embeddings
# ----------------------
if os.path.exists(EMBEDDINGS_NPY) and os.path.exists(SECTIONS_PKL):
    print("Loading cached embeddings and sections...")
    embeddings = np.load(EMBEDDINGS_NPY)
    with open(SECTIONS_PKL, "rb") as f:
        sections = pickle.load(f)
    # ensure lengths match
    assert embeddings.shape[0] == len(sections)
else:
    print("Computing embeddings (this may take a while)...")
    model = SentenceTransformer(MODEL_NAME)
    embeddings = model.encode(texts, batch_size=EMBED_BATCH, show_progress_bar=True, convert_to_numpy=True)
    # normalize embeddings for cosine via dot product
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    norms[norms==0] = 1.0
    embeddings = embeddings / norms
    # save
    np.save(EMBEDDINGS_NPY, embeddings)
    with open(SECTIONS_PKL, "wb") as f:
        pickle.dump(sections, f)
    print("Saved embeddings and section list.")

# ----------------------
# Build neighbor index (FAISS) or use exact cosine
# ----------------------
if USE_FAISS:
    d = embeddings.shape[1]
    # FAISS expects float32
    xb = embeddings.astype('float32')
    index = faiss.IndexFlatIP(d)  # inner product on unit-normalized vectors = cosine similarity
    index.add(xb)
    print("FAISS index built with", index.ntotal, "vectors")
else:
    index = None
    # precompute cosine similarity matrix if small
    print("FAISS not available; will use exact cosine similarity (O(n^2) memory/time)")

# ----------------------
# Build Graph (weighted edges) - with pruning
# ----------------------
def build_graph(embeddings, sections, threshold=SIM_THRESHOLD, topk_per_node=None):
    n = embeddings.shape[0]
    G = nx.Graph()
    for i, sec in enumerate(sections):
        row = df.loc[df['Section']==sec].iloc[0] if 'df' in globals() else {}
        G.add_node(sec,
                   description=row.get('Description', '') if isinstance(row, dict) else df.loc[df['Section']==sec, 'Description'].values[0] if 'Description' in df.columns else '',
                   offense=row.get('Offense', '') if isinstance(row, dict) else (df.loc[df['Section']==sec, 'Offense'].values[0] if 'Offense' in df.columns else ''),
                   punishment=row.get('Punishment', '') if isinstance(row, dict) else (df.loc[df['Section']==sec, 'Punishment'].values[0] if 'Punishment' in df.columns else ''))
    # Efficient creation: for each node, get topk neighbors and add edges
    if USE_FAISS:
        xb = embeddings.astype('float32')
        # query all at once: k = topk_per_node+1 (including self)
        k = topk_per_node+1 if topk_per_node is not None else min(50, n)
        D, I = index.search(xb, k)  # inner product scores (cosine)
        for i in range(n):
            for score, j in zip(D[i,1:], I[i,1:]):  # skip self at pos 0
                if score >= threshold:
                    a, b = sections[i], sections[j]
                    if not G.has_edge(a,b):
                        G.add_edge(a, b, weight=float(score))
    else:
        # exact cosine matrix (symmetric): compute only upper triangle
        sim = embeddings.dot(embeddings.T)
        # optionally prune by topk_per_node to reduce edges
        if topk_per_node is None:
            for i in range(n):
                for j in range(i+1, n):
                    s = float(sim[i,j])
                    if s >= threshold:
                        G.add_edge(sections[i], sections[j], weight=s)
        else:
            for i in range(n):
                row = sim[i]
                # ignore self
                idxs = np.argsort(row)[::-1]
                taken = 0
                for j in idxs:
                    if j == i:
                        continue
                    if taken >= topk_per_node:
                        break
                    s = float(row[j])
                    if s >= threshold:
                        a,b = sections[i], sections[j]
                        if not G.has_edge(a,b):
                            G.add_edge(a,b, weight=s)
                            taken += 1
    return G

print("Building graph with threshold", SIM_THRESHOLD)
G = build_graph(embeddings, sections, threshold=SIM_THRESHOLD, topk_per_node=10)
print("Graph built. Nodes:", G.number_of_nodes(), "Edges:", G.number_of_edges())

# save graph
with open(GRAPH_PKL, "wb") as f:
    pickle.dump(G, f)
print("Saved graph to", GRAPH_PKL)

# ----------------------
# Query utilities
# ----------------------
def get_topk_similar_by_faiss(query_text, model, k=10):
    q_emb = model.encode([query_text], convert_to_numpy=True)
    q_emb = q_emb / np.linalg.norm(q_emb, axis=1, keepdims=True)
    if USE_FAISS:
        D, I = index.search(q_emb.astype('float32'), k)
        results = [(sections[i], float(D[0,idx])) for idx,i in enumerate(I[0])]
    else:
        sims = q_emb.dot(embeddings.T)[0]
        idxs = np.argsort(sims)[::-1][:k]
        results = [(sections[i], float(sims[i])) for i in idxs]
    return results

def find_related_sections(section_id, k=5):
    section_id = str(section_id).strip().upper()
    if section_id not in G:
        print(f"{section_id} not found in graph.")
        return []
    neighbors = [(nbr, G[section_id][nbr]['weight']) for nbr in G.neighbors(section_id)]
    neighbors = sorted(neighbors, key=lambda x: x[1], reverse=True)
    return neighbors[:k]

# Example usage (if you want to run interactively)
# load model for on-the-fly queries (if not already loaded)
try:
    query_model
except NameError:
    query_model = SentenceTransformer(MODEL_NAME)

print("Top similar to first section via FAISS/Exact:")
print(get_topk_similar_by_faiss(texts[0], query_model, k=5))

print("Top graph neighbors for first section:")
print(find_related_sections(sections[0], k=5))

# ----------------------
# Optional: pyvis visualization (saves html)
# ----------------------
try:
    from pyvis.network import Network
    net = Network(height="900px", width="100%", notebook=False)
    # add a subset of nodes to avoid huge HTML — take top connected components
    comps = list(nx.connected_components(G))
    # take the largest component
    largest = max(comps, key=len)
    subnodes = list(largest)[:200]  # limit to 200 nodes
    subG = G.subgraph(subnodes)
    for n in subG.nodes(data=True):
        net.add_node(n[0], title=n[1].get('description',''), label=n[0])
    for u,v,data in subG.edges(data=True):
        net.add_edge(u, v, value=data.get('weight', 1.0))
    net.force_atlas_2based()  # layout
    net.show(PYVIS_HTML)
    print("Saved interactive graph to", PYVIS_HTML)
except Exception as e:
    print("pyvis visualization skipped - not available or failed:", str(e))

# Done — graph and embeddings saved for reuse


Loaded 455 sections
Computing embeddings (this may take a while)...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/671 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/534M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/516 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/534M [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Batches:   0%|          | 0/15 [00:00<?, ?it/s]

Saved embeddings and section list.
FAISS index built with 455 vectors
Building graph with threshold 0.75




Graph built. Nodes: 454 Edges: 3522
Saved graph to /content/ipc_graph.pkl
Top similar to first section via FAISS/Exact:
[('IPC_2', 0.9999999403953552), ('IPC_3', 0.936187207698822), ('IPC_376DB', 0.9171409606933594), ('IPC_120B', 0.9170635342597961), ('IPC_195', 0.9170509576797485)]
Top graph neighbors for first section:
[('IPC_3', 0.9361872673034668), ('IPC_376DB', 0.9171410202980042), ('IPC_120B', 0.9170634746551514), ('IPC_195', 0.9170506596565247), ('IPC_511', 0.9162786602973938)]
/content/ipc_graph.html
pyvis visualization skipped - not available or failed: 'NoneType' object has no attribute 'render'


sanity checks

In [3]:
# run these to confirm everything exists and sizes look right
print("embeddings:", type(embeddings), getattr(embeddings, "shape", None))
print("num sections:", len(sections))
print("graph nodes:", G.number_of_nodes(), "edges:", G.number_of_edges())
print("FAISS index present:", 'index' in globals() and index is not None)
# sample some node metadata
sample = list(G.nodes)[:5]
for s in sample:
    print(s, G.nodes[s].get('description','')[:120])

embeddings: <class 'numpy.ndarray'> (455, 768)
num sections: 455
graph nodes: 454 edges: 3522
FAISS index present: True
IPC_2 Description of IPC Section 2, Every person shall be liable to punishment under this Code and not otherwise for every act
IPC_3 Description of IPC Section 3, 
 'Any person liable, by any Indian law to be tried for an offence committed beyond India 
IPC_4 Description of IPC Section 4, 
 'The provisions of this Code apply also to any offence committed by'
 '(1) any citizen o
IPC_5 Description of IPC Section 5, Nothing in this Act shall affect the provisions of any Act for punishing mutiny and desert
IPC_13 Description of IPC Section 13, Rep. by the A.O. 1950.


Test nearest neighbors (text query) — FAISS (if available) or fallback

In [4]:
# text query -> topk similar sections
def topk_text_query(qtext, k=5):
    q_emb = model.encode([qtext], convert_to_numpy=True)
    q_emb = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-12)
    if 'index' in globals() and index is not None:
        D, I = index.search(q_emb.astype('float32'), k)
        return [(sections[i], float(D[0,idx])) for idx,i in enumerate(I[0])]
    else:
        sims = q_emb.dot(embeddings.T)[0]
        ids = np.argsort(sims)[::-1][:k]
        return [(sections[i], float(sims[i])) for i in ids]

# try:
print(topk_text_query("dishonest inducement, cheating, fraud", k=8))


[('IPC_415', 0.6561954617500305), ('IPC_391', 0.6285791397094727), ('IPC_420', 0.6223862171173096), ('IPC_463', 0.6194421052932739), ('IPC_416', 0.617461085319519), ('IPC_464', 0.6165838241577148), ('IPC_424', 0.6155958771705627), ('IPC_109', 0.6151362061500549)]


test graph neighbour lookup

In [5]:
def show_graph_neighbors(section_id, k=8):
    section_id = str(section_id).strip().upper()
    if section_id not in G:
        return f"{section_id} not in graph"
    neigh = sorted([(n, G[section_id][n]['weight']) for n in G.neighbors(section_id)],
                   key=lambda x: x[1], reverse=True)[:k]
    return neigh

print(show_graph_neighbors('IPC_420', k=8))


[('IPC_423', 0.9680238366127014), ('IPC_330', 0.9664953947067261), ('IPC_477', 0.965897798538208), ('IPC_331', 0.964945912361145), ('IPC_348', 0.964379072189331), ('IPC_329', 0.964013397693634), ('IPC_239', 0.9638614058494568), ('IPC_327', 0.9633259773254395)]


visual sanity: degree distribution + components

In [6]:
import numpy as np
deg = np.array([d for _, d in G.degree()])
print("degree: min,median,mean,max:", deg.min(), np.median(deg), deg.mean(), deg.max())
# largest connected component size
cc = max(nx.connected_components(G), key=len)
print("largest component size:", len(cc))

degree: min,median,mean,max: 10 13.0 15.515418502202643 66
largest component size: 454


save artifacts

In [7]:
import pickle, numpy as np, os
os.makedirs('/content/graph_artifacts', exist_ok=True)
np.save('/content/graph_artifacts/ipc_embeddings.npy', embeddings)
with open('/content/graph_artifacts/ipc_sections.pkl', 'wb') as f:
    pickle.dump(sections, f)
with open('/content/graph_artifacts/ipc_graph.pkl', 'wb') as f:
    pickle.dump(G, f)
print("Saved embeddings, sections and graph to /content/graph_artifacts")

Saved embeddings, sections and graph to /content/graph_artifacts
