# ASSIGNMENT NO : 02

*   Nooran Ishtiaq
*   22i-2010
*   DS-B



In [None]:
import zipfile
import os

zip_file_path = '/content/archive.zip'
extraction_path = '/content/csv/' # Changed extraction path to /content/csv/

# Create the extraction directory if it doesn't exist
os.makedirs(extraction_path, exist_ok=True)
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extraction_path)
print(f'Archive extracted to: {extraction_path}')
# List the contents of the extracted directory to confirm
print('\nContents of extracted directory:')
for item in os.listdir(extraction_path):
    print(item)

# Baseline 1 :  BiLSTMEncoder

In [None]:
import argparse
import math
import random
import re
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
import time
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    average_precision_score,
)

# ---- Standalone utilities (inlined from baseline1) ----


@dataclass
class Clause:
    text: str
    clause_type: str
    source_file: str


def set_global_seeds(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def load_clauses_from_folder(folder: Path) -> List[Clause]:
    csv_files = sorted(folder.rglob("*.csv"))
    clauses: List[Clause] = []
    for csv_path in csv_files:
        try:
            df = pd.read_csv(csv_path)
        except Exception as exc:
            print(f"Warning: failed to read {csv_path.name}: {exc}")
            continue

        if "clause_text" not in df.columns or "clause_type" not in df.columns:
            print(f"Warning: {csv_path.name} does not contain required columns 'clause_text' and 'clause_type'. Skipping.")
            continue

        for _, row in df.iterrows():
            text = str(row["clause_text"]) if not pd.isna(row["clause_text"]) else ""
            ctype = str(row["clause_type"]) if not pd.isna(row["clause_type"]) else ""
            text = text.strip()
            ctype = ctype.strip()
            if text:
                clauses.append(Clause(text=text, clause_type=ctype, source_file=csv_path.name))

    return clauses


def make_balanced_pairs(
    clauses: List[Clause],
    max_pairs_per_class: int | None,
    rng: np.random.Generator,
) -> Tuple[List[Tuple[str, str]], np.ndarray]:
    type_to_indices: dict[str, List[int]] = {}
    for idx, c in enumerate(clauses):
        type_to_indices.setdefault(c.clause_type, []).append(idx)

    unique_types = [t for t in type_to_indices.keys() if t]
    if len(unique_types) < 2:
        raise ValueError(
            "Found fewer than 2 distinct clause_type values across CSVs. "
            "To build a similarity dataset, place multiple CSVs (or files having multiple clause_type values) in the folder."
        )

    positive_pairs: List[Tuple[int, int]] = []
    for ctype, indices in type_to_indices.items():
        if len(indices) < 2:
            continue
        shuffled = indices.copy()
        rng.shuffle(shuffled)
        pair_count = len(shuffled) // 2
        if max_pairs_per_class is not None:
            pair_count = min(pair_count, max_pairs_per_class)
        for i in range(pair_count):
            a = shuffled[2 * i]
            b = shuffled[2 * i + 1]
            if a != b:
                positive_pairs.append((a, b))

    if not positive_pairs:
        raise ValueError("Could not form any positive pairs. Ensure each clause_type has at least two rows.")

    num_pos = len(positive_pairs)

    all_indices = np.arange(len(clauses))
    negative_pairs: List[Tuple[int, int]] = []
    attempts = 0
    while len(negative_pairs) < num_pos and attempts < num_pos * 20:
        attempts += 1
        i, j = rng.choice(all_indices, size=2, replace=False)
        if clauses[i].clause_type != clauses[j].clause_type:
            negative_pairs.append((i, j))

    pairs_text: List[Tuple[str, str]] = []
    labels: List[int] = []

    for i, j in positive_pairs:
        pairs_text.append((clauses[i].text, clauses[j].text))
        labels.append(1)
    for i, j in negative_pairs:
        pairs_text.append((clauses[i].text, clauses[j].text))
        labels.append(0)

    order = np.arange(len(pairs_text))
    rng.shuffle(order)
    pairs_text = [pairs_text[k] for k in order]
    y = np.asarray([labels[k] for k in order], dtype=np.int64)
    return pairs_text, y


def evaluate_binary(y_true: np.ndarray, y_proba: np.ndarray) -> dict:
    y_pred = (y_proba >= 0.5).astype(int)
    metrics = {
        "accuracy": float(accuracy_score(y_true, y_pred)),
        "precision": float(precision_score(y_true, y_pred, zero_division=0)),
        "recall": float(recall_score(y_true, y_pred, zero_division=0)),
        "f1": float(f1_score(y_true, y_pred, zero_division=0)),
    }
    try:
        metrics["roc_auc"] = float(roc_auc_score(y_true, y_proba))
    except Exception:
        metrics["roc_auc"] = float("nan")
    try:
        metrics["pr_auc"] = float(average_precision_score(y_true, y_proba))
    except Exception:
        metrics["pr_auc"] = float("nan")
    return metrics


def find_best_threshold(y_true: np.ndarray, y_proba: np.ndarray, metric: str = "f1") -> float:
    thresholds = np.linspace(0.05, 0.95, 19)
    best_t = 0.5
    best_score = -1.0
    for t in thresholds:
        y_pred = (y_proba >= t).astype(int)
        if metric == "accuracy":
            score = accuracy_score(y_true, y_pred)
        else:
            score = f1_score(y_true, y_pred, zero_division=0)
        if score > best_score:
            best_score = score
            best_t = t
    return float(best_t)


def build_vocab(texts: List[str], max_vocab_size: int, min_freq: int) -> tuple[dict, dict]:
    """
    Build a simple word-level vocabulary from training texts.
    Returns (token_to_id, id_to_token). 0 is PAD, 1 is UNK.
    """
    token_freq: dict[str, int] = {}
    for t in texts:
        for tok in re.findall(r"\b\w+\b", t.lower()):
            token_freq[tok] = token_freq.get(tok, 0) + 1
    # Sort by frequency then lexicographically for determinism
    sorted_tokens = sorted(
        [tok for tok, f in token_freq.items() if f >= min_freq],
        key=lambda x: (-token_freq[x], x),
    )
    # Reserve 0:PAD, 1:UNK
    limited = sorted_tokens[: max(0, max_vocab_size - 2)]
    token_to_id = {"<PAD>": 0, "<UNK>": 1}
    for i, tok in enumerate(limited, start=2):
        token_to_id[tok] = i
    id_to_token = {i: t for t, i in token_to_id.items()}
    return token_to_id, id_to_token


def texts_to_ids(texts: List[str], token_to_id: dict, max_len: int) -> tuple[np.ndarray, np.ndarray]:
    """
    Convert a list of texts to padded ID sequences and lengths.
    Returns:
        seqs: (N, max_len) int32
        lens: (N,) int32
    """
    unk_id = token_to_id.get("<UNK>", 1)
    pad_id = token_to_id.get("<PAD>", 0)
    seqs = np.full((len(texts), max_len), pad_id, dtype=np.int64)
    lens = np.zeros((len(texts),), dtype=np.int64)
    for i, t in enumerate(texts):
        tokens = re.findall(r"\b\w+\b", t.lower())
        ids = [token_to_id.get(tok, unk_id) for tok in tokens][:max_len]
        seqs[i, : len(ids)] = np.asarray(ids, dtype=np.int64)
        lens[i] = len(ids)
    return seqs, lens


class PairDataset(Dataset):
    def __init__(
        self,
        pairs: List[Tuple[str, str]],
        token_to_id: dict,
        max_len: int,
        labels: np.ndarray | None = None,
    ) -> None:
        self.left_texts = [a for a, _ in pairs]
        self.right_texts = [b for _, b in pairs]
        self.labels = None if labels is None else labels.astype(np.int64)
        self.token_to_id = token_to_id
        self.max_len = max_len

        self.left_ids, self.left_lens = texts_to_ids(self.left_texts, token_to_id, max_len)
        self.right_ids, self.right_lens = texts_to_ids(self.right_texts, token_to_id, max_len)

    def __len__(self) -> int:
        return len(self.left_texts)

    def __getitem__(self, idx: int):
        l_ids = torch.from_numpy(self.left_ids[idx])
        r_ids = torch.from_numpy(self.right_ids[idx])
        l_len = int(self.left_lens[idx])
        r_len = int(self.right_lens[idx])
        if self.labels is None:
            return l_ids, l_len, r_ids, r_len
        return l_ids, l_len, r_ids, r_len, int(self.labels[idx])


class BiLSTMEncoder(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int,
        hidden_size: int,
        num_layers: int,
        pad_idx: int = 0,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0.0,
        )

    def forward(self, ids: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        """
        Args:
            ids: (B, T) long
            lengths: (B,) long
        Returns:
            enc: (B, 2*hidden_size) final BiLSTM pooled representation
        """
        embeds = self.embedding(ids)  # (B, T, D)
        packed = pack_padded_sequence(embeds, lengths.cpu(), batch_first=True, enforce_sorted=False)
        _, (h_n, _) = self.lstm(packed)
        # h_n: (num_layers*2, B, H). Take last layer's forward and backward
        # last forward = h_n[-2], last backward = h_n[-1]
        h_forward = h_n[-2]
        h_backward = h_n[-1]
        enc = torch.cat([h_forward, h_backward], dim=1)  # (B, 2H)
        return enc


class SiameseBiLSTM(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int,
        hidden_size: int,
        num_layers: int,
        mlp_hidden: int,
        pad_idx: int = 0,
        dropout: float = 0.2,
        use_cosine_feature: bool = True,
    ) -> None:
        super().__init__()
        self.encoder = BiLSTMEncoder(
            vocab_size=vocab_size,
            embedding_dim=embedding_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            pad_idx=pad_idx,
            dropout=dropout,
        )
        self.use_cosine_feature = use_cosine_feature

        enc_dim = hidden_size * 2
        feat_dim = enc_dim * 4  # [u, v, |u-v|, u*v]
        if self.use_cosine_feature:
            feat_dim += 1

        self.mlp = nn.Sequential(
            nn.Linear(feat_dim, mlp_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, mlp_hidden // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden // 2, 1),
        )

    @staticmethod
    def _cosine(u: torch.Tensor, v: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
        num = (u * v).sum(dim=1, keepdim=True)
        den = (u.norm(dim=1, keepdim=True) * v.norm(dim=1, keepdim=True)).clamp_min(eps)
        return num / den

    def forward(self, l_ids: torch.Tensor, l_len: torch.Tensor, r_ids: torch.Tensor, r_len: torch.Tensor) -> torch.Tensor:
        u = self.encoder(l_ids, l_len)  # (B, 2H)
        v = self.encoder(r_ids, r_len)  # (B, 2H)
        abs_diff = torch.abs(u - v)
        hadamard = u * v
        feats = [u, v, abs_diff, hadamard]
        if self.use_cosine_feature:
            cos = self._cosine(u, v)  # (B,1)
            feats.append(cos)
        x = torch.cat(feats, dim=1)
        logits = self.mlp(x).squeeze(1)
        return logits


def evaluate_model(model: nn.Module, loader: DataLoader, device: torch.device) -> tuple[np.ndarray, np.ndarray]:
    model.eval()
    all_probs: list[float] = []
    all_labels: list[int] = []
    with torch.no_grad():
        for batch in loader:
            l_ids, l_len, r_ids, r_len, y = batch
            l_ids = l_ids.to(device)
            r_ids = r_ids.to(device)
            l_len = l_len.to(device)
            r_len = r_len.to(device)
            y = y.to(device)
            logits = model(l_ids, l_len, r_ids, r_len)
            probs = torch.sigmoid(logits)
            all_probs.extend(probs.cpu().numpy().tolist())
            all_labels.extend(y.cpu().numpy().tolist())
    return np.asarray(all_labels, dtype=np.int64), np.asarray(all_probs, dtype=np.float64)


def run(
    data_dir: Path,
    seed: int,
    max_pairs_per_class: int | None,
    max_len: int,
    max_vocab_size: int,
    min_freq: int,
    embedding_dim: int,
    hidden_size: int,
    num_layers: int,
    mlp_hidden: int,
    batch_size: int,
    epochs: int,
    lr: float,
    log_interval: int,
) -> None:
    set_global_seeds(seed)
    print("Loading clauses from:", data_dir, flush=True)
    clauses = load_clauses_from_folder(data_dir)
    print(f"Loaded {len(clauses)} clauses from {len(set(c.source_file for c in clauses))} CSV file(s).", flush=True)

    # Clause-level split first to avoid leakage of the same clause across splits
    from sklearn.model_selection import train_test_split

    clause_indices = np.arange(len(clauses))
    clause_types = np.array([c.clause_type for c in clauses], dtype=object)

    # Try stratified split by clause_type; if it fails due to rare classes, fall back to random split
    try:
        idx_train, idx_temp = train_test_split(
            clause_indices, test_size=0.30, random_state=seed, stratify=clause_types
        )
        idx_val, idx_test = train_test_split(
            idx_temp, test_size=0.50, random_state=seed, stratify=clause_types[idx_temp]
        )
    except Exception:
        idx_train, idx_temp = train_test_split(clause_indices, test_size=0.30, random_state=seed)
        idx_val, idx_test = train_test_split(idx_temp, test_size=0.50, random_state=seed)

    clauses_train = [clauses[i] for i in idx_train]
    clauses_val = [clauses[i] for i in idx_val]
    clauses_test = [clauses[i] for i in idx_test]

    # Build pairs independently within each split
    rng = np.random.default_rng(seed)
    X_train, y_train = make_balanced_pairs(clauses_train, max_pairs_per_class=max_pairs_per_class, rng=rng)
    X_val, y_val = make_balanced_pairs(clauses_val, max_pairs_per_class=max_pairs_per_class, rng=rng)
    X_test, y_test = make_balanced_pairs(clauses_test, max_pairs_per_class=max_pairs_per_class, rng=rng)

    print(
        f"Split clauses -> train/val/test: {len(clauses_train)}/{len(clauses_val)}/{len(clauses_test)}. "
        f"Built pairs -> train/val/test: {len(X_train)}/{len(X_val)}/{len(X_test)}.",
        flush=True,
    )

    # Build vocab on training texts only
    train_texts = [a for a, _ in X_train] + [b for _, b in X_train]
    token_to_id, _ = build_vocab(train_texts, max_vocab_size=max_vocab_size, min_freq=min_freq)
    vocab_size = len(token_to_id)
    print(f"Vocab size: {vocab_size} (min_freq={min_freq}, max_vocab_size={max_vocab_size})", flush=True)

    # Datasets / Loaders
    ds_train = PairDataset(X_train, token_to_id, max_len=max_len, labels=y_train)
    ds_val = PairDataset(X_val, token_to_id, max_len=max_len, labels=y_val)
    ds_test = PairDataset(X_test, token_to_id, max_len=max_len, labels=y_test)

    def collate(batch):
        l_ids, l_len, r_ids, r_len, y = zip(*batch)
        return (
            torch.stack(l_ids, dim=0),
            torch.tensor(l_len, dtype=torch.long),
            torch.stack(r_ids, dim=0),
            torch.tensor(r_len, dtype=torch.long),
            torch.tensor(y, dtype=torch.long),
        )

    train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate)
    val_loader = DataLoader(ds_val, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=collate)
    test_loader = DataLoader(ds_test, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=collate)

    # Model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}", flush=True)
    model = SiameseBiLSTM(
        vocab_size=vocab_size,
        embedding_dim=embedding_dim,
        hidden_size=hidden_size,
        num_layers=num_layers,
        mlp_hidden=mlp_hidden,
        pad_idx=0,
        dropout=0.2,
        use_cosine_feature=True,
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCEWithLogitsLoss()

    best_val_f1 = -1.0
    best_state = None
    total_batches = len(train_loader)
    print(f"Starting training for {epochs} epoch(s), steps per epoch: {total_batches}", flush=True)
    for epoch in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        n_batches = 0
        epoch_start = time.time()
        for batch_idx, batch in enumerate(train_loader, start=1):
            l_ids, l_len, r_ids, r_len, y = batch
            l_ids = l_ids.to(device)
            r_ids = r_ids.to(device)
            l_len = l_len.to(device)
            r_len = r_len.to(device)
            y = y.float().to(device)

            logits = model(l_ids, l_len, r_ids, r_len)
            loss = loss_fn(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += float(loss.item())
            n_batches = batch_idx

            if (batch_idx % max(1, log_interval) == 0) or (batch_idx == total_batches):
                elapsed = time.time() - epoch_start
                avg_loss = running_loss / n_batches
                print(
                    f"Epoch {epoch:02d}/{epochs} - step {batch_idx:05d}/{total_batches:05d} "
                    f"loss: {float(loss.item()):.4f} (avg: {avg_loss:.4f}) - elapsed: {elapsed:.1f}s",
                    flush=True,
                )

        # Validation
        print("Evaluating on validation...", flush=True)
        y_val_true, y_val_prob = evaluate_model(model, val_loader, device)
        val_metrics = evaluate_binary(y_val_true, y_val_prob)
        if val_metrics["f1"] > best_val_f1:
            best_val_f1 = val_metrics["f1"]
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}

        avg_loss = running_loss / max(1, n_batches)
        print(
            f"Epoch {epoch:02d}/{epochs} - loss: {avg_loss:.4f} - val_f1: {val_metrics['f1']:.4f} - val_acc: {val_metrics['accuracy']:.4f}",
            flush=True,
        )

    # Load best model
    if best_state is not None:
        model.load_state_dict(best_state)

    # Evaluate with default 0.5
    print("Final evaluation (val/test, t=0.50) ...", flush=True)
    y_val_true, y_val_prob = evaluate_model(model, val_loader, device)
    y_test_true, y_test_prob = evaluate_model(model, test_loader, device)

    val_metrics = evaluate_binary(y_val_true, y_val_prob)
    test_metrics = evaluate_binary(y_test_true, y_test_prob)

    print("\nValidation metrics (t=0.50):", flush=True)
    for k, v in val_metrics.items():
        print(f"  {k:>9}: {v:.4f}", flush=True)

    # Tune threshold on validation to maximize F1 (reuse helper)
    best_t = find_best_threshold(y_val_true, y_val_prob, metric="f1")
    print(f"\nChosen decision threshold from validation (max F1): t = {best_t:.2f}", flush=True)

    def with_threshold(y_true: np.ndarray, y_prob: np.ndarray, t: float) -> dict:
        y_pred = (y_prob >= t).astype(int)
        from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

        return {
            "accuracy": float(accuracy_score(y_true, y_pred)),
            "precision": float(precision_score(y_true, y_pred, zero_division=0)),
            "recall": float(recall_score(y_true, y_pred, zero_division=0)),
            "f1": float(f1_score(y_true, y_pred, zero_division=0)),
        }

    val_tuned = with_threshold(y_val_true, y_val_prob, best_t)
    test_tuned = with_threshold(y_test_true, y_test_prob, best_t)

    print("\nValidation (tuned threshold):", flush=True)
    for k, v in val_tuned.items():
        print(f"  {k:>9}: {v:.4f}", flush=True)

    print("\nTest (tuned threshold):", flush=True)
    for k, v in test_tuned.items():
        print(f"  {k:>9}: {v:.4f}", flush=True)


def main() -> None:
    parser = argparse.ArgumentParser(description="Baseline 2: Siamese BiLSTM + MLP for legal clause similarity")
    parser.add_argument(
        "--data_dir",
        type=Path,
        default=Path("csv"),
        help="Directory containing CSV files with columns: clause_text, clause_type (default: ./csv)",
    )
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument(
        "--max_pairs_per_class",
        type=int,
        default=200,
        help="Upper bound of positive pairs sampled per clause_type (for speed/balance).",
    )
    parser.add_argument("--max_len", type=int, default=128, help="Maximum tokenized length per clause")
    parser.add_argument("--max_vocab_size", type=int, default=30000, help="Maximum vocabulary size")
    parser.add_argument("--min_freq", type=int, default=2, help="Minimum frequency for a token to enter the vocab")

    parser.add_argument("--embedding_dim", type=int, default=200, help="Embedding dimension")
    parser.add_argument("--hidden_size", type=int, default=128, help="BiLSTM hidden size (per direction)")
    parser.add_argument("--num_layers", type=int, default=1, help="Number of LSTM layers")
    parser.add_argument("--mlp_hidden", type=int, default=256, help="Hidden size of the MLP head")

    parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
    parser.add_argument("--epochs", type=int, default=8, help="Training epochs")
    parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
    parser.add_argument("--log_interval", type=int, default=200, help="Steps between training progress prints")
    args = parser.parse_args([]) # Modified to pass an empty list

    run(
        data_dir=args.data_dir,
        seed=args.seed,
        max_pairs_per_class=args.max_pairs_per_class,
        max_len=args.max_len,
        max_vocab_size=args.max_vocab_size,
        min_freq=args.min_freq,
        embedding_dim=args.embedding_dim,
        hidden_size=args.hidden_size,
        num_layers=args.num_layers,
        mlp_hidden=args.mlp_hidden,
        batch_size=args.batch_size,
        epochs=args.epochs,
        lr=args.lr,
        log_interval=args.log_interval,
    )


if __name__ == "__main__":
    main()


Loading clauses from: csv
Loaded 150881 clauses from 395 CSV file(s).
Split clauses -> train/val/test: 105616/22632/22633. Built pairs -> train/val/test: 105306/22438/22438.
Vocab size: 30000 (min_freq=2, max_vocab_size=30000)
Using device: cuda
Starting training for 8 epoch(s), steps per epoch: 1646
Epoch 01/8 - step 00200/01646 loss: 0.0011 (avg: 0.0550) - elapsed: 7.6s
Epoch 01/8 - step 00400/01646 loss: 0.0015 (avg: 0.0319) - elapsed: 13.3s
Epoch 01/8 - step 00600/01646 loss: 0.0026 (avg: 0.0255) - elapsed: 19.6s
Epoch 01/8 - step 00800/01646 loss: 0.0002 (avg: 0.0211) - elapsed: 25.3s
Epoch 01/8 - step 01000/01646 loss: 0.0014 (avg: 0.0182) - elapsed: 31.6s
Epoch 01/8 - step 01200/01646 loss: 0.0027 (avg: 0.0164) - elapsed: 37.3s
Epoch 01/8 - step 01400/01646 loss: 0.0004 (avg: 0.0151) - elapsed: 43.7s
Epoch 01/8 - step 01600/01646 loss: 0.0017 (avg: 0.0142) - elapsed: 49.4s
Epoch 01/8 - step 01646/01646 loss: 0.0000 (avg: 0.0138) - elapsed: 50.7s
Evaluating on validation...
Epoch

#  Baseline 2 :  TextCNNEncoder

In [None]:
import argparse
import random
import re
import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Literal

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    average_precision_score,
)


# -----------------------------
# Standalone utilities (shared)
# -----------------------------
@dataclass
class Clause:
    text: str
    clause_type: str
    source_file: str


def set_global_seeds(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def load_clauses_from_folder(folder: Path) -> List[Clause]:
    csv_files = sorted(folder.rglob("*.csv"))
    clauses: List[Clause] = []
    for csv_path in csv_files:
        try:
            df = pd.read_csv(csv_path)
        except Exception as exc:
            print(f"Warning: failed to read {csv_path.name}: {exc}")
            continue
        if "clause_text" not in df.columns or "clause_type" not in df.columns:
            print(f"Warning: {csv_path.name} missing columns 'clause_text'/'clause_type'. Skipping.")
            continue
        for _, row in df.iterrows():
            text = "" if pd.isna(row["clause_text"]) else str(row["clause_text"]).strip()
            ctype = "" if pd.isna(row["clause_type"]) else str(row["clause_type"]).strip()
            if text:
                clauses.append(Clause(text=text, clause_type=ctype, source_file=csv_path.name))
    return clauses


def normalize_text(t: str) -> str:
    return re.sub(r"\s+", " ", t.lower()).strip()


def deduplicate_clauses(clauses: List[Clause]) -> List[Clause]:
    seen: set[str] = set()
    unique: List[Clause] = []
    for c in clauses:
        key = normalize_text(c.text)
        if key in seen:
            continue
        seen.add(key)
        unique.append(c)
    return unique


def check_overlap(a: List[Clause], b: List[Clause], name_a: str, name_b: str) -> None:
    a_set = {normalize_text(c.text) for c in a}
    b_set = {normalize_text(c.text) for c in b}
    inter = a_set & b_set
    if inter:
        sample = list(inter)[:5]
        print(f"Warning: Found {len(inter)} overlapping clause_texts between {name_a} and {name_b}. e.g.: {sample}")
    else:
        print(f"No overlapping clause_texts between {name_a} and {name_b}.")


def make_balanced_pairs(
    clauses: List[Clause],
    max_pairs_per_class: int | None,
    rng: np.random.Generator,
) -> Tuple[List[Tuple[str, str]], np.ndarray]:
    type_to_indices: dict[str, List[int]] = {}
    for idx, c in enumerate(clauses):
        type_to_indices.setdefault(c.clause_type, []).append(idx)
    unique_types = [t for t in type_to_indices.keys() if t]
    if len(unique_types) < 2:
        raise ValueError("Need at least two distinct clause_type values to form pairs.")

    positives: List[Tuple[int, int]] = []
    for ctype, idxs in type_to_indices.items():
        if len(idxs) < 2:
            continue
        pool = idxs.copy()
        rng.shuffle(pool)
        count = len(pool) // 2
        if max_pairs_per_class is not None:
            count = min(count, max_pairs_per_class)
        for i in range(count):
            a, b = pool[2 * i], pool[2 * i + 1]
            if a != b:
                positives.append((a, b))
    if not positives:
        raise ValueError("Could not form positive pairs inside split.")

    negatives: List[Tuple[int, int]] = []
    all_idx = np.arange(len(clauses))
    attempts = 0
    while len(negatives) < len(positives) and attempts < len(positives) * 20:
        attempts += 1
        i, j = rng.choice(all_idx, size=2, replace=False)
        if clauses[i].clause_type != clauses[j].clause_type:
            negatives.append((i, j))

    pairs: List[Tuple[str, str]] = []
    labels: List[int] = []
    for i, j in positives:
        pairs.append((clauses[i].text, clauses[j].text))
        labels.append(1)
    for i, j in negatives:
        pairs.append((clauses[i].text, clauses[j].text))
        labels.append(0)
    order = np.arange(len(pairs))
    rng.shuffle(order)
    pairs = [pairs[k] for k in order]
    y = np.asarray([labels[k] for k in order], dtype=np.int64)
    return pairs, y


def evaluate_binary(y_true: np.ndarray, y_proba: np.ndarray) -> dict:
    y_pred = (y_proba >= 0.5).astype(int)
    out = {
        "accuracy": float(accuracy_score(y_true, y_pred)),
        "precision": float(precision_score(y_true, y_pred, zero_division=0)),
        "recall": float(recall_score(y_true, y_pred, zero_division=0)),
        "f1": float(f1_score(y_true, y_pred, zero_division=0)),
    }
    try:
        out["roc_auc"] = float(roc_auc_score(y_true, y_proba))
    except Exception:
        out["roc_auc"] = float("nan")
    try:
        out["pr_auc"] = float(average_precision_score(y_true, y_proba))
    except Exception:
        out["pr_auc"] = float("nan")
    return out


def find_best_threshold(y_true: np.ndarray, y_proba: np.ndarray, metric: str = "f1") -> float:
    thresholds = np.linspace(0.05, 0.95, 19)
    best_t, best_score = 0.5, -1.0
    for t in thresholds:
        y_pred = (y_proba >= t).astype(int)
        score = accuracy_score(y_true, y_pred) if metric == "accuracy" else f1_score(y_true, y_pred, zero_division=0)
        if score > best_score:
            best_score, best_t = score, t
    return float(best_t)


# -----------------------------
# Tokenizer / Dataset
# -----------------------------
def build_vocab(texts: List[str], max_vocab_size: int, min_freq: int) -> tuple[dict, dict]:
    freq: dict[str, int] = {}
    for t in texts:
        for tok in re.findall(r"\b\w+\b", t.lower()):
            freq[tok] = freq.get(tok, 0) + 1
    sorted_tokens = sorted([t for t, f in freq.items() if f >= min_freq], key=lambda x: (-freq[x], x))
    tok2id = {"<PAD>": 0, "<UNK>": 1}
    for i, tok in enumerate(sorted_tokens[: max(0, max_vocab_size - 2)], start=2):
        tok2id[tok] = i
    id2tok = {i: t for t, i in tok2id.items()}
    return tok2id, id2tok


def texts_to_ids(texts: List[str], tok2id: dict, max_len: int) -> tuple[np.ndarray, np.ndarray]:
    pad_id, unk_id = tok2id.get("<PAD>", 0), tok2id.get("<UNK>", 1)
    seqs = np.full((len(texts), max_len), pad_id, dtype=np.int64)
    lens = np.zeros((len(texts),), dtype=np.int64)
    for i, t in enumerate(texts):
        ids = [tok2id.get(tok, unk_id) for tok in re.findall(r"\b\w+\b", t.lower())][:max_len]
        seqs[i, : len(ids)] = np.asarray(ids, dtype=np.int64)
        lens[i] = len(ids)
    return seqs, lens


class PairDataset(Dataset):
    def __init__(self, pairs: List[Tuple[str, str]], tok2id: dict, max_len: int, labels: np.ndarray | None = None) -> None:
        self.left = [a for a, _ in pairs]
        self.right = [b for _, b in pairs]
        self.labels = None if labels is None else labels.astype(np.int64)
        self.left_ids, self.left_lens = texts_to_ids(self.left, tok2id, max_len)
        self.right_ids, self.right_lens = texts_to_ids(self.right, tok2id, max_len)

    def __len__(self) -> int:
        return len(self.left)

    def __getitem__(self, idx: int):
        l_ids = torch.from_numpy(self.left_ids[idx])
        r_ids = torch.from_numpy(self.right_ids[idx])
        l_len = int(self.left_lens[idx])
        r_len = int(self.right_lens[idx])
        if self.labels is None:
            return l_ids, l_len, r_ids, r_len
        return l_ids, l_len, r_ids, r_len, int(self.labels[idx])


# -----------------------------
# Model: TextCNN Encoder
# -----------------------------
class TextCNNEncoder(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int,
        num_channels: int,
        kernel_sizes: List[int],
        pad_idx: int = 0,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        self.convs = nn.ModuleList(
            [nn.Conv1d(in_channels=embedding_dim, out_channels=num_channels, kernel_size=k) for k in kernel_sizes]
        )
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.output_dim = num_channels * len(kernel_sizes)

    def forward(self, ids: torch.Tensor) -> torch.Tensor:
        """
        Args:
            ids: (B, T) long
        Returns:
            enc: (B, output_dim) - concatenated max-pooled feature maps
        """
        x = self.embedding(ids)  # (B, T, D)
        x = x.transpose(1, 2)    # (B, D, T) for Conv1d
        pooled_outputs = []
        for conv in self.convs:
            h = self.activation(conv(x))         # (B, C, T')
            h = torch.max(h, dim=2).values       # Global max pool over time -> (B, C)
            pooled_outputs.append(h)
        out = torch.cat(pooled_outputs, dim=1)   # (B, C * num_kernels)
        out = self.dropout(out)
        return out


class SiameseTextCNN(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int,
        num_channels: int,
        kernel_sizes: List[int],
        mlp_hidden: int,
        pad_idx: int = 0,
        dropout: float = 0.2,
        use_cosine_feature: bool = True,
    ) -> None:
        super().__init__()
        self.encoder = TextCNNEncoder(
            vocab_size=vocab_size,
            embedding_dim=embedding_dim,
            num_channels=num_channels,
            kernel_sizes=kernel_sizes,
            pad_idx=pad_idx,
            dropout=dropout,
        )
        self.use_cosine_feature = use_cosine_feature
        enc_dim = self.encoder.output_dim
        feat_dim = enc_dim * 4 + (1 if use_cosine_feature else 0)  # [u, v, |u-v|, u*v, cos]
        self.mlp = nn.Sequential(
            nn.Linear(feat_dim, mlp_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, mlp_hidden // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden // 2, 1),
        )

    @staticmethod
    def _cosine(u: torch.Tensor, v: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
        return (u * v).sum(dim=1, keepdim=True) / (u.norm(dim=1, keepdim=True) * v.norm(dim=1, keepdim=True)).clamp_min(eps)

    def forward(self, l_ids: torch.Tensor, r_ids: torch.Tensor) -> torch.Tensor:
        u = self.encoder(l_ids)   # (B, enc_dim)
        v = self.encoder(r_ids)   # (B, enc_dim)
        feats = [u, v, torch.abs(u - v), u * v]
        if self.use_cosine_feature:
            feats.append(self._cosine(u, v))
        x = torch.cat(feats, dim=1)
        return self.mlp(x).squeeze(1)


def evaluate_model(model: nn.Module, loader: DataLoader, device: torch.device) -> tuple[np.ndarray, np.ndarray]:
    model.eval()
    all_probs: list[float] = []
    all_labels: list[int] = []
    with torch.no_grad():
        for l_ids, l_len, r_ids, r_len, y in loader:
            l_ids = l_ids.to(device)
            r_ids = r_ids.to(device)
            y = y.to(device)
            probs = torch.sigmoid(model(l_ids, r_ids))
            all_probs.extend(probs.cpu().numpy().tolist())
            all_labels.extend(y.cpu().numpy().tolist())
    return np.asarray(all_labels, dtype=np.int64), np.asarray(all_probs, dtype=np.float64)


# -----------------------------
# Run pipeline
# -----------------------------
def run(
    data_dir: Path,
    seed: int,
    split_mode: Literal["clause", "type"],
    dedup_texts: bool,
    max_pairs_per_class: int | None,
    max_len: int,
    max_vocab_size: int,
    min_freq: int,
    embedding_dim: int,
    cnn_channels: int,
    kernel_sizes: List[int],
    mlp_hidden: int,
    batch_size: int,
    epochs: int,
    lr: float,
    log_interval: int,
) -> None:
    set_global_seeds(seed)
    print("Loading clauses from:", data_dir, flush=True)
    clauses = load_clauses_from_folder(data_dir)
    print(f"Loaded {len(clauses)} clauses from {len(set(c.source_file for c in clauses))} CSV file(s).", flush=True)

    if dedup_texts:
        before = len(clauses)
        clauses = deduplicate_clauses(clauses)
        after = len(clauses)
        print(f"Deduplicated exact texts: {before} -> {after}", flush=True)

    from sklearn.model_selection import train_test_split

    rng = np.random.default_rng(seed)
    if split_mode == "clause":
        clause_indices = np.arange(len(clauses))
        clause_types = np.array([c.clause_type for c in clauses], dtype=object)
        try:
            idx_train, idx_temp = train_test_split(clause_indices, test_size=0.30, random_state=seed, stratify=clause_types)
            idx_val, idx_test = train_test_split(idx_temp, test_size=0.50, random_state=seed, stratify=clause_types[idx_temp])
        except Exception:
            idx_train, idx_temp = train_test_split(clause_indices, test_size=0.30, random_state=seed)
            idx_val, idx_test = train_test_split(idx_temp, test_size=0.50, random_state=seed)
        clauses_train = [clauses[i] for i in idx_train]
        clauses_val = [clauses[i] for i in idx_val]
        clauses_test = [clauses[i] for i in idx_test]
    else:
        all_types = sorted({c.clause_type for c in clauses if c.clause_type})
        rng.shuffle(all_types)
        n_types = len(all_types)
        n_train = int(0.7 * n_types)
        n_val = int(0.15 * n_types)
        train_types = set(all_types[:n_train])
        val_types = set(all_types[n_train:n_train + n_val])
        test_types = set(all_types[n_train + n_val:])
        clauses_train = [c for c in clauses if c.clause_type in train_types]
        clauses_val = [c for c in clauses if c.clause_type in val_types]
        clauses_test = [c for c in clauses if c.clause_type in test_types]

    check_overlap(clauses_train, clauses_val, "train", "val")
    check_overlap(clauses_train, clauses_test, "train", "test")
    check_overlap(clauses_val, clauses_test, "val", "test")

    X_train, y_train = make_balanced_pairs(clauses_train, max_pairs_per_class=max_pairs_per_class, rng=rng)
    X_val, y_val = make_balanced_pairs(clauses_val, max_pairs_per_class=max_pairs_per_class, rng=rng)
    X_test, y_test = make_balanced_pairs(clauses_test, max_pairs_per_class=max_pairs_per_class, rng=rng)
    print(
        f"Split clauses -> train/val/test: {len(clauses_train)}/{len(clauses_val)}/{len(clauses_test)}. "
        f"Built pairs -> train/val/test: {len(X_train)}/{len(X_val)}/{len(X_test)}.",
        flush=True,
    )

    train_texts = [a for a, _ in X_train] + [b for _, b in X_train]
    tok2id, _ = build_vocab(train_texts, max_vocab_size=max_vocab_size, min_freq=min_freq)
    vocab_size = len(tok2id)
    print(f"Vocab size: {vocab_size} (min_freq={min_freq}, max_vocab_size={max_vocab_size})", flush=True)

    ds_train = PairDataset(X_train, tok2id, max_len=max_len, labels=y_train)
    ds_val = PairDataset(X_val, tok2id, max_len=max_len, labels=y_val)
    ds_test = PairDataset(X_test, tok2id, max_len=max_len, labels=y_test)

    def collate(batch):
        l_ids, l_len, r_ids, r_len, y = zip(*batch)
        return (
            torch.stack(l_ids, dim=0),
            torch.tensor(l_len, dtype=torch.long),
            torch.stack(r_ids, dim=0),
            torch.tensor(r_len, dtype=torch.long),
            torch.tensor(y, dtype=torch.long),
        )

    train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate)
    val_loader = DataLoader(ds_val, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=collate)
    test_loader = DataLoader(ds_test, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=collate)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}", flush=True)
    model = SiameseTextCNN(
        vocab_size=vocab_size,
        embedding_dim=embedding_dim,
        num_channels=cnn_channels,
        kernel_sizes=kernel_sizes,
        mlp_hidden=mlp_hidden,
        pad_idx=0,
        dropout=0.2,
        use_cosine_feature=True,
    ).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCEWithLogitsLoss()

    best_val_f1 = -1.0
    best_state = None
    steps_per_epoch = len(train_loader)
    print(f"Starting training for {epochs} epoch(s), steps per epoch: {steps_per_epoch}", flush=True)
    for epoch in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        epoch_start = time.time()
        for step, batch in enumerate(train_loader, start=1):
            l_ids, l_len, r_ids, r_len, y = batch
            l_ids = l_ids.to(device)
            r_ids = r_ids.to(device)
            y = y.float().to(device)

            logits = model(l_ids, r_ids)
            loss = loss_fn(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += float(loss.item())
            if (step % max(1, log_interval) == 0) or (step == steps_per_epoch):
                avg_loss = running_loss / step
                elapsed = time.time() - epoch_start
                print(
                    f"Epoch {epoch:02d}/{epochs} - step {step:05d}/{steps_per_epoch:05d} "
                    f"loss: {float(loss.item()):.4f} (avg: {avg_loss:.4f}) - elapsed: {elapsed:.1f}s",
                    flush=True,
                )

        print("Evaluating on validation...", flush=True)
        y_val_true, y_val_prob = evaluate_model(model, val_loader, device)
        val_metrics = evaluate_binary(y_val_true, y_val_prob)
        if val_metrics["f1"] > best_val_f1:
            best_val_f1 = val_metrics["f1"]
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        avg_loss = running_loss / max(1, steps_per_epoch)
        print(f"Epoch {epoch:02d}/{epochs} - loss: {avg_loss:.4f} - val_f1: {val_metrics['f1']:.4f} - val_acc: {val_metrics['accuracy']:.4f}", flush=True)

    if best_state is not None:
        model.load_state_dict(best_state)

    print("Final evaluation (val/test, t=0.50) ...", flush=True)
    y_val_true, y_val_prob = evaluate_model(model, val_loader, device)
    y_test_true, y_test_prob = evaluate_model(model, test_loader, device)
    val_metrics = evaluate_binary(y_val_true, y_val_prob)
    test_metrics = evaluate_binary(y_test_true, y_test_prob)

    print("\nValidation metrics (t=0.50):", flush=True)
    for k, v in val_metrics.items():
        print(f"  {k:>9}: {v:.4f}", flush=True)

    best_t = find_best_threshold(y_val_true, y_val_prob, metric="f1")
    print(f"\nChosen decision threshold from validation (max F1): t = {best_t:.2f}", flush=True)

    def with_threshold(y_true: np.ndarray, y_prob: np.ndarray, t: float) -> dict:
        y_pred = (y_prob >= t).astype(int)
        return {
            "accuracy": float(accuracy_score(y_true, y_pred)),
            "precision": float(precision_score(y_true, y_pred, zero_division=0)),
            "recall": float(recall_score(y_true, y_pred, zero_division=0)),
            "f1": float(f1_score(y_true, y_pred, zero_division=0)),
        }

    val_tuned = with_threshold(y_val_true, y_val_prob, best_t)
    test_tuned = with_threshold(y_test_true, y_test_prob, best_t)
    print("\nValidation (tuned threshold):", flush=True)
    for k, v in val_tuned.items():
        print(f"  {k:>9}: {v:.4f}", flush=True)
    print("\nTest (tuned threshold):", flush=True)
    for k, v in test_tuned.items():
        print(f"  {k:>9}: {v:.4f}", flush=True)


def _parse_kernel_sizes(s: str) -> List[int]:
    try:
        return [int(x) for x in s.split(",") if x.strip()]
    except Exception:
        raise argparse.ArgumentTypeError("kernel_sizes must be a comma-separated list of integers, e.g. '3,4,5'")


def main() -> None:
    parser = argparse.ArgumentParser(description="Baseline 4: Siamese TextCNN with clause/type split and optional dedup")
    parser.add_argument("--data_dir", type=Path, default=Path("csv"), help="Directory containing CSVs with clause_text, clause_type")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--split_mode", type=str, choices=["clause", "type"], default="clause", help="Split at clause level or hold out entire types")
    parser.add_argument("--dedup_texts", action="store_true", help="Deduplicate exact clause_text duplicates before splitting")
    parser.add_argument("--max_pairs_per_class", type=int, default=200, help="Max positive pairs sampled per type (per split)")
    parser.add_argument("--max_len", type=int, default=128, help="Maximum tokenized length per clause")
    parser.add_argument("--max_vocab_size", type=int, default=30000, help="Maximum vocabulary size")
    parser.add_argument("--min_freq", type=int, default=2, help="Minimum token frequency for vocab")
    parser.add_argument("--embedding_dim", type=int, default=200, help="Embedding dimension")
    parser.add_argument("--cnn_channels", type=int, default=128, help="Number of channels per convolutional kernel size")
    parser.add_argument("--kernel_sizes", type=_parse_kernel_sizes, default="3,4,5", help="Comma-separated kernel sizes, e.g., '3,4,5'")
    parser.add_argument("--mlp_hidden", type=int, default=256, help="Hidden size of the MLP head")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
    parser.add_argument("--epochs", type=int, default=8, help="Training epochs")
    parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
    parser.add_argument("--log_interval", type=int, default=200, help="Steps between training progress prints")
    args = parser.parse_args([])

    run(
        data_dir=args.data_dir,
        seed=args.seed,
        split_mode=args.split_mode,  # type: ignore[arg-type]
        dedup_texts=args.dedup_texts,
        max_pairs_per_class=args.max_pairs_per_class,
        max_len=args.max_len,
        max_vocab_size=args.max_vocab_size,
        min_freq=args.min_freq,
        embedding_dim=args.embedding_dim,
        cnn_channels=args.cnn_channels,
        kernel_sizes=args.kernel_sizes if isinstance(args.kernel_sizes, list) else _parse_kernel_sizes(args.kernel_sizes),  # type: ignore[arg-type]
        mlp_hidden=args.mlp_hidden,
        batch_size=args.batch_size,
        epochs=args.epochs,
        lr=args.lr,
        log_interval=args.log_interval,
    )


if __name__ == "__main__":
    main()


Loading clauses from: csv
Loaded 150881 clauses from 395 CSV file(s).
Split clauses -> train/val/test: 105616/22632/22633. Built pairs -> train/val/test: 105306/22438/22438.
Vocab size: 30000 (min_freq=2, max_vocab_size=30000)
Using device: cuda
Starting training for 8 epoch(s), steps per epoch: 1646
Epoch 01/8 - step 00200/01646 loss: 0.6842 (avg: 0.6985) - elapsed: 92.8s
Epoch 01/8 - step 00400/01646 loss: 0.6733 (avg: 0.6891) - elapsed: 187.3s
Epoch 01/8 - step 00600/01646 loss: 0.6498 (avg: 0.6742) - elapsed: 282.1s
Epoch 01/8 - step 00800/01646 loss: 0.5053 (avg: 0.6524) - elapsed: 377.0s
Epoch 01/8 - step 01000/01646 loss: 0.5289 (avg: 0.6328) - elapsed: 471.8s
Epoch 01/8 - step 01200/01646 loss: 0.5145 (avg: 0.6136) - elapsed: 566.6s
Epoch 01/8 - step 01400/01646 loss: 0.4789 (avg: 0.5959) - elapsed: 661.4s
Epoch 01/8 - step 01600/01646 loss: 0.4687 (avg: 0.5811) - elapsed: 756.1s
Epoch 01/8 - step 01646/01646 loss: 0.4927 (avg: 0.5774) - elapsed: 777.8s
Evaluating on validation

#  Baseline 3 :  Siamese Self Attention Encoder

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import sys
project_path = '/content/drive/MyDrive/A2'
sys.path.append(project_path)


In [None]:
!pip install numpy pandas torch scikit-learn tqdm matplotlib

In [None]:
!python /content/drive/MyDrive/A2/a2.py --data_dir /content/drive/MyDrive/A2/datafiles --epochs 50 --max_len 128 --max_vocab_size 30000 --embedding_dim 200 --num_heads 4 --mlp_hidden 256 --batch_size 64



In [None]:
import argparse
import random
import re
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    average_precision_score,
    roc_curve,
    auc
)


# =========================
# Utility Classes & Helpers
# =========================

@dataclass
class Clause:
    text: str
    clause_type: str
    source_file: str


def set_global_seeds(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def load_clauses_from_folder(folder: Path) -> List[Clause]:
    csv_files = sorted(folder.rglob("*.csv"))
    clauses: List[Clause] = []
    for csv_path in csv_files:
        try:
            df = pd.read_csv(csv_path)
        except Exception as exc:
            print(f"Warning: failed to read {csv_path.name}: {exc}")
            continue

        if "clause_text" not in df.columns or "clause_type" not in df.columns:
            print(f"Warning: {csv_path.name} missing required columns.")
            continue

        for _, row in df.iterrows():
            text = str(row["clause_text"]) if not pd.isna(row["clause_text"]) else ""
            ctype = str(row["clause_type"]) if not pd.isna(row["clause_type"]) else ""
            if text.strip():
                clauses.append(Clause(text.strip(), ctype.strip(), csv_path.name))
    return clauses


def make_balanced_pairs(clauses: List[Clause], max_pairs_per_class: int | None, rng: np.random.Generator):
    type_to_indices: dict[str, List[int]] = {}
    for idx, c in enumerate(clauses):
        type_to_indices.setdefault(c.clause_type, []).append(idx)

    pos_pairs = []
    for ctype, idxs in type_to_indices.items():
        if len(idxs) < 2:
            continue
        rng.shuffle(idxs)
        pair_count = len(idxs) // 2
        if max_pairs_per_class is not None:
            pair_count = min(pair_count, max_pairs_per_class)
        for i in range(pair_count):
            a, b = idxs[2 * i], idxs[2 * i + 1]
            pos_pairs.append((a, b))

    neg_pairs = []
    all_idx = np.arange(len(clauses))
    while len(neg_pairs) < len(pos_pairs):
        i, j = rng.choice(all_idx, 2, replace=False)
        if clauses[i].clause_type != clauses[j].clause_type:
            neg_pairs.append((i, j))

    X, y = [], []
    for i, j in pos_pairs:
        X.append((clauses[i].text, clauses[j].text))
        y.append(1)
    for i, j in neg_pairs:
        X.append((clauses[i].text, clauses[j].text))
        y.append(0)
    order = rng.permutation(len(X))
    X = [X[k] for k in order]
    y = np.array([y[k] for k in order])
    return X, y


def build_vocab(texts: List[str], max_vocab_size: int, min_freq: int):
    token_freq = {}
    for t in texts:
        for tok in re.findall(r"\b\w+\b", t.lower()):
            token_freq[tok] = token_freq.get(tok, 0) + 1

    sorted_tokens = sorted([tok for tok, f in token_freq.items() if f >= min_freq],
                           key=lambda x: (-token_freq[x], x))
    limited = sorted_tokens[: max_vocab_size - 2]
    token_to_id = {"<PAD>": 0, "<UNK>": 1}
    for i, tok in enumerate(limited, start=2):
        token_to_id[tok] = i
    return token_to_id


def texts_to_ids(texts: List[str], token_to_id: dict, max_len: int):
    pad, unk = 0, 1
    seqs = np.full((len(texts), max_len), pad, dtype=np.int64)
    lens = np.zeros(len(texts), dtype=np.int64)
    for i, t in enumerate(texts):
        tokens = re.findall(r"\b\w+\b", t.lower())
        ids = [token_to_id.get(tok, unk) for tok in tokens][:max_len]
        seqs[i, :len(ids)] = ids
        lens[i] = len(ids)
    return seqs, lens


# ======================
# Dataset
# ======================

class PairDataset(Dataset):
    def __init__(self, pairs, token_to_id, max_len, labels=None):
        self.left_texts = [a for a, _ in pairs]
        self.right_texts = [b for _, b in pairs]
        self.labels = labels.astype(np.int64) if labels is not None else None
        self.token_to_id = token_to_id
        self.max_len = max_len
        self.left_ids, self.left_lens = texts_to_ids(self.left_texts, token_to_id, max_len)
        self.right_ids, self.right_lens = texts_to_ids(self.right_texts, token_to_id, max_len)

    def __len__(self):
        return len(self.left_texts)

    def __getitem__(self, idx):
        l_ids = torch.tensor(self.left_ids[idx])
        r_ids = torch.tensor(self.right_ids[idx])
        l_len = int(self.left_lens[idx])
        r_len = int(self.right_lens[idx])
        if self.labels is None:
            return l_ids, l_len, r_ids, r_len
        return l_ids, l_len, r_ids, r_len, int(self.labels[idx])


# ======================
# Self-Attention Encoder
# ======================

class SelfAttentionEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads=4, dropout=0.1, pad_idx=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)
        self.fc = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

    def forward(self, ids, lengths):
        mask = torch.arange(ids.size(1), device=ids.device)[None, :] >= lengths[:, None]
        x = self.embedding(ids)
        attn_out, _ = self.attn(x, x, x, key_padding_mask=mask)
        x = self.norm(x + attn_out)
        x = self.fc(x)
        mask_inv = (~mask).unsqueeze(-1)
        pooled = (x * mask_inv).sum(1) / mask_inv.sum(1).clamp_min(1e-6)
        return pooled


# ======================
# Siamese Model
# ======================

class SiameseSelfAttention(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, mlp_hidden, dropout=0.2):
        super().__init__()
        self.encoder = SelfAttentionEncoder(vocab_size, embed_dim, num_heads, dropout)
        enc_dim = embed_dim
        feat_dim = enc_dim * 4 + 1
        self.mlp = nn.Sequential(
            nn.Linear(feat_dim, mlp_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, mlp_hidden // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden // 2, 1),
        )

    @staticmethod
    def cosine(u, v, eps=1e-12):
        num = (u * v).sum(dim=1, keepdim=True)
        den = (u.norm(dim=1, keepdim=True) * v.norm(dim=1, keepdim=True)).clamp_min(eps)
        return num / den

    def forward(self, l_ids, l_len, r_ids, r_len):
        u = self.encoder(l_ids, l_len)
        v = self.encoder(r_ids, r_len)
        abs_diff = torch.abs(u - v)
        hadamard = u * v
        cos = self.cosine(u, v)
        feats = torch.cat([u, v, abs_diff, hadamard, cos], dim=1)
        return self.mlp(feats).squeeze(1)


# ======================
# Evaluation helpers
# ======================

def evaluate_binary(y_true, y_proba):
    y_pred = (y_proba >= 0.5).astype(int)
    return {
        "accuracy": float(accuracy_score(y_true, y_pred)),
        "precision": float(precision_score(y_true, y_pred, zero_division=0)),
        "recall": float(recall_score(y_true, y_pred, zero_division=0)),
        "f1": float(f1_score(y_true, y_pred, zero_division=0)),
        "roc_auc": float(roc_auc_score(y_true, y_proba)),
        "pr_auc": float(average_precision_score(y_true, y_proba)),
    }


def evaluate_model(model, loader, device):
    model.eval()
    all_probs, all_labels = [], []
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating", leave=False):
            l_ids, l_len, r_ids, r_len, y = batch
            l_ids, r_ids = l_ids.to(device), r_ids.to(device)
            l_len, r_len, y = l_len.to(device), r_len.to(device), y.to(device)
            probs = torch.sigmoid(model(l_ids, l_len, r_ids, r_len))
            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
    return np.array(all_labels), np.array(all_probs)


# ======================
# Training Function
# ======================
def run(data_dir, seed, max_pairs_per_class, max_len, max_vocab_size, min_freq,
        embedding_dim, num_heads, mlp_hidden, batch_size, epochs, lr):

    set_global_seeds(seed)
    print("Loading dataset...")
    clauses = load_clauses_from_folder(data_dir)
    print(f"Loaded {len(clauses)} clauses from {len(set(c.source_file for c in clauses))} files.")

    from sklearn.model_selection import train_test_split
    idx = np.arange(len(clauses))
    try:
        idx_train, idx_temp = train_test_split(idx, test_size=0.3, stratify=[c.clause_type for c in clauses], random_state=seed)
        idx_val, idx_test = train_test_split(idx_temp, test_size=0.5, stratify=[c.clause_type for c in [clauses[i] for i in idx_temp]], random_state=seed)
    except Exception:
        idx_train, idx_temp = train_test_split(idx, test_size=0.3, random_state=seed)
        idx_val, idx_test = train_test_split(idx_temp, test_size=0.5, random_state=seed)

    clauses_train = [clauses[i] for i in idx_train]
    clauses_val = [clauses[i] for i in idx_val]
    clauses_test = [clauses[i] for i in idx_test]

    rng = np.random.default_rng(seed)
    X_train, y_train = make_balanced_pairs(clauses_train, max_pairs_per_class, rng)
    X_val, y_val = make_balanced_pairs(clauses_val, max_pairs_per_class, rng)
    X_test, y_test = make_balanced_pairs(clauses_test, max_pairs_per_class, rng)

    token_to_id = build_vocab([a for a, _ in X_train] + [b for _, b in X_train], max_vocab_size, min_freq)
    vocab_size = len(token_to_id)
    print(f"Vocab size: {vocab_size}")

    ds_train = PairDataset(X_train, token_to_id, max_len, y_train)
    ds_val = PairDataset(X_val, token_to_id, max_len, y_val)
    ds_test = PairDataset(X_test, token_to_id, max_len, y_test)

    def collate(batch):
        l_ids, l_len, r_ids, r_len, y = zip(*batch)
        return (
            torch.stack(l_ids),
            torch.tensor(l_len),
            torch.stack(r_ids),
            torch.tensor(r_len),
            torch.tensor(y),
        )

    train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True, collate_fn=collate)
    val_loader = DataLoader(ds_val, batch_size=batch_size, shuffle=False, collate_fn=collate)
    test_loader = DataLoader(ds_test, batch_size=batch_size, shuffle=False, collate_fn=collate)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    model = SiameseSelfAttention(vocab_size, embedding_dim, num_heads, mlp_hidden, dropout=0.2).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCEWithLogitsLoss()

    # Track training progress
    train_losses = []
    val_f1_scores = []
    best_f1 = -1
    best_state = None

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
            l_ids, l_len, r_ids, r_len, y = batch
            l_ids, r_ids = l_ids.to(device), r_ids.to(device)
            l_len, r_len, y = l_len.to(device), r_len.to(device), y.float().to(device)

            logits = model(l_ids, l_len, r_ids, r_len)
            loss = loss_fn(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        y_val_true, y_val_prob = evaluate_model(model, val_loader, device)
        val_metrics = evaluate_binary(y_val_true, y_val_prob)
        val_f1_scores.append(val_metrics["f1"])

        print(f"Epoch {epoch:02d} | loss={avg_train_loss:.4f} | val_f1={val_metrics['f1']:.4f}")

        if val_metrics["f1"] > best_f1:
            best_f1 = val_metrics["f1"]
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}

    # Restore best weights
    if best_state:
        model.load_state_dict(best_state)

    # Plot training graphs
    import matplotlib.pyplot as plt

    plt.figure(figsize=(8, 5))
    plt.plot(range(1, len(train_losses)+1), train_losses, marker='o', label='Training Loss')
    plt.title("Training Loss per Epoch")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.legend()
    plt.savefig("/content/drive/MyDrive/A2/training_loss.png")
    plt.close()

    plt.figure(figsize=(8, 5))
    plt.plot(range(1, len(val_f1_scores)+1), val_f1_scores, marker='o', color='orange', label='Validation F1-Score')
    plt.title("Validation F1-Score per Epoch")
    plt.xlabel("Epoch")
    plt.ylabel("F1-Score")
    plt.grid(True)
    plt.legend()
    plt.savefig("/content/drive/MyDrive/A2/val_f1.png")
    plt.close()

    # Final test evaluation
    y_test_true, y_test_prob = evaluate_model(model, test_loader, device)
    test_metrics = evaluate_binary(y_test_true, y_test_prob)
    print("\nFinal Test Metrics:")
    for k, v in test_metrics.items():
        print(f"{k:>10}: {v:.4f}")

    # ROC Curve
    from sklearn.metrics import roc_curve, auc
    fpr, tpr, _ = roc_curve(y_test_true, y_test_prob)
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(6, 5))
    plt.plot(fpr, tpr, label=f"ROC Curve (AUC = {roc_auc:.3f})")
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve on Test Set")
    plt.legend()
    plt.grid(True)
    plt.savefig("/content/drive/MyDrive/A2/roc_curve.png")
    plt.close()



def main():
    parser = argparse.ArgumentParser(description="Siamese Self-Attention Encoder Baseline")
    parser.add_argument("--data_dir", type=Path, default=Path("datafiles"))
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--max_pairs_per_class", type=int, default=200)
    parser.add_argument("--max_len", type=int, default=128)
    parser.add_argument("--max_vocab_size", type=int, default=30000)
    parser.add_argument("--min_freq", type=int, default=2)
    parser.add_argument("--embedding_dim", type=int, default=200)
    parser.add_argument("--num_heads", type=int, default=4)
    parser.add_argument("--mlp_hidden", type=int, default=256)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--epochs", type=int, default=8)
    parser.add_argument("--lr", type=float, default=1e-3)
    args = parser.parse_args()

    run(**vars(args))


if __name__ == "__main__":
    main()


Loading dataset...
Loaded 150881 clauses from 395 files.

Vocab size: 30000

Using device: cuda

Epoch 01 | loss=0.3763 | val_f1=0.8993

Epoch 02 | loss=0.1894 | val_f1=0.9330

Epoch 03 | loss=0.1332 | val_f1=0.9352

Epoch 04 | loss=0.1034 | val_f1=0.9410

Epoch 05 | loss=0.0829 | val_f1=0.9435

Epoch 06 | loss=0.0674 | val_f1=0.9383

Epoch 07 | loss=0.0592 | val_f1=0.9440

Epoch 08 | loss=0.0509 | val_f1=0.9448

Epoch 09 | loss=0.0469 | val_f1=0.9415

Epoch 10 | loss=0.0393 | val_f1=0.9452

Epoch 11 | loss=0.0385 | val_f1=0.9428

Epoch 12 | loss=0.0351 | val_f1=0.9419

Epoch 13 | loss=0.0342 | val_f1=0.9416

Epoch 14 | loss=0.0307 | val_f1=0.9436

Epoch 15 | loss=0.0280 | val_f1=0.9419

Epoch 16 | loss=0.0270 | val_f1=0.9444

Epoch 17 | loss=0.0259 | val_f1=0.9432

Epoch 18 | loss=0.0246 | val_f1=0.9427

Epoch 19 | loss=0.0243 | val_f1=0.9421

Epoch 20 | loss=0.0231 | val_f1=0.9434

Epoch 21 | loss=0.0219 | val_f1=0.9406

Epoch 22 | loss=0.0216 | val_f1=0.9362

Epoch 23 | loss=0.0212 | val_f1=0.9435

Epoch 24 | loss=0.0204 | val_f1=0.9440

Epoch 25 | loss=0.0177 | val_f1=0.9415

Epoch 26 | loss=0.0200 | val_f1=0.9394

Epoch 27 | loss=0.0180 | val_f1=0.9434

Epoch 28 | loss=0.0190 | val_f1=0.9405

Epoch 29 | loss=0.0179 | val_f1=0.9404

Epoch 30 | loss=0.0175 | val_f1=0.9413

Epoch 31 | loss=0.0173 | val_f1=0.9427

Epoch 32 | loss=0.0158 | val_f1=0.9407

Epoch 33 | loss=0.0146 | val_f1=0.9422

Epoch 34 | loss=0.0155 | val_f1=0.9417

Epoch 35 | loss=0.0158 | val_f1=0.9399

Epoch 36 | loss=0.0157 | val_f1=0.9385

Epoch 37 | loss=0.0135 | val_f1=0.9412

Epoch 38 | loss=0.0138 | val_f1=0.9423

Epoch 39 | loss=0.0149 | val_f1=0.9408

Epoch 40 | loss=0.0140 | val_f1=0.9402

Epoch 41 | loss=0.0124 | val_f1=0.9421

Epoch 42 | loss=0.0118 | val_f1=0.9418

Epoch 43 | loss=0.0123 | val_f1=0.9410

Epoch 44 | loss=0.0140 | val_f1=0.9423

Epoch 45 | loss=0.0121 | val_f1=0.9375

Epoch 46 | loss=0.0116 | val_f1=0.9404

Epoch 47 | loss=0.0123 | val_f1=0.9390

Epoch 48 | loss=0.0120 | val_f1=0.9400

Epoch 49 | loss=0.0111 | val_f1=0.9421

Epoch 50 | loss=0.0118 | val_f1=0.9425
                                                  
Final Test Metrics:
  
 accuracy: 0.9438

 precision: 0.9479

 recall: 0.9391

 f1: 0.9435

 roc_auc: 0.9850

 pr_auc: 0.9838