<a href="https://colab.research.google.com/github/Shravan-Kumar-18/legal-domain-similarity/blob/main/Paragraph_Filtering3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 📦 Setup: clean environment
!pip install -U datasets huggingface_hub fsspec
!pip install -q datasets transformers scikit-learn tqdm

In [None]:
# Imports
import torch
import numpy as np
import re
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import pearsonr, spearmanr
from collections import defaultdict, Counter

In [None]:
#  Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
#  LegalLongformer (safe dtype)
tokenizer = AutoTokenizer.from_pretrained("lexlms/legal-longformer-large")
model = AutoModel.from_pretrained("lexlms/legal-longformer-large", torch_dtype=torch.float32).to(device)

In [None]:
#  Load 10 expression-related cases
dataset = load_dataset("lex_glue", "ecthr_a", split="train")
subset = dataset.filter(lambda d: "freedom of expression" in str(d["text"]).lower()).select(range(50))
texts_raw = [doc["text"] if isinstance(doc["text"], str) else " ".join(doc["text"]) for doc in subset]
titles = [f"Case {i}" for i in range(len(texts_raw))]

In [None]:
#  Extract citations
def extract_citations(text):
    return set(re.findall(r"(Article\s\d+|Section\s\d+|\(.*?\)|\[\d{4}\][^\]]+\])", text))
citation_sets = [extract_citations(text) for text in texts_raw]

In [None]:
# Paragraph filtering
def is_structural(para):
    return bool(re.search(r'^(author:|bench:|party:|court:|coram:|before:)', para.strip().lower()))

def is_substantive(para):
    para = para.strip().lower()
    return len(para.split()) >= 25 and any(kw in para for kw in [
        "facts", "issue", "judgment", "reasoning", "held", "argument", "legal", "disputed"
    ])

In [None]:
def paragraph_filter(text):
    paras = [p.strip() for p in re.split(r"\n{2,}", text) if p.strip()]
    filtered, all_cites = [], extract_citations(text)
    if paras:
        if is_substantive(paras[0]): filtered.append(paras[0])
        if len(paras) > 1 and is_substantive(paras[-1]): filtered.append(paras[-1])
    for para in paras[1:-1]:
        if is_substantive(para) and not is_structural(para):
            if set(extract_citations(para)) & all_cites:
                filtered.append(para)
    return filtered if len(filtered) >= 3 else paras

filtered_texts = [" ".join(paragraph_filter(text)) for text in texts_raw]

In [None]:
#  Section segmentation
def segment_sections(text):
    sections = {"Facts": "", "Reasoning": "", "Judgment": ""}
    for para in re.split(r"\n{2,}", text):
        para_lower = para.lower()
        if "facts" in para_lower:
            sections["Facts"] += para + "\n"
        elif "reasoning" in para_lower or "held" in para_lower:
            sections["Reasoning"] += para + "\n"
        elif "judgment" in para_lower or "conclusion" in para_lower:
            sections["Judgment"] += para + "\n"
    return sections

In [None]:
#  Embedding
def get_embedding(text):
    try:
        if isinstance(text, list): text = " ".join(text)
        inputs = tokenizer(text, padding="max_length", truncation=True, return_tensors="pt", max_length=2048)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            output = model(**inputs)
            return output.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
    except Exception as e:
        print(f" Embedding error: {e}")
        return np.zeros(model.config.hidden_size)

In [None]:
#  Section-wise embeddings
section_embeddings = []
for text in filtered_texts:
    sections = segment_sections(text)
    vecs = []
    for name in ["Facts", "Reasoning", "Judgment"]:
        content = sections[name]
        vec = get_embedding(content) if content.strip() else np.zeros(model.config.hidden_size)
        vecs.append(vec)
    section_embeddings.append(np.mean(vecs, axis=0))

In [None]:
#  Cosine similarity
sim_matrix_cosine = cosine_similarity(section_embeddings)

# Citation similarity
sim_matrix_citation = np.zeros((len(titles), len(titles)))
for i in range(len(titles)):
    for j in range(len(titles)):
        union = citation_sets[i] | citation_sets[j]
        intersection = citation_sets[i] & citation_sets[j]
        sim_matrix_citation[i][j] = len(intersection) / len(union) if union else 0.0

#  Hybrid matrix
sim_matrix_avg = (sim_matrix_cosine + sim_matrix_citation) / 2

In [None]:
#  Matrix printer
def print_matrix(matrix, title):
    print(f"\n {title} Similarity Matrix:")
    for i, row in enumerate(matrix):
        row_str = "  ".join([f"{v:.2f}" for v in row])
        print(f"{titles[i].ljust(8)}{row_str}")
print_matrix(sim_matrix_cosine, "Cosine")
print_matrix(sim_matrix_citation, "Citation")
print_matrix(sim_matrix_avg, "Hybrid")

In [None]:
#  Precision@5
k = 5
global_freq = Counter()
for c in citation_sets: global_freq.update(c)
def get_strong_citations(cites): return [c for c in cites if global_freq[c] <= 3]
silver_set = defaultdict(set)
for i, c_set in enumerate(citation_sets):
    strong_i = set(get_strong_citations(c_set))
    for j, other_set in enumerate(citation_sets):
        if i != j:
            strong_j = set(get_strong_citations(other_set))
            if strong_i & strong_j: silver_set[titles[i]].add(titles[j])
print(f"\n Silver Standard Evaluation (Precision@{k}):")
for i, title in enumerate(titles):
    sims = sorted([(titles[j], sim_matrix_avg[i][j]) for j in range(len(titles)) if j != i], key=lambda x: x[1], reverse=True)[:k]
    predicted = set(d for d, _ in sims)
    actual = silver_set[title]
    match = predicted & actual
    print(f"{title}: Precision@{k} = {len(match)/k:.2f} | Matches: {list(match)}")

In [None]:
#  Gold standard
gold_pairs = [(titles[0], titles[1], 0.80), (titles[2], titles[4], 0.65), (titles[5], titles[7], 0.90), (titles[10], titles[13], 0.70)]
true_scores, model_scores = [], []
for d1, d2, expert in gold_pairs:
    i, j = titles.index(d1), titles.index(d2)
    true_scores.append(expert)
    model_scores.append(sim_matrix_avg[i][j])
if true_scores:
    p, _ = pearsonr(true_scores, model_scores)
    s, _ = spearmanr(true_scores, model_scores)
    print(f"\n Gold Correlation:\n → Pearson  : {p:.3f}\n → Spearman : {s:.3f}")

In [None]:
#  Top-k Retrieval (Hybrid)
print("\n Top 5 Similar Documents (Hybrid Score):")
for i in range(len(titles)):
    scores = [(titles[j], sim_matrix_avg[i][j]) for j in range(len(titles)) if j != i]
    top5 = sorted(scores, key=lambda x: x[1], reverse=True)[:5]
    print(f"\n {titles[i]}")
    for doc, score in top5:
        print(f" → {doc} | Score: {score:.2f}")

In [None]:
#  Search Engine (text or title index)
def search_similar_docs(query_text_or_id, k=5):
    if isinstance(query_text_or_id, int):  # ID-based query
        qvec = section_embeddings[query_text_or_id]
        print(f"\n Query Case: {titles[query_text_or_id]}")
    else:  # Free-text query
        qvec = get_embedding(query_text_or_id)
        print(f"\n Free-text Query: {query_text_or_id[:60]}...")
    sims = cosine_similarity([qvec], section_embeddings)[0]
    top = sorted(zip(titles, sims), key=lambda x: x[1], reverse=True)[:k]
    for title, score in top:
        print(f" → {title} | Cosine Score: {score:.2f}")
search_similar_docs("article 10 freedom of expression", k=5)
search_similar_docs(45, k=5)


In [None]:
def explain_similarity(query_id, target_id, top_para_count=2):
    print(f"\n Comparing: {titles[query_id]} ↔ {titles[target_id]}")

    #  Score
    sim_score = sim_matrix_avg[query_id][target_id]
    print(f"\n Hybrid Similarity Score: {sim_score:.3f}")

    #  Shared Citations
    c1, c2 = citation_sets[query_id], citation_sets[target_id]
    shared_cites = sorted(c1 & c2)
    print(f"\n Shared Citations ({len(shared_cites)}): {shared_cites if shared_cites else 'None'}")

    #  Common Themes (Keyword Overlap)
    def extract_keywords(text):
        legal_kw = ["freedom", "expression", "restriction", "right", "speech", "privacy", "press", "torture", "protection", "minority"]
        return {kw for kw in legal_kw if kw in text.lower()}

    k1 = extract_keywords(filtered_texts[query_id])
    k2 = extract_keywords(filtered_texts[target_id])
    common_kws = sorted(k1 & k2)
    print(f"\n Common Themes: {common_kws if common_kws else 'None'}")

    #  Similar Paragraphs
    def get_top_paragraphs(text1, text2, n=2):
        paras1 = [p.strip() for p in re.split(r"\n{2,}", text1) if len(p.strip()) > 80]
        paras2 = [p.strip() for p in re.split(r"\n{2,}", text2) if len(p.strip()) > 80]

        embs1 = [get_embedding(p) for p in paras1]
        embs2 = [get_embedding(p) for p in paras2]

        sim_table = np.array([[cosine_similarity([e1], [e2])[0][0] for e2 in embs2] for e1 in embs1])
        top_pairs = sorted([(i, j, sim_table[i][j]) for i in range(len(paras1)) for j in range(len(paras2))], key=lambda x: x[2], reverse=True)[:n]
        return [(paras1[i], paras2[j], score) for i, j, score in top_pairs]

    print(f"\n Most Similar Paragraphs:")
    top_matches = get_top_paragraphs(filtered_texts[query_id], filtered_texts[target_id], top_para_count)
    for i, (p1, p2, score) in enumerate(top_matches):
        print(f"\n Match {i+1} (Score: {score:.2f})")
        print(f" → [{titles[query_id]}]: {p1[:300]}...")
        print(f" → [{titles[target_id]}]: {p2[:300]}...")
explain_similarity(11, 14)
