In [2]:
# =====================================================
# 🔧 1️⃣ SETUP
# =====================================================
!pip install transformers torch tqdm scikit-learn pandas --quiet

import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from torch.optim import AdamW
from sklearn.metrics import f1_score, precision_score, recall_score
from tqdm import tqdm
import numpy as np

# =====================================================
# 📂 2️⃣ LOAD CSV (uploaded or in /content)
# =====================================================
csv_file = "/content/encoder_clean.csv"   # upload to this path in Colab
df = pd.read_csv(csv_file)
print(f"✅ Loaded {csv_file} with shape:", df.shape)

# =====================================================
# ⚙️ 3️⃣ CONFIG
# =====================================================
MODEL_NAME = "law-ai/InLegalBERT"   # or 'bert-base-uncased' if you hit GPU RAM issues
EPOCHS = 2
BATCH_SIZE = 4
LR = 2e-5
MAX_LEN = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("🖥️ Using device:", DEVICE)

# =====================================================
# 🧹 4️⃣ UTILITIES
# =====================================================
def to_list_safe(val):
    """Converts almost anything to a list of numbers if possible."""
    if pd.isna(val):
        return [0]
    try:
        parsed = eval(str(val))
        if isinstance(parsed, list):
            return parsed
        elif isinstance(parsed, (int, float)):
            return [parsed]
        elif isinstance(parsed, str) and "," in parsed:
            return [float(x.strip()) for x in parsed.split(",") if x.strip()]
        else:
            return list(parsed)
    except Exception:
        return [0]

# =====================================================
# 📘 5️⃣ DATASET
# =====================================================
class LegalDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=256):
        self.df = df
        self.tokenizer = tokenizer
        self.max_len = max_len

        # Build vocab (limit large ones for Colab)
        all_labels, all_charges = set(), set()
        for _, row in df.iterrows():
            val1, val2 = str(row.get("label", "")), str(row.get("charges", ""))
            if not any(c.isdigit() for c in val1):
                all_labels.update([x.strip(" []'\"") for x in val1.split(",") if x.strip()])
            if not any(c.isdigit() for c in val2):
                all_charges.update([x.strip(" []'\"") for x in val2.split(",") if x.strip()])

        all_charges = list(all_charges)[:300]
        self.label_vocab = {lbl: i for i, lbl in enumerate(sorted(all_labels))} or {"dummy": 0}
        self.charge_vocab = {chg: i for i, chg in enumerate(sorted(all_charges))} or {"dummy": 0}
        print(f"🧾 Label vocab size: {len(self.label_vocab)} | Charge vocab size: {len(self.charge_vocab)}")

    def encode_labels(self, text, vocab):
        """Always returns a valid 0–1 vector."""
        text = str(text).strip()
        vec = torch.zeros(len(vocab), dtype=torch.float)

        # textual labels
        if any(c.isalpha() for c in text):
            for item in [x.strip(" []'\"") for x in text.split(",") if x.strip()]:
                if item in vocab:
                    vec[vocab[item]] = 1.0
        # numeric vectors
        else:
            try:
                vals = to_list_safe(text)
                vals = [float(x) for x in vals if str(x).replace('.', '', 1).isdigit()]
                vals = torch.tensor(vals, dtype=torch.float)
                vals = torch.clamp(vals, 0, 1)
                vec[: min(len(vals), len(vec))] = vals[: min(len(vals), len(vec))]
            except Exception:
                pass
        return vec

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text_inputs = self.tokenizer(str(row["text"]), padding="max_length",
                                     truncation=True, max_length=self.max_len, return_tensors="pt")
        stat_inputs = self.tokenizer(str(row["statutes"]), padding="max_length",
                                     truncation=True, max_length=self.max_len, return_tensors="pt")
        fact_inputs = self.tokenizer(str(row["facts"]), padding="max_length",
                                     truncation=True, max_length=self.max_len, return_tensors="pt")
        return {
            "text": {k: v.squeeze(0) for k, v in text_inputs.items()},
            "statute": {k: v.squeeze(0) for k, v in stat_inputs.items()},
            "fact": {k: v.squeeze(0) for k, v in fact_inputs.items()},
            "statute_labels": self.encode_labels(row.get("label", ""), self.label_vocab),
            "charge_labels": self.encode_labels(row.get("charges", ""), self.charge_vocab)
        }

# =====================================================
# 🧠 6️⃣ MODEL
# =====================================================
class CrossAttentionLegalModel(nn.Module):
    def __init__(self, model_name, hidden_dim=768, num_statutes=100, num_charges=50):
        super().__init__()
        self.encoder_text = AutoModel.from_pretrained(model_name)
        self.encoder_stat = AutoModel.from_pretrained(model_name)
        self.encoder_fact = AutoModel.from_pretrained(model_name)

        self.cross_attn_1 = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4, dropout=0.1)
        self.cross_attn_2 = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4, dropout=0.1)

        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        self.out_statute = nn.Linear(hidden_dim, num_statutes)
        self.out_charge = nn.Linear(hidden_dim, num_charges)

    def forward(self, text_inputs, statute_inputs, fact_inputs):
        text_out = self.encoder_text(**text_inputs).last_hidden_state.mean(dim=1)
        stat_out = self.encoder_stat(**statute_inputs).last_hidden_state.mean(dim=1)
        fact_out = self.encoder_fact(**fact_inputs).last_hidden_state.mean(dim=1)

        attn1, _ = self.cross_attn_1(text_out.unsqueeze(0), stat_out.unsqueeze(0), stat_out.unsqueeze(0))
        attn2, _ = self.cross_attn_2(text_out.unsqueeze(0), fact_out.unsqueeze(0), fact_out.unsqueeze(0))

        fused = torch.cat([attn1.squeeze(0), attn2.squeeze(0), text_out], dim=-1)
        fused = self.fusion(fused)
        return torch.sigmoid(self.out_statute(fused)), torch.sigmoid(self.out_charge(fused))

# =====================================================
# 📊 7️⃣ METRICS
# =====================================================
def compute_metrics(y_true_stat, y_pred_stat, y_true_chg, y_pred_chg, thr=0.3):
    y_pred_stat = (y_pred_stat > thr).astype(int)
    y_pred_chg = (y_pred_chg > thr).astype(int)
    return {
        "Statute_F1": f1_score(y_true_stat, y_pred_stat, average="macro", zero_division=0),
        "Charge_F1": f1_score(y_true_chg, y_pred_chg, average="macro", zero_division=0),
        "Statute_Prec": precision_score(y_true_stat, y_pred_stat, average="macro", zero_division=0),
        "Charge_Prec": precision_score(y_true_chg, y_pred_chg, average="macro", zero_division=0),
        "Statute_Rec": recall_score(y_true_stat, y_pred_stat, average="macro", zero_division=0),
        "Charge_Rec": recall_score(y_true_chg, y_pred_chg, average="macro", zero_division=0)
    }

# =====================================================
# 🚀 8️⃣ TRAINING
# =====================================================
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
dataset = LegalDataset(df, tokenizer, MAX_LEN)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

num_statutes, num_charges = len(dataset.label_vocab), len(dataset.charge_vocab)
model = CrossAttentionLegalModel(MODEL_NAME, num_statutes=num_statutes, num_charges=num_charges).to(DEVICE)
optimizer = AdamW(model.parameters(), lr=LR)
criterion = nn.BCELoss()

model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    all_y_stat_true, all_y_stat_pred, all_y_chg_true, all_y_chg_pred = [], [], [], []

    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        optimizer.zero_grad()
        text_inputs = {k: v.to(DEVICE) for k, v in batch["text"].items()}
        statute_inputs = {k: v.to(DEVICE) for k, v in batch["statute"].items()}
        fact_inputs = {k: v.to(DEVICE) for k, v in batch["fact"].items()}
        y_stat, y_chg = batch["statute_labels"].to(DEVICE), batch["charge_labels"].to(DEVICE)

        pred_stat, pred_chg = model(text_inputs, statute_inputs, fact_inputs)
        loss = criterion(pred_stat, y_stat) + criterion(pred_chg, y_chg)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        all_y_stat_true.extend(y_stat.cpu().numpy())
        all_y_stat_pred.extend(pred_stat.detach().cpu().numpy())
        all_y_chg_true.extend(y_chg.cpu().numpy())
        all_y_chg_pred.extend(pred_chg.detach().cpu().numpy())

    metrics = compute_metrics(
        np.array(all_y_stat_true), np.array(all_y_stat_pred),
        np.array(all_y_chg_true), np.array(all_y_chg_pred)
    )
    print(f"\nEpoch {epoch+1} | Avg Loss: {total_loss/len(dataloader):.4f}")
    for k, v in metrics.items():
        print(f"{k}: {v:.4f}")

print("\n✅ Training complete!")

# =====================================================
# 💾 9️⃣ SAVE MODEL (optional download)
# =====================================================
torch.save(model.state_dict(), "legal_crossattention.pt")
print("💾 Model saved as legal_crossattention.pt")


✅ Loaded /content/encoder_clean.csv with shape: (12747, 6)
🖥️ Using device: cuda
🧾 Label vocab size: 1 | Charge vocab size: 300


Epoch 1/2: 100%|██████████| 3187/3187 [35:47<00:00,  1.48it/s]



Epoch 1 | Avg Loss: 0.7086
Statute_F1: 0.3754
Charge_F1: 0.0003
Statute_Prec: 0.5526
Charge_Prec: 0.0002
Statute_Rec: 0.5010
Charge_Rec: 0.0320


Epoch 2/2: 100%|██████████| 3187/3187 [35:40<00:00,  1.49it/s]



Epoch 2 | Avg Loss: 0.6597
Statute_F1: 0.3739
Charge_F1: 0.0000
Statute_Prec: 0.6274
Charge_Prec: 0.0000
Statute_Rec: 0.5012
Charge_Rec: 0.0000

✅ Training complete!
💾 Model saved as legal_crossattention.pt


In [None]:
# =====================================================
# ⚖️ Smart Legal Judgment Prediction (Cross-Attention – Charge Only)
# =====================================================

!pip install transformers torch tqdm scikit-learn pandas --quiet

import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from torch.optim import AdamW
from sklearn.metrics import f1_score, precision_score, recall_score
from tqdm import tqdm
import numpy as np
import gc

# =====================================================
# 1️⃣ CONFIG
# =====================================================
MODEL_NAME = "law-ai/InLegalBERT"   # switch to 'bert-base-uncased' if VRAM low
EPOCHS = 3
BATCH_SIZE = 4
LR = 2e-5
MAX_LEN = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("🖥️ Using device:", DEVICE)

# =====================================================
# 2️⃣ LOAD CSV
# =====================================================
csv_file = "/content/encoder_clean.csv"   # upload to this path
df = pd.read_csv(csv_file)
print(f"✅ Loaded {csv_file} with shape {df.shape}")
df = df.fillna("")

print("\n📋 Columns found:", list(df.columns))
print("🔍 Sample charges:", df["charges"].iloc[:3].tolist())

# =====================================================
# 3️⃣ UTILITIES
# =====================================================
def to_list_safe(val):
    """Convert anything to list of strings safely."""
    if pd.isna(val):
        return []
    try:
        parsed = eval(str(val))
        if isinstance(parsed, list):
            return [str(x).strip() for x in parsed if str(x).strip()]
        elif isinstance(parsed, (int, float, str)):
            return [str(parsed).strip()]
        else:
            return list(parsed)
    except Exception:
        return [str(val)]

# =====================================================
# 4️⃣ DATASET (Charge Only)
# =====================================================
class LegalDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=256):
        self.df = df
        self.tokenizer = tokenizer
        self.max_len = max_len

        # Build charge vocabulary (limit to top 300 frequent)
        all_charges = df["charges"].explode().astype(str)
        top_charges = all_charges.value_counts().head(300).index
        self.charge_vocab = {chg: i for i, chg in enumerate(sorted(top_charges))}
        print(f"🧾 Charge vocab size: {len(self.charge_vocab)}")

    def encode_labels(self, charge_list):
        """Convert list of charges into 0–1 multi-hot vector."""
        vec = torch.zeros(len(self.charge_vocab), dtype=torch.float)
        for ch in to_list_safe(charge_list):
            if ch in self.charge_vocab:
                vec[self.charge_vocab[ch]] = 1.0
        return vec

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text_inputs = self.tokenizer(
            str(row["text"]), padding="max_length", truncation=True,
            max_length=self.max_len, return_tensors="pt"
        )
        stat_inputs = self.tokenizer(
            str(row["statutes"]), padding="max_length", truncation=True,
            max_length=self.max_len, return_tensors="pt"
        )
        fact_inputs = self.tokenizer(
            str(row["facts"]), padding="max_length", truncation=True,
            max_length=self.max_len, return_tensors="pt"
        )

        return {
            "text": {k: v.squeeze(0) for k, v in text_inputs.items()},
            "statute": {k: v.squeeze(0) for k, v in stat_inputs.items()},
            "fact": {k: v.squeeze(0) for k, v in fact_inputs.items()},
            "charge_labels": self.encode_labels(row["charges"])
        }

# =====================================================
# 5️⃣ MODEL
# =====================================================
class CrossAttentionLegalModel(nn.Module):
    def __init__(self, model_name, hidden_dim=768, num_charges=50):
        super().__init__()
        self.encoder_text = AutoModel.from_pretrained(model_name)
        self.encoder_stat = AutoModel.from_pretrained(model_name)
        self.encoder_fact = AutoModel.from_pretrained(model_name)

        self.cross_attn_1 = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4, dropout=0.1)
        self.cross_attn_2 = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4, dropout=0.1)

        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        self.out_charge = nn.Linear(hidden_dim, num_charges)

    def forward(self, text_inputs, statute_inputs, fact_inputs):
        text_out = self.encoder_text(**text_inputs).last_hidden_state.mean(dim=1)
        stat_out = self.encoder_stat(**statute_inputs).last_hidden_state.mean(dim=1)
        fact_out = self.encoder_fact(**fact_inputs).last_hidden_state.mean(dim=1)

        attn1, _ = self.cross_attn_1(text_out.unsqueeze(0), stat_out.unsqueeze(0), stat_out.unsqueeze(0))
        attn2, _ = self.cross_attn_2(text_out.unsqueeze(0), fact_out.unsqueeze(0), fact_out.unsqueeze(0))

        fused = torch.cat([attn1.squeeze(0), attn2.squeeze(0), text_out], dim=-1)
        fused = self.fusion(fused)
        logits = self.out_charge(fused)
        return logits  # return raw logits (no sigmoid)

# =====================================================
# 6️⃣ METRICS
# =====================================================
def compute_metrics(y_true, y_prob, thr=0.1):
    y_pred = (y_prob > thr).astype(int)
    return {
        "F1": f1_score(y_true, y_pred, average="macro", zero_division=0),
        "Precision": precision_score(y_true, y_pred, average="macro", zero_division=0),
        "Recall": recall_score(y_true, y_pred, average="macro", zero_division=0)
    }

# =====================================================
# 7️⃣ TRAINING LOOP
# =====================================================
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
dataset = LegalDataset(df, tokenizer, MAX_LEN)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

num_charges = len(dataset.charge_vocab)
model = CrossAttentionLegalModel(MODEL_NAME, num_charges=num_charges).to(DEVICE)
optimizer = AdamW(model.parameters(), lr=LR)

# compute positive weight for imbalance
charge_vectors = torch.stack([dataset.encode_labels(ch) for ch in df["charges"]])
pos_weight = (charge_vectors.numel() - charge_vectors.sum()) / (charge_vectors.sum() + 1e-8)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(DEVICE))

model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    all_y_true, all_y_pred = [], []
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for batch in pbar:
        optimizer.zero_grad()
        text_inputs = {k: v.to(DEVICE) for k, v in batch["text"].items()}
        statute_inputs = {k: v.to(DEVICE) for k, v in batch["statute"].items()}
        fact_inputs = {k: v.to(DEVICE) for k, v in batch["fact"].items()}
        y_true = batch["charge_labels"].to(DEVICE)

        logits = model(text_inputs, statute_inputs, fact_inputs)
        loss = criterion(logits, y_true)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        all_y_true.extend(y_true.cpu().numpy())
        all_y_pred.extend(torch.sigmoid(logits).detach().cpu().numpy())
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    metrics = compute_metrics(np.array(all_y_true), np.array(all_y_pred))
    print(f"\nEpoch {epoch+1} | Avg Loss: {total_loss/len(dataloader):.4f}")
    for k, v in metrics.items():
        print(f"{k}: {v:.4f}")

    gc.collect()
    torch.cuda.empty_cache()

print("\n✅ Training complete!")

# =====================================================
# 8️⃣ SAVE MODEL
# =====================================================
torch.save(model.state_dict(), "legal_crossattention_charge.pt")
print("💾 Model saved as legal_crossattention_charge.pt")


🖥️ Using device: cuda
✅ Loaded /content/encoder_clean.csv with shape (12747, 6)

📋 Columns found: ['filename', 'label', 'text', 'statutes', 'charges', 'facts']
🔍 Sample charges: ["['Cruelty', 'Dowry Death', 'Criminal Breach of Trust']", "['Breach of Contract', 'Specific Performance of Contract']", "['Murder', 'Rioting armed with deadly weapon', 'Unlawful assembly', 'Use of firearms']"]


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/516 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

🧾 Charge vocab size: 300


config.json:   0%|          | 0.00/671 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/534M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/534M [00:00<?, ?B/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch 1/3:  22%|██▏       | 688/3187 [07:42<28:18,  1.47it/s, loss=0.0001][A
Epoch 1/3:  22%|██▏       | 688/3187 [07:43<28:18,  1.47it/s, loss=0.0000][A
Epoch 1/3:  22%|██▏       | 689/3187 [07:43<28:32,  1.46it/s, loss=0.0000][A
Epoch 1/3:  22%|██▏       | 689/3187 [07:44<28:32,  1.46it/s, loss=0.0002][A
Epoch 1/3:  22%|██▏       | 690/3187 [07:44<28:32,  1.46it/s, loss=0.0002][A
Epoch 1/3:  22%|██▏       | 690/3187 [07:44<28:32,  1.46it/s, loss=0.0002][A
Epoch 1/3:  22%|██▏       | 691/3187 [07:44<28:34,  1.46it/s, loss=0.0002][A
Epoch 1/3:  22%|██▏       | 691/3187 [07:45<28:34,  1.46it/s, loss=0.0002][A
Epoch 1/3:  22%|██▏       | 692/3187 [07:45<28:33,  1.46it/s, loss=0.0002][A
Epoch 1/3:  22%|██▏       | 692/3187 [07:46<28:33,  1.46it/s, loss=0.0000][A
Epoch 1/3:  22%|██▏       | 693/3187 [07:46<28:28,  1.46it/s, loss=0.0000][A
Epoch 1/3:  22%|██▏       | 693/3187 [07:47<28:28,  1.46it/s, loss=0.0001][A


Epoch 1 | Avg Loss: 0.0094
F1: 0.0000
Precision: 0.0000
Recall: 0.0000


Epoch 2/3:  24%|██▎       | 755/3187 [08:26<27:17,  1.49it/s, loss=0.0000]