In [1]:
import pandas as pd
import numpy as np
import torch
import json
import ast
from tqdm import tqdm
from torch_geometric.data import Data

# === Load everything ===
df = pd.read_csv("graph.csv")
for col in ['facts', 'statutes', 'charges']:
    df[col] = df[col].apply(ast.literal_eval)

embeddings = np.load("node_embeddings_384.npy")
with open("node_index.json") as f:
    node_index = json.load(f)

# === Create a global ID map ===
id_map = {}
counter = 0
for t in ['case', 'fact', 'statute', 'charge']:
    for node in node_index[t]:
        id_map[node] = counter
        counter += 1

edges = []
edge_types = []

for _, row in tqdm(df.iterrows(), total=len(df), desc="Building edges"):
    c = row['filename']
    case_id = id_map[c]

    # Connect case to facts
    for f in row['facts']:
        if f in id_map:
            edges.append([case_id, id_map[f]])
            edge_types.append('has_fact')

    # Connect case to statutes
    for s in row['statutes']:
        if s in id_map:
            edges.append([case_id, id_map[s]])
            edge_types.append('refers_to')

    # Connect case to charges
    for ch in row['charges']:
        if ch in id_map:
            edges.append([case_id, id_map[ch]])
            edge_types.append('charged_under')

# Convert edges
edge_index = torch.tensor(np.array(edges).T, dtype=torch.long)
x = torch.tensor(embeddings, dtype=torch.float)

# === Label Encoding ===
y = torch.tensor(df['label'].values, dtype=torch.long)  # if label column exists per case

# === Build PyG Graph ===
data = Data(x=x, edge_index=edge_index, y=y)
torch.save(data, "global_graph.pt")

print("✅ Saved PyTorch graph as global_graph.pt")


  from .autonotebook import tqdm as notebook_tqdm
Building edges: 100%|██████████| 12747/12747 [00:01<00:00, 10558.62it/s]


✅ Saved PyTorch graph as global_graph.pt


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import warnings

warnings.filterwarnings("ignore")

# ==========================================
# 1️⃣  Safe load of your saved graph
# ==========================================
torch.serialization.add_safe_globals([Data])  # allowlist PyG Data class

data = torch.load("global_graph.pt", weights_only=False)
print(f"✅ Graph loaded: {data}")

# ==========================================
# 2️⃣  Ensure labels & features match
# ==========================================
num_nodes = data.x.size(0)
print(f"Node features: {num_nodes}, Labels: {len(data.y)}")

# pad labels if unlabeled nodes exist
if len(data.y) < num_nodes:
    y_new = torch.full((num_nodes,), -1, dtype=torch.long)
    y_new[:len(data.y)] = data.y
    data.y = y_new

# ==========================================
# 3️⃣  Create train/test masks for labeled nodes
# ==========================================
labeled_mask = data.y >= 0
labeled_indices = torch.nonzero(labeled_mask, as_tuple=False).squeeze()
num_labeled = len(labeled_indices)
train_size = int(0.8 * num_labeled)

perm = torch.randperm(num_labeled)
train_idx = labeled_indices[perm[:train_size]]
test_idx = labeled_indices[perm[train_size:]]

data.train_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.test_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.train_mask[train_idx] = True
data.test_mask[test_idx] = True

print(f"Train nodes: {data.train_mask.sum().item()}, Test nodes: {data.test_mask.sum().item()}")

# ==========================================
# 4️⃣  Detect number of classes automatically
# ==========================================
unique_labels = torch.unique(data.y[data.y >= 0])
num_classes = int(unique_labels.max().item() + 1)
print(f"Detected {num_classes} classes:", unique_labels.tolist())

# ==========================================
# 5️⃣  Define GAT model
# ==========================================
class LegalGAT(nn.Module):
    def __init__(self, in_dim=384, hidden=128, out_dim=num_classes, heads=3):
        super().__init__()
        self.gat1 = GATConv(in_dim, hidden, heads=heads, dropout=0.6)
        self.gat2 = GATConv(hidden * heads, out_dim, heads=1, concat=False, dropout=0.6)

    def forward(self, x, edge_index):
        x = F.elu(self.gat1(x, edge_index))
        x = self.gat2(x, edge_index)
        return F.log_softmax(x, dim=1)

model = LegalGAT()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

# use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data, model = data.to(device), model.to(device)

# ==========================================
# 6️⃣  Training loop
# ==========================================
print("\n🚀 Training started...\n")
for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 5 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:02d} | Loss: {loss.item():.4f}")

# ==========================================
# 7️⃣  Evaluation
# ==========================================
model.eval()
with torch.no_grad():
    logits = model(data.x, data.edge_index)
    preds = logits[data.test_mask].argmax(dim=1).cpu()
    true = data.y[data.test_mask].cpu()

acc  = accuracy_score(true, preds)
prec = precision_score(true, preds, average='macro', zero_division=0)
rec  = recall_score(true, preds, average='macro', zero_division=0)
f1   = f1_score(true, preds, average='macro', zero_division=0)

print("\n📊 Evaluation Metrics")
print(f"Accuracy  : {acc:.4f}")
print(f"Precision : {prec:.4f}")
print(f"Recall    : {rec:.4f}")
print(f"F1-score  : {f1:.4f}")

# ==========================================
# 8️⃣  Save model
# ==========================================
torch.save(model.state_dict(), "legal_gat_model.pt")
print("\n✅ Model trained and saved as 'legal_gat_model.pt'")


✅ Graph loaded: Data(x=[132695, 384], edge_index=[2, 169609], y=[12747])
Node features: 132695, Labels: 12747
Train nodes: 10197, Test nodes: 2550
Detected 3 classes: [0, 1, 2]

🚀 Training started...

Epoch 01 | Loss: 1.8666
Epoch 05 | Loss: 1.2390
Epoch 10 | Loss: 1.0717
Epoch 15 | Loss: 0.9751
Epoch 20 | Loss: 0.9191
Epoch 25 | Loss: 0.8822
Epoch 30 | Loss: 0.8545
Epoch 35 | Loss: 0.8346
Epoch 40 | Loss: 0.8165
Epoch 45 | Loss: 0.7907
Epoch 50 | Loss: 0.7751
Epoch 55 | Loss: 0.7563
Epoch 60 | Loss: 0.7356
Epoch 65 | Loss: 0.7183
Epoch 70 | Loss: 0.7089
Epoch 75 | Loss: 0.7013
Epoch 80 | Loss: 0.6853
Epoch 85 | Loss: 0.6775
Epoch 90 | Loss: 0.6702
Epoch 95 | Loss: 0.6603
Epoch 100 | Loss: 0.6590

📊 Evaluation Metrics
Accuracy  : 0.5024
Precision : 0.3341
Recall    : 0.3499
F1-score  : 0.3315

✅ Model trained and saved as 'legal_gat_model.pt'


In [5]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

model.eval()
preds = model(data.x, data.edge_index).argmax(dim=1).cpu().numpy()
true = data.y.cpu().numpy()

acc = accuracy_score(true, preds)
prec = precision_score(true, preds, average='macro')
rec = recall_score(true, preds, average='macro')
f1 = f1_score(true, preds, average='macro')

print(f"✅ Accuracy : {acc:.4f}")
print(f"✅ Precision: {prec:.4f}")
print(f"✅ Recall   : {rec:.4f}")
print(f"✅ F1-score : {f1:.4f}")
print("\nConfusion Matrix:\n", confusion_matrix(true, preds))


AttributeError: 'tuple' object has no attribute 'argmax'

In [None]:
att_weights = model.gat1.att_src
print("Attention Weights Shape:", att_weights.shape)


In [10]:
# ===========================================================
# Smart Legal Judgment Prediction – Multi-Label GAT
# Predicts Charges & Statutes based on Case Facts
# ===========================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score, hamming_loss, precision_score, recall_score
import pandas as pd, ast, warnings, numpy as np
warnings.filterwarnings("ignore")

# ==========================================
# 1️⃣  Load your pre-built graph
# ==========================================
torch.serialization.add_safe_globals([Data])
data = torch.load("global_graph.pt", weights_only=False)
print(f"✅ Graph loaded: {data}")

num_nodes = data.x.size(0)
print(f"Node features: {num_nodes}, Labels: {len(data.y)}")

# Pad labels (for completeness)
if len(data.y) < num_nodes:
    y_new = torch.full((num_nodes,), -1, dtype=torch.long)
    y_new[:len(data.y)] = data.y
    data.y = y_new

# ==========================================
# 2️⃣  Load CSV for multi-label targets
# ==========================================
df = pd.read_csv("graph.csv")
df["charges"]  = df["charges"].apply(ast.literal_eval)
df["statutes"] = df["statutes"].apply(ast.literal_eval)

charge_bin  = MultiLabelBinarizer()
statute_bin = MultiLabelBinarizer()
Y_charge  = torch.tensor(charge_bin.fit_transform(df["charges"]),  dtype=torch.float)
Y_statute = torch.tensor(statute_bin.fit_transform(df["statutes"]), dtype=torch.float)

print(f"Charges vector:  {Y_charge.shape}")
print(f"Statutes vector: {Y_statute.shape}")

# ==========================================
# 3️⃣  Train/Test masks (case-level)
# ==========================================
num_cases = Y_charge.shape[0]
perm = torch.randperm(num_cases)
train_size = int(0.8 * num_cases)
train_idx = perm[:train_size]
test_idx  = perm[train_size:]

data.train_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.test_mask  = torch.zeros(num_nodes, dtype=torch.bool)
data.train_mask[train_idx] = True
data.test_mask[test_idx]  = True

print(f"Train: {data.train_mask.sum().item()} | Test: {data.test_mask.sum().item()}")

# ==========================================
# 4️⃣  Define multi-head LegalGAT
# ==========================================
class MultiHeadLegalGAT(nn.Module):
    def __init__(self, in_dim=384, hidden=128,
                 out_charges=Y_charge.shape[1],
                 out_statutes=Y_statute.shape[1],
                 heads=3):
        super().__init__()
        self.gat1 = GATConv(in_dim, hidden, heads=heads, dropout=0.5)
        self.gat2_charge  = GATConv(hidden * heads, out_charges, heads=1, concat=False, dropout=0.5)
        self.gat2_statute = GATConv(hidden * heads, out_statutes, heads=1, concat=False, dropout=0.5)

    def forward(self, x, edge_index):
        x = F.elu(self.gat1(x, edge_index))
        out_charge  = torch.sigmoid(self.gat2_charge(x, edge_index))
        out_statute = torch.sigmoid(self.gat2_statute(x, edge_index))
        return out_charge, out_statute

# ==========================================
# 5️⃣  Setup model, optimizer, device
# ==========================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultiHeadLegalGAT().to(device)
data  = data.to(device)
Y_charge, Y_statute = Y_charge.to(device), Y_statute.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.003, weight_decay=5e-4)
criterion = nn.BCELoss()

# ==========================================
# 6️⃣  Training loop
# ==========================================
print("\n🚀 Training Multi-Label GAT...\n")
for epoch in range(40):
    model.train()
    optimizer.zero_grad()
    out_c, out_s = model(data.x, data.edge_index)

    loss_c = criterion(out_c[data.train_mask],  Y_charge[data.train_mask])
    loss_s = criterion(out_s[data.train_mask],  Y_statute[data.train_mask])
    loss = loss_c + loss_s

    loss.backward()
    optimizer.step()

    if (epoch + 1) % 5 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:02d} | Total Loss: {loss.item():.4f} "
              f"(Charges: {loss_c.item():.4f}, Statutes: {loss_s.item():.4f})")

# ==========================================
# 7️⃣  Evaluation
# ==========================================
model.eval()
with torch.no_grad():
    pred_c, pred_s = model(data.x, data.edge_index)
    pred_c = (pred_c[data.test_mask] > 0.5).float().cpu().numpy()
    pred_s = (pred_s[data.test_mask] > 0.5).float().cpu().numpy()
    true_c = Y_charge[data.test_mask].cpu().numpy()
    true_s = Y_statute[data.test_mask].cpu().numpy()

def evaluate_multilabel(true, pred, name):
    ham = hamming_loss(true, pred)
    micro_f1 = f1_score(true, pred, average='micro', zero_division=0)
    macro_f1 = f1_score(true, pred, average='macro', zero_division=0)
    prec = precision_score(true, pred, average='micro', zero_division=0)
    rec  = recall_score(true, pred, average='micro', zero_division=0)

    print(f"\n📊 {name} Prediction Results")
    print(f"Hamming Loss : {ham:.4f}")
    print(f"Precision    : {prec:.4f}")
    print(f"Recall       : {rec:.4f}")
    print(f"Micro-F1     : {micro_f1:.4f}")
    print(f"Macro-F1     : {macro_f1:.4f}")

evaluate_multilabel(true_c, pred_c, "Charge")
evaluate_multilabel(true_s, pred_s, "Statute")

# ==========================================
# 8️⃣  Save model
# ==========================================
torch.save(model.state_dict(), "multi_label_legal_gat.pt")
print("\n✅ Model trained & saved as 'multi_label_legal_gat.pt'")


✅ Graph loaded: Data(x=[132695, 384], edge_index=[2, 169609], y=[12747])
Node features: 132695, Labels: 12747
Charges vector:  torch.Size([12747, 9401])
Statutes vector: torch.Size([12747, 27959])
Train: 10197 | Test: 2550

🚀 Training Multi-Label GAT...



RuntimeError: [enforce fail at alloc_cpu.cpp:121] data. DefaultCPUAllocator: not enough memory: you tried to allocate 11367839616 bytes.

In [4]:
# ===============================================================
# ⚖️ Smart Legal Judgment Prediction – CPU-Friendly Weighted GAT
# Multi-Label (Charges + Statutes)
# ===============================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score, hamming_loss, precision_score, recall_score
import pandas as pd, ast, numpy as np, warnings
warnings.filterwarnings("ignore")

# ===============================================================
# 1️⃣ Load Graph
# ===============================================================
torch.serialization.add_safe_globals([Data])
data = torch.load("global_graph.pt", weights_only=False)
print(f"✅ Graph loaded with {data.num_nodes} nodes and {data.num_edges} edges.")

num_nodes = data.x.size(0)
if not hasattr(data, "y") or data.y is None:
    data.y = torch.zeros(num_nodes, dtype=torch.long)
elif len(data.y) < num_nodes:
    y_new = torch.full((num_nodes,), -1, dtype=torch.long)
    y_new[:len(data.y)] = data.y
    data.y = y_new

# ===============================================================
# 2️⃣ Load and Preprocess CSV
# ===============================================================
df = pd.read_csv("graph.csv")
df["charges"]  = df["charges"].apply(ast.literal_eval)
df["statutes"] = df["statutes"].apply(ast.literal_eval)

# keep only top frequent labels
top_charges  = df["charges"].explode().value_counts().head(300).index
top_statutes = df["statutes"].explode().value_counts().head(300).index
df["charges"]  = df["charges"].apply(lambda x: [c for c in x if c in top_charges])
df["statutes"] = df["statutes"].apply(lambda x: [s for s in x if s in top_statutes])

charge_bin  = MultiLabelBinarizer()
statute_bin = MultiLabelBinarizer()
Y_charge  = torch.tensor(charge_bin.fit_transform(df["charges"]),  dtype=torch.float)
Y_statute = torch.tensor(statute_bin.fit_transform(df["statutes"]), dtype=torch.float)

print(f"Charges vector:  {Y_charge.shape}")
print(f"Statutes vector: {Y_statute.shape}")

# ===============================================================
# 3️⃣ Node ↔ Case Index Mapping
# ===============================================================
case_count = len(df)
case_indices = torch.arange(case_count)  # assume first N nodes = cases

# train/test split
perm = torch.randperm(case_count)
train_size = int(0.8 * case_count)
train_idx = perm[:train_size]
test_idx  = perm[train_size:]

data.train_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.test_mask  = torch.zeros(num_nodes, dtype=torch.bool)
data.train_mask[case_indices[train_idx]] = True
data.test_mask[case_indices[test_idx]]   = True

print(f"Train cases: {len(train_idx)} | Test cases: {len(test_idx)}")

# ===============================================================
# 4️⃣ Lightweight Multi-Head GAT
# ===============================================================
class MultiHeadLegalGAT(nn.Module):
    def __init__(self, in_dim=384, hidden=64,
                 out_charges=Y_charge.shape[1],
                 out_statutes=Y_statute.shape[1],
                 heads=2):
        super().__init__()
        self.gat1 = GATConv(in_dim, hidden, heads=heads, dropout=0.6)
        self.gat2_charge  = GATConv(hidden * heads, out_charges, heads=1, concat=False, dropout=0.6)
        self.gat2_statute = GATConv(hidden * heads, out_statutes, heads=1, concat=False, dropout=0.6)

    def forward(self, x, edge_index):
        x = F.elu(self.gat1(x, edge_index))
        out_c = self.gat2_charge(x, edge_index)
        out_s = self.gat2_statute(x, edge_index)
        return out_c, out_s  # logits (no sigmoid!)

# ===============================================================
# 5️⃣ Setup & Weighted Loss
# ===============================================================
device = torch.device("cpu")
model = MultiHeadLegalGAT().to(device)
data  = data.to(device)
Y_charge, Y_statute = Y_charge.to(device), Y_statute.to(device)

# class imbalance weights
pos_weight_c = (Y_charge.numel() - Y_charge.sum()) / (Y_charge.sum() + 1e-8)
pos_weight_s = (Y_statute.numel() - Y_statute.sum()) / (Y_statute.sum() + 1e-8)
criterion_c = nn.BCEWithLogitsLoss(pos_weight=pos_weight_c)
criterion_s = nn.BCEWithLogitsLoss(pos_weight=pos_weight_s)
sigmoid = nn.Sigmoid()

optimizer = torch.optim.Adam(model.parameters(), lr=0.003, weight_decay=5e-4)

# print density for sanity check
print(f"\n🔍 charge label density: {Y_charge.mean().item():.6f}")
print(f"🔍 statute label density: {Y_statute.mean().item():.6f}\n")

# ===============================================================
# 6️⃣ Training Loop (mini-batch, CPU-friendly)
# ===============================================================
print("🚀 Training (Weighted BCE mode)...\n")
batch_size = 512
indices = torch.arange(case_count)

for epoch in range(500):
    model.train()
    total_loss = 0.0

    for i in range(0, len(indices), batch_size):
        batch_idx = indices[i:i+batch_size]
        optimizer.zero_grad()

        logits_c, logits_s = model(data.x, data.edge_index)
        logits_c, logits_s = logits_c[case_indices], logits_s[case_indices]

        loss_c = criterion_c(logits_c[batch_idx], Y_charge[batch_idx])
        loss_s = criterion_s(logits_s[batch_idx], Y_statute[batch_idx])
        loss = loss_c + loss_s
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    if (epoch + 1) % 5 == 0 or epoch == 0:
        avg_loss = total_loss / (len(indices) // batch_size + 1)
        print(f"Epoch {epoch+1:02d} | Avg Loss: {avg_loss:.4f}")

# ===============================================================
# 7️⃣ Evaluation
# ===============================================================
model.eval()
with torch.no_grad():
    logits_c, logits_s = model(data.x, data.edge_index)
    prob_c, prob_s = sigmoid(logits_c[case_indices]), sigmoid(logits_s[case_indices])

true_c, true_s = Y_charge.cpu().numpy(), Y_statute.cpu().numpy()
prob_c, prob_s = prob_c.cpu().numpy(), prob_s.cpu().numpy()

def evaluate_multilabel(true, prob, name, threshold=0.2):
    pred = (prob > threshold).astype(int)
    ham = hamming_loss(true, pred)
    micro_f1 = f1_score(true, pred, average='micro', zero_division=0)
    macro_f1 = f1_score(true, pred, average='macro', zero_division=0)
    prec = precision_score(true, pred, average='micro', zero_division=0)
    rec  = recall_score(true, pred, average='micro', zero_division=0)
    print(f"\n📊 {name} Prediction Results @thr={threshold}")
    print(f"Hamming Loss : {ham:.4f}")
    print(f"Precision    : {prec:.4f}")
    print(f"Recall       : {rec:.4f}")
    print(f"Micro-F1     : {micro_f1:.4f}")
    print(f"Macro-F1     : {macro_f1:.4f}")

for thr in [0.3, 0.2, 0.1]:
    evaluate_multilabel(true_c[test_idx], prob_c[test_idx], "Charge", threshold=thr)
    evaluate_multilabel(true_s[test_idx], prob_s[test_idx], "Statute", threshold=thr)

# ===============================================================
# 8️⃣ Save Model
# ===============================================================
torch.save(model.state_dict(), "legal_gat_weighted.pt")
print("\n✅ Model trained & saved as 'legal_gat_weighted.pt'")


✅ Graph loaded with 132695 nodes and 169609 edges.
Charges vector:  torch.Size([12747, 300])
Statutes vector: torch.Size([12747, 300])
Train cases: 10197 | Test cases: 2550

🔍 charge label density: 0.003923
🔍 statute label density: 0.004660

🚀 Training (Weighted BCE mode)...

Epoch 01 | Avg Loss: 3.4298
Epoch 05 | Avg Loss: 2.5896
Epoch 10 | Avg Loss: 2.3077
Epoch 15 | Avg Loss: 2.2184
Epoch 20 | Avg Loss: 2.1551
Epoch 25 | Avg Loss: 2.1208
Epoch 30 | Avg Loss: 2.1038
Epoch 35 | Avg Loss: 2.0913
Epoch 40 | Avg Loss: 2.0760
Epoch 45 | Avg Loss: 2.0709
Epoch 50 | Avg Loss: 2.0679
Epoch 55 | Avg Loss: 2.0677
Epoch 60 | Avg Loss: 2.0505
Epoch 65 | Avg Loss: 2.0572
Epoch 70 | Avg Loss: 2.0552
Epoch 75 | Avg Loss: 2.0496
Epoch 80 | Avg Loss: 2.0457
Epoch 85 | Avg Loss: 2.0492
Epoch 90 | Avg Loss: 2.0379
Epoch 95 | Avg Loss: 2.0438
Epoch 100 | Avg Loss: 2.0445
Epoch 105 | Avg Loss: 2.0381
Epoch 110 | Avg Loss: 2.0441
Epoch 115 | Avg Loss: 2.0377
Epoch 120 | Avg Loss: 2.0370
Epoch 125 | Avg Lo

In [2]:
# 1️⃣ are there any positive labels at all?
print("charge label density:", Y_charge.sum().item() / Y_charge.numel())
print("statute label density:", Y_statute.sum().item() / Y_statute.numel())

# 2️⃣ how confident are predictions?
out_c, out_s = model(data.x, data.edge_index)
print(out_c.min().item(), out_c.max().item())

# 3️⃣ fraction of positives predicted
print("pred positives charge:", (out_c > 0.3).float().mean().item())
print("pred positives statute:", (out_s > 0.3).float().mean().item())


charge label density: 0.00392327606495646
statute label density: 0.004660442980047593
0.04717245325446129 0.18653100728988647
pred positives charge: 0.0
pred positives statute: 0.0
