In [2]:
# ====================================
# 🧩 Graph-RAG setup for JudgEx
# ====================================
import pandas as pd
import numpy as np
import torch, gc
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util

gc.collect()
torch.cuda.empty_cache()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("🖥️ Using:", device)


  from .autonotebook import tqdm as notebook_tqdm


🖥️ Using: cpu


In [3]:
# Replace the path below if graph.csv is elsewhere
df = pd.read_csv("graph.csv")
print("✅ Loaded graph.csv with shape:", df.shape)
display(df.head(3))


✅ Loaded graph.csv with shape: (12747, 5)


Unnamed: 0,filename,label,statutes,charges,facts
0,Uttarakhand_HC_2015_1701,1,"['Wild Life (Protection) Act, 1972']",['Illegal possession of animal bones'],"['The applicant was found sitting in a car.', ..."
1,Bombay_HC_BomHC_2017_1744,1,"['Value Added Tax Act 2002 Sec 45-A', 'Value A...",['Non-payment of VAT on branded tobacco'],['Petitioners sell tobacco under brand names.'...
2,Bombay_HC_BomHC_1987_396,1,"['IPC Sec 494', 'IPC Sec 109', 'IPC Sec 34', '...","['Bigamy', 'Abetment of Bigamy', 'Attempt to C...",['Complainant married Nivrutti about 7-8 years...


In [4]:
# Use a compact legal / general encoder for RTX 3050
model = SentenceTransformer("all-mpnet-base-v2", device=str(device))

fact_texts    = df["facts"].astype(str).tolist()
statute_texts = df["statutes"].astype(str).tolist()
charge_texts  = df["charges"].astype(str).tolist()

print("Encoding node texts ...")
fact_embs    = model.encode(fact_texts, batch_size=8, convert_to_tensor=True, show_progress_bar=True)
statute_embs = model.encode(statute_texts, batch_size=8, convert_to_tensor=True, show_progress_bar=True)
charge_embs  = model.encode(charge_texts, batch_size=8, convert_to_tensor=True, show_progress_bar=True)


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Encoding node texts ...


Batches: 100%|██████████| 1594/1594 [2:02:29<00:00,  4.61s/it]    
Batches: 100%|██████████| 1594/1594 [34:08<00:00,  1.29s/it] 
Batches: 100%|██████████| 1594/1594 [15:13<00:00,  1.75it/s]


In [5]:
G = nx.Graph()

for i, row in df.iterrows():
    f = f"fact_{i}"
    s = f"statute_{row['statutes']}"
    c = f"charge_{row['charges']}"

    G.add_node(f, type="fact", text=row["facts"])
    G.add_node(s, type="statute", text=str(row["statutes"]))
    G.add_node(c, type="charge", text=str(row["charges"]))

    G.add_edge(f, s)
    G.add_edge(s, c)
    G.add_edge(f, c)

print(f"✅ Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


✅ Graph built: 32459 nodes, 37491 edges


In [6]:
query = input("Enter case/fact description for retrieval:\n> ")

query_emb = model.encode(query, convert_to_tensor=True)
cos_scores = util.cos_sim(query_emb, fact_embs)[0]
top_k = torch.topk(cos_scores, k=5)

print("\n🔎 Top-5 Similar Cases:\n")
for idx, score in zip(top_k.indices, top_k.values):
    row = df.iloc[idx]
    print(f"→ Case: {row['filename']} | Similarity Score = {score:.4f}")
    print(f"Facts: {row['facts'][:200]}...\n")



🔎 Top-5 Similar Cases:



TypeError: Cannot index by location index with a non-integer key

In [None]:
deg_cent = nx.degree_centrality(G)
between_cent = nx.betweenness_centrality(G)
density = nx.density(G)

print(f"Graph Density = {density:.4f}\nTop Degree Central Nodes:")
for n, v in sorted(deg_cent.items(), key=lambda x:x[1], reverse=True)[:5]:
    print(f"{n:25} → {v:.4f}")


In [None]:
plt.figure(figsize=(10,8))
pos = nx.spring_layout(G, seed=42, k=0.5)
color_map = ['#4c72b0' if G.nodes[n]['type']=='fact'
             else '#55a868' if G.nodes[n]['type']=='statute'
             else '#c44e52' for n in G.nodes]

nx.draw(G, pos,
        node_color=color_map,
        node_size=60,
        alpha=0.8,
        with_labels=False)
plt.title("Legal Graph — Facts ↔ Statutes ↔ Charges", fontsize=13)
plt.show()


In [None]:
scores = cos_scores.cpu().numpy()
sns.histplot(scores, bins=40, kde=True, color='teal')
plt.title("Distribution of Cosine Similarity (Query vs Facts)")
plt.xlabel("Cosine Similarity")
plt.ylabel("Frequency")
plt.show()


In [None]:
mean_sim = float(np.mean(scores))
std_sim  = float(np.std(scores))
top5_mean = float(top_k.values.mean())

print(f"📊 Mean Similarity: {mean_sim:.4f}")
print(f"📈 Std Deviation: {std_sim:.4f}")
print(f"🏆 Top-5 Mean Score: {top5_mean:.4f}")


In [None]:
torch.save(fact_embs, "fact_embeddings.pt")
nx.write_gpickle(G, "legal_graph.gpickle")
print("✅ Saved graph and embeddings.")
