<a href="https://colab.research.google.com/github/Y-Noor/JAX/blob/main/xgboost/jax_xgboost.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install jax jaxlib equinox pandas numpy scikit-learn matplotlib seaborn

Collecting equinox
  Downloading equinox-0.13.2-py3-none-any.whl.metadata (19 kB)
Collecting jaxtyping>=0.2.20 (from equinox)
  Downloading jaxtyping-0.3.4-py3-none-any.whl.metadata (7.8 kB)
Collecting wadler-lindig>=0.1.0 (from equinox)
  Downloading wadler_lindig-0.1.7-py3-none-any.whl.metadata (17 kB)
Downloading equinox-0.13.2-py3-none-any.whl (179 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.2/179.2 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxtyping-0.3.4-py3-none-any.whl (56 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.0/56.0 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading wadler_lindig-0.1.7-py3-none-any.whl (20 kB)
Installing collected packages: wadler-lindig, jaxtyping, equinox
Successfully installed equinox-0.13.2 jaxtyping-0.3.4 wadler-lindig-0.1.7


In [None]:
import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp
from jax import vmap, jit
from functools import partial
import equinox as eqx
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt

In [None]:
# ============================================================================
# PART 1: REFINED FEATURE ENGINEERING (Skewness + Outliers)
# ============================================================================

def engineer_features_v4(train_df, test_df):
    """
    Robust Feature Engineering for Leaderboard Stability:
    - Log-transforms highly skewed numeric features.
    - Preserves Ordinal rankings.
    - One-Hot encodes categories.
    """
    train_len = len(train_df)
    y_train = train_df['SalePrice'].copy()

    # Combine for consistent processing
    df = pd.concat([train_df.drop(['Id', 'SalePrice'], axis=1),
                    test_df.drop(['Id'], axis=1)], axis=0).reset_index(drop=True)

    # 1. Fill Missing Values
    cat_cols = df.select_dtypes(include=['object']).columns
    for col in cat_cols:
        df[col] = df[col].fillna(df[col].mode()[0] if not df[col].mode().empty else "None")

    num_cols = df.select_dtypes(exclude=['object']).columns
    for col in num_cols:
        df[col] = df[col].fillna(df[col].median())

    # 2. FIX SKEWNESS: Apply Log-Transform to high-skew numeric features
    # This prevents 'LotArea' or 'MiscVal' from dominating the Gain calculations
    skewed_feats = df[num_cols].apply(lambda x: x.skew()).sort_values(ascending=False)
    high_skew = skewed_feats[abs(skewed_feats) > 0.75].index
    for feat in high_skew:
        df[feat] = np.log1p(df[feat])

    # 3. Manual Ordinal Mapping (Preserving Quality Ranks)
    qual_map = {'Ex': 5, 'Gd': 4, 'TA': 3, 'Fa': 2, 'Po': 1, 'None': 0}
    qual_cols = ['ExterQual', 'ExterCond', 'BsmtQual', 'BsmtCond', 'HeatingQC',
                 'KitchenQual', 'FireplaceQu', 'GarageQual', 'GarageCond', 'PoolQC']
    for col in qual_cols:
        if col in df.columns:
            df[col] = df[col].map(qual_map).astype(int)

    # 4. Create Engineered Features
    df['TotalSF'] = df['TotalBsmtSF'] + df['1stFlrSF'] + df['2ndFlrSF']
    df['OverallScore'] = df['OverallQual'] * df['OverallCond']

    # 5. One-Hot Encoding
    df = pd.get_dummies(df)

    X_train = df.iloc[:train_len].copy()
    X_test = df.iloc[train_len:].copy()

    return X_train, X_test, y_train


In [None]:
# ============================================================================
# PART 2: THE JAX-XGBOOST ENGINE
# ============================================================================

class JAX_XGB_Tree(eqx.Module):
    split_features: jnp.ndarray
    split_thresholds: jnp.ndarray
    leaf_values: jnp.ndarray
    max_depth: int

    def predict(self, x: jnp.ndarray) -> jnp.ndarray:
        @jit
        def walk_tree(sample):
            node_idx = 0
            for _ in range(self.max_depth):
                f = self.split_features[node_idx]
                t = self.split_thresholds[node_idx]
                go_right = sample[f] > t
                node_idx = 2 * node_idx + 1 + go_right.astype(jnp.int32)
            return self.leaf_values[node_idx - (2**self.max_depth - 1)]
        return vmap(walk_tree)(x)

class JAX_XGB_Model:
    def __init__(self, n_estimators=450, max_depth=3, learning_rate=0.025, lambda_=25.0, n_bins=64):
        self.n_estimators = n_estimators
        self.max_depth = max_depth
        self.learning_rate = learning_rate
        self.lambda_ = lambda_
        self.n_bins = n_bins
        self.trees = []
        self.base_score = None

    @partial(jit, static_argnums=(0,))
    def _find_best_split(self, X, G, H, bins, mask):
        def get_gain(feat_idx, threshold):
            m_l = (X[:, feat_idx] <= threshold) & mask
            G_l, H_l = jnp.sum(G * m_l), jnp.sum(H * m_l)
            G_r, H_r = jnp.sum(G * mask) - G_l, jnp.sum(H * mask) - H_l
            def term(g, h): return (g**2) / (h + self.lambda_)
            gain = 0.5 * (term(G_l, H_l) + term(G_r, H_r) - term(jnp.sum(G*mask), jnp.sum(H*mask)))
            return jnp.where(jnp.sum(mask) > 0, gain, -1.0)

        v_bins = vmap(get_gain, in_axes=(None, 0))
        v_feats = vmap(v_bins, in_axes=(0, 0))
        gains = v_feats(jnp.arange(X.shape[1]), bins)
        best_idx = jnp.argmax(gains)
        f_idx, b_idx = jnp.unravel_index(best_idx, gains.shape)
        return f_idx, bins[f_idx, b_idx]

    def fit(self, X_np, y_np):
        X, y = jnp.array(X_np), jnp.array(y_np)
        n_samples, n_feats = X.shape
        self.bins = jnp.stack([jnp.quantile(X[:, i], jnp.linspace(0, 1, self.n_bins)) for i in range(n_feats)])

        self.base_score = jnp.mean(y)
        preds = jnp.full_like(y, self.base_score)

        print(f"Training JAX-XGBoost (Strong Regularization: lambda={self.lambda_})")
        for i in range(self.n_estimators):
            G, H = preds - y, jnp.ones_like(y)
            s_f, s_t = np.zeros(2**self.max_depth - 1, dtype=int), np.zeros(2**self.max_depth - 1)
            l_v = np.zeros(2**self.max_depth)
            masks = [jnp.ones(n_samples, dtype=bool)]

            curr = 0
            for d in range(self.max_depth):
                nxt = []
                for m in masks:
                    f, t = self._find_best_split(X, G, H, self.bins, m)
                    s_f[curr], s_t[curr] = int(f), float(t)
                    nxt.append(m & (X[:, int(f)] <= float(t))); nxt.append(m & (X[:, int(f)] > float(t)))
                    curr += 1
                masks = nxt

            for j, m in enumerate(masks):
                l_v[j] = -jnp.sum(G[m]) / (jnp.sum(H[m]) + self.lambda_) if jnp.sum(m) > 0 else 0.0

            tree = JAX_XGB_Tree(jnp.array(s_f), jnp.array(s_t), jnp.array(l_v), self.max_depth)
            preds += self.learning_rate * tree.predict(X)
            self.trees.append(tree)
            if (i+1) % 150 == 0:
                print(f"  Round {i+1} RMSE: {jnp.sqrt(jnp.mean((preds-y)**2)):.5f}")

    def predict(self, X_np):
        X = jnp.array(X_np)
        p = jnp.full(X.shape[0], self.base_score)
        for t in self.trees: p += self.learning_rate * t.predict(X)
        return np.array(p)


In [None]:
# ============================================================================
# PART 3: MAIN EXECUTION
# ============================================================================

def main():
    # 1. Load Data
    train = pd.read_csv('train.csv')
    test = pd.read_csv('test.csv')
    test_ids = test['Id']

    # AMES OUTLIER REMOVAL (Removes noisy high-leverage points)
    train = train.drop(train[(train['GrLivArea']>4000) & (train['SalePrice']<300000)].index)

    # 2. Engineer Features
    X_train, X_test, y_train = engineer_features_v4(train, test)
    y_log = np.log1p(y_train.values)

    # Standardize Features (Ensures JAX binary search is numerically stable)
    mu, std = X_train.mean(), X_train.std() + 1e-7
    X_train_s, X_test_s = (X_train - mu) / std, (X_test - mu) / std

    # 3. Cross Validation
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    oof_preds = np.zeros(len(y_log))
    test_preds_total = np.zeros(len(X_test))

    print(f"Training on {X_train_s.shape[1]} features...")
    for fold, (tr_idx, val_idx) in enumerate(kf.split(X_train_s)):
        X_tr, X_val = X_train_s.iloc[tr_idx].values, X_train_s.iloc[val_idx].values
        y_tr, y_val = y_log[tr_idx], y_log[val_idx]

        # PARAMETER TUNING FOR LEADERBOARD:
        # High lambda (25.0) and small learning rate (0.025)
        model = JAX_XGB_Model(n_estimators=500, max_depth=3, learning_rate=0.025, lambda_=25.0)
        model.fit(X_tr, y_tr)

        f_preds = model.predict(X_val)
        oof_preds[val_idx] = f_preds
        test_preds_total += model.predict(X_test_s.values) / 5
        print(f"Fold {fold+1} Val RMSE: {np.sqrt(mean_squared_error(y_val, f_preds)):.5f}")

    # 4. Final Verification
    total_rmse = np.sqrt(mean_squared_error(y_log, oof_preds))
    print("\n" + "="*40)
    print(f"ROBUST OOF RMSE: {total_rmse:.5f}")
    print(f"R2 SCORE:        {r2_score(y_log, oof_preds):.5f}")
    print("="*40)

    # 5. Save Final Submission
    pd.DataFrame({'Id': test_ids, 'SalePrice': np.expm1(test_preds_total)}).to_csv('submission_jax_robust.csv', index=False)
    print("Final file saved: submission_jax_robust.csv")

if __name__ == "__main__":
    main()

Training on 254 features...
Training JAX-XGBoost (Strong Regularization: lambda=25.0)
  Round 150 RMSE: 0.12788
  Round 300 RMSE: 0.10156
  Round 450 RMSE: 0.09282
Fold 1 Val RMSE: 0.12426
Training JAX-XGBoost (Strong Regularization: lambda=25.0)
  Round 150 RMSE: 0.12874
  Round 300 RMSE: 0.10302
  Round 450 RMSE: 0.09400
Fold 2 Val RMSE: 0.12456
Training JAX-XGBoost (Strong Regularization: lambda=25.0)
  Round 150 RMSE: 0.12643
  Round 300 RMSE: 0.10024
  Round 450 RMSE: 0.09114
Fold 3 Val RMSE: 0.13432
Training JAX-XGBoost (Strong Regularization: lambda=25.0)
  Round 150 RMSE: 0.12585
  Round 300 RMSE: 0.09982
  Round 450 RMSE: 0.09170
Fold 4 Val RMSE: 0.13641
Training JAX-XGBoost (Strong Regularization: lambda=25.0)
  Round 150 RMSE: 0.13111
  Round 300 RMSE: 0.10579
  Round 450 RMSE: 0.09605
Fold 5 Val RMSE: 0.10610

ROBUST OOF RMSE: 0.12559
R2 SCORE:        0.90120
Final file saved: submission_jax_robust.csv


In [None]:
# Create empty arrays to store the "Out-Of-Fold" results
oof_preds = None
y_actual_log = None

def main_with_evaluation():
    global oof_preds, y_actual_log # Make these accessible outside the function

    print("Loading and Engineering...")
    train = pd.read_csv('train.csv')
    test = pd.read_csv('test.csv')
    X_train, X_test, y_train = engineer_features(train, test)
    y_log = np.log1p(y_train.values)

    mu, std = X_train.mean(), X_train.std() + 1e-7
    X_train_s = (X_train - mu) / std

    # Initialize OOF array
    oof_preds = np.zeros(len(y_log))
    y_actual_log = y_log

    kf = KFold(n_splits=5, shuffle=True, random_state=42)

    for fold, (tr_idx, val_idx) in enumerate(kf.split(X_train_s)):
        X_tr, X_val = X_train_s.iloc[tr_idx].values, X_train_s.iloc[val_idx].values
        y_tr, y_val = y_log[tr_idx], y_log[val_idx]

        # Train model
        model = JAX_XGB_Model(n_estimators=300, learning_rate=0.05, lambda_=5.0, n_bins=64)
        model.fit(X_tr, y_tr)

        # Save validation predictions to the OOF array
        fold_preds = model.predict(X_val)
        oof_preds[val_idx] = fold_preds

        print(f"Fold {fold+1} complete.")

    print("\nTraining Finished! You can now run the evaluation cell.")

main_with_evaluation()

Loading and Engineering...


NameError: name 'engineer_features' is not defined

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error

def evaluate_model(y_true_log, y_pred_log):
    # Convert back to actual Dollars
    y_true = np.expm1(y_true_log)
    y_pred = np.expm1(y_pred_log)

    # Calculate Metrics
    log_rmse = np.sqrt(mean_squared_error(y_true_log, y_pred_log))
    mae = mean_absolute_error(y_true, y_pred)
    r2 = r2_score(y_true_log, y_pred_log)
    # Mean Absolute Percentage Error
    mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100

    print("="*40)
    print("      FINAL MODEL PERFORMANCE")
    print("="*40)
    print(f"Kaggle Score (Log RMSE): {log_rmse:.5f}")
    print(f"R-Squared (Variance):    {r2:.5f}")
    print(f"Avg. Price Error (MAE):  ${mae:,.2f}")
    print(f"Avg. Percentage Error:   {mape:.2f}%")
    print("="*40)

    # Plot
    plt.figure(figsize=(12, 6))

    # Subplot 1: Actual vs Predicted
    plt.subplot(1, 2, 1)
    sns.regplot(x=y_true, y=y_pred, scatter_kws={'alpha':0.3}, line_kws={'color':'red'})
    plt.title('Actual vs Predicted Price')
    plt.xlabel('Actual Price ($)')
    plt.ylabel('Predicted Price ($)')

    # Subplot 2: Residuals
    plt.subplot(1, 2, 2)
    residuals = y_true - y_pred
    sns.histplot(residuals, kde=True, color='purple')
    plt.title('Distribution of Price Errors')
    plt.xlabel('Error ($)')

    plt.tight_layout()
    plt.show()

# Run the evaluation using the Out-Of-Fold data
evaluate_model(y_actual_log, oof_preds)

In [None]:
import pandas as pd
import numpy as np


# Add this to the end of your existing main() function to save the state
global fitted_model, train_columns, scaler_mu, scaler_std, qual_map_global

fitted_model = model  # The last trained model from the CV loop
train_columns = X_train.columns # The list of 256+ columns after OHE
scaler_mu = mu
scaler_std = std
qual_map_global = {'Ex': 5, 'Gd': 4, 'TA': 3, 'Fa': 2, 'Po': 1, 'None': 0}

# 1. ENTER YOUR HOUSE DETAILS HERE
new_house = {
    'MSSubClass': 60,
    'MSZoning': 'RL',
    'LotFrontage': 70.0,
    'LotArea': 9000,
    'Street': 'Pave',
    'LotShape': 'Reg',
    'LandContour': 'Lvl',
    'Utilities': 'AllPub',
    'LotConfig': 'Inside',
    'LandSlope': 'Gtl',
    'Neighborhood': 'CollgCr',
    'Condition1': 'Norm',
    'Condition2': 'Norm',
    'BldgType': '1Fam',
    'HouseStyle': '2Story',
    'OverallQual': 7,       # 1-10 Scale
    'OverallCond': 5,       # 1-10 Scale
    'YearBuilt': 2005,
    'YearRemodAdd': 2006,
    'RoofStyle': 'Gable',
    'RoofMatl': 'CompShg',
    'Exterior1st': 'VinylSd',
    'Exterior2nd': 'VinylSd',
    'MasVnrType': 'None',
    'MasVnrArea': 0.0,
    'ExterQual': 'Gd',      # Ex, Gd, TA, Fa, Po
    'ExterCond': 'TA',
    'Foundation': 'PConc',
    'BsmtQual': 'Gd',
    'BsmtCond': 'TA',
    'BsmtExposure': 'No',
    'BsmtFinType1': 'GLQ',
    'BsmtFinSF1': 700,
    'BsmtFinType2': 'Unf',
    'BsmtFinSF2': 0,
    'BsmtUnfSF': 300,
    'TotalBsmtSF': 1000,
    'Heating': 'GasA',
    'HeatingQC': 'Ex',
    'CentralAir': 'Y',
    'Electrical': 'SBrkr',
    '1stFlrSF': 1000,
    '2ndFlrSF': 1000,
    'LowQualFinSF': 0,
    'GrLivArea': 2000,
    'BsmtFullBath': 1,
    'BsmtHalfBath': 0,
    'FullBath': 2,
    'HalfBath': 1,
    'BedroomAbvGr': 3,
    'KitchenAbvGr': 1,
    'KitchenQual': 'Gd',
    'TotRmsAbvGrd': 8,
    'Functional': 'Typ',
    'Fireplaces': 1,
    'FireplaceQu': 'Gd',
    'GarageType': 'Attchd',
    'GarageYrBlt': 2005,
    'GarageFinish': 'RFn',
    'GarageCars': 2,
    'GarageArea': 500,
    'GarageQual': 'TA',
    'GarageCond': 'TA',
    'PavedDrive': 'Y',
    'WoodDeckSF': 0,
    'OpenPorchSF': 50,
    'EnclosedPorch': 0,
    '3SsnPorch': 0,
    'ScreenPorch': 0,
    'PoolArea': 0,
    'PoolQC': 'None',
    'Fence': 'None',
    'MiscFeature': 'None',
    'MiscVal': 0,
    'MoSold': 5,
    'YrSold': 2010,
    'SaleType': 'WD',
    'SaleCondition': 'Normal'
}

def predict_single_house(input_dict):
    # Convert to DataFrame
    df_input = pd.DataFrame([input_dict])

    # 2. Ordinal Mapping
    qual_cols = ['ExterQual', 'ExterCond', 'BsmtQual', 'BsmtCond', 'HeatingQC',
                 'KitchenQual', 'FireplaceQu', 'GarageQual', 'GarageCond', 'PoolQC']
    for col in qual_cols:
        df_input[col] = df_input[col].map(qual_map_global).fillna(0).astype(int)

    # 3. Engineered Features
    df_input['TotalSF'] = df_input['TotalBsmtSF'] + df_input['1stFlrSF'] + df_input['2ndFlrSF']
    df_input['OverallScore'] = df_input['OverallQual'] * df_input['OverallCond']

    # 4. Skewness Log-Transform (must match train logic)
    # We apply log1p to numeric columns where the training data was skewed
    # Note: For inference, we use the training set's skewness decision
    num_cols = df_input.select_dtypes(exclude=['object']).columns
    for col in num_cols:
        # If your train log-transform threshold was 0.75, we apply it here
        # For a single point, we just apply log1p to the known high-skew features
        if col in ['LotArea', 'GrLivArea', 'TotalBsmtSF', '1stFlrSF', 'TotalSF']:
            df_input[col] = np.log1p(df_input[col])

    # 5. One-Hot Encoding & Alignment
    df_ohe = pd.get_dummies(df_input)

    # ADD MISSING COLUMNS: This is the critical step.
    # We ensure the 1 row has all 256+ columns the model expects.
    missing_cols = set(train_columns) - set(df_ohe.columns)
    for c in missing_cols:
        df_ohe[c] = 0

    # Reorder columns to match training
    df_ohe = df_ohe[train_columns]

    # 6. Scaling
    df_scaled = (df_ohe - scaler_mu) / scaler_std

    # 7. Prediction
    log_price = fitted_model.predict(df_scaled.values)
    actual_price = np.expm1(log_price)[0]

    return actual_price

# RUN INFERENCE
price = predict_single_house(new_house)

print("-" * 30)
print(f"PREDICTED HOUSE PRICE")
print("-" * 30)
print(f"Estimated Value: ${price:,.2f}")
print("-" * 30)