In [13]:
import optuna
import pandas as pd
import numpy as np

from sklearn.metrics import mean_absolute_percentage_error
from sklearn.model_selection import train_test_split
from sklearn.cluster import KMeans
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler

import os

import json
import time

import math
import random
from copy import deepcopy
from typing import Any, Literal, NamedTuple

import numpy as np
import rtdl_num_embeddings  # https://github.com/yandex-research/rtdl-num-embeddings
import scipy.special
import sklearn.datasets
import sklearn.metrics
import sklearn.model_selection
import sklearn.preprocessing
import tabm
import torch
import torch.nn.functional as F
import torch.optim
from torch import Tensor

In [14]:
def compute_volume_weighted_component_features(X):
    """
    Computes individual volume-weighted features WjPk = Componentj_fraction * Componentj_Propertyk
    for j in 1..5 and k in 1..10 (total 50 features).
    """
    features = {}
    for comp_idx in range(1, 6):  # Components 1–5
        for prop_idx in range(1, 11):  # Properties 1–10
            vol_col = f'Component{comp_idx}_fraction'
            prop_col = f'Component{comp_idx}_Property{prop_idx}'
            feat_name = f'W{comp_idx}P{prop_idx}'
            features[feat_name] = X[vol_col] * X[prop_col]
    return pd.DataFrame(features)

In [15]:
# Load data
train_df = pd.read_csv("/pscratch/sd/r/ritesh11/temp_dir/dataset/train.csv")
test_df = pd.read_csv("/pscratch/sd/r/ritesh11/temp_dir/dataset/test.csv")

# Separate features and targets
feature_cols = train_df.columns[:55]
target_cols = train_df.columns[55:]

X = train_df[feature_cols].copy()
y = train_df[target_cols].copy()

In [16]:
model_dir = "/pscratch/sd/r/ritesh11/temp_dir/NN_models"

In [17]:
targets = [f"BlendProperty{i}" for i in range(1, 11)]
target = targets[2]

In [18]:
kmeans = KMeans(n_clusters=10, random_state=42)
clusters = kmeans.fit_predict(y[[target]])

# Stratified sampling on clusters
sss = StratifiedShuffleSplit(n_splits=1, test_size=100, random_state=42)
train_idx, val_idx = next(sss.split(X, clusters))

In [19]:
X_train = X.iloc[train_idx]
y_train = y.iloc[train_idx][[target]]
X_val = X.iloc[val_idx]
y_val  = y.iloc[val_idx][[target]]

In [20]:
scaler = StandardScaler()

In [21]:
blend_features = compute_volume_weighted_component_features(X_train)
X_train = pd.concat([X_train, blend_features], axis=1)
# X_train = pd.DataFrame(
#         scaler.fit_transform(X_train),
#         columns=X_train.columns,
#         index=X_train.index
#     )

blend_features = compute_volume_weighted_component_features(X_val)
X_val = pd.concat([X_val, blend_features], axis=1)
# X_val = pd.DataFrame(
#         scaler.transform(X_val),
#         columns=X_val.columns,
#         index=X_val.index
#     )

In [22]:
# Device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

Y_train = torch.as_tensor(y_train.values, device=device, dtype=torch.float32)
Y_val = torch.as_tensor(y_val.values, device=device, dtype=torch.float32)

# 1. Convert DataFrames to NumPy
X_train_np = X_train.values
X_val_np = X_val.values

# 2. Add tiny Gaussian noise to training set to avoid constant features
noise = (
    np.random.default_rng(0)
    .normal(0.0, 1e-5, X_train_np.shape)
    .astype(X_train_np.dtype)
)

# 3. Fit QuantileTransformer on X_train + noise
quantile_preproc = sklearn.preprocessing.QuantileTransformer(
    n_quantiles=max(min(len(X_train_np) // 30, 1000), 10),
    output_distribution='normal',
    subsample=10**9,
)
quantile_preproc.fit(X_train_np + noise)

# 4. Transform both X_train and X_val
X_train_transformed = quantile_preproc.transform(X_train_np)
X_val_transformed = quantile_preproc.transform(X_val_np)

# 5. Convert back to PyTorch tensors
X_train_tensor = torch.tensor(X_train_transformed, dtype=torch.float32, device=device)
X_val_tensor = torch.tensor(X_val_transformed, dtype=torch.float32, device=device)

# (Optional) replace existing vars
X_train = X_train_tensor
X_val = X_val_tensor

# Label preprocessing.
class RegressionLabelStats(NamedTuple):
    mean: float
    std: float

regression_label_stats = RegressionLabelStats(
    Y_train.mean().item(), Y_train.std().item()
)

Y_train = (Y_train - regression_label_stats.mean) / regression_label_stats.std
Y_val = (Y_val - regression_label_stats.mean) / regression_label_stats.std

amp_dtype = (
    torch.bfloat16
    if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    else torch.float16
    if torch.cuda.is_available()
    else None
)


amp_enabled = amp_dtype is not None
grad_scaler = torch.cuda.amp.GradScaler() if amp_dtype is torch.float16 else None  # type: ignore

# torch.compile
compile_model = True

# fmt: off
print(f'Device:        {device.type.upper()}')
print(f'AMP:           {amp_enabled}{f" ({amp_dtype})"if amp_enabled else ""}')
print(f'torch.compile: {compile_model}')
# fmt: on

Device:        CUDA
AMP:           True (torch.bfloat16)
torch.compile: True


In [23]:
# Piecewise-linear embeddings.
num_embeddings = rtdl_num_embeddings.PiecewiseLinearEmbeddings(
    rtdl_num_embeddings.compute_bins(X_train, n_bins=48),
    d_embedding=16,
    activation=False,
    version='B',
)

In [24]:
model = tabm.TabM.make(
    n_num_features=X_train.shape[1],
    d_out=1,
    num_embeddings=num_embeddings,
    # n_blocks = 4,
    # k=64,
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3, weight_decay=3e-4)

if compile_model:
    # NOTE
    # `torch.compile(model, mode="reduce-overhead")` caused issues during training,
    # so the `mode` argument is not used.
    model = torch.compile(model)
    evaluation_mode = torch.no_grad
else:
    evaluation_mode = torch.inference_mode

In [25]:
model

OptimizedModule(
  (_orig_mod): TabM(
    (num_module): PiecewiseLinearEmbeddings(
      (linear0): LinearEmbeddings()
      (impl): _PiecewiseLinearEncodingImpl()
      (linear): _NLinear()
    )
    (ensemble_view): EnsembleView()
    (backbone): MLPBackboneBatchEnsemble(
      (blocks): ModuleList(
        (0-1): 2 x Sequential(
          (0): LinearBatchEnsemble()
          (1): ReLU()
          (2): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (output): LinearEnsemble()
  )
)

In [26]:
@torch.autocast(device.type, enabled=amp_enabled, dtype=amp_dtype)  # type: ignore[code]
def apply_model(data: Tensor, idx: Tensor) -> Tensor:
    return (
        model(
            data[idx],
            None,
        )
        .squeeze(-1)  # Remove the last dimension for regression tasks.
        .float()
    )

In [27]:
share_training_batches = True

def loss_fn(y_pred: Tensor, y_true: Tensor) -> Tensor:
    
    y_pred = y_pred.flatten(0, 1)

    if share_training_batches:
        # (batch_size,) -> (batch_size * k,)
        y_true = y_true.repeat_interleave(model.backbone.k)
    else:
        # (batch_size, k) -> (batch_size * k,)
        y_true = y_true.flatten(0, 1)

    return F.mse_loss(y_pred, y_true)

In [28]:
@evaluation_mode()
def evaluate(x_data: Tensor, y_data: Tensor) -> float:
    model.eval()

    eval_batch_size = 64
    y_pred: np.ndarray = (
        torch.cat(
            [
                apply_model(x_data, idx)
                for idx in torch.arange(len(x_data), device=device).split(eval_batch_size)
            ]
        )
        .cpu()
        .numpy()
    )

    assert regression_label_stats is not None
    y_pred = y_pred * regression_label_stats.std + regression_label_stats.mean

    y_pred = y_pred.mean(1) 

    y_true = y_data.cpu().numpy()

    score = -mean_absolute_percentage_error(y_true, y_pred) 
    return float(score)

In [30]:
n_epochs = 1_000_000_000
train_size = X_train.shape[0]
batch_size = 64
epoch_size = math.ceil(train_size / batch_size)

epoch = -1
metrics = {'val': -math.inf, 'test': -math.inf}


def make_checkpoint() -> dict[str, Any]:
    return deepcopy(
        {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
            'metrics': metrics,
        }
    )


best_checkpoint = make_checkpoint()

# Early stopping: the training stops if the validation score
# does not improve for more than `patience` consecutive epochs.
patience = 200
remaining_patience = patience

for epoch in range(n_epochs):
    batches = (
        # Create one standard batch sequence.
        torch.randperm(train_size, device=device).split(batch_size)
        if share_training_batches
        # Create k independent batch sequences.
        else (
            torch.rand((train_size, model.backbone.k), device=device)
            .argsort(dim=0)
            .split(batch_size, dim=0)
        )
    )
    for batch_idx in batches:
        model.train()
        optimizer.zero_grad()
        loss = loss_fn(apply_model(X_train, batch_idx), Y_train[batch_idx])
        if grad_scaler is None:
            loss.backward()
            optimizer.step()
        else:
            grad_scaler.scale(loss).backward()  # type: ignore
            grad_scaler.step(optimizer)
            grad_scaler.update()

    val_metric = evaluate(X_val, Y_val)
    val_score_improved = val_metric > best_checkpoint['metrics']['val']
    
    metrics["val"] = val_metric  # <-- add this line
    
    print(
        f'{"*" if val_score_improved else " "}'
        f' [epoch] {epoch:<3}'
        f' [val] {metrics["val"]:.3f}'
    )

    print(
        f'{"*" if val_score_improved else " "}'
        f' [epoch] {epoch:<3}'
        f' [val] {metrics["val"]:.3f}'
    )

    if val_score_improved:
        best_checkpoint = make_checkpoint()
        remaining_patience = patience
    else:
        remaining_patience -= 1

    if remaining_patience < 0:
        break

# To make final predictions, load the best checkpoint.
model.load_state_dict(best_checkpoint['model'])

* [epoch] 0   [val] -1.424
* [epoch] 0   [val] -1.424
* [epoch] 1   [val] -1.318
* [epoch] 1   [val] -1.318
* [epoch] 2   [val] -1.199
* [epoch] 2   [val] -1.199
* [epoch] 3   [val] -1.105
* [epoch] 3   [val] -1.105
* [epoch] 4   [val] -1.055
* [epoch] 4   [val] -1.055
* [epoch] 5   [val] -0.985
* [epoch] 5   [val] -0.985
* [epoch] 6   [val] -0.932
* [epoch] 6   [val] -0.932
* [epoch] 7   [val] -0.901
* [epoch] 7   [val] -0.901
* [epoch] 8   [val] -0.893
* [epoch] 8   [val] -0.893
* [epoch] 9   [val] -0.845
* [epoch] 9   [val] -0.845
* [epoch] 10  [val] -0.802
* [epoch] 10  [val] -0.802
* [epoch] 11  [val] -0.781
* [epoch] 11  [val] -0.781
* [epoch] 12  [val] -0.767
* [epoch] 12  [val] -0.767
* [epoch] 13  [val] -0.724
* [epoch] 13  [val] -0.724
* [epoch] 14  [val] -0.719
* [epoch] 14  [val] -0.719
* [epoch] 15  [val] -0.699
* [epoch] 15  [val] -0.699
* [epoch] 16  [val] -0.690
* [epoch] 16  [val] -0.690
* [epoch] 17  [val] -0.668
* [epoch] 17  [val] -0.668
* [epoch] 18  [val] -0.626
*

KeyboardInterrupt: 