# Model 4: Regular features + all embedding families (each PCA-compressed) – Linear classifier

This notebook trains a linear PyTorch classifier on regular features plus **all available embedding families**, each
compressed separately via IncrementalPCA. This gives a "full information but simple model" that respects the
constraint of using simpler models as feature complexity grows.

**Features:**
- ✅ Uses **all embedding families** (sent_transformer, scibert, specter2, etc.)
- ✅ **Per-family PCA** to preserve information while reducing dimensionality
- ✅ 5-fold Cross-Validation
- ✅ Hyperparameter Tuning
- ✅ Threshold Fine-tuning
- ✅ Model Weight Saving
- ✅ Submission.csv Generation
- ✅ OOM Safe with aggressive memory management
- ✅ SMOTETomek for class imbalance


# 📑 Model 4 - Code Navigation Index

## Quick Navigation
- **[Setup](#1-setup)** - Imports, paths, device configuration, robustness utilities
- **[Data Loading](#2-data-loading--feature-extraction)** - Load and split features
- **[PCA Preprocessing](#3-feature-preprocessing-pca)** - Embedding compression (if applicable)
- **[SMOTETomek](#4-class-imbalance-handling-smotetomek)** - Class imbalance resampling
- **[Feature Scaling](#5-feature-scaling)** - StandardScaler normalization
- **[Cross-Validation](#6-cross-validation--hyperparameter-tuning)** - Hyperparameter optimization
- **[Threshold Tuning](#7-threshold-tuning--final-evaluation)** - Optimal threshold finding
- **[Model Saving](#8-save-model)** - Save model weights and metadata
- **[Submission](#9-generate-submission)** - Generate test predictions

## Model Type: Linear Classifier (regular + all embeddings, PCA)

## Key Features
✅ GPU-friendly with CPU fallback  
✅ Aggressive garbage collection  
✅ OOM resistant with chunked processing  
✅ Kernel panic resistant (signal handlers, checkpoints)  
✅ Polars-only (no pandas)  
✅ GPU-friendly PCA (IncrementalTorchPCA option)  
✅ SMOTETomek for class imbalance  
✅ Feature scaling & embedding normalization  
✅ Hyperparameter tuning (RandomizedSearchCV/GridSearchCV)  
✅ Fine-grained threshold optimization (120+ thresholds)  
✅ Model weights saved  
✅ Chunked/batched data processing  

## Memory Management
- `cleanup_memory()`: Aggressive GC + GPU cache clearing
- `check_memory_safe()`: Pre-operation memory checks
- `chunked_operation()`: Process large data in chunks
- `safe_operation()`: Retry decorator with OOM handling
- Signal handlers: SIGINT/SIGTERM for graceful shutdown
- Checkpoints: Resume from failures

## Device Handling
- Automatic GPU detection with CPU fallback
- `device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')`
- All tensors moved to device explicitly
- GPU cache cleared aggressively after operations


## 1. Setup

In [None]:
import os
from pathlib import Path
import random
import gc
import numpy as np
import polars as pl
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import (
    f1_score,
    classification_report,
    roc_auc_score,
    average_precision_score,
    roc_curve,
    precision_recall_curve,
)
import matplotlib.pyplot as plt
import sys
import time
import json
import pickle
import signal
import atexit
from functools import wraps

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Device (GPU if available, else CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Paths
current = Path(os.getcwd())
PROJECT_ROOT = current
for _ in range(5):
    if (PROJECT_ROOT / 'data').exists():
        break
    PROJECT_ROOT = PROJECT_ROOT.parent
else:
    PROJECT_ROOT = current.parent.parent

MODEL_READY_DIR = PROJECT_ROOT / 'data' / 'model_ready'
utils_path = PROJECT_ROOT / 'src' / 'utils'
print('PROJECT_ROOT:', PROJECT_ROOT)
print('MODEL_READY_DIR:', MODEL_READY_DIR)

# Import PCA utilities
USE_TORCH_PCA = False  # Set to True to use PyTorch PCA (requires more memory)
if utils_path.exists():
    sys.path.insert(0, str(utils_path))

if USE_TORCH_PCA:
    try:
        from pca_utils import IncrementalTorchPCA
        IncrementalPCA = IncrementalTorchPCA  # Alias for compatibility
        IS_TORCH_PCA = True
        print("✅ Using PyTorch PCA (GPU-friendly)")
    except ImportError:
        from sklearn.decomposition import IncrementalPCA
        IS_TORCH_PCA = False
        print("⚠️ Using sklearn IncrementalPCA (CPU only)")
else:
    from sklearn.decomposition import IncrementalPCA
    IS_TORCH_PCA = False
    print("✅ Using sklearn IncrementalPCA (memory-efficient)")

# Import memory utilities from shared module
if utils_path.exists():
    sys.path.insert(0, str(utils_path))

try:
    from model_training_utils import cleanup_memory, memory_usage
    print("✅ Memory utilities imported from shared module")
except ImportError:
    def cleanup_memory():
        """Aggressive memory cleanup for both CPU and GPU."""
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            torch.cuda.ipc_collect()
        gc.collect()
    
    def memory_usage():
        """Display current memory usage statistics."""
        try:
            import psutil
            process = psutil.Process(os.getpid())
            mem_info = process.memory_info()
            print(f"💾 Memory: {mem_info.rss / 1024**3:.2f} GB (RAM)", end="")
            if torch.cuda.is_available():
                gpu_mem = torch.cuda.memory_allocated() / 1024**3
                gpu_reserved = torch.cuda.memory_reserved() / 1024**3
                print(f" | {gpu_mem:.2f}/{gpu_reserved:.2f} GB (GPU used/reserved)")
            else:
                print()
        except ImportError:
            print("💾 Memory tracking requires psutil: pip install psutil")
    
    print("⚠️ Using fallback memory utilities")


In [None]:
# ============================================================================
# ENHANCED ROBUSTNESS UTILITIES
# ============================================================================

# Global checkpoint state
_checkpoint_state = {
    'pca_complete': False,
    'scaling_complete': False,
    'cv_complete': False,
    'final_model_trained': False,
    'last_saved_checkpoint': None
}

def save_checkpoint(state_name: str, data: dict, checkpoint_dir: Path = None):
    """Save checkpoint to resume from failures."""
    if checkpoint_dir is None:
        checkpoint_dir = PROJECT_ROOT / 'data' / 'checkpoints'
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    checkpoint_path = checkpoint_dir / f'model4_checkpoint_{state_name}.pkl'
    try:
        with open(checkpoint_path, 'wb') as f:
            pickle.dump(data, f)
        _checkpoint_state['last_saved_checkpoint'] = checkpoint_path
        print(f"✅ Checkpoint saved: {checkpoint_path}")
    except Exception as e:
        print(f"⚠️ Failed to save checkpoint: {e}")

def load_checkpoint(state_name: str, checkpoint_dir: Path = None):
    """Load checkpoint to resume from failures."""
    if checkpoint_dir is None:
        checkpoint_dir = PROJECT_ROOT / 'data' / 'checkpoints'
    checkpoint_path = checkpoint_dir / f'model4_checkpoint_{state_name}.pkl'
    if checkpoint_path.exists():
        try:
            with open(checkpoint_path, 'rb') as f:
                data = pickle.load(f)
            print(f"✅ Checkpoint loaded: {checkpoint_path}")
            return data
        except Exception as e:
            print(f"⚠️ Failed to load checkpoint: {e}")
    return None

def check_memory_safe(ram_threshold_gb=0.85, gpu_threshold=0.80):
    """Check if memory usage is safe before operations."""
    try:
        import psutil
        process = psutil.Process(os.getpid())
        ram_gb = process.memory_info().rss / 1024**3
        total_ram = psutil.virtual_memory().total / 1024**3
        ram_ratio = ram_gb / total_ram if total_ram > 0 else 0
        gpu_ratio = 0
        if torch.cuda.is_available():
            gpu_used = torch.cuda.memory_allocated() / 1024**3
            gpu_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
            gpu_ratio = gpu_used / gpu_total if gpu_total > 0 else 0
        is_safe = ram_ratio < ram_threshold_gb and gpu_ratio < gpu_threshold
        return is_safe, {'ram_gb': ram_gb, 'ram_ratio': ram_ratio, 'gpu_ratio': gpu_ratio}
    except:
        return True, {}

def chunked_operation(data, operation_func, chunk_size: int = 10000, progress_every: int = 10, operation_name: str = "operation"):
    """Execute operation on data in chunks with progress tracking."""
    total_chunks = (len(data) + chunk_size - 1) // chunk_size
    results = []
    for i in range(0, len(data), chunk_size):
        chunk_num = i // chunk_size + 1
        chunk = data[i:i+chunk_size]
        try:
            is_safe, mem_info = check_memory_safe(ram_threshold_gb=0.85, gpu_threshold=0.80)
            if not is_safe:
                cleanup_memory()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                time.sleep(0.5)
            chunk_result = operation_func(chunk)
            results.append(chunk_result)
            if chunk_num % progress_every == 0 or chunk_num == total_chunks:
                print(f"  Progress: {chunk_num}/{total_chunks} chunks ({chunk_num*100//total_chunks}%)")
            del chunk
            if chunk_num % 5 == 0:
                cleanup_memory()
        except (MemoryError, RuntimeError) as e:
            error_msg = str(e).lower()
            if 'out of memory' in error_msg or 'oom' in error_msg:
                cleanup_memory()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                smaller_chunk_size = max(1000, chunk_size // 2)
                if smaller_chunk_size < chunk_size:
                    return chunked_operation(data[i:], operation_func, chunk_size=smaller_chunk_size, progress_every=progress_every, operation_name=operation_name)
                else:
                    raise
            else:
                raise
    return results

def emergency_cleanup():
    """Emergency cleanup on exit."""
    try:
        cleanup_memory()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        print("✅ Emergency cleanup completed")
    except:
        pass

atexit.register(emergency_cleanup)

# Signal handler for graceful shutdown
def signal_handler(signum, frame):
    """Handle signals for graceful shutdown."""
    print(f"⚠️ Received signal {signum}, saving checkpoint...")
    save_checkpoint('emergency', {'status': 'signal_received', 'signal': signum})
    emergency_cleanup()
    raise KeyboardInterrupt

try:
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)
except:
    pass

print("✅ Enhanced robustness utilities loaded")


## 2. Dataset & utilities

In [None]:
def load_parquet_split(split: str) -> pl.DataFrame:
    """Load a model_ready parquet split with error handling."""
    try:
        path = MODEL_READY_DIR / f'{split}_model_ready.parquet'
        if not path.exists():
            alt = MODEL_READY_DIR / f'{split}_model_ready_reduced.parquet'
            if alt.exists():
                path = alt
            else:
                raise FileNotFoundError(f'Could not find {split} data')
        print(f'Loading {split} from {path}')
        return pl.read_parquet(path)
    except Exception as e:
        print(f"❌ Error loading {split}: {e}")
        raise

def split_features_reg_and_all_emb(df: pl.DataFrame):
    """Split features into regular and all embedding families."""
    cols = df.columns
    dtypes = df.dtypes
    label = df['label'].to_numpy() if 'label' in cols else None
    
    reg_cols = []
    EMBEDDING_FAMILY_PREFIXES = ['sent_transformer_', 'scibert_', 'specter_', 'specter2_', 'ner_']
    emb_family_to_cols = {p: [] for p in EMBEDDING_FAMILY_PREFIXES}
    
    NUMERIC_DTYPES = {
        pl.Int8, pl.Int16, pl.Int32, pl.Int64,
        pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64,
        pl.Float32, pl.Float64
    }
    
    for c, dt in zip(cols, dtypes):
        if c in ('id', 'label'):
            continue
        matched = False
        for p in EMBEDDING_FAMILY_PREFIXES:
            if c.startswith(p):
                emb_family_to_cols[p].append(c)
                matched = True
                break
        if not matched and dt in NUMERIC_DTYPES:
            reg_cols.append(c)
    
    X_reg = df.select(reg_cols).to_numpy() if reg_cols else None
    X_emb_families = {}
    for p, clist in emb_family_to_cols.items():
        if clist:
            X_emb_families[p] = df.select(clist).to_numpy()
    
    return X_reg, X_emb_families, label, reg_cols, emb_family_to_cols


In [None]:
# ============================================================================
# ROBUST FEATURE SELECTION: Use reasoned handpicked features
# ============================================================================

USE_FEATURE_SELECTION = False  # Disabled: JSON contains all features anyway

if USE_FEATURE_SELECTION:
    try:
        curated_path = MODEL_READY_DIR / 'handpicked_features_reasoned.json'
        if curated_path.exists():
            with open(curated_path) as f:
                curated_data = json.load(f)
            handpicked_features = curated_data['handpicked_features']
            print(f'\n📊 Feature Selection: Using {len(handpicked_features)} reasoned handpicked features')
        else:
            print(f'  ⚠️ Reasoned features file not found, using all features')
            handpicked_features = None
    except Exception as e:
        print(f'  ⚠️ Error in feature selection: {e}')
        handpicked_features = None
else:
    handpicked_features = None
    print('\n📊 Feature Selection: DISABLED (using all features)')


In [None]:
# Load data
train_df = load_parquet_split('train')
val_df = load_parquet_split('val')

X_reg_train, X_emb_train_fams, y_train, reg_cols, emb_family_to_cols = split_features_reg_and_all_emb(train_df)
X_reg_val, X_emb_val_fams, y_val, _, _ = split_features_reg_and_all_emb(val_df)

# Apply feature selection to regular features if enabled
if handpicked_features is not None:
    available_features = set(reg_cols)
    selected_features = [f for f in handpicked_features if f in available_features]
    if len(selected_features) < len(reg_cols):
        feature_idx_map = {f: i for i, f in enumerate(reg_cols)}
        selected_indices = [feature_idx_map[f] for f in selected_features]
        X_reg_train = X_reg_train[:, selected_indices]
        X_reg_val = X_reg_val[:, selected_indices]
        reg_cols = selected_features
        print(f'  ✅ Feature selection applied! Regular features: {len(reg_cols)}')

print(f'\n📊 Data Summary:')
print(f'  Regular features: {len(reg_cols)}')
for fam, arr in X_emb_train_fams.items():
    print(f'  Embedding {fam}: {arr.shape[1]} dims')

del train_df, val_df
cleanup_memory()
memory_usage()


In [None]:
# ============================================================================
# PCA COMPRESSION PER EMBEDDING FAMILY
# ============================================================================

# Number of PCA components per family
PCA_COMPONENTS_PER_FAMILY = {
    'sent_transformer_': 64,
    'scibert_': 64,
    'specter_': 64,
    'specter2_': 64,
    'ner_': 32,
}

print('\n📊 Applying PCA compression per embedding family...')

X_emb_train_pca_list = []
X_emb_val_pca_list = []
pca_transformers = {}

for fam_prefix, X_emb_train_fam in X_emb_train_fams.items():
    n_components = PCA_COMPONENTS_PER_FAMILY.get(fam_prefix, 64)
    print(f'\n  Processing {fam_prefix}:')
    print(f'    Original dim: {X_emb_train_fam.shape[1]}')
    print(f'    Target PCA components: {n_components}')
    
    # Fit PCA on training data (use subset if large)
    max_pca_rows = min(150_000, X_emb_train_fam.shape[0])
    if X_emb_train_fam.shape[0] > max_pca_rows:
        idx = np.random.choice(X_emb_train_fam.shape[0], size=max_pca_rows, replace=False)
        pca_fit_data = X_emb_train_fam[idx]
    else:
        pca_fit_data = X_emb_train_fam
    
    # Create PCA transformer
    if IS_TORCH_PCA:
        ipca = IncrementalPCA(n_components=n_components, batch_size=5000, device=device)
    else:
        ipca = IncrementalPCA(n_components=n_components, batch_size=5000)
    
    # Fit and transform
    ipca.fit(pca_fit_data)
    X_emb_train_pca = ipca.transform(X_emb_train_fam)
    X_emb_val_pca = ipca.transform(X_emb_val_fams[fam_prefix]) if fam_prefix in X_emb_val_fams else None
    
    X_emb_train_pca_list.append(X_emb_train_pca)
    if X_emb_val_pca is not None:
        X_emb_val_pca_list.append(X_emb_val_pca)
    
    pca_transformers[fam_prefix] = ipca
    
    print(f'    ✅ PCA fitted with {n_components} components')
    print(f'    Reduced dim: {X_emb_train_pca.shape[1]}')
    
    del X_emb_train_pca, X_emb_val_pca, pca_fit_data
    cleanup_memory()

# Combine regular features + all PCA-compressed embeddings
X_train = np.concatenate([X_reg_train] + X_emb_train_pca_list, axis=1)
X_val = np.concatenate([X_reg_val] + X_emb_val_pca_list, axis=1)

del X_reg_train, X_reg_val, X_emb_train_fams, X_emb_val_fams, X_emb_train_pca_list, X_emb_val_pca_list
cleanup_memory()

print(f'\n✅ Combined features:')
print(f'  Train shape: {X_train.shape}')
print(f'  Val shape: {X_val.shape}')
memory_usage()


In [None]:
class TabularDataset(Dataset):
    """Simple tabular dataset for PyTorch."""
    def __init__(self, X, y):
        self.X = torch.FloatTensor(X)
        self.y = torch.FloatTensor(y)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# DataLoader configuration
BATCH_SIZE = 512
VAL_BATCH_SIZE = 512
NUM_WORKERS = 0

train_dataset = TabularDataset(X_train, y_train)
val_dataset = TabularDataset(X_val, y_val)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=VAL_BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=False
)

print(f'\n📊 DataLoader Configuration:')
print(f'   Train batch size: {BATCH_SIZE}')
print(f'   Val batch size: {VAL_BATCH_SIZE}')
print(f'   Num workers: {NUM_WORKERS} (0 = single process, saves memory)')
print(f'Class counts (train): {np.bincount(y_train.astype(int))}')
memory_usage()


## 3. Class Imbalance Handling: SMOTETomek

In [None]:
from imblearn.combine import SMOTETomek

print('\n📊 Checking class imbalance and applying SMOTETomek resampling...')
print(f'  Before: {len(X_train)} samples, Positive: {y_train.sum()}, Negative: {(y_train==0).sum()}')
print(f'  Imbalance ratio: {(y_train==0).sum() / max(y_train.sum(), 1):.2f}:1')

try:
    smt = SMOTETomek(random_state=42, sampling_strategy='auto', n_jobs=-1)
    X_train_resampled, y_train_resampled = smt.fit_resample(X_train, y_train)
    
    print(f'  After: {len(X_train_resampled)} samples, Positive: {y_train_resampled.sum()}, Negative: {(y_train_resampled==0).sum()}')
    print(f'  Balance ratio: {(y_train_resampled==0).sum() / max(y_train_resampled.sum(), 1):.2f}:1')
    
    X_train = X_train_resampled
    y_train = y_train_resampled
    
    del X_train_resampled, y_train_resampled
    cleanup_memory()
    
    train_dataset = TabularDataset(X_train, y_train)
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=False
    )
except Exception as e:
    print(f'  ⚠️ SMOTETomek failed: {e}')
    print('  Continuing with original training data...')
    cleanup_memory()


## 4. Feature Scaling

In [None]:
from sklearn.preprocessing import StandardScaler

print('\n📊 Applying Feature Scaling...')

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)

X_train = X_train_scaled
X_val = X_val_scaled

del X_train_scaled, X_val_scaled
cleanup_memory()

print(f'  ✅ Features scaled: {X_train.shape}')

# Update datasets
train_dataset = TabularDataset(X_train, y_train)
val_dataset = TabularDataset(X_val, y_val)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=VAL_BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=False
)

memory_usage()


## 5. Model Definition

In [None]:
class LinearClassifier(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1)
    
    def forward(self, x):
        return self.linear(x)

input_dim = X_train.shape[1]
print(f'📊 Input dimension: {input_dim}')
print(f'   Regular features: {len(reg_cols)}')
total_pca_dims = sum(PCA_COMPONENTS_PER_FAMILY.get(fam, 0) for fam in emb_family_to_cols.keys() if emb_family_to_cols.get(fam))
print(f'   PCA-compressed embeddings: {total_pca_dims} total')

model = LinearClassifier(input_dim)
model = model.to(device)
print(model)


## 6. Training Loop

In [None]:
EPOCHS = 10
LR = 1e-3

# Compute pos_weight for BCEWithLogitsLoss
pos_count = (y_train == 1).sum()
neg_count = (y_train == 0).sum()
pos_weight_value = torch.tensor([neg_count / max(pos_count, 1)], dtype=torch.float32).to(device)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_value)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

best_val_f1 = 0.0
best_state_dict = None

for epoch in range(1, EPOCHS + 1):
    model.train()
    running_loss = 0.0
    
    for xb, yb in train_loader:
        xb = xb.to(device)
        yb = yb.to(device).unsqueeze(1)
        
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * xb.size(0)
        del xb, yb, logits, loss
    
    avg_train_loss = running_loss / len(train_loader.dataset)
    
    # Validation
    model.eval()
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(device)
            yb_np = yb.numpy()
            logits = model(xb)
            probs = torch.sigmoid(logits).cpu().numpy().ravel()
            all_preds.append(probs)
            all_targets.append(yb_np)
            del xb, logits, probs, yb_np
    
    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)
    
    # Threshold tuning
    roc_auc = roc_auc_score(all_targets, all_preds)
    pr_auc = average_precision_score(all_targets, all_preds)
    
    best_epoch_f1 = 0.0
    best_thr = 0.5
    
    thresholds = np.concatenate([
        np.linspace(0.01, 0.05, 20),
        np.linspace(0.05, 0.15, 50),
        np.linspace(0.15, 0.3, 30),
        np.linspace(0.3, 0.9, 20)
    ])
    
    for thr in thresholds:
        preds_bin = (all_preds >= thr).astype(int)
        f1 = f1_score(all_targets, preds_bin, pos_label=1)
        if f1 > best_epoch_f1:
            best_epoch_f1 = f1
            best_thr = thr
        del preds_bin
    
    del all_preds, all_targets
    
    print(f'Epoch {epoch:02d} | train_loss={avg_train_loss:.4f} | val_f1={best_epoch_f1:.4f} @ thr={best_thr:.2f} | roc_auc={roc_auc:.4f} | pr_auc={pr_auc:.4f}')
    memory_usage()
    
    if best_epoch_f1 > best_val_f1:
        best_val_f1 = best_epoch_f1
        best_state_dict = model.state_dict().copy()
    
    cleanup_memory()

print('Best val F1:', best_val_f1)

if best_state_dict is not None:
    model.load_state_dict(best_state_dict)
    MODEL_SAVE_DIR = PROJECT_ROOT / 'models' / 'saved_models'
    MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
    model_save_path = MODEL_SAVE_DIR / 'refined_model4_best.pt'
    torch.save({
        'model_state_dict': best_state_dict,
        'input_dim': input_dim,
        'best_val_f1': best_val_f1,
        'epochs': EPOCHS,
        'learning_rate': LR,
        'pos_weight': pos_weight_value.cpu().item(),
        'pca_components_per_family': PCA_COMPONENTS_PER_FAMILY
    }, model_save_path)
    print(f'\n💾 Saved best model to: {model_save_path}')


## 7. 5-Fold Cross-Validation

In [None]:
# Import CV utilities
try:
    from model_training_utils import stratified_kfold_splits, find_optimal_threshold
    USE_UTILS = True
except ImportError:
    USE_UTILS = False
    from sklearn.model_selection import StratifiedKFold

# Combine train and val for CV
X_full = np.vstack([X_train, X_val])
y_full = np.concatenate([y_train, y_val])

# Hyperparameter search space
hyperparams_list = [
    {'lr': 0.001, 'batch_size': 512},
    {'lr': 0.0005, 'batch_size': 512},
    {'lr': 0.001, 'batch_size': 256},
]

best_hyperparams = None
best_cv_score = 0.0

for hyperparams in hyperparams_list:
    print(f'\n{"="*80}')
    print(f'Hyperparameter Set: {hyperparams}')
    print(f'{"="*80}')
    
    cv_scores = []
    
    if USE_UTILS:
        splits = stratified_kfold_splits(y_full, n_splits=5, shuffle=True)
    else:
        skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
        splits = skf.split(X_full, y_full)
    
    for fold_idx, (train_idx, val_idx) in enumerate(splits, 1):
        print(f'\nFold {fold_idx}/5')
        
        X_fold_train, X_fold_val = X_full[train_idx], X_full[val_idx]
        y_fold_train, y_fold_val = y_full[train_idx], y_full[val_idx]
        
        # Scale
        scaler_fold = StandardScaler()
        X_fold_train = scaler_fold.fit_transform(X_fold_train)
        X_fold_val = scaler_fold.transform(X_fold_val)
        
        # Create model
        fold_model = LinearClassifier(input_dim)
        fold_model = fold_model.to(device)
        
        # Training
        fold_dataset = TabularDataset(X_fold_train, y_fold_train)
        fold_loader = DataLoader(fold_dataset, batch_size=hyperparams['batch_size'], shuffle=True)
        
        pos_count_fold = (y_fold_train == 1).sum()
        neg_count_fold = (y_fold_train == 0).sum()
        pos_weight_fold = torch.tensor([neg_count_fold / max(pos_count_fold, 1)], dtype=torch.float32).to(device)
        
        criterion_fold = nn.BCEWithLogitsLoss(pos_weight=pos_weight_fold)
        optimizer_fold = torch.optim.Adam(fold_model.parameters(), lr=hyperparams['lr'])
        
        best_fold_f1 = 0.0
        
        for epoch in range(1, 6):  # Max 5 epochs per fold
            fold_model.train()
            for xb, yb in fold_loader:
                xb, yb = xb.to(device), yb.to(device).unsqueeze(1)
                optimizer_fold.zero_grad()
                logits = fold_model(xb)
                loss = criterion_fold(logits, yb)
                loss.backward()
                optimizer_fold.step()
            
            # Validation
            fold_model.eval()
            val_preds = []
            val_targets = []
            
            with torch.no_grad():
                val_dataset_fold = TabularDataset(X_fold_val, y_fold_val)
                val_loader_fold = DataLoader(val_dataset_fold, batch_size=512, shuffle=False)
                
                for xb, yb in val_loader_fold:
                    xb = xb.to(device)
                    logits = fold_model(xb)
                    probs = torch.sigmoid(logits).cpu().numpy().ravel()
                    val_preds.append(probs)
                    val_targets.append(yb.numpy())
            
            val_preds = np.concatenate(val_preds)
            val_targets = np.concatenate(val_targets)
            
            # Find best threshold
            if USE_UTILS:
                best_thr_fold, best_f1_fold = find_optimal_threshold(val_targets, val_preds)
            else:
                thresholds = np.linspace(0.01, 0.5, 50)
                best_f1_fold = 0.0
                best_thr_fold = 0.5
                for thr in thresholds:
                    preds_bin = (val_preds >= thr).astype(int)
                    f1 = f1_score(val_targets, preds_bin, pos_label=1)
                    if f1 > best_f1_fold:
                        best_f1_fold = f1
                        best_thr_fold = thr
            
            if best_f1_fold > best_fold_f1:
                best_fold_f1 = best_f1_fold
            
            del val_preds, val_targets
            
            if best_fold_f1 > 0.3:  # Early stopping
                break
        
        cv_scores.append(best_fold_f1)
        print(f'  Fold {fold_idx} - Val F1: {best_fold_f1:.4f}')
        
        cleanup_memory()
    
    mean_cv_score = np.mean(cv_scores)
    print(f'\n📊 CV Results: Mean F1: {mean_cv_score:.4f} ± {np.std(cv_scores):.4f}')
    
    if mean_cv_score > best_cv_score:
        best_cv_score = mean_cv_score
        best_hyperparams = hyperparams
        print(f'  ✅ New best!')
    
    cleanup_memory()

print(f'\n🏆 Best hyperparameters: {best_hyperparams}')
print(f'🏆 Best CV F1: {best_cv_score:.4f}')


## 8. Final Model Training

In [None]:
# Train final model with best hyperparameters
if best_hyperparams is None:
    best_hyperparams = {'lr': 0.001, 'batch_size': 512}

final_model = LinearClassifier(input_dim)
final_model = final_model.to(device)

final_dataset = TabularDataset(X_train, y_train)
final_loader = DataLoader(final_dataset, batch_size=best_hyperparams['batch_size'], shuffle=True)

pos_weight_final = torch.tensor([neg_count / max(pos_count, 1)], dtype=torch.float32).to(device)
criterion_final = nn.BCEWithLogitsLoss(pos_weight=pos_weight_final)
optimizer_final = torch.optim.Adam(final_model.parameters(), lr=best_hyperparams['lr'])

best_final_f1 = 0.0
best_final_state = None

for epoch in range(1, EPOCHS + 1):
    final_model.train()
    running_loss = 0.0
    
    for xb, yb in final_loader:
        xb, yb = xb.to(device), yb.to(device).unsqueeze(1)
        optimizer_final.zero_grad()
        logits = final_model(xb)
        loss = criterion_final(logits, yb)
        loss.backward()
        optimizer_final.step()
        running_loss += loss.item() * xb.size(0)
    
    avg_loss = running_loss / len(final_loader.dataset)
    
    # Validation
    final_model.eval()
    val_preds = []
    val_targets = []
    
    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(device)
            logits = final_model(xb)
            probs = torch.sigmoid(logits).cpu().numpy().ravel()
            val_preds.append(probs)
            val_targets.append(yb.numpy())
    
    val_preds = np.concatenate(val_preds)
    val_targets = np.concatenate(val_targets)
    
    # Find best threshold
    thresholds = np.concatenate([
        np.linspace(0.01, 0.05, 20),
        np.linspace(0.05, 0.15, 50),
        np.linspace(0.15, 0.3, 30),
    ])
    
    best_epoch_f1 = 0.0
    best_thr = 0.5
    
    for thr in thresholds:
        preds_bin = (val_preds >= thr).astype(int)
        f1 = f1_score(val_targets, preds_bin, pos_label=1)
        if f1 > best_epoch_f1:
            best_epoch_f1 = f1
            best_thr = thr
    
    print(f'Epoch {epoch:02d} | loss={avg_loss:.4f} | val_f1={best_epoch_f1:.4f} @ thr={best_thr:.2f}')
    
    if best_epoch_f1 > best_final_f1:
        best_final_f1 = best_epoch_f1
        best_final_state = final_model.state_dict().copy()
        final_threshold = best_thr
    
    del val_preds, val_targets
    cleanup_memory()

if best_final_state is not None:
    final_model.load_state_dict(best_final_state)
    print(f'\n✅ Final model trained. Best F1: {best_final_f1:.4f} @ threshold: {final_threshold:.4f}')
else:
    final_threshold = 0.5
    print(f'\n⚠️ Using default threshold: {final_threshold}')


## 9. Generate Submission

In [None]:
import redef extract_work_id(id_value: str) -> str:
    """Extract work_id from URL or return as is if already just ID."""
    if isinstance(id_value, str) and id_value.startswith('W') and len(id_value) > 1 and '/' not in id_value:
        return id_value
    id_str = str(id_value)
    match = re.search(r'W\d+', id_str)
    if match:
        return match.group(0)
    return id_str

    """Extract work_id from URL or return as is if already just ID."""    if id_value.startswith('W') and len(id_value) > 1 and '/' not in id_value:        return id_value    match = re.search(r'W\d+', id_value)    if match:        return match.group(0)    return id_value# Load test datatest_df = load_parquet_split('test')test_ids = test_df['id'].to_numpy()X_reg_test, X_emb_test_fams, _, _, _ = split_features_reg_and_all_emb(test_df)# Apply feature selection to regular features if usedif handpicked_features is not None:    available_features = set(reg_cols)    selected_features = [f for f in handpicked_features if f in available_features]    if len(selected_features) < len(reg_cols):        feature_idx_map = {f: i for i, f in enumerate(reg_cols)}        selected_indices = [feature_idx_map[f] for f in selected_features]        X_reg_test = X_reg_test[:, selected_indices]# Apply PCA transform to test embeddings per familyX_emb_test_pca_list = []for fam_prefix in X_emb_test_fams.keys():    if fam_prefix in pca_transformers:        X_emb_test_pca = pca_transformers[fam_prefix].transform(X_emb_test_fams[fam_prefix])        X_emb_test_pca_list.append(X_emb_test_pca)        del X_emb_test_pca# Combine regular + all PCA-compressed embeddingsX_test = np.concatenate([X_reg_test] + X_emb_test_pca_list, axis=1)del X_reg_test, X_emb_test_fams, X_emb_test_pca_list, test_df# ScaleX_test_scaled = scaler.transform(X_test)# Predicttest_dataset = TabularDataset(X_test_scaled, np.zeros(len(X_test_scaled)))test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)final_model.eval()test_preds = []with torch.no_grad():    for xb, _ in test_loader:        xb = xb.to(device)        logits = final_model(xb)        probs = torch.sigmoid(logits).cpu().numpy().ravel()        test_preds.append(probs)test_preds = np.concatenate(test_preds)test_preds_binary = (test_preds >= final_threshold).astype(int)# Save submissionSUBMISSION_DIR = PROJECT_ROOT / 'data' / 'submission_files'SUBMISSION_DIR.mkdir(parents=True, exist_ok=True)submission_path = SUBMISSION_DIR / 'submission_refined_model4.csv'# Extract work_id from test_ids    work_ids = np.array([extract_work_id(str(id_val)) for id_val in test_ids])    submission_df = pl.DataFrame({    'work_id': work_ids,    'label': test_preds_binary})submission_df.write_csv(submission_path)print(f'\n✅ Submission saved to: {submission_path}')print(f'   Predictions: {test_preds_binary.sum()} positive, {(test_preds_binary==0).sum()} negative')cleanup_memory()memory_usage()print('\n✅ All done!')