# Train

This notebook trains a lightgbm model on the 3W dataset.

In [0]:
%pip install uv

In [0]:
%sh uv pip install ../.

In [0]:
from hydrate.utils import get_config_path, DotConfig
config_path = get_config_path()
config = DotConfig(config_path)

In [0]:
from databricks.connect import DatabricksSession as SparkSession
spark = SparkSession.builder.getOrCreate()

We focus on five wells (19, 29, 31, 28, 25) where hydrate predictions have been observed. This gives us 2.6 million observations for our first ML pass. We are going to train a global lightgbm model.

In [0]:
df = spark.table(f"{config.catalog}.{config.schema}.well_data_c")

ml_df = (df
  .filter(df.well_number.isin([19, 29, 31, 28, 25]))
  .toPandas()
)

In [0]:
ml_df

In [0]:
ml_df.columns

In [0]:
indices = ['timestamp', 'well_number']
tags = ['P-PDG', 'P-TPT', 'T-TPT', 'P-MON-CKP', 'T-JUS-CKP', 'QGL']
target = ['state']
X_df = ml_df[tags + ['well_number']]
y_df = ml_df[target]

In [0]:
import pandas as pd
from tsfresh import extract_features
from tsfresh.feature_extraction import MinimalFCParameters
import lightgbm as lgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score

# Prepare data for TSFresh - need well_number as id and timestamp sorted
ts_df = ml_df[['well_number', 'timestamp'] + tags + ['state']].copy()
ts_df = ts_df.sort_values(['well_number', 'timestamp']).reset_index(drop=False)

print(f"Data shape for feature extraction: {ts_df.shape}")
print(f"Wells: {ts_df['well_number'].nunique()}")
print(f"Time range: {ts_df['timestamp'].min()} to {ts_df['timestamp'].max()}")

In [0]:
# TODO: Later
# Extract time series features using TSFresh with minimal feature set
# Use minimal parameters for speed - we can expand later
feature_extraction_settings = MinimalFCParameters()

# Extract features for each tag
# print("Extracting time series features...")
# features = extract_features(
#     ts_df,
#     column_id='index',
#     column_sort='timestamp',
#     default_fc_parameters=feature_extraction_settings,
#     disable_progressbar=False
# )

# print(f"Extracted features shape: {features.shape}")
# print(f"Number of features per tag: {features.shape[1] // len(tags)}")

In [0]:
import mlflow
mlflow.lightgbm.autolog()

In [0]:
# Handle missing values and split data
# X = features.fillna(0)  # Simple imputation - can improve later
X_train, X_test, y_train, y_test = train_test_split(
    X_df, y_df, test_size=0.2, random_state=42, stratify=y_df
)

print(f"Training set: {X_train.shape}")
print(f"Test set: {X_test.shape}")
print(f"Training classes: {sorted(y_train.nunique())}")
print(f"Test classes: {sorted(y_test.nunique())}")

# Train LightGBM model for multiclass classification
lgb_model = lgb.LGBMClassifier(
    objective='multiclass',
    metric='multi_logloss',
    boosting_type='gbdt',
    num_leaves=31,
    learning_rate=0.1,
    feature_fraction=0.8,
    bagging_fraction=0.8,
    bagging_freq=5,
    verbose=-1,
    random_state=42
)

In [0]:
print("\nTraining LightGBM model...")
lgb_model.fit(X_train, y_train)
print("Model training completed!")

In [0]:
X_train

In [0]:
lgb_model.predict(X_train[5:9])

In [0]:
# Evaluate model performance
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np

y_pred = lgb_model.predict(X_test)
y_pred_proba = lgb_model.predict_proba(X_test)

print("=== Model Performance ===")
print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=[f"State_{i}" for i in sorted(y_df.state.unique())]))

# Multiclass ROC-AUC (one-vs-rest)
try:
    from sklearn.metrics import roc_auc_score
    auc_score = roc_auc_score(y_test, y_pred_proba, multi_class='ovr', average='weighted')
    print(f"\nWeighted ROC-AUC Score (OvR): {auc_score:.4f}")
except:
    print("\nROC-AUC not calculated (insufficient classes in test set)")

# Feature importance
feature_importance = pd.DataFrame({
    'feature': X_df.columns,
    'importance': lgb_model.feature_importances_
}).sort_values('importance', ascending=False)

print("\n=== Top 10 Most Important Features ===")
print(feature_importance.head(10))

# Show some predictions with state names
print("\n=== Sample Predictions ===")
state_names = {
    0: 'Normal', 1: 'Abrupt Increase of BSW', 2: 'Spurious Closure of DHSV',
    3: 'Severe Slugging', 4: 'Flow Instability', 5: 'Rapid Productivity Loss',
    6: 'Quick Restriction in PCK', 7: 'Scaling in PCK', 8: 'Hydrate in Production Line'
}

input_example = X_test.sample(5)
print(input_example)

In [0]:
sample_results = pd.DataFrame({
    'actual': y_test.iloc[:10],
    'predicted': y_pred[:10],
    'actual_name': y_test.iloc[:10].state.map(state_names),
    'predicted_name': pd.Series(y_pred[:10]).map(state_names),
    'max_probability': np.max(y_pred_proba[:10], axis=1)
})
print(sample_results)

## Log the Model



In [0]:
model_info = mlflow.lightgbm.log_model(
    lgb_model=lgb_model,
    artifact_path="model",
    input_example=input_example,
    registered_model_name="shm.3w.lightgbm"
)

In [0]:
model_info.model_uri

In [0]:
import mlflow
mlflow.models.predict(
    model_uri=model_info.model_uri,
    input_data=input_example,
    env_manager="uv",
)

In [0]:
from mlflow.models.utils import load_serving_example
from mlflow.models import validate_serving_input

# Load serving example
serving_example = load_serving_example(model_info.model_uri)

# Validate it works
result = validate_serving_input(model_info.model_uri, serving_example)
print(f"Validation result: {result}")

In [0]:
serving_example