<a href="https://colab.research.google.com/github/Hitika-Jain/LegalTalk/blob/main/notebooks/data_transformation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/Hitika-Jain/LegalTalk.git

%cd LegalTalk
#

Cloning into 'LegalTalk'...
remote: Enumerating objects: 77, done.[K
remote: Counting objects: 100% (77/77), done.[K
remote: Compressing objects: 100% (68/68), done.[K
remote: Total 77 (delta 30), reused 24 (delta 5), pack-reused 0 (from 0)[K
Receiving objects: 100% (77/77), 473.01 KiB | 4.93 MiB/s, done.
Resolving deltas: 100% (30/30), done.
/content/LegalTalk


In [3]:
# scripts/convert_labels_to_canonical.py
import pandas as pd
import ast
import re
from pathlib import Path
from collections import Counter

# ========== CONFIG ==========
STATUTES_CSV = "/content/drive/MyDrive/legal_dataset/statutes-00000-of-00001.csv"   # contains canonical IDs like IPC_302 in a column e.g. 'id' or 'statute_id'
TRAIN_CSV = "/content/drive/MyDrive/legal_dataset/train-00000-of-00001.csv"
DEV_CSV = "/content/drive/MyDrive/legal_dataset/dev-00000-of-00001 (1).csv"
TEST_CSV = "/content/drive/MyDrive/legal_dataset/test-00000-of-00001.csv"  # set path if you have a test file
OUT_DIR = Path("/content/converted_labels")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# If your filesystem/statutes index is 1-based (labels start at 1 -> first row), set to 1
# If unsure, leave as None to auto-detect
ASSUME_INDEX_BASE = None  # set to 0 or 1 to force
# ============================

def read_statutes(path):
    df = pd.read_csv(path, dtype=str).fillna('')
    # pick canonical id column heuristically
    for cand in ('id','statute_id','canonical_id','code','label'):
        if cand in df.columns:
            return df, cand
    # otherwise return first column as canonical id column
    return df, df.columns[0]

def parse_label_cell(cell):
    """Return list of tokens (strings) from many label formats."""
    if pd.isna(cell) or str(cell).strip() == "":
        return []
    s = str(cell).strip()
    # try python literal
    if (s.startswith("[") and s.endswith("]")) or (s.startswith("(") and s.endswith(")")):
        try:
            parsed = ast.literal_eval(s)
            if isinstance(parsed, (list, tuple)):
                return [str(x).strip() for x in parsed if str(x).strip()!='']
        except Exception:
            pass
    # if contains semicolon or comma separate
    if ";" in s:
        toks = [t.strip().strip("'\"") for t in s.split(";") if t.strip()]
        return toks
    if "," in s and not re.fullmatch(r'^\[\s*\d+(?:\s+\d+)*\s*\]$', s):  # avoid "[69  9  3]" false comma
        toks = [t.strip().strip("'\"") for t in s.split(",") if t.strip()]
        return toks
    # fallback: extract tokens (either words like IPC_302 or numbers)
    # treat sequences like "69  9  3" -> ['69','9','3']
    splitted = re.split(r'\s+', re.sub(r'[\[\]\(\)\,;]+',' ', s)).strip() if False else re.split(r'\s+', re.sub(r'[\[\]\(\)\,;]+',' ', s))
    toks = [t.strip().strip("'\"") for t in splitted if t.strip()]
    return toks

def build_mappings(stat_df, id_col, index_base_guess=0):
    """Return index->ipc mapping and digit->ipc mapping (from numeric suffix)."""
    index_to_ipc = {}
    digit_to_ipc = {}
    for idx, row in stat_df.reset_index().iterrows():
        ipc = str(row[id_col]).strip()
        # Map index token according to index_base_guess (0 or 1)
        index_to_ipc[idx + index_base_guess] = ipc
        # map numeric suffix, e.g., IPC_302 -> '302'
        m = re.search(r'(\d+)', ipc)
        if m:
            digit_to_ipc[m.group(1)] = ipc
        digit_to_ipc[ipc] = ipc
    return index_to_ipc, digit_to_ipc

def try_detect_index_base(sample_labels, index_to_ipc_len):
    """
    Heuristics:
     - If any numeric token > len(index_to_ipc) and <= len+1 then probably 1-based.
     - If all numeric tokens <= len(index_to_ipc) maybe 0-based.
    Returns 0 or 1 (best guess).
    """
    nums = []
    for lab in sample_labels:
        for tok in lab:
            if tok.isdigit():
                nums.append(int(tok))
    if not nums:
        return 0  # default 0 if no numeric tokens found

    max_num = max(nums)
    # if max exceeds length and equals length -> likely 1-based
    if max_num > index_to_ipc_len and max_num <= index_to_ipc_len + 1:
        return 1
    # if smallest numeric token is 1 and many are >=1, could be 1-based
    if min(nums) >= 1 and max_num <= index_to_ipc_len:
        return 1
    return 0

def convert_df_labels(df, index_to_ipc, digit_to_ipc, index_base):
    out_rows = []
    unmapped = Counter()
    for _, r in df.iterrows():
        text_col = None
        for c in ['text','sentence_text','document_text','content']:
            if c in df.columns:
                text_col = c
                break
        if text_col is None:
            # fallback: pick first non-id column
            cols = list(df.columns)
            if len(cols) >= 2:
                text_col = cols[1]
            else:
                raise ValueError("No text column found.")
        text = r[text_col]
        raw_labels = r.get('labels') or r.get('label') or r.get('labels_raw') or ''
        toks = parse_label_cell(raw_labels)

        mapped = []
        for t in toks:
            t_str = str(t).strip()
            if t_str == '':
                continue
            # numeric?
            if re.fullmatch(r'\d+', t_str):
                n = int(t_str)
                # try digit->ipc mapping first (if legends like '302' -> IPC_302)
                ipc_from_digit = digit_to_ipc.get(str(n))
                if ipc_from_digit:
                    mapped.append(ipc_from_digit)
                    continue
                # else use index mapping (apply index_base)
                key = n - index_base
                ipc_from_index = index_to_ipc.get(key)
                if ipc_from_index:
                    mapped.append(ipc_from_index)
                    continue
                # fallback: try direct index if present
                ipc_from_index2 = index_to_ipc.get(n)
                if ipc_from_index2:
                    mapped.append(ipc_from_index2)
                    continue
                unmapped[t_str] += 1
                mapped.append(t_str)  # keep raw for inspection
            else:
                # maybe already canonical like IPC_302
                # try digit extraction
                m = re.search(r'(\d+)', t_str)
                if t_str in digit_to_ipc:
                    mapped.append(digit_to_ipc[t_str])
                elif m and m.group(1) in digit_to_ipc:
                    mapped.append(digit_to_ipc[m.group(1)])
                else:
                    # unknown token - keep as-is but count
                    unmapped[t_str] += 1
                    mapped.append(t_str)
        # unique preserve order
        seen = set()
        final = []
        for x in mapped:
            if x not in seen:
                final.append(x); seen.add(x)
        out_rows.append({'text': text, 'labels': ';'.join(final)})
    return pd.DataFrame(out_rows), unmapped

def process_file(in_path, out_path, index_to_ipc, digit_to_ipc, index_base):
    df = pd.read_csv(in_path, dtype=str).fillna('')
    conv_df, unmapped = convert_df_labels(df, index_to_ipc, digit_to_ipc, index_base)
    conv_df.to_csv(out_path, index=False)
    return conv_df, unmapped

# ========== run ==========
stat_df, id_col = read_statutes(STATUTES_CSV)
# Build index mapping using both possible index bases to inspect
index_to_ipc_0, digit_to_ipc_0 = build_mappings(stat_df, id_col, index_base_guess=0)
index_to_ipc_1, digit_to_ipc_1 = build_mappings(stat_df, id_col, index_base_guess=1)

# Take a small sample from train to detect base, else use ASSUME_INDEX_BASE
sample_df = pd.read_csv(TRAIN_CSV, dtype=str, nrows=200).fillna('')
sample_labels = [parse_label_cell(x) for x in sample_df.get('labels', sample_df.columns[-1])]
if ASSUME_INDEX_BASE is None:
    guess = try_detect_index_base(sample_labels, len(index_to_ipc_0))
    print("Auto-detected index_base:", guess)
    index_base_used = guess
else:
    index_base_used = ASSUME_INDEX_BASE
    print("Forced index_base:", index_base_used)

index_to_ipc = index_to_ipc_0 if index_base_used == 0 else index_to_ipc_1
digit_to_ipc = digit_to_ipc_0 if index_base_used == 0 else digit_to_ipc_1

# Process train/dev/test
print("Processing train ->", OUT_DIR / 'train_converted.csv')
train_conv, train_unmapped = process_file(TRAIN_CSV, OUT_DIR / 'train_converted.csv', index_to_ipc, digit_to_ipc, index_base_used)
print("Unmapped tokens in train (sample):", dict(list(train_unmapped.items())[:20]))

if DEV_CSV:
    print("Processing dev ->", OUT_DIR / 'dev_converted.csv')
    dev_conv, dev_unmapped = process_file(DEV_CSV, OUT_DIR / 'dev_converted.csv', index_to_ipc, digit_to_ipc, index_base_used)
    print("Unmapped tokens in dev (sample):", dict(list(dev_unmapped.items())[:20]))

if TEST_CSV:
    print("Processing test ->", OUT_DIR / 'test_converted.csv')
    test_conv, test_unmapped = process_file(TEST_CSV, OUT_DIR / 'test_converted.csv', index_to_ipc, digit_to_ipc, index_base_used)
    print("Unmapped tokens in test (sample):", dict(list(test_unmapped.items())[:20]))

# Quick label distribution check
all_labels = ';'.join(train_conv['labels'].astype(str)).split(';')
print("Top labels in train:")
from collections import Counter
print(Counter([l for l in all_labels if l]).most_common(30))
print("Done. Converted CSVs are in:", OUT_DIR)

Auto-detected index_base: 1
Processing train -> /content/converted_labels/train_converted.csv
Unmapped tokens in train (sample): {'0': 343}
Processing dev -> /content/converted_labels/dev_converted.csv
Unmapped tokens in dev (sample): {'0': 86}
Processing test -> /content/converted_labels/test_converted.csv
Unmapped tokens in test (sample): {'0': 117}
Top labels in train:
[('Section 5', 11730), ('Section 34', 11151), ('Section 500', 7880), ('Section 313', 7728), ('Section 304B', 7120), ('Section 13', 4805), ('Section 120B', 4324), ('Section 114', 4134), ('Section 417', 4046), ('Section 494', 3835), ('Section 147', 3818), ('Section 366A', 3723), ('Section 320', 3660), ('Section 337', 3184), ('Section 300', 3157), ('Section 229A', 3009), ('Section 323', 2618), ('Section 324', 2509), ('Section 395', 2303), ('Section 342', 2204), ('Section 465', 2113), ('Section 467', 2066), ('Section 498A', 1973), ('Section 193', 1909), ('Section 353', 1787), ('Section 419', 1739), ('Section 304A', 1713),

In [4]:
# convert_labels_for_training.py
import pandas as pd
from pathlib import Path
from src.utils import normalize_to_ipc   # use your utils if present
import re

IN_CSV = "/content/converted_labels/train_converted.csv"   # your converted file (text, labels)
OUT_CSV = "/content/converted_labels/train_canonical.csv"

df = pd.read_csv(IN_CSV, dtype=str).fillna('')
# preserve raw
df['labels_raw'] = df['labels'].astype(str)

# helper to split semicolon separated raw labels into tokens
def split_raw_labels(s):
    if pd.isna(s) or str(s).strip()=='':
        return []
    toks = [t.strip() for t in re.split(r'[;,\|]', str(s)) if t.strip()]
    return toks

# map tokens -> canonical using normalize_to_ipc
def map_tokens_to_canonical(tokens):
    mapped = []
    unmapped = []
    for t in tokens:
        # first try if token already looks canonical (IPC_)
        if isinstance(t, str) and t.upper().startswith('IPC_'):
            mapped.append(t.upper())
            continue
        # use normalize_to_ipc which converts things like 'Section 120-B' or '120' -> 'IPC_120B' etc
        cand = normalize_to_ipc(t)
        if cand:
            mapped.append(cand)
        else:
            unmapped.append(t)
    # unique preserve order
    seen = set(); out=[]
    for m in mapped:
        if m not in seen: out.append(m); seen.add(m)
    return out, unmapped

all_unmapped = []
labels_canonical = []
for _, row in df.iterrows():
    toks = split_raw_labels(row['labels_raw'])
    mapped, unmapped = map_tokens_to_canonical(toks)
    labels_canonical.append(';'.join(mapped))
    if unmapped:
        all_unmapped.extend(unmapped)

df['labels_canonical'] = labels_canonical

# save
Path(OUT_CSV).parent.mkdir(parents=True, exist_ok=True)
df.to_csv(OUT_CSV, index=False)
print("Wrote:", OUT_CSV)
print("Total rows:", len(df))
print("Unique unmapped tokens (sample up to 50):", list(dict.fromkeys(all_unmapped))[:50])

Wrote: /content/converted_labels/train_canonical.csv
Total rows: 42750
Unique unmapped tokens (sample up to 50): []


In [5]:
# convert_labels_for_training.py
import pandas as pd
from pathlib import Path
from src.utils import normalize_to_ipc   # use your utils if present
import re

IN_CSV = "/content/converted_labels/test_converted.csv"   # your converted file (text, labels)
OUT_CSV = "/content/converted_labels/test_canonical.csv"

df = pd.read_csv(IN_CSV, dtype=str).fillna('')
# preserve raw
df['labels_raw'] = df['labels'].astype(str)

# helper to split semicolon separated raw labels into tokens
def split_raw_labels(s):
    if pd.isna(s) or str(s).strip()=='':
        return []
    toks = [t.strip() for t in re.split(r'[;,\|]', str(s)) if t.strip()]
    return toks

# map tokens -> canonical using normalize_to_ipc
def map_tokens_to_canonical(tokens):
    mapped = []
    unmapped = []
    for t in tokens:
        # first try if token already looks canonical (IPC_)
        if isinstance(t, str) and t.upper().startswith('IPC_'):
            mapped.append(t.upper())
            continue
        # use normalize_to_ipc which converts things like 'Section 120-B' or '120' -> 'IPC_120B' etc
        cand = normalize_to_ipc(t)
        if cand:
            mapped.append(cand)
        else:
            unmapped.append(t)
    # unique preserve order
    seen = set(); out=[]
    for m in mapped:
        if m not in seen: out.append(m); seen.add(m)
    return out, unmapped

all_unmapped = []
labels_canonical = []
for _, row in df.iterrows():
    toks = split_raw_labels(row['labels_raw'])
    mapped, unmapped = map_tokens_to_canonical(toks)
    labels_canonical.append(';'.join(mapped))
    if unmapped:
        all_unmapped.extend(unmapped)

df['labels_canonical'] = labels_canonical

# save
Path(OUT_CSV).parent.mkdir(parents=True, exist_ok=True)
df.to_csv(OUT_CSV, index=False)
print("Wrote:", OUT_CSV)
print("Total rows:", len(df))
print("Unique unmapped tokens (sample up to 50):", list(dict.fromkeys(all_unmapped))[:50])

Wrote: /content/converted_labels/test_canonical.csv
Total rows: 13019
Unique unmapped tokens (sample up to 50): []


In [6]:
# convert_labels_for_training.py
import pandas as pd
from pathlib import Path
from src.utils import normalize_to_ipc   # use your utils if present
import re

IN_CSV = "/content/converted_labels/dev_converted.csv"   # your converted file (text, labels)
OUT_CSV = "/content/converted_labels/dev_canonical.csv"

df = pd.read_csv(IN_CSV, dtype=str).fillna('')
# preserve raw
df['labels_raw'] = df['labels'].astype(str)

# helper to split semicolon separated raw labels into tokens
def split_raw_labels(s):
    if pd.isna(s) or str(s).strip()=='':
        return []
    toks = [t.strip() for t in re.split(r'[;,\|]', str(s)) if t.strip()]
    return toks

# map tokens -> canonical using normalize_to_ipc
def map_tokens_to_canonical(tokens):
    mapped = []
    unmapped = []
    for t in tokens:
        # first try if token already looks canonical (IPC_)
        if isinstance(t, str) and t.upper().startswith('IPC_'):
            mapped.append(t.upper())
            continue
        # use normalize_to_ipc which converts things like 'Section 120-B' or '120' -> 'IPC_120B' etc
        cand = normalize_to_ipc(t)
        if cand:
            mapped.append(cand)
        else:
            unmapped.append(t)
    # unique preserve order
    seen = set(); out=[]
    for m in mapped:
        if m not in seen: out.append(m); seen.add(m)
    return out, unmapped

all_unmapped = []
labels_canonical = []
for _, row in df.iterrows():
    toks = split_raw_labels(row['labels_raw'])
    mapped, unmapped = map_tokens_to_canonical(toks)
    labels_canonical.append(';'.join(mapped))
    if unmapped:
        all_unmapped.extend(unmapped)

df['labels_canonical'] = labels_canonical

# save
Path(OUT_CSV).parent.mkdir(parents=True, exist_ok=True)
df.to_csv(OUT_CSV, index=False)
print("Wrote:", OUT_CSV)
print("Total rows:", len(df))
print("Unique unmapped tokens (sample up to 50):", list(dict.fromkeys(all_unmapped))[:50])

Wrote: /content/converted_labels/dev_canonical.csv
Total rows: 10181
Unique unmapped tokens (sample up to 50): []
