# 02 - Deduplication for Training Data

## Why Deduplication Matters

Duplicate and near-duplicate documents in training data cause real problems:

- **Memorization** -- The model memorizes repeated passages verbatim and
  regurgitates them at inference time, which is especially dangerous for
  legal text that may contain PII or privileged information.
- **Benchmark contamination** -- If evaluation data appears in the training
  set (even slightly modified), benchmark scores are inflated and unreliable.
- **Wasted compute** -- Training on duplicates burns GPU hours without
  teaching the model anything new.

### Exact vs. Fuzzy Deduplication

There are two levels of deduplication:

| Approach | Method | Catches |
|----------|--------|---------|
| **Exact dedup** | Hash each document; remove identical hashes | Byte-for-byte duplicates |
| **Fuzzy (near) dedup** | MinHash + LSH | Documents that differ by minor edits, formatting changes, or OCR variations |

In legal corpora, near-duplicates are common: the same opinion may appear
in multiple databases with slightly different formatting, headers, or
pagination. A court may also issue amended opinions that differ by only a
few sentences from the original.

### MinHash + LSH in Brief

**MinHash** (Min-wise Independent Permutations) is a technique for quickly
estimating the Jaccard similarity between two sets:

1. Represent each document as a set of **shingles** (contiguous word n-grams).
2. Apply multiple hash functions to the shingle set and keep the **minimum**
   hash value for each function. This produces a compact *signature*.
3. The probability that two documents share the same minimum hash equals
   their Jaccard similarity.

**LSH** (Locality-Sensitive Hashing) makes the search efficient by hashing
signatures into *bands*. Documents that share at least one band are candidate
pairs, which are then verified with a full similarity check. This reduces
pairwise comparison from O(n^2) to near-linear time.

## Setup

Install and import the libraries we need:
- **datasketch** -- MinHash and LSH implementation
- **hashlib** -- hash-based exact deduplication

In [None]:
# Install dependencies (uncomment if needed)
# %pip install datasketch

In [None]:
import hashlib
import json
from pathlib import Path
from typing import Any

from datasketch import MinHash, MinHashLSH

In [None]:
DATA_PATH = Path("../../datasets/sample/court_opinions.jsonl")


def load_opinions(path: Path = DATA_PATH) -> list[dict[str, Any]]:
    """Load court opinions from a JSONL file."""
    opinions = []
    with open(path) as f:
        for line in f:
            line = line.strip()
            if line:
                opinions.append(json.loads(line))
    return opinions


opinions = load_opinions()
print(f"Loaded {len(opinions)} opinions.")

## Exact Deduplication

The simplest form of deduplication: compute a cryptographic hash of each
document's text and remove duplicates. We use SHA-256, which has a
negligible collision probability.

This catches byte-for-byte identical copies but misses documents that
differ even by a single character.

In [None]:
def compute_hash(text: str) -> str:
    """Compute a SHA-256 hash of the text."""
    return hashlib.sha256(text.encode("utf-8")).hexdigest()


def exact_dedup(
    docs: list[dict[str, Any]],
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
    """Remove exact duplicates based on text hash.

    Returns (unique_docs, duplicate_docs).
    """
    seen_hashes: set[str] = set()
    unique = []
    duplicates = []

    for doc in docs:
        h = compute_hash(doc["text"])
        if h in seen_hashes:
            duplicates.append(doc)
        else:
            seen_hashes.add(h)
            unique.append(doc)

    return unique, duplicates

In [None]:
# Run exact dedup on the original data
unique, duplicates = exact_dedup(opinions)
print(f"Original documents:  {len(opinions)}")
print(f"Unique documents:    {len(unique)}")
print(f"Exact duplicates:    {len(duplicates)}")

In [None]:
# Demonstrate exact dedup by adding a known duplicate
opinions_with_dup = opinions + [opinions[0].copy()]  # duplicate first opinion
print(f"Documents with added duplicate: {len(opinions_with_dup)}")

unique2, dups2 = exact_dedup(opinions_with_dup)
print(f"After exact dedup:             {len(unique2)}")
print(f"Duplicates found:              {len(dups2)}")
if dups2:
    print(f"Duplicate case: {dups2[0]['case_name']}")

## Near-Duplicate Detection with MinHash

Exact dedup misses near-duplicates -- documents that are almost identical
but differ in minor ways (formatting, typos, amended text). MinHash + LSH
solves this efficiently.

### Step-by-Step

1. **Shingling** -- Convert each document into a set of overlapping word
   n-grams ("shingles"). For example, with 3-word shingles:
   
   ```
   "the court finds" -> {"the court finds"}
   "court finds that" -> {"court finds that"}
   ```

2. **MinHash signature** -- For each shingle set, compute a fixed-size
   signature using `num_perm` random hash permutations. Each element of the
   signature is the minimum hash value across all shingles for that
   permutation.

3. **LSH indexing** -- Insert signatures into an LSH index that groups
   similar documents into the same buckets.

4. **Query** -- For each document, query the LSH index for candidates.
   Candidates sharing a bucket are likely near-duplicates.

In [None]:
def create_shingles(text: str, k: int = 5) -> set[str]:
    """Create a set of word k-shingles from text.

    Args:
        text: Input text.
        k: Number of words per shingle.

    Returns:
        Set of shingle strings.
    """
    words = text.lower().split()
    if len(words) < k:
        return {" ".join(words)} if words else set()
    return {" ".join(words[i : i + k]) for i in range(len(words) - k + 1)}


# Demonstrate shingling
sample_text = "The court finds that the defendant is liable."
shingles = create_shingles(sample_text, k=3)
print(f"Text: {sample_text!r}")
print(f"3-word shingles ({len(shingles)}):")
for s in sorted(shingles):
    print(f"  {s!r}")

In [None]:
def build_minhash(shingles: set[str], num_perm: int = 128) -> MinHash:
    """Build a MinHash signature from a set of shingles."""
    m = MinHash(num_perm=num_perm)
    for s in shingles:
        m.update(s.encode("utf-8"))
    return m


def compute_minhash_signatures(
    docs: list[dict[str, Any]],
    shingle_k: int = 5,
    num_perm: int = 128,
) -> list[tuple[dict[str, Any], MinHash]]:
    """Compute MinHash signatures for all documents."""
    results = []
    for doc in docs:
        shingles = create_shingles(doc["text"], k=shingle_k)
        mh = build_minhash(shingles, num_perm=num_perm)
        results.append((doc, mh))
    return results


# Compute signatures for our corpus
signatures = compute_minhash_signatures(opinions)
print(f"Computed MinHash signatures for {len(signatures)} documents.")
print(f"Signature size: {signatures[0][1].num_perm} permutations")

In [None]:
# Show pairwise Jaccard similarity estimates between all documents
print("Estimated Jaccard similarities (from MinHash):")
print()

# Header
short_names = [doc["case_name"].split(" v.")[0].split("v.")[0].strip()[:15]
               for doc, _ in signatures]
print(f"{'':>16s}", end="")
for name in short_names:
    print(f"{name:>16s}", end="")
print()

for i, (doc_i, mh_i) in enumerate(signatures):
    print(f"{short_names[i]:>16s}", end="")
    for j, (doc_j, mh_j) in enumerate(signatures):
        sim = mh_i.jaccard(mh_j)
        print(f"{sim:>16.3f}", end="")
    print()

In [None]:
def find_near_duplicates(
    docs: list[dict[str, Any]],
    threshold: float = 0.5,
    num_perm: int = 128,
    shingle_k: int = 5,
) -> list[tuple[int, int, float]]:
    """Find near-duplicate pairs using MinHash + LSH.

    Args:
        docs: List of document dictionaries with 'text' field.
        threshold: Jaccard similarity threshold for near-duplicates.
        num_perm: Number of hash permutations for MinHash.
        shingle_k: Number of words per shingle.

    Returns:
        List of (idx_a, idx_b, similarity) tuples.
    """
    # Build MinHash signatures
    sigs = compute_minhash_signatures(docs, shingle_k=shingle_k, num_perm=num_perm)

    # Create LSH index
    lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
    for i, (doc, mh) in enumerate(sigs):
        lsh.insert(str(i), mh)

    # Query for each document
    duplicate_pairs = []
    seen_pairs: set[tuple[int, int]] = set()

    for i, (doc, mh) in enumerate(sigs):
        candidates = lsh.query(mh)
        for cand in candidates:
            j = int(cand)
            if i >= j:
                continue  # skip self and already-seen pairs
            pair = (min(i, j), max(i, j))
            if pair in seen_pairs:
                continue
            seen_pairs.add(pair)
            sim = sigs[i][1].jaccard(sigs[j][1])
            duplicate_pairs.append((i, j, sim))

    return sorted(duplicate_pairs, key=lambda x: -x[2])


# Run near-dedup on the original corpus
near_dups = find_near_duplicates(opinions, threshold=0.5)
print(f"Near-duplicate pairs found (threshold=0.5): {len(near_dups)}")
for i, j, sim in near_dups:
    print(f"  [{i}] {opinions[i]['case_name'][:40]}")
    print(f"  [{j}] {opinions[j]['case_name'][:40]}")
    print(f"  Similarity: {sim:.3f}")
    print()

## Demonstration: Catching Synthetic Near-Duplicates

To show that MinHash works for fuzzy matching, we create synthetic
near-duplicates by taking an existing opinion and making minor edits:
replacing words, changing formatting, and adding/removing sentences.

This simulates what happens in practice when the same opinion appears
in multiple databases with different OCR outputs or editorial formatting.

In [None]:
import random

random.seed(42)


def create_near_duplicate(text: str, edit_fraction: float = 0.05) -> str:
    """Create a near-duplicate by randomly replacing words.

    Args:
        text: Original text.
        edit_fraction: Fraction of words to replace.

    Returns:
        Modified text.
    """
    words = text.split()
    n_edits = max(1, int(len(words) * edit_fraction))
    indices = random.sample(range(len(words)), min(n_edits, len(words)))

    replacement_words = [
        "the", "court", "hereby", "finds", "that", "pursuant",
        "said", "aforementioned", "respectively", "therein",
    ]

    for idx in indices:
        words[idx] = random.choice(replacement_words)

    return " ".join(words)


# Create near-duplicates of the first two opinions
synthetic_docs = []
for op in opinions[:2]:
    near_dup = op.copy()
    near_dup["text"] = create_near_duplicate(op["text"], edit_fraction=0.05)
    near_dup["case_name"] = op["case_name"] + " [NEAR-DUP]"
    near_dup["id"] = op["id"] + 9000
    synthetic_docs.append(near_dup)

# Combine original + synthetic
augmented_corpus = opinions + synthetic_docs
print(f"Augmented corpus: {len(augmented_corpus)} documents")
print(f"  Original:  {len(opinions)}")
print(f"  Synthetic: {len(synthetic_docs)}")

In [None]:
# First: exact dedup will NOT catch the near-duplicates
unique_exact, dups_exact = exact_dedup(augmented_corpus)
print("Exact dedup results:")
print(f"  Unique:     {len(unique_exact)}")
print(f"  Duplicates: {len(dups_exact)}")
print("  (Exact dedup misses the near-duplicates, as expected.)")

In [None]:
# Now: MinHash + LSH catches them
near_dups_aug = find_near_duplicates(augmented_corpus, threshold=0.5)
print(f"Near-duplicate pairs found: {len(near_dups_aug)}")
print()
for i, j, sim in near_dups_aug:
    name_i = augmented_corpus[i]["case_name"][:50]
    name_j = augmented_corpus[j]["case_name"][:50]
    print(f"  Pair (similarity={sim:.3f}):")
    print(f"    [{i}] {name_i}")
    print(f"    [{j}] {name_j}")
    print()

In [None]:
def deduplicate_corpus(
    docs: list[dict[str, Any]],
    threshold: float = 0.5,
    num_perm: int = 128,
) -> tuple[list[dict[str, Any]], list[list[int]]]:
    """Remove near-duplicates from a corpus.

    For each cluster of near-duplicates, keeps the first document
    (by position in the input list).

    Returns:
        (deduplicated_docs, clusters) where clusters is a list of
        duplicate groups (lists of indices).
    """
    pairs = find_near_duplicates(
        docs, threshold=threshold, num_perm=num_perm
    )

    # Build clusters using union-find
    parent: dict[int, int] = {}

    def find(x: int) -> int:
        while parent.get(x, x) != x:
            parent[x] = parent.get(parent[x], parent[x])
            x = parent[x]
        return x

    def union(a: int, b: int) -> None:
        ra, rb = find(a), find(b)
        if ra != rb:
            # Keep the lower index as root (so we keep the earlier document)
            if ra < rb:
                parent[rb] = ra
            else:
                parent[ra] = rb

    for i, j, _ in pairs:
        union(i, j)

    # Group by cluster root
    clusters_map: dict[int, list[int]] = {}
    all_indices = set()
    for i, j, _ in pairs:
        all_indices.add(i)
        all_indices.add(j)

    for idx in all_indices:
        root = find(idx)
        clusters_map.setdefault(root, []).append(idx)

    # For each cluster, mark all but the root for removal
    to_remove: set[int] = set()
    clusters = []
    for root, members in clusters_map.items():
        members = sorted(set(members))
        clusters.append(members)
        for idx in members:
            if idx != root:
                to_remove.add(idx)

    deduped = [doc for i, doc in enumerate(docs) if i not in to_remove]
    return deduped, clusters


# Deduplicate the augmented corpus
deduped, clusters = deduplicate_corpus(augmented_corpus, threshold=0.5)
print(f"Before dedup: {len(augmented_corpus)} documents")
print(f"After dedup:  {len(deduped)} documents")
print(f"Removed:      {len(augmented_corpus) - len(deduped)} near-duplicates")
print(f"\nDuplicate clusters found: {len(clusters)}")
for i, cluster in enumerate(clusters):
    print(f"  Cluster {i + 1}: indices {cluster}")
    for idx in cluster:
        print(f"    {augmented_corpus[idx]['case_name'][:60]}")

## Parameter Tuning

MinHash + LSH has two key parameters:

- **`num_perm`** -- Number of hash permutations. More permutations give a
  more accurate Jaccard estimate but use more memory and time.
- **`threshold`** -- Jaccard similarity threshold for considering two
  documents as near-duplicates. Lower threshold = more aggressive dedup
  (catches more distant duplicates but risks false positives).

Let's experiment with both parameters to understand the trade-offs.

In [None]:
# Create near-duplicates at varying edit distances
edit_levels = [0.02, 0.05, 0.10, 0.20, 0.30]
test_docs = list(opinions)  # start with originals

for frac in edit_levels:
    dup = opinions[0].copy()
    dup["text"] = create_near_duplicate(opinions[0]["text"], edit_fraction=frac)
    dup["case_name"] = f"{opinions[0]['case_name']} [edit={frac:.0%}]"
    dup["id"] = 9000 + int(frac * 100)
    test_docs.append(dup)

print(f"Test corpus: {len(test_docs)} documents")
print(f"  Original opinions:    {len(opinions)}")
print(f"  Synthetic variants:   {len(edit_levels)}")
print(f"  Edit levels tested:   {edit_levels}")

In [None]:
# Test different threshold values
thresholds = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

print(f"{'Threshold':>10s}  {'Pairs Found':>12s}  {'Docs Removed':>13s}")
print("-" * 40)

for thresh in thresholds:
    pairs = find_near_duplicates(test_docs, threshold=thresh, num_perm=128)
    deduped_t, _ = deduplicate_corpus(test_docs, threshold=thresh, num_perm=128)
    removed = len(test_docs) - len(deduped_t)
    print(f"{thresh:>10.1f}  {len(pairs):>12d}  {removed:>13d}")

In [None]:
# Test different num_perm values at a fixed threshold
perm_values = [32, 64, 128, 256]
fixed_threshold = 0.5

print(f"Threshold fixed at {fixed_threshold}")
print(f"{'num_perm':>10s}  {'Pairs Found':>12s}  {'Docs Removed':>13s}")
print("-" * 40)

for np in perm_values:
    pairs = find_near_duplicates(
        test_docs, threshold=fixed_threshold, num_perm=np
    )
    deduped_p, _ = deduplicate_corpus(
        test_docs, threshold=fixed_threshold, num_perm=np
    )
    removed = len(test_docs) - len(deduped_p)
    print(f"{np:>10d}  {len(pairs):>12d}  {removed:>13d}")

### Interpreting the Results

- **Lower thresholds** catch more near-duplicates (including documents with
  larger edit distances) but increase the risk of false positives -- removing
  documents that are genuinely different but happen to share legal terminology.
- **Higher `num_perm`** gives more precise similarity estimates. With too few
  permutations, the Jaccard estimate is noisy and results are less
  reproducible.
- For legal text, a threshold of **0.5-0.7** and **128 permutations** is a
  reasonable starting point. Legal documents naturally share vocabulary
  ("the court finds", "summary judgment"), so overly aggressive dedup can
  remove distinct opinions that happen to discuss similar legal standards.

## Exercises

### Exercise (a): Precision/Recall of Dedup

Using the synthetic near-duplicates as ground truth:

1. Create a test set with known duplicate pairs at various edit distances
   (5%, 10%, 20%, 30%, 50% of words replaced).
2. For each combination of `num_perm` (64, 128, 256) and `threshold`
   (0.3, 0.5, 0.7, 0.9), compute:
   - **Precision**: What fraction of detected pairs are actual duplicates?
   - **Recall**: What fraction of actual duplicates are detected?
3. Plot precision vs. recall curves for different parameter settings.

### Exercise (b): Dedup Ratio by Document Type

Legal corpora contain different types of documents with different
duplication characteristics:

- **Appellate opinions** -- Often unique, but amended opinions create
  near-duplicates.
- **Orders** -- Short, formulaic documents that may look very similar
  across cases.
- **Motions** -- Often use template language, leading to high false-positive
  rates in dedup.

Using the `court` field in our sample data:
1. Group documents by court.
2. Run dedup within each group and across groups.
3. Estimate which court or document type has the highest natural duplication
   rate.
4. Discuss: should dedup thresholds vary by document type?