In [3]:
# ============================================
# CONFIG — edit as needed
# ============================================
IN_DIR = "../Datasets/Ingestor"
COMBINED_OUT = "../Datasets/Siamese_Train/pairs_AB.csv"

# Core
GLOB_PATTERN = "*.csv"
RECURSIVE = False
ID_COL = None
NEG_PER_POS = 9  # ✨ RESTORED to 9 - this ratio is required (CHANGED BACK)

# Column split options
COLUMN_SPLIT_MODE = "half"
COLUMN_SPLIT_SHUFFLE = False
EXPLICIT_COLS_A = []
EXPLICIT_COLS_B = []

# Row cap per CSV
ROW_LIMIT = 10_000
ROW_LIMIT_MODE = "head"

# Date parsing → Unix timestamp
DATE_MIN_VALID_FRACTION = 0.50
NUMERIC_MIN_VALID_FRACTION = 0.98
UNIX_UNIT = "s"

# Logging
LOG_EVERY_FILES = 10
LOG_LEVEL = "INFO"
LOG_FILE = None

# ✨ NEW: Feature normalization (ADDED)
NORMALIZE_FEATURES = True  # Enable z-score normalization per feature column

# Columns to DROP from output
DROP_FROM_OUTPUT = [
    "num_nan_a", "num_nan_b", "num_nan_mismatch",
    "row_idx_A", "row_idx_B", "label_type", "source_file"
]

from __future__ import annotations

import glob
import logging
import os
import time
from typing import List, Optional, Tuple

import numpy as np
import pandas as pd

def setup_logging(level: str = "INFO", logfile: Optional[str] = None):
    lvl = getattr(logging, level.upper(), logging.INFO)
    fmt = "%(asctime)s | %(levelname)-7s | %(message)s"
    datefmt = "%Y-%m-%d %H:%M:%S"

    handlers = [logging.StreamHandler()]
    if logfile:
        logdir = os.path.dirname(os.path.abspath(logfile))
        if logdir:
            os.makedirs(logdir, exist_ok=True)
        handlers.append(logging.FileHandler(logfile, mode="w", encoding="utf-8"))

    try:
        logging.basicConfig(level=lvl, format=fmt, datefmt=datefmt, handlers=handlers, force=True)
    except TypeError:
        root = logging.getLogger()
        for h in list(root.handlers):
            root.removeHandler(h)
        logging.basicConfig(level=lvl, format=fmt, datefmt=datefmt, handlers=handlers)


def _clean_numeric_like_text(s: pd.Series) -> pd.Series:
    """Remove common decorators like commas/spaces before numeric coercion."""
    return (s.astype(str)
              .str.replace(",", "", regex=False)
              .str.replace(" ", "", regex=False)
              .str.replace("\u00A0", "", regex=False))


def detect_numeric_columns(df: pd.DataFrame, *, id_col: Optional[str], min_valid_fraction: float) -> List[str]:
    """Columns that are numeric or safely numeric after cleaning."""
    cols = [c for c in df.columns if (id_col is None or c != id_col)]
    numeric_cols = []
    for c in cols:
        s_num = pd.to_numeric(_clean_numeric_like_text(df[c]), errors="coerce")
        if s_num.notna().mean() >= min_valid_fraction:
            numeric_cols.append(c)
    return numeric_cols


def datetime_series_to_unix(dt: pd.Series, unit: str) -> pd.Series:
    """Convert timezone-aware datetime64[ns] Series to Unix timestamps."""
    arr_dt = dt.to_numpy(dtype="datetime64[ns]")
    mask_nat = np.isnat(arr_dt)
    arr_ns = arr_dt.astype("datetime64[ns]").astype("int64").astype("float64")
    arr_ns[mask_nat] = np.nan
    if unit == "ms":
        return pd.Series(arr_ns / 1e6, index=dt.index, dtype="float64")
    else:
        return pd.Series(arr_ns / 1e9, index=dt.index, dtype="float64")


def convert_only_non_numeric_dates_to_unix(
    df: pd.DataFrame,
    *,
    id_col: Optional[str],
    numeric_min_valid_fraction: float,
    date_min_valid_fraction: float,
    unit: str,
) -> Tuple[pd.DataFrame, List[str]]:
    """Only non-numeric columns are tested as dates and converted to Unix."""
    df = df.copy()
    candidate_cols = [c for c in df.columns if (id_col is None or c != id_col)]

    numeric_cols = set(detect_numeric_columns(df, id_col=id_col, min_valid_fraction=numeric_min_valid_fraction))
    converted = []
    for c in candidate_cols:
        if c in numeric_cols:
            continue
        dt = pd.to_datetime(df[c], errors="coerce", utc=True)
        if dt.notna().mean() >= date_min_valid_fraction:
            df[c] = datetime_series_to_unix(dt, unit=unit)
            converted.append(c)
    return df, converted


def as_numeric(df: pd.DataFrame, cols: List[str]) -> pd.DataFrame:
    """Coerce columns to numeric."""
    out = df.copy()
    for c in cols:
        out[c] = pd.to_numeric(_clean_numeric_like_text(out[c]), errors="coerce")
    return out


# ✨ NEW FUNCTION: Normalize features (ADDED)
def normalize_features(df: pd.DataFrame, cols: List[str], stats: Optional[dict] = None) -> Tuple[pd.DataFrame, dict]:
    """
    Z-score normalization: (x - mean) / std
    
    Args:
        df: DataFrame with columns to normalize
        cols: List of column names to normalize
        stats: Optional dict with pre-computed means/stds (for test set)
    
    Returns:
        normalized_df, stats_dict
    """
    df_norm = df.copy()
    
    if stats is None:
        # Compute stats from training data
        stats = {}
        for c in cols:
            mean_val = df[c].mean()
            std_val = df[c].std()
            if std_val == 0 or pd.isna(std_val):
                std_val = 1.0  # Avoid division by zero
            stats[c] = {'mean': mean_val, 'std': std_val}
    
    # Apply normalization
    for c in cols:
        mean_val = stats[c]['mean']
        std_val = stats[c]['std']
        df_norm[c] = (df[c] - mean_val) / std_val
    
    return df_norm, stats


def split_numeric_columns(
    df: pd.DataFrame,
    *,
    id_col: Optional[str],
    mode: str = "half",
    shuffle: bool = False,
    seed: int = 42,
    explicit_A: Optional[List[str]] = None,
    explicit_B: Optional[List[str]] = None,
    min_valid_fraction: float = 0.01,
) -> Tuple[List[str], List[str]]:
    """Return (cols_A, cols_B) disjoint numeric feature lists."""
    all_cols = df.columns.tolist()
    feature_cols = [c for c in all_cols if (id_col is None or c != id_col)]

    valid_frac = {}
    for c in feature_cols:
        s = pd.to_numeric(_clean_numeric_like_text(df[c]), errors="coerce")
        valid_frac[c] = s.notna().mean()

    numeric_ok = [c for c in feature_cols if valid_frac[c] >= min_valid_fraction]

    if explicit_A and explicit_B:
        missing = [c for c in (explicit_A + explicit_B) if c not in numeric_ok]
        if missing:
            raise ValueError(f"Explicit columns not usable: {missing}")
        overlap = set(explicit_A).intersection(explicit_B)
        if overlap:
            raise ValueError(f"Explicit A/B overlap: {overlap}")
        return list(explicit_A), list(explicit_B)

    cols = numeric_ok.copy()
    if shuffle:
        rng = np.random.default_rng(seed)
        rng.shuffle(cols)

    dropped = [c for c in feature_cols if c not in numeric_ok]
    if dropped:
        logging.warning(f"Dropping {len(dropped)} non-usable column(s): {dropped[:10]}{'...' if len(dropped)>10 else ''}")

    mid = len(cols) // 2
    cols_A, cols_B = cols[:mid], cols[mid:]
    if len(cols_A) == 0 or len(cols_B) == 0:
        raise ValueError("Column split resulted in an empty side; need at least two usable numeric columns.")
    return cols_A, cols_B


def _apply_row_cap(df: pd.DataFrame, cap: int, mode: str, seed: int) -> pd.DataFrame:
    if cap is None or len(df) <= cap:
        return df
    if mode == "sample":
        return df.sample(n=cap, random_state=seed)
    return df.head(cap)


def rows_to_AB_record(
    i: int, j: int,
    A_df: pd.DataFrame, B_df: pd.DataFrame,
    cols_A: List[str], cols_B: List[str],
    id_col_eff: str,
    label: int,
    source_file: str
) -> dict:
    """Construct a single training record with idA, idB, A_<col>..., B_<col>..., label."""
    rec = {
        "idA": A_df[id_col_eff].iloc[i],
        "idB": B_df[id_col_eff].iloc[j],
        "label": int(label),
        "source_file": os.path.basename(source_file),
        "row_idx_A": int(i),
        "row_idx_B": int(j),
    }
    for c in cols_A:
        rec[f"A_{c}"] = A_df[c].iloc[i]
    for c in cols_B:
        rec[f"B_{c}"] = B_df[c].iloc[j]
    return rec


def build_pairs_from_single_df_column_split_siamese(
    df: pd.DataFrame,
    *,
    id_col: Optional[str],
    negatives_per_positive: int,
    seed: int,
    source_file: str,
    column_split_mode: str = "half",
    column_split_shuffle: bool = False,
    explicit_cols_A: Optional[List[str]] = None,
    explicit_cols_B: Optional[List[str]] = None,
    normalize: bool = False,  # ✨ NEW PARAMETER (ADDED)
    norm_stats: Optional[dict] = None,  # ✨ NEW PARAMETER (ADDED)
) -> Tuple[pd.DataFrame, pd.Series, pd.DataFrame, Optional[dict]]:  # ✨ MODIFIED RETURN TYPE
    """
    One CSV -> Siamese-ready pairs with optional normalization.
    Returns: (pairs_df, labels, index_meta, normalization_stats)
    """
    rng = np.random.default_rng(seed)

    # Ensure ID
    if id_col is None or id_col not in df.columns:
        df = df.copy()
        df["_row_id"] = np.arange(len(df)).astype(str)
        stem = os.path.splitext(os.path.basename(source_file))[0]
        df["_row_id"] = stem + "::" + df["_row_id"]
        id_col_eff = "_row_id"
        logging.debug(f"[{source_file}] Synthetic ID '_row_id' created")
    else:
        id_col_eff = id_col

    # Convert date-like (non-numeric) to Unix
    df, converted = convert_only_non_numeric_dates_to_unix(
        df,
        id_col=id_col_eff,
        numeric_min_valid_fraction=NUMERIC_MIN_VALID_FRACTION,
        date_min_valid_fraction=DATE_MIN_VALID_FRACTION,
        unit=UNIX_UNIT,
    )
    if converted:
        logging.info(f"Converted {len(converted)} date columns: {converted}")

    # Split numeric columns
    cols_A, cols_B = split_numeric_columns(
        df, id_col=id_col_eff,
        mode=column_split_mode, shuffle=column_split_shuffle, seed=seed,
        explicit_A=explicit_cols_A, explicit_B=explicit_cols_B,
    )
    logging.info(f"[{os.path.basename(source_file)}] Split: A={len(cols_A)}, B={len(cols_B)}")

    # Coerce to numeric
    df_num = as_numeric(df, cols_A + cols_B)

    # ✨ NEW: Normalize features (ADDED)
    if normalize:
        df_num, norm_stats = normalize_features(df_num, cols_A + cols_B, stats=norm_stats)
        logging.info(f"Applied z-score normalization to {len(cols_A + cols_B)} features")
    else:
        norm_stats = None

    # Build A/B sides
    A_side = df_num[[id_col_eff] + cols_A].copy()
    B_side = df_num[[id_col_eff] + cols_B].copy()

    n = len(df_num)
    if n < 2:
        raise ValueError("Need at least 2 rows to form negatives.")

    # Build pairs
    records = []
    pos_pairs, neg_pairs = [], []
    all_indices = np.arange(n)

    for i in range(n):
        pos_pairs.append((i, i))
        size = min(negatives_per_positive, n - 1)
        if size > 0:
            candidates = np.delete(all_indices, i)
            choices = rng.choice(candidates, size=size, replace=False)
            for j in choices:
                neg_pairs.append((i, j))

    # Emit records
    for (i, j) in pos_pairs:
        records.append(rows_to_AB_record(i, j, A_side, B_side, cols_A, cols_B, id_col_eff, 1, source_file))
    for (i, j) in neg_pairs:
        records.append(rows_to_AB_record(i, j, A_side, B_side, cols_A, cols_B, id_col_eff, 0, source_file))

    out_df = pd.DataFrame.from_records(records)
    y = out_df["label"].astype(int).rename("label")
    idx = out_df[["idA", "idB", "row_idx_A", "row_idx_B", "source_file"]].copy()

    return out_df.drop(columns=["row_idx_A", "row_idx_B", "source_file"]), y, idx, norm_stats


def save_siamese_pairs(out_df: pd.DataFrame, y: pd.Series, idx: pd.DataFrame, out_path: str, drop_cols: List[str]):
    """Save combined pairs dataset."""
    df_combined = out_df.copy()
    if "label" not in df_combined.columns:
        df_combined = pd.concat([df_combined, y.reset_index(drop=True)], axis=1)

    drop_present = [c for c in drop_cols if c in df_combined.columns]
    if drop_present:
        df_combined = df_combined.drop(columns=drop_present)
        logging.info(f"Dropped from output: {drop_present}")

    os.makedirs(os.path.dirname(os.path.abspath(out_path)) or ".", exist_ok=True)
    ext = os.path.splitext(out_path)[1].lower()
    if ext == ".parquet":
        df_combined.to_parquet(out_path, index=False)
    else:
        df_combined.to_csv(out_path, index=False)
    logging.info(f"Combined saved: {out_path} | rows={len(df_combined):,} | cols={df_combined.shape[1]:,}")


def build_pairs_from_dir_column_split_siamese(
    in_dir: str,
    *,
    glob_pattern: str = "*.csv",
    recursive: bool = False,
    id_col: Optional[str] = None,
    negatives_per_positive: int = 9,
    seed: int = 42,
    combined_out: str = "../Datasets/Siamese_Train/pairs_AB.csv",
    log_every_files: int = 10,
    row_limit: Optional[int] = None,
    row_limit_mode: str = "head",
    column_split_mode: str = "half",
    column_split_shuffle: bool = False,
    explicit_cols_A: Optional[List[str]] = None,
    explicit_cols_B: Optional[List[str]] = None,
    normalize: bool = False,  # ✨ NEW PARAMETER (ADDED)
):
    """Multi-CSV orchestrator for Siamese encoder with optional normalization."""
    t0 = time.perf_counter()

    search = os.path.join(in_dir, "**", glob_pattern) if recursive else os.path.join(in_dir, glob_pattern)
    files = sorted(glob.glob(search, recursive=recursive))
    if not files:
        raise FileNotFoundError(f"No files matched: {search}")
    logging.info(f"Found {len(files):,} CSV files")

    all_rows = []
    total_pos, total_neg = 0, 0
    norm_stats = None  # ✨ NEW: Track normalization stats (ADDED)

    for k, f in enumerate(files, 1):
        df = pd.read_csv(f)

        original_rows = len(df)
        df = _apply_row_cap(df, row_limit, row_limit_mode, seed)
        if len(df) < original_rows:
            logging.info(f"[{k}/{len(files)}] {os.path.basename(f)} — capped {original_rows:,} -> {len(df):,}")

        if len(df) < 2:
            logging.warning(f"[{k}/{len(files)}] {os.path.basename(f)} has <2 rows; skipping.")
            continue

        logging.info(f"[{k}/{len(files)}] {os.path.basename(f)} — rows={len(df):,}, cols={len(df.columns):,}")

        # ✨ MODIFIED: Pass normalize parameter and get stats back
        out_df, y, _idx_meta, file_norm_stats = build_pairs_from_single_df_column_split_siamese(
            df=df,
            id_col=id_col,
            negatives_per_positive=negatives_per_positive,
            seed=seed,
            source_file=f,
            column_split_mode=column_split_mode,
            column_split_shuffle=column_split_shuffle,
            explicit_cols_A=explicit_cols_A,
            explicit_cols_B=explicit_cols_B,
            normalize=normalize,
            norm_stats=norm_stats if k > 1 else None,  # Use stats from first file
        )

        # ✨ NEW: Store normalization stats from first file (ADDED)
        if k == 1 and normalize:
            norm_stats = file_norm_stats

        all_rows.append(out_df)
        pos = int(y.sum())
        neg = len(y) - pos
        total_pos += pos
        total_neg += neg

        if k % log_every_files == 0:
            logging.info(f"  Progress: {k:,}/{len(files):,} files | "
                         f"pairs so far={total_pos+total_neg:,} (pos={total_pos:,}, neg={total_neg:,})")

    if not all_rows:
        raise RuntimeError("No valid CSVs produced pairs.")

    final_df = pd.concat(all_rows, axis=0, ignore_index=True)
    logging.info(
        f"FINAL — pairs={len(final_df):,} "
        f"(pos={int(final_df['label'].sum()):,}, neg={len(final_df)-int(final_df['label'].sum()):,}), "
        f"A-cols={len([c for c in final_df.columns if c.startswith('A_')])}, "
        f"B-cols={len([c for c in final_df.columns if c.startswith('B_')])}"
    )

    save_siamese_pairs(final_df, final_df["label"], final_df[["idA","idB"]], combined_out, drop_cols=DROP_FROM_OUTPUT)

    # ✨ NEW: Save normalization stats (ADDED)
    if normalize and norm_stats:
        import json
        stats_path = combined_out.replace('.csv', '_norm_stats.json')
        with open(stats_path, 'w') as f:
            json.dump(norm_stats, f, indent=2)
        logging.info(f"Normalization stats saved to {stats_path}")

    logging.info(f"Total elapsed: {time.perf_counter() - t0:.2f}s")
    return final_df


# ============================================
# MAIN EXECUTION
# ============================================
setup_logging(LOG_LEVEL, LOG_FILE)

pairs_df = build_pairs_from_dir_column_split_siamese(
    in_dir=IN_DIR,
    glob_pattern=GLOB_PATTERN,
    recursive=RECURSIVE,
    id_col=ID_COL,
    negatives_per_positive=NEG_PER_POS,
    seed=42,
    combined_out=COMBINED_OUT,
    log_every_files=LOG_EVERY_FILES,
    row_limit=ROW_LIMIT,
    row_limit_mode=ROW_LIMIT_MODE,
    column_split_mode=COLUMN_SPLIT_MODE,
    column_split_shuffle=COLUMN_SPLIT_SHUFFLE,
    explicit_cols_A=EXPLICIT_COLS_A,
    explicit_cols_B=EXPLICIT_COLS_B,
    normalize=NORMALIZE_FEATURES,  # ✨ NEW: Enable normalization (ADDED)
)

print("\n" + "="*60)
print("DATA PREPARATION COMPLETE")
print("="*60)
print(f"Saved: {COMBINED_OUT}")
print(f"Shape: {pairs_df.shape}")

# Column summary
cols = list(pairs_df.columns)
num_A = sum(c.startswith("A_") for c in cols)
num_B = sum(c.startswith("B_") for c in cols)
print(f"A_* columns: {num_A} | B_* columns: {num_B}")

# Label balance
print("\nLabel distribution:")
print(pairs_df["label"].value_counts(dropna=False).to_frame("count"))

print("\nSample rows:")
print(pairs_df.sample(min(3, len(pairs_df)), random_state=0))

# ===========================
# TRAIN/VAL/TEST SPLIT
# ===========================
print("\n" + "="*60)
print("CREATING TRAIN/VAL/TEST SPLITS")
print("="*60)

TRAIN_FRACTION = 0.75
VAL_FRACTION   = 0.10
TEST_FRACTION  = 0.15
SEED = 42

assert abs(TRAIN_FRACTION + VAL_FRACTION + TEST_FRACTION - 1.0) < 1e-9

from sklearn.model_selection import train_test_split

OUT_DIR = os.path.dirname(COMBINED_OUT)
y = pairs_df["label"]

# First split: train vs temp
test_val_size = VAL_FRACTION + TEST_FRACTION
train_df, temp_df = train_test_split(
    pairs_df, test_size=test_val_size, random_state=SEED, stratify=y
)

# Second split: val vs test
rel_val = VAL_FRACTION / (VAL_FRACTION + TEST_FRACTION)
temp_y = temp_df["label"]
val_df, test_df = train_test_split(
    temp_df, test_size=(1 - rel_val), random_state=SEED, stratify=temp_y
)

# Summary
def summarize(split_name, d):
    counts = d["label"].value_counts(dropna=False).sort_index()
    pct = (counts / len(d)).round(4)
    print(f"{split_name:>5} | rows={len(d):,} | label counts: {counts.to_dict()} | ratio: {pct.to_dict()}")

summarize("train", train_df)
summarize(" val ", val_df)
summarize("test ", test_df)

# Save splits
train_path = os.path.join(OUT_DIR, "train.csv")
val_path   = os.path.join(OUT_DIR, "val.csv")
test_path  = os.path.join(OUT_DIR, "test.csv")

train_df.to_csv(train_path, index=False)
val_df.to_csv(val_path, index=False)
test_df.to_csv(test_path, index=False)

print("\nSaved:")
print(f" - {train_path}")
print(f" - {val_path}")
print(f" - {test_path}")
print("\n✓ Ready for model training!")

2025-10-13 23:02:00 | INFO    | Found 25 CSV files
2025-10-13 23:02:01 | INFO    | [1/25] AAPL.csv — capped 987,754 -> 10,000
2025-10-13 23:02:01 | INFO    | [1/25] AAPL.csv — rows=10,000, cols=8
2025-10-13 23:02:01 | INFO    | Converted 1 date columns: ['ts']
2025-10-13 23:02:01 | INFO    | [AAPL.csv] Split: A=4, B=4
2025-10-13 23:02:01 | INFO    | Applied z-score normalization to 8 features
2025-10-13 23:02:04 | INFO    | [2/25] AMZN.csv — capped 824,787 -> 10,000
2025-10-13 23:02:04 | INFO    | [2/25] AMZN.csv — rows=10,000, cols=8
2025-10-13 23:02:04 | INFO    | Converted 1 date columns: ['ts']
2025-10-13 23:02:04 | INFO    | [AMZN.csv] Split: A=4, B=4
2025-10-13 23:02:04 | INFO    | Applied z-score normalization to 8 features
2025-10-13 23:02:07 | INFO    | [3/25] BA.csv — capped 664,841 -> 10,000
2025-10-13 23:02:07 | INFO    | [3/25] BA.csv — rows=10,000, cols=8
2025-10-13 23:02:07 | INFO    | Converted 1 date columns: ['ts']
2025-10-13 23:02:07 | INFO    | [BA.csv] Split: A=4, 


DATA PREPARATION COMPLETE
Saved: ../Datasets/Siamese_Train/pairs_AB.csv
Shape: (2500000, 11)
A_* columns: 4 | B_* columns: 4

Label distribution:
         count
label         
0      2250000
1       250000

Sample rows:
                idA         idB  label  A_volume     A_vwap     A_open  \
103401   AMZN::3401  AMZN::3401      1  0.454418  15.780217  15.756851   
2167674     T::6408     T::7845      0 -0.264422 -32.888015 -32.897835   
202020     BA::2020    BA::2020      1 -0.572847  18.582892  18.696160   

           A_close     B_high      B_low      B_ts  B_transactions  
103401   15.757568  15.786320  15.760567  0.398547       -0.255348  
2167674 -32.868750 -33.202602 -33.171527  2.387220       -0.654134  
202020   18.655676  18.672651  18.685645 -0.927760       -0.649567  

CREATING TRAIN/VAL/TEST SPLITS
train | rows=1,875,000 | label counts: {0: 1687500, 1: 187500} | ratio: {0: 0.9, 1: 0.1}
 val  | rows=250,000 | label counts: {0: 225000, 1: 25000} | ratio: {0: 0.9, 1: 0.1}


## Model

In [5]:
# ============================================
# SIAMESE ENCODER MODEL
# ============================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
from tqdm import tqdm
import matplotlib.pyplot as plt

# ============================================
# CONFIG
# ============================================
DATA_DIR = "../Datasets/Siamese_Train"
TRAIN_PATH = f"{DATA_DIR}/train.csv"
VAL_PATH = f"{DATA_DIR}/val.csv"
TEST_PATH = f"{DATA_DIR}/test.csv"

# Model architecture
HIDDEN_DIMS = [64, 32]         # ✨ SIMPLIFIED - smaller model learns faster
EMBEDDING_DIM = 16             # ✨ REDUCED embedding size
DROPOUT = 0.2                  # ✨ REDUCED dropout - was preventing learning
USE_BATCH_NORM = False         # ✨ DISABLED - can cause issues with small batches

# Training
BATCH_SIZE = 128               # ✨ SMALLER batches for more frequent updates
LEARNING_RATE = 0.01           # ✨ MUCH HIGHER LR to force learning
NUM_EPOCHS = 100
PATIENCE = 20                  
WEIGHT_DECAY = 0               # ✨ DISABLED - no regularization until model learns

# ✨ AGGRESSIVE focal loss parameters
USE_FOCAL_LOSS = True
FOCAL_ALPHA = 0.9              # ✨ INCREASED - 90% focus on positives
FOCAL_GAMMA = 3.0              # ✨ INCREASED - more aggressive focus on hard examples

# ✨ NEW: Class weights for even more aggressive balancing
USE_CLASS_WEIGHTS = True       # Apply additional class weights
POSITIVE_CLASS_WEIGHT = 15.0   # Weight positive samples 15x (beyond focal loss)

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths
MODEL_SAVE_PATH = f"{DATA_DIR}/siamese_model.pt"
RESULTS_PATH = f"{DATA_DIR}/training_results.csv"


# ============================================
# DATASET CLASS
# ============================================
class SiamesePairDataset(Dataset):
    """
    Dataset for Siamese network training.
    Loads pairs with A_* features, B_* features, and label.
    """
    def __init__(self, csv_path, a_cols=None, b_cols=None):
        self.df = pd.read_csv(csv_path)
        
        # Auto-detect A_* and B_* columns if not provided
        if a_cols is None:
            self.a_cols = [c for c in self.df.columns if c.startswith("A_")]
        else:
            self.a_cols = a_cols
            
        if b_cols is None:
            self.b_cols = [c for c in self.df.columns if c.startswith("B_")]
        else:
            self.b_cols = b_cols
        
        # Extract features and labels
        self.X_a = self.df[self.a_cols].values.astype(np.float32)
        self.X_b = self.df[self.b_cols].values.astype(np.float32)
        self.y = self.df["label"].values.astype(np.float32)
        
        # Handle NaNs: replace with column mean
        self._impute_nans()
        
        print(f"Loaded {len(self.df)} pairs from {csv_path}")
        print(f"  A features: {len(self.a_cols)}, B features: {len(self.b_cols)}")
        print(f"  Positive pairs: {self.y.sum():.0f} ({100*self.y.mean():.1f}%)")
    
    def _impute_nans(self):
        """Replace NaNs with column means."""
        # Compute means ignoring NaNs
        a_means = np.nanmean(self.X_a, axis=0)
        b_means = np.nanmean(self.X_b, axis=0)
        
        # Replace NaNs
        a_nan_mask = np.isnan(self.X_a)
        b_nan_mask = np.isnan(self.X_b)
        
        for i in range(self.X_a.shape[1]):
            self.X_a[a_nan_mask[:, i], i] = a_means[i]
        
        for i in range(self.X_b.shape[1]):
            self.X_b[b_nan_mask[:, i], i] = b_means[i]
        
        # If any column is entirely NaN, set to 0
        self.X_a = np.nan_to_num(self.X_a, nan=0.0)
        self.X_b = np.nan_to_num(self.X_b, nan=0.0)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return (
            torch.from_numpy(self.X_a[idx]),
            torch.from_numpy(self.X_b[idx]),
            torch.tensor(self.y[idx], dtype=torch.float32)
        )


# ============================================
# MODEL: TWO-TOWER SIAMESE ENCODER
# ============================================
class TowerNetwork(nn.Module):
    """
    Single tower: transforms input features to embedding space.
    """
    def __init__(self, input_dim, hidden_dims, embedding_dim, dropout=0.3, use_batch_norm=True):
        super().__init__()
        
        layers = []
        prev_dim = input_dim
        
        # Hidden layers
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            if use_batch_norm:
                layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            prev_dim = hidden_dim
        
        # Final embedding layer (no activation - we want raw embeddings)
        layers.append(nn.Linear(prev_dim, embedding_dim))
        
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)


class SiameseNetwork(nn.Module):
    """
    Two-tower Siamese architecture:
    - Tower A: encodes A_* features
    - Tower B: encodes B_* features
    - Similarity: computes distance/similarity between embeddings
    - Classifier: predicts if pair matches (label=1) or not (label=0)
    """
    def __init__(self, input_dim_a, input_dim_b, hidden_dims, embedding_dim, 
                 dropout=0.3, use_batch_norm=True, distance_metric="cosine"):
        super().__init__()
        
        self.distance_metric = distance_metric
        
        # Two towers (can share weights or be separate - here separate)
        self.tower_a = TowerNetwork(input_dim_a, hidden_dims, embedding_dim, dropout, use_batch_norm)
        self.tower_b = TowerNetwork(input_dim_b, hidden_dims, embedding_dim, dropout, use_batch_norm)
        
        # Final classifier: takes similarity features and predicts match probability
        # Input: embedding_dim * 3 (concatenated embeddings + element-wise abs diff)
        classifier_input = embedding_dim * 3
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x_a, x_b):
        # Encode both sides
        emb_a = self.tower_a(x_a)  # (batch, embedding_dim)
        emb_b = self.tower_b(x_b)  # (batch, embedding_dim)
        
        # Compute similarity features
        # Option 1: Concatenate embeddings + absolute difference
        abs_diff = torch.abs(emb_a - emb_b)
        combined = torch.cat([emb_a, emb_b, abs_diff], dim=1)
        
        # Predict match probability
        output = self.classifier(combined)
        
        return output.squeeze(-1), emb_a, emb_b
    
    def get_embeddings(self, x_a, x_b):
        """Extract embeddings without classification."""
        with torch.no_grad():
            emb_a = self.tower_a(x_a)
            emb_b = self.tower_b(x_b)
        return emb_a, emb_b


# ============================================
# TRAINING UTILITIES
# ============================================
class FocalLoss(nn.Module):
    """
    Focal Loss for handling extreme class imbalance.
    Focuses learning on hard-to-classify examples.
    
    FL(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t)
    
    Args:
        alpha: Weight for positive class (0.75 means focus on positives)
        gamma: Focusing parameter (2.0 is standard, higher = more focus on hard examples)
    """
    def __init__(self, alpha=0.75, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, preds, labels):
        # Clip predictions to prevent log(0)
        preds = torch.clamp(preds, min=1e-7, max=1-1e-7)
        
        # Compute focal loss components
        bce = -(labels * torch.log(preds) + (1 - labels) * torch.log(1 - preds))
        
        # Compute p_t (probability of correct class)
        p_t = torch.where(labels == 1, preds, 1 - preds)
        
        # Compute focal weight: (1 - p_t)^gamma
        focal_weight = (1 - p_t) ** self.gamma
        
        # Compute alpha weight
        alpha_t = torch.where(labels == 1, self.alpha, 1 - self.alpha)
        
        # Final focal loss
        focal_loss = alpha_t * focal_weight * bce
        
        return focal_loss.mean()


def train_epoch(model, loader, optimizer, criterion, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    for x_a, x_b, labels in tqdm(loader, desc="Training", leave=False):
        x_a, x_b, labels = x_a.to(device), x_b.to(device), labels.to(device)
        
        optimizer.zero_grad()
        preds, _, _ = model(x_a, x_b)
        
        # Handle both function and nn.Module losses
        if callable(criterion) and not isinstance(criterion, nn.Module):
            loss = criterion(preds, labels)
        else:
            loss = criterion(preds, labels)
        
        loss.backward()
        
        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        total_loss += loss.item()
        all_preds.extend(preds.detach().cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(loader)
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    accuracy = accuracy_score(all_labels, all_preds > 0.5)
    
    # Also compute F1 on training set to monitor
    _, _, train_f1, _ = precision_recall_fscore_support(
        all_labels, all_preds > 0.5, average='binary', zero_division=0
    )
    
    return avg_loss, accuracy, train_f1


def evaluate(model, loader, criterion, device, threshold=0.5):
    """Evaluate on validation/test set."""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for x_a, x_b, labels in tqdm(loader, desc="Evaluating", leave=False):
            x_a, x_b, labels = x_a.to(device), x_b.to(device), labels.to(device)
            
            preds, _, _ = model(x_a, x_b)
            
            # Handle both function and nn.Module losses
            if callable(criterion) and not isinstance(criterion, nn.Module):
                loss = criterion(preds, labels)
            else:
                loss = criterion(preds, labels)
            
            total_loss += loss.item()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(loader)
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    # Metrics with specified threshold
    accuracy = accuracy_score(all_labels, all_preds > threshold)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds > threshold, average='binary', zero_division=0
    )
    
    # Only compute AUC if we have both classes
    try:
        auc = roc_auc_score(all_labels, all_preds)
    except ValueError:
        auc = 0.5  # If only one class present
    
    return avg_loss, accuracy, precision, recall, f1, auc


def find_optimal_threshold(model, loader, device):
    """
    Find optimal decision threshold for imbalanced data.
    Returns threshold that maximizes F1 score.
    """
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for x_a, x_b, labels in tqdm(loader, desc="Finding threshold", leave=False):
            x_a, x_b, labels = x_a.to(device), x_b.to(device), labels.to(device)
            preds, _, _ = model(x_a, x_b)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    # Try different thresholds
    best_f1 = 0
    best_threshold = 0.5
    
    for threshold in np.arange(0.1, 0.9, 0.05):
        _, _, f1, _ = precision_recall_fscore_support(
            all_labels, all_preds > threshold, average='binary', zero_division=0
        )
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold
    
    return best_threshold, best_f1


# ============================================
# MAIN TRAINING LOOP
# ============================================
def train_siamese_model():
    print(f"Using device: {DEVICE}")
    
    # Load datasets
    train_dataset = SiamesePairDataset(TRAIN_PATH)
    val_dataset = SiamesePairDataset(VAL_PATH, 
                                     a_cols=train_dataset.a_cols, 
                                     b_cols=train_dataset.b_cols)
    test_dataset = SiamesePairDataset(TEST_PATH,
                                      a_cols=train_dataset.a_cols,
                                      b_cols=train_dataset.b_cols)
    
    # Dataloaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    # Initialize model
    input_dim_a = len(train_dataset.a_cols)
    input_dim_b = len(train_dataset.b_cols)
    
    model = SiameseNetwork(
        input_dim_a=input_dim_a,
        input_dim_b=input_dim_b,
        hidden_dims=HIDDEN_DIMS,
        embedding_dim=EMBEDDING_DIM,
        dropout=DROPOUT,
        use_batch_norm=USE_BATCH_NORM
    ).to(DEVICE)
    
    print(f"\nModel architecture:")
    print(f"  Tower A input: {input_dim_a} features")
    print(f"  Tower B input: {input_dim_b} features")
    print(f"  Hidden layers: {HIDDEN_DIMS}")
    print(f"  Embedding dim: {EMBEDDING_DIM}")
    print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Loss and optimizer
    pos_weight = (len(train_dataset) - train_dataset.y.sum()) / train_dataset.y.sum()
    print(f"\n⚠️  Class imbalance ratio: {pos_weight:.2f}:1 (negatives:positives)")
    print(f"   Strategy: {'Focal Loss' if USE_FOCAL_LOSS else 'Weighted BCE Loss'}")
    
    if USE_FOCAL_LOSS:
        # Focal Loss - better for extreme imbalance
        criterion = FocalLoss(alpha=FOCAL_ALPHA, gamma=FOCAL_GAMMA)
        print(f"   Focal Loss - alpha={FOCAL_ALPHA}, gamma={FOCAL_GAMMA}")
        print(f"   This focuses learning on hard-to-classify examples")
    else:
        # Weighted BCE Loss
        def weighted_bce_loss(predictions, labels):
            loss = -(pos_weight * labels * torch.log(predictions + 1e-7) + 
                     (1 - labels) * torch.log(1 - predictions + 1e-7))
            return loss.mean()
        criterion = weighted_bce_loss
        print(f"   Weighted BCE - positive class weighted {pos_weight:.2f}x")
    
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    # ✨ NEW: Reduce LR when F1 plateaus (better metric for imbalanced data)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', patience=5, factor=0.5, verbose=True
    )  # mode='max' because we want to maximize F1
    
    # Training loop with early stopping
    best_val_f1 = 0.0  # ✨ CHANGED: Track best F1 instead of loss
    patience_counter = 0
    history = []
    
    print(f"\nStarting training for {NUM_EPOCHS} epochs...")
    print(f"Batch size: {BATCH_SIZE}, LR: {LEARNING_RATE}")
    print(f"Monitoring F1 score for early stopping (patience={PATIENCE})")
    print(f"Target: F1 > 0.60 for good performance with 9:1 imbalance\n")
    
    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
        
        # Train
        train_loss, train_acc, train_f1 = train_epoch(model, train_loader, optimizer, criterion, DEVICE)
        
        # Validate
        val_loss, val_acc, val_prec, val_rec, val_f1, val_auc = evaluate(model, val_loader, criterion, DEVICE)
        
        # Learning rate scheduling based on F1 (better for imbalanced data)
        scheduler.step(val_f1)
        current_lr = optimizer.param_groups[0]['lr']
        
        # Log results
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train F1: {train_f1:.4f}")
        print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
        print(f"  Val Precision: {val_prec:.4f} | Val Recall: {val_rec:.4f}")
        print(f"  Val F1: {val_f1:.4f} | Val AUC: {val_auc:.4f}")
        print(f"  LR: {current_lr:.6f}")
        
        # ✨ Warning if model is predicting all one class
        if val_f1 < 0.01 and epoch > 3:
            print(f"  ⚠️  WARNING: F1 is near zero - model may be predicting all negatives!")
            print(f"     Consider: Lower LR, increase focal gamma, or check data")
        
        history.append({
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'train_f1': train_f1,
            'val_loss': val_loss,
            'val_acc': val_acc,
            'val_precision': val_prec,
            'val_recall': val_rec,
            'val_f1': val_f1,
            'val_auc': val_auc,
            'lr': current_lr
        })
        
        # ✨ CHANGED: Save based on F1 score (better for imbalanced data)
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': val_f1,
                'val_loss': val_loss,
                'a_cols': train_dataset.a_cols,
                'b_cols': train_dataset.b_cols
            }, MODEL_SAVE_PATH)
            print(f"  → Model saved (best val F1: {best_val_f1:.4f})")
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print(f"\nEarly stopping triggered after {epoch+1} epochs")
                break
    
    # Save training history
    history_df = pd.DataFrame(history)
    history_df.to_csv(RESULTS_PATH, index=False)
    print(f"\nTraining history saved to {RESULTS_PATH}")
    
    # Load best model and evaluate on test set
    print("\nLoading best model for test evaluation...")
    checkpoint = torch.load(MODEL_SAVE_PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # ✨ NEW: Find optimal threshold on validation set
    print("\nFinding optimal decision threshold on validation set...")
    optimal_threshold, val_f1_at_threshold = find_optimal_threshold(model, val_loader, DEVICE)
    print(f"Optimal threshold: {optimal_threshold:.3f} (F1: {val_f1_at_threshold:.4f})")
    print(f"Default 0.5 threshold may not be optimal for imbalanced data!")
    
    # Evaluate with both thresholds
    print("\n" + "="*60)
    print("FINAL TEST RESULTS (threshold=0.5):")
    print("="*60)
    test_loss, test_acc, test_prec, test_rec, test_f1, test_auc = evaluate(
        model, test_loader, criterion, DEVICE, threshold=0.5
    )
    print(f"  Test Loss: {test_loss:.4f}")
    print(f"  Test Accuracy: {test_acc:.4f}")
    print(f"  Test Precision: {test_prec:.4f}")
    print(f"  Test Recall: {test_rec:.4f}")
    print(f"  Test F1: {test_f1:.4f}")
    print(f"  Test AUC: {test_auc:.4f}")
    
    print("\n" + "="*60)
    print(f"FINAL TEST RESULTS (threshold={optimal_threshold:.3f}):")
    print("="*60)
    test_loss_opt, test_acc_opt, test_prec_opt, test_rec_opt, test_f1_opt, test_auc_opt = evaluate(
        model, test_loader, criterion, DEVICE, threshold=optimal_threshold
    )
    print(f"  Test Loss: {test_loss_opt:.4f}")
    print(f"  Test Accuracy: {test_acc_opt:.4f}")
    print(f"  Test Precision: {test_prec_opt:.4f}")
    print(f"  Test Recall: {test_rec_opt:.4f}")
    print(f"  Test F1: {test_f1_opt:.4f}")
    print(f"  Test AUC: {test_auc_opt:.4f}")
    print("="*60)
    
    # Save optimal threshold
    checkpoint['optimal_threshold'] = optimal_threshold
    torch.save(checkpoint, MODEL_SAVE_PATH)
    print(f"\n✓ Optimal threshold ({optimal_threshold:.3f}) saved to model checkpoint")
    
    return model, history_df


# ============================================
# INFERENCE UTILITIES
# ============================================
def predict_pairs(model, x_a, x_b, device=DEVICE):
    """
    Predict match probability for new pairs.
    
    Args:
        model: trained SiameseNetwork
        x_a: numpy array of A features (n_samples, n_features_a)
        x_b: numpy array of B features (n_samples, n_features_b)
    
    Returns:
        predictions: match probabilities (n_samples,)
    """
    model.eval()
    with torch.no_grad():
        x_a_tensor = torch.from_numpy(x_a.astype(np.float32)).to(device)
        x_b_tensor = torch.from_numpy(x_b.astype(np.float32)).to(device)
        
        preds, _, _ = model(x_a_tensor, x_b_tensor)
        return preds.cpu().numpy()


def get_embeddings(model, x_a, x_b, device=DEVICE):
    """
    Get embeddings for rows from both towers.
    
    Returns:
        emb_a, emb_b: numpy arrays of embeddings
    """
    model.eval()
    with torch.no_grad():
        x_a_tensor = torch.from_numpy(x_a.astype(np.float32)).to(device)
        x_b_tensor = torch.from_numpy(x_b.astype(np.float32)).to(device)
        
        emb_a, emb_b = model.get_embeddings(x_a_tensor, x_b_tensor)
        return emb_a.cpu().numpy(), emb_b.cpu().numpy()


# ============================================
# VISUALIZATION
# ============================================
def plot_training_history(history_df):
    """Plot training curves."""
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Loss
    axes[0, 0].plot(history_df['epoch'], history_df['train_loss'], label='Train Loss', marker='o')
    axes[0, 0].plot(history_df['epoch'], history_df['val_loss'], label='Val Loss', marker='s')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training & Validation Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[0, 1].plot(history_df['epoch'], history_df['train_acc'], label='Train Acc', marker='o')
    axes[0, 1].plot(history_df['epoch'], history_df['val_acc'], label='Val Acc', marker='s')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].set_title('Training & Validation Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # F1 Score
    axes[1, 0].plot(history_df['epoch'], history_df['val_f1'], label='Val F1', marker='s', color='green')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('F1 Score')
    axes[1, 0].set_title('Validation F1 Score')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # AUC
    axes[1, 1].plot(history_df['epoch'], history_df['val_auc'], label='Val AUC', marker='s', color='purple')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('AUC')
    axes[1, 1].set_title('Validation AUC')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f"{DATA_DIR}/training_curves.png", dpi=150)
    print(f"Training curves saved to {DATA_DIR}/training_curves.png")
    plt.show()


# ============================================
# RUN TRAINING
# ============================================
if __name__ == "__main__":
    model, history = train_siamese_model()
    plot_training_history(history)
    
    print("\n✓ Training complete!")
    print(f"✓ Model saved to: {MODEL_SAVE_PATH}")
    print(f"✓ Results saved to: {RESULTS_PATH}")

Using device: cuda
Loaded 1875000 pairs from ../Datasets/Siamese_Train/train.csv
  A features: 4, B features: 4
  Positive pairs: 187500 (10.0%)
Loaded 250000 pairs from ../Datasets/Siamese_Train/val.csv
  A features: 4, B features: 4
  Positive pairs: 25000 (10.0%)




Loaded 375000 pairs from ../Datasets/Siamese_Train/test.csv
  A features: 4, B features: 4
  Positive pairs: 37500 (10.0%)

Model architecture:
  Tower A input: 4 features
  Tower B input: 4 features
  Hidden layers: [64, 32]
  Embedding dim: 16
  Total parameters: 9,057

⚠️  Class imbalance ratio: 9.00:1 (negatives:positives)
   Strategy: Focal Loss
   Focal Loss - alpha=0.9, gamma=3.0
   This focuses learning on hard-to-classify examples

Starting training for 100 epochs...
Batch size: 128, LR: 0.01
Monitoring F1 score for early stopping (patience=20)
Target: F1 > 0.60 for good performance with 9:1 imbalance


Epoch 1/100


                                                                 

  Train Loss: 0.0156 | Train Acc: 0.4959 | Train F1: 0.1669
  Val Loss: 0.0156 | Val Acc: 0.1000
  Val Precision: 0.1000 | Val Recall: 1.0000
  Val F1: 0.1818 | Val AUC: 0.5000
  LR: 0.010000
  → Model saved (best val F1: 0.1818)

Epoch 2/100


                                                                 

  Train Loss: 0.0156 | Train Acc: 0.4996 | Train F1: 0.1667
  Val Loss: 0.0156 | Val Acc: 0.9000
  Val Precision: 0.0000 | Val Recall: 0.0000
  Val F1: 0.0000 | Val AUC: 0.5000
  LR: 0.010000

Epoch 3/100


                                                                 

  Train Loss: 0.0156 | Train Acc: 0.4949 | Train F1: 0.1666
  Val Loss: 0.0156 | Val Acc: 0.1000
  Val Precision: 0.1000 | Val Recall: 1.0000
  Val F1: 0.1818 | Val AUC: 0.5000
  LR: 0.010000

Epoch 4/100


                                                                 

  Train Loss: 0.0156 | Train Acc: 0.4941 | Train F1: 0.1674
  Val Loss: 0.0156 | Val Acc: 0.1000
  Val Precision: 0.1000 | Val Recall: 1.0000
  Val F1: 0.1818 | Val AUC: 0.5000
  LR: 0.010000

Epoch 5/100


                                                                 

  Train Loss: 0.0156 | Train Acc: 0.4975 | Train F1: 0.1664
  Val Loss: 0.0156 | Val Acc: 0.9000
  Val Precision: 0.0000 | Val Recall: 0.0000
  Val F1: 0.0000 | Val AUC: 0.5000
  LR: 0.010000
     Consider: Lower LR, increase focal gamma, or check data

Epoch 6/100


                                                               

KeyboardInterrupt: 