In [1]:
import warnings

warnings.filterwarnings(
    "ignore",
    message="Computing tree-based bins involves the conversion of the input PyTorch tensors to NumPy arrays.*"
)

In [2]:
import os
import json
import time
import math
import random
from copy import deepcopy
from typing import NamedTuple

import optuna
import pandas as pd
import numpy as np

from scipy.stats import skew, kurtosis
from sklearn.preprocessing import StandardScaler, QuantileTransformer, PowerTransformer
from sklearn.metrics import mean_absolute_percentage_error
from sklearn.model_selection import train_test_split
from sklearn.cluster import KMeans
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn.functional as F
import torch.optim
from torch import Tensor

# import rtdl_num_embeddings
from  rtdl_num_embeddings import compute_bins, PiecewiseLinearEmbeddings
import scipy.special
import tabm

In [3]:
def compute_volume_weighted_component_features(X):
    """
    Computes individual volume-weighted features WjPk = Componentj_fraction * Componentj_Propertyk
    for j in 1..5 and k in 1..10 (total 50 features).
    """
    features = {}
    for comp_idx in range(1, 6):  # Components 1–5
        for prop_idx in range(1, 11):  # Properties 1–10
            vol_col = f'Component{comp_idx}_fraction'
            prop_col = f'Component{comp_idx}_Property{prop_idx}'
            feat_name = f'W{comp_idx}P{prop_idx}'
            features[feat_name] = X[vol_col] * X[prop_col]
    return pd.DataFrame(features)

In [4]:
model_dir = "/pscratch/sd/r/ritesh11/temp_dir/TabM_models"
TARGETS = [f"BlendProperty{i}" for i in range(1, 11)]
BASE_PATH = "/pscratch/sd/r/ritesh11/temp_dir/dataset"
fi_path = "/pscratch/sd/r/ritesh11/temp_dir/feature_importance"
N_TRIALS = 100

In [5]:
def get_data(target, selective=False, skew_thresh=0.5, kurt_thresh=3.5):
    # Load train and val sets
    X_train = pd.read_csv(f"{BASE_PATH}/train/{target}_X.csv")
    y_train = pd.read_csv(f"{BASE_PATH}/train/{target}_y.csv")
    X_val = pd.read_csv(f"{BASE_PATH}/val/{target}_X.csv")
    y_val = pd.read_csv(f"{BASE_PATH}/val/{target}_y.csv")

    # Feature engineering
    X_train = pd.concat([X_train, compute_volume_weighted_component_features(X_train)], axis=1)
    X_val = pd.concat([X_val, compute_volume_weighted_component_features(X_val)], axis=1)

    # Feature selection
    if selective:
        df = pd.read_csv(os.path.join(fi_path, f"{target}.csv"))
        cols = df[df["importance"] > 0.01].iloc[:, 0].tolist()
        # print(cols)
        X_train = X_train[cols]
        X_val = X_val[cols]

    # Separate out fraction-based columns
    fraction_cols = [col for col in X_train.columns if "fraction" in col.lower()]
    non_fraction_cols = [col for col in X_train.columns if col not in fraction_cols]

    # Analyze distribution statistics
    feature_stats = pd.DataFrame({
        'skewness': X_train[non_fraction_cols].apply(skew, nan_policy='omit'),
        'kurtosis': X_train[non_fraction_cols].apply(kurtosis, nan_policy='omit'),
        'std': X_train[non_fraction_cols].std(),
    })


    # Initialize scaled DataFrames
    X_train_scaled = X_train.copy()
    X_val_scaled = X_val.copy()

    for col in non_fraction_cols:
        col_vals = X_train[[col]].values

        if feature_stats.loc[col, 'std'] == 0.0:
            col_vals += np.random.normal(0.0, 1e-5, size=col_vals.shape)

        use_quantile = (
            abs(feature_stats.loc[col, 'skewness']) > skew_thresh or
            feature_stats.loc[col, 'kurtosis'] > kurt_thresh
        )

        if use_quantile:
            scaler = QuantileTransformer(
                n_quantiles=max(min(len(col_vals) // 30, 1000), 20),
                output_distribution='normal',
                subsample=10**9
            )
        else:
            scaler = StandardScaler()

        X_train_scaled[col] = scaler.fit_transform(col_vals).ravel()
        X_val_scaled[col] = scaler.transform(X_val[[col]].values).ravel()

    # y transformation based on skew
    y_vals = y_train.values.ravel()
    y_skew = skew(y_vals)
    should_transform_y = abs(y_skew) > skew_thresh

    if should_transform_y:
        y_transformer = PowerTransformer(method="yeo-johnson")
        y_train_transformed = y_transformer.fit_transform(y_train.values.reshape(-1, 1)).ravel()
        y_val_transformed = y_transformer.transform(y_val.values.reshape(-1, 1)).ravel()
        # print(f"⚠️ Applied PowerTransformer to target '{target}' (skewness: {y_skew:.2f})")
    else:
        y_transformer = StandardScaler()
        y_train_transformed = y_transformer.fit_transform(y_train.values.reshape(-1, 1)).ravel()
        y_val_transformed = y_transformer.transform(y_val.values.reshape(-1, 1)).ravel()
        # print(f"✅ Applied StandardScaler to target '{target}' (skewness: {y_skew:.2f})")

    return X_train_scaled, y_train_transformed, X_val_scaled, y_val_transformed, y_transformer


In [6]:
# Device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')


amp_dtype = (
    torch.bfloat16
    if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    else torch.float16
    if torch.cuda.is_available()
    else None
)


amp_enabled = amp_dtype is not None
grad_scaler = torch.cuda.amp.GradScaler() if amp_dtype is torch.float16 else None  # type: ignore

# torch.compile
compile_model = False

# fmt: off
print(f'Device:        {device.type.upper()}')
print(f'AMP:           {amp_enabled}{f" ({amp_dtype})"if amp_enabled else ""}')
print(f'torch.compile: {compile_model}')
# fmt: on

Device:        CUDA
AMP:           True (torch.bfloat16)
torch.compile: False


In [7]:
def get_model(model_params, X_train):

    if model_params['tree_binning']:
        bins = compute_bins(
            X_train.cpu(),
            n_bins=model_params['bins_emb'],
            y=Y_train,
            regression=True,
            tree_kwargs=model_params.get('tree_kwargs', {'min_samples_leaf': 64}),
            verbose=False
        )
    else:
        bins = compute_bins(X_train, n_bins=model_params['bins_emb'])

    num_embeddings = PiecewiseLinearEmbeddings(
        bins=bins,
        d_embedding=model_params['d_emb'],
        activation=model_params['act_emb'],
        version='B'
    )

    model = tabm.TabM.make(
        n_num_features=X_train.shape[1],
        d_out=model_params["d_out"],
        num_embeddings=num_embeddings,
        n_blocks=model_params["n_blocks"],
        d_block=model_params["d_block"],
        dropout=model_params["dropout"],
        activation=model_params["activation"],
        k=32,
        arch_type=model_params["arch_type"],
        start_scaling_init=model_params["start_scaling_init"],
    ).to(device)
    
    # model = torch.compile(model)
    optimizer = torch.optim.AdamW(model.parameters(), lr=model_params['lr'], weight_decay=model_params['decay'])

    return model, optimizer

In [8]:
@torch.autocast(device.type, enabled=amp_enabled, dtype=amp_dtype)  # type: ignore[code]
def apply_model(model, data: Tensor, idx: Tensor) -> Tensor:
    return (
        model(
            data[idx],
            None,
        )
        .squeeze(-1)  # Remove the last dimension for regression tasks.
        .float()
    )

In [9]:
share_training_batches = True

def loss_fn(y_pred: Tensor, y_true: Tensor, model, delta: float = 1.0) -> Tensor:
    y_pred = y_pred.flatten(0, 1)

    if share_training_batches:
        y_true = y_true.repeat_interleave(model.backbone.k)
    else:
        y_true = y_true.flatten(0, 1)

    return F.huber_loss(y_pred, y_true, delta=delta)

In [10]:
@torch.inference_mode()
def evaluate(x_data: Tensor, y_data: Tensor, model, y_scaler=None) -> float:
    model.eval()
    eval_batch_size = 64
    y_pred: np.ndarray = (
        torch.cat(
            [
                apply_model(model, x_data, idx)
                for idx in torch.arange(len(x_data), device=device).split(eval_batch_size)
            ]
        )
        .cpu()
        .numpy()
    )
    y_pred = y_pred.mean(1)

    if y_scaler is not None:
        y_pred = y_scaler.inverse_transform(y_pred.reshape(-1, 1)).ravel()

    y_true = y_data.cpu().numpy()

    score = mean_absolute_percentage_error(y_true, y_pred)
    return float(score)

In [11]:
def train(X_train, Y_train, X_val, Y_val, Y_scaler, model, optimizer, batch_size, delta, patience=100):
    device = X_train.device
    train_size = X_train.shape[0]
    n_epochs = 2000

    best_val = float('inf')  # <-- because we're minimizing
    best_state = deepcopy(model.state_dict())
    remaining_patience = patience

    for epoch in range(n_epochs):
        batches = (
            torch.randperm(train_size, device=device).split(batch_size)
            if share_training_batches
            else torch.rand((train_size, model.backbone.k), device=device)
                 .argsort(dim=0).split(batch_size, dim=0)
        )

        for batch_idx in batches:
            model.train()
            optimizer.zero_grad()
            loss = loss_fn(apply_model(model,X_train, batch_idx), Y_train[batch_idx], model, delta)

            if grad_scaler is None:
                loss.backward()
                optimizer.step()
            else:
                grad_scaler.scale(loss).backward()
                grad_scaler.step(optimizer)
                grad_scaler.update()

        val_metric = evaluate(X_val, Y_val, model,  Y_scaler)  # lower is better now

        if val_metric < best_val:
            best_val = val_metric
            best_state = deepcopy(model.state_dict())
            remaining_patience = patience
        else:
            remaining_patience -= 1

        if remaining_patience < 0:
            break

    model.load_state_dict(best_state)
    return best_val

In [12]:
def objective(trial, target):
    model_params = {
        "lr": trial.suggest_float("lr", 1e-4, 1e-2, log=True),
        'selective': trial.suggest_categorical("selective", [True, False]),
        "decay": trial.suggest_float("decay", 1e-6, 0.5, log=True),
        "bins_emb": trial.suggest_int("bins_emb", 16, 128),
        "d_emb": trial.suggest_categorical("d_emb", [8, 16, 32, 64]),
        "batch_size": trial.suggest_categorical("batch_size", [8, 16, 32]),
        "act_emb": trial.suggest_categorical("act_emb", [True, False]),
        "n_blocks": trial.suggest_int("n_blocks", 2, 6),
        "d_block": trial.suggest_categorical("d_block", [256, 512, 1024, 2048]),
        "dropout": trial.suggest_float("dropout", 0.0, 0.5),
        "activation" : trial.suggest_categorical("activation", ["ReLU", "GELU", "SiLU","ELU"]),
        "arch_type": "tabm",
        "start_scaling_init": trial.suggest_categorical("start_scaling_init", ["normal","random-signs"]),
        "d_out": 1,
        "tree_binning": False,
        "huber_delta" : trial.suggest_float("huber_delta", 0.1, 3.0, log=True)
    }

    X_train, Y_train, X_val, Y_val, Y_scaler = get_data(target,selective=model_params['selective'])
    X_train = torch.tensor(X_train.values, dtype=torch.float32, device=device)
    Y_train = torch.tensor(Y_train, dtype=torch.float32, device=device)
    X_val   = torch.tensor(X_val.values,   dtype=torch.float32, device=device)
    Y_val   = torch.tensor(Y_val,   dtype=torch.float32, device=device)

    model, optimizer = get_model(model_params, X_train)
    best_val = train(
        X_train, Y_train,
        X_val, Y_val,Y_scaler,
        model, optimizer,
        batch_size=model_params['batch_size'],
        delta=model_params['huber_delta'],
        patience=100
    )

    return best_val  

In [13]:
optuna.logging.set_verbosity(optuna.logging.WARNING)

In [None]:
for target in TARGETS:
    
    study = optuna.create_study(direction="minimize")
    study.optimize(lambda trial: objective(trial, target), n_trials=N_TRIALS,
                   n_jobs=16,show_progress_bar=True)

    print(f"\nBest MAPE for {target}: {study.best_value:.4f}")
    print(f"Best params for {target}:\n{study.best_params}\n")
    
    complete_params = {**study.best_params}
    
    # Save best params (skip model training for now)
    os.makedirs(model_dir, exist_ok=True)
    with open(os.path.join(model_dir, f"best_params_{target}.json"), "w") as f:
        json.dump(complete_params, f, indent=2)

  0%|          | 0/100 [00:00<?, ?it/s]