In [None]:
# ============================================
# Catechol Benchmark — Final Multi-GPU Optimized MLP (one-shot)
# ============================================

import os, sys, math, random, tqdm
import numpy as np
import pandas as pd
from typing import List, Generator
from abc import ABC

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler

# -------------------
# CONFIG
# -------------------
sys.path.append('/kaggle/input/catechol-benchmark-hackathon/')
DATA_ROOT = '/kaggle/input/catechol-benchmark-hackathon/'
NUM_GPUS = torch.cuda.device_count()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SEEDS = [42, 77, 2025]          # ensemble seeds per fold
EPOCHS = 250
BASE_BATCH_SIZE = 32
LR = 1e-3
WEIGHT_DECAY = 1e-4
PATIENCE = 30
CLIP_GRAD_NORM = 1.0
HUBER_BETA = 1.0
CLIP_PRED_TO_UNIT = True
USE_NUMERIC_FE = True

# scale batch size a bit when >1 GPU
BATCH_SIZE = BASE_BATCH_SIZE * max(1, NUM_GPUS)

torch.set_default_dtype(torch.float64)

# -------------------
# LABELS
# -------------------
INPUT_LABELS_FULL_SOLVENT = ["Residence Time", "Temperature", "SOLVENT A NAME", "SOLVENT B NAME", "SolventB%"]
INPUT_LABELS_SINGLE_SOLVENT = ["Residence Time", "Temperature", "SOLVENT NAME"]
INPUT_LABELS_NUMERIC = ["Residence Time", "Temperature"]
TARGET_LABELS = ["Product 2", "Product 3", "SM"]

# -------------------
# UTILS
# -------------------
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def load_data(name="full"):
    df = pd.read_csv(f'{DATA_ROOT}/catechol_{"full_data_yields" if name=="full" else "single_solvent_yields"}.csv')
    X = df[INPUT_LABELS_FULL_SOLVENT if name == "full" else INPUT_LABELS_SINGLE_SOLVENT].copy()
    Y = df[TARGET_LABELS].copy()
    return X, Y

def load_features(name="spange_descriptors"):
    return pd.read_csv(f'{DATA_ROOT}/{name}_lookup.csv', index_col=0)

def generate_leave_one_out_splits(X, Y) -> Generator:
    for solvent in sorted(X["SOLVENT NAME"].unique()):
        mask = X["SOLVENT NAME"] != solvent
        yield (X[mask], Y[mask]), (X[~mask], Y[~mask])

def generate_leave_one_ramp_out_splits(X, Y) -> Generator:
    ramps = X[["SOLVENT A NAME", "SOLVENT B NAME"]].drop_duplicates()
    for _, row in ramps.iterrows():
        mask = ~((X["SOLVENT A NAME"] == row["SOLVENT A NAME"]) &
                 (X["SOLVENT B NAME"] == row["SOLVENT B NAME"]))
        yield (X[mask], Y[mask]), (X[~mask], Y[~mask])

# -------------------
# FEATURIZATION
# -------------------
def _numeric_block(X):
    x = X[INPUT_LABELS_NUMERIC].astype(float).values
    if USE_NUMERIC_FE:
        rt, temp = x[:, [0]], x[:, [1]]
        feats = [rt, temp, rt**2, temp**2, np.log1p(rt), np.log1p(temp), rt * temp]
        x = np.concatenate(feats, axis=1)
    return x

class SmilesFeaturizer(ABC):
    def fit_transform(self, X, Y): raise NotImplementedError
    def transform(self, X): raise NotImplementedError

class PrecomputedFeaturizer(SmilesFeaturizer):
    def __init__(self, features='spange_descriptors'):
        self.lookup = load_features(features)
        self.scaler_num, self.scaler_desc = StandardScaler(), StandardScaler()
        self.feats_dim = None
    def fit_transform(self, X, Y):
        X_num = _numeric_block(X)
        desc = self.lookup.loc[X["SOLVENT NAME"]].values
        X_num_s = self.scaler_num.fit_transform(X_num)
        desc_s = self.scaler_desc.fit_transform(desc)
        feats = np.concatenate([X_num_s, desc_s], axis=1)
        self.feats_dim = feats.shape[1]
        return torch.tensor(feats, device=DEVICE), torch.tensor(Y.values, device=DEVICE)
    def transform(self, X):
        X_num = _numeric_block(X)
        desc = self.lookup.loc[X["SOLVENT NAME"]].values
        feats = np.concatenate([self.scaler_num.transform(X_num),
                                self.scaler_desc.transform(desc)], axis=1)
        return torch.tensor(feats, device=DEVICE)

class PrecomputedFeaturizerMixed(SmilesFeaturizer):
    def __init__(self, features='spange_descriptors'):
        self.lookup = load_features(features)
        self.scaler_num, self.scaler_desc = StandardScaler(), StandardScaler()
        self.feats_dim = None
    def fit_transform(self, X, Y):
        X_num = _numeric_block(X)
        A = self.lookup.loc[X["SOLVENT A NAME"]].values
        B = self.lookup.loc[X["SOLVENT B NAME"]].values
        pct = X["SolventB%"].astype(float).values.reshape(-1, 1)  # if 0..100, divide by 100
        mix = A * (1 - pct) + B * pct
        X_num_s = self.scaler_num.fit_transform(X_num)
        mix_s = self.scaler_desc.fit_transform(mix)
        feats = np.concatenate([X_num_s, mix_s], axis=1)
        self.feats_dim = feats.shape[1]
        return torch.tensor(feats, device=DEVICE), torch.tensor(Y.values, device=DEVICE)
    def transform(self, X):
        X_num = _numeric_block(X)
        A = self.lookup.loc[X["SOLVENT A NAME"]].values
        B = self.lookup.loc[X["SOLVENT B NAME"]].values
        pct = X["SolventB%"].astype(float).values.reshape(-1, 1)
        mix = A * (1 - pct) + B * pct
        feats = np.concatenate([self.scaler_num.transform(X_num),
                                self.scaler_desc.transform(mix)], axis=1)
        return torch.tensor(feats, device=DEVICE)

# -------------------
# MODEL
# -------------------
class ImprovedMLP(nn.Module):
    def __init__(self, feats_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.BatchNorm1d(feats_dim),
            nn.Linear(feats_dim, 256), nn.ReLU(), nn.Dropout(0.30),
            nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.20),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 3)  # linear head
        )
    def forward(self, x): return self.net(x)

# -------------------
# TRAINING
# -------------------
class EarlyStopper:
    def __init__(self, patience=PATIENCE):
        self.best, self.count = math.inf, 0
        self.state = None
        self.patience = patience
    def step(self, loss, model):
        if loss < self.best - 1e-12:
            self.best = loss
            self.count = 0
            self.state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        else:
            self.count += 1
        return self.count > self.patience

def remove_module_prefix(state_dict):
    return {k.replace("module.", ""): v for k, v in state_dict.items()}

def train_one_model(Xtr_t, Ytr_t, feats_dim, seed):
    set_seed(seed)
    model = ImprovedMLP(feats_dim).to(DEVICE)

    if torch.cuda.is_available() and torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs via DataParallel...")
        model = nn.DataParallel(model)

    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)
    crit = nn.SmoothL1Loss(beta=HUBER_BETA)

    loader = DataLoader(TensorDataset(Xtr_t, Ytr_t), batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
    stopper = EarlyStopper(patience=PATIENCE)

    for epoch in range(EPOCHS):
        model.train()
        tot = 0.0
        for xb, yb in loader:
            opt.zero_grad(set_to_none=True)
            loss = crit(model(xb), yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD_NORM)
            opt.step()
            tot += loss.item() * xb.size(0)
        sched.step()
        if stopper.step(tot / len(Xtr_t), model):
            break

    # restore best weights (handle DataParallel prefixes)
    best = stopper.state
    if best is not None:
        if isinstance(model, nn.DataParallel):
            model = model.module
        model.load_state_dict(remove_module_prefix(best), strict=False)

    return model

@torch.no_grad()
def predict_model(model, X_t):
    model.eval()
    out = model(X_t).cpu().numpy()
    if CLIP_PRED_TO_UNIT: out = np.clip(out, 0, 1)
    return out

def ensemble_predict(models, X_t):
    return np.mean([predict_model(m, X_t) for m in models], axis=0)

# -------------------
# CV LOOPS
# -------------------
def run_single_solvent_cv(features='spange_descriptors'):
    X, Y = load_data("single_solvent")
    preds_rows, true_rows, fold_mses = [], [], []
    for fold, ((trX, trY), (teX, teY)) in enumerate(tqdm.tqdm(list(generate_leave_one_out_splits(X, Y)), desc='Single-solvent CV')):
        feat = PrecomputedFeaturizer(features)
        Xtr_t, Ytr_t = feat.fit_transform(trX, trY)
        Xte_t = feat.transform(teX)
        dim = Xtr_t.shape[1]
        models = [train_one_model(Xtr_t, Ytr_t, dim, s) for s in SEEDS]
        pred = ensemble_predict(models, Xte_t)
        true = teY.values
        fold_mses.append(float(np.mean((pred - true) ** 2)))
        for i, p in enumerate(pred):
            preds_rows.append({"task": 0, "fold": fold, "row": i, "target_1": p[0], "target_2": p[1], "target_3": p[2]})
            true_rows.append({"task": 0, "fold": fold, "row": i, "true_1": true[i,0], "true_2": true[i,1], "true_3": true[i,2]})
    return pd.DataFrame(preds_rows), pd.DataFrame(true_rows), float(np.mean(fold_mses))

def run_full_ramp_cv(features='spange_descriptors'):
    X, Y = load_data("full")
    preds_rows, true_rows, fold_mses = [], [], []
    for fold, ((trX, trY), (teX, teY)) in enumerate(tqdm.tqdm(list(generate_leave_one_ramp_out_splits(X, Y)), desc='Full-data CV')):
        feat = PrecomputedFeaturizerMixed(features)
        Xtr_t, Ytr_t = feat.fit_transform(trX, trY)
        Xte_t = feat.transform(teX)
        dim = Xtr_t.shape[1]
        models = [train_one_model(Xtr_t, Ytr_t, dim, s) for s in SEEDS]
        pred = ensemble_predict(models, Xte_t)
        true = teY.values
        fold_mses.append(float(np.mean((pred - true) ** 2)))
        for i, p in enumerate(pred):
            preds_rows.append({"task": 1, "fold": fold, "row": i, "target_1": p[0], "target_2": p[1], "target_3": p[2]})
            true_rows.append({"task": 1, "fold": fold, "row": i, "true_1": true[i,0], "true_2": true[i,1], "true_3": true[i,2]})
    return pd.DataFrame(preds_rows), pd.DataFrame(true_rows), float(np.mean(fold_mses))

# -------------------
# MAIN
# -------------------
if __name__ == "__main__":
    print(f"Detected {NUM_GPUS} GPU(s): {[torch.cuda.get_device_name(i) for i in range(NUM_GPUS)]}")
    print(f"Using device: {DEVICE} | Batch size: {BATCH_SIZE}")

    sub_single, true_single, single_cv = run_single_solvent_cv('spange_descriptors')
    sub_full, true_full, full_cv = run_full_ramp_cv('spange_descriptors')

    submission = pd.concat([sub_single, sub_full]).reset_index(drop=True)
    submission.index.name = "id"
    submission.to_csv("submission.csv", index=True)

    merged_true = pd.concat([true_single, true_full]).reset_index(drop=True)
    merged = submission.merge(merged_true, on=['task','fold','row'])

    overall_mse = np.mean([
        (merged['target_1'] - merged['true_1'])**2,
        (merged['target_2'] - merged['true_2'])**2,
        (merged['target_3'] - merged['true_3'])**2
    ])
    print("\n" + "="*72)
    print("FINAL RESULTS — Multi-GPU Upgraded CV")
    print("="*72)
    print(f"Overall CV MSE : {overall_mse:.6f}")
    print(f"Overall CV RMSE: {overall_mse**0.5:.6f}")
    print(f"Single Solvent : {single_cv:.6f}")
    print(f"Full Data      : {full_cv:.6f}")
    print("="*72)
    print("Saved: submission.csv")