In [2]:
import pandas as pd
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import PolynomialFeatures
from sklearn.metrics import accuracy_score
import pickle
import xgboost as xgb

# Load training data
data = pd.read_csv('/Users/man/Desktop/project-part-2-ok-computer-main/data/train_prices_decisions_2024.csv')

# Convert 'item_bought' to binary (True -> 1, False -> 0)
data['item_bought'] = data['item_bought'].astype(int)

# Features and target
X = data[['Covariate1', 'Covariate2', 'Covariate3', 'price_item']]
y = data['item_bought']

# Polynomial features for non-linear interactions
poly = PolynomialFeatures(degree=2, include_bias=False)
X_poly = poly.fit_transform(X)

# Train-test split
X_train, X_val, y_train, y_val = train_test_split(X_poly, y, test_size=0.2, random_state=42)

# Initialize XGBoost model
xgb_model = xgb.XGBClassifier()

# Hyperparameter grid
param_grid = {
    'n_estimators': [50, 100, 200],
    'learning_rate': [0.01, 0.1, 0.2],
    'max_depth': [3, 5, 7],
    'subsample': [0.8, 1.0],
    'colsample_bytree': [0.8, 1.0]
}

# Grid search
grid_search = GridSearchCV(estimator=xgb_model, param_grid=param_grid, cv=3, scoring='accuracy', verbose=2, n_jobs=-1)
grid_search.fit(X_train, y_train)

# Best model from grid search
best_model = grid_search.best_estimator_
print(f"Best Parameters: {grid_search.best_params_}")

# Validation accuracy
y_pred = best_model.predict(X_val)
accuracy = accuracy_score(y_val, y_pred)
print("Validation Accuracy:", accuracy)

# Save the model and polynomial transformer
with open('/Users/man/Desktop/project-part-1-ok-computer-main/agents/ok-computer/trained_model.pkl', 'wb') as f:
    pickle.dump({'model': best_model, 'poly': poly}, f)

print("Best XGBoost model and transformer saved.")

Fitting 3 folds for each of 108 candidates, totalling 324 fits
Best Parameters: {'colsample_bytree': 1.0, 'learning_rate': 0.2, 'max_depth': 3, 'n_estimators': 200, 'subsample': 0.8}
Validation Accuracy: 0.9857
Best XGBoost model and transformer saved.
