In [4]:
from datetime import datetime
from aeon.classification.convolution_based import RocketClassifier
from aeon.classification.feature_based import Catch22Classifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LinearRegression, LogisticRegression
from aeon.classification.deep_learning import InceptionTimeClassifier
from aeon.classification.deep_learning import ResNetClassifier

from f1_etl import (
    DataConfig,
    SessionConfig,
    create_safety_car_dataset,
)
from f1_etl.train import (
    ModelEvaluationSuite,
    create_metadata_from_f1_dataset,
    prepare_data_with_validation,
    create_model_metadata,
    train_and_validate_model,
    evaluate_on_external_dataset,
    compare_performance_across_datasets,
)

In [5]:
drivers=["1"]

# 1. Load dataset
data_config = DataConfig(
    sessions=[
        SessionConfig(2024, "Qatar Grand Prix", "R"),
    ],
    drivers=drivers,
    include_weather=False,
)

dataset = create_safety_car_dataset(
    config=data_config,
    window_size=50,
    prediction_horizon=100,
    normalize=True,
    target_column="TrackStatus",
    # resampling_strategy="adasyn",
)

# 2. Create metadata
dataset_metadata = create_metadata_from_f1_dataset(
    data_config=data_config,
    dataset=dataset,
    features_used="multivariate_all_9_features",
)

# 3. Prepare data
splits = prepare_data_with_validation(dataset, val_size=0.00, test_size=0.30)
class_names = list(dataset["label_encoder"].class_to_idx.keys())

# 5. Train model

# model_name = f"rocket_driver{"_".join(drivers)}"
# model = RocketClassifier(n_kernels=1000)

model_name = f"rocket_rf_driver{"_".join(drivers)}"
model = RocketClassifier(
    n_kernels=1000,
    estimator=RandomForestClassifier(
        n_estimators=100,
        random_state=42,
        # class_weight=cls_weight,
        max_depth=10,
    ),
)

# model_name = f"catch22_rf_driver{"_".join(drivers)}"
# model = Catch22Classifier(
#     estimator=RandomForestClassifier(
#         n_estimators=100,
#         random_state=42,
#         # class_weight=cls_weight,
#         max_depth=10,
#     ),
#     outlier_norm=True,
#     random_state=42,
# )

# model_name = f"catch22_logistic_driver{"_".join(drivers)}"
# model = Catch22Classifier(
#     estimator=LogisticRegression(
#         random_state=42,
#         max_iter=3000,
#         # solver='liblinear',
#         solver="saga",
#         penalty="l1",
#         C=0.1,
#         # class_weight=cls_weight,
#     ),
#     outlier_norm=True,
#     random_state=42,
# )

# TODO fix for binary classification
# model_name = f"catch22_linear_driver{"_".join(drivers)}"
# model = Catch22Classifier(
#     estimator=LinearRegression(),
#     outlier_norm=True,
#     random_state=42,
# )

# model_name = f"inceptiontime_driver{'_'.join(drivers)}"
# model = InceptionTimeClassifier(
#     n_classifiers=5,  # ensemble of 5 models
#     depth=6,  # network depth
#     n_filters=32,  # number of filters
#     n_epochs=100,
#     batch_size=16,
#     random_state=42,
#     verbose=False,
# )

# model_name = f"resnet_driver{'_'.join(drivers)}"
# model = ResNetClassifier(
#     n_residual_blocks=3,
#     n_filters=[128, 256, 128],
#     n_epochs=100,
#     batch_size=16,
#     random_state=42,
#     verbose=False
# )

run_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{model_name}"

evaluator = ModelEvaluationSuite(
    output_dir="evaluation_results",
    run_id=run_id,
)

model_metadata = create_model_metadata(
    model_name=model_name,
    model=model,
)

training_results = train_and_validate_model(
    model=model,
    splits=splits,
    class_names=class_names,
    evaluator=evaluator,
    dataset_metadata=dataset_metadata,
    model_metadata=model_metadata,
)

2025-07-11 20:44:36,691 - f1_etl - INFO - Preprocessing configuration:
2025-07-11 20:44:36,692 - f1_etl - INFO -   Missing values: enabled (forward_fill)
2025-07-11 20:44:36,692 - f1_etl - INFO -   Normalization: enabled (standard)
2025-07-11 20:44:36,693 - f1_etl - INFO -   Resampling: disabled
2025-07-11 20:44:36,693 - f1_etl - INFO - Driver configuration:
2025-07-11 20:44:36,694 - f1_etl - INFO -   Global drivers: ['1']
2025-07-11 20:44:36,695 - f1_etl - INFO -   Qatar Grand Prix: ['1']


Loading session: 2024 Qatar Grand Prix R


core           INFO 	Loading data for Qatar Grand Prix - Race [v3.5.3]
req            INFO 	Using cached data for session_info
req            INFO 	Using cached data for driver_info
req            INFO 	Using cached data for session_status_data
req            INFO 	Using cached data for lap_count
req            INFO 	Using cached data for track_status_data
req            INFO 	Using cached data for _extended_timing_data
req            INFO 	Using cached data for timing_app_data
core           INFO 	Processing timing data...
req            INFO 	Using cached data for car_data
req            INFO 	Using cached data for position_data
req            INFO 	Using cached data for weather_data
req            INFO 	Using cached data for race_control_messages
core           INFO 	Finished loading data for 20 drivers: ['1', '16', '81', '63', '10', '55', '14', '24', '20', '4', '77', '44', '22', '30', '23', '27', '11', '18', '43', '31']
2025-07-11 20:49:38,866 - f1_etl - INFO - Creating new fixed v


📊 Track Status Analysis (training_data):
   green       : 57940 samples ( 83.7%)
   safety_car  : 10503 samples ( 15.2%)
   vsc         :   146 samples (  0.2%)
   yellow      :   613 samples (  0.9%)
   Missing classes: [np.str_('red'), np.str_('unknown'), np.str_('vsc_ending')]
✅ FixedVocabTrackStatusEncoder fitted
   Classes seen: ['green', 'safety_car', 'vsc', 'yellow']
   Total classes: 7
   Output mode: integer labels


2025-07-11 20:49:39,399 - f1_etl - INFO - Total sequences generated: 2763
2025-07-11 20:49:39,403 - f1_etl - INFO - Generated 2763 sequences with shape (2763, 50, 9)
2025-07-11 20:49:39,404 - f1_etl - INFO - No missing values detected, skipping imputation
2025-07-11 20:49:39,404 - f1_etl - INFO - Applying normalization with method: standard
2025-07-11 20:49:39,415 - f1_etl - INFO - Final dataset summary:
2025-07-11 20:49:39,416 - f1_etl - INFO -   Sequences: 2763
2025-07-11 20:49:39,416 - f1_etl - INFO -   Features: 9
2025-07-11 20:49:39,417 - f1_etl - INFO -   Classes: 7 (integer)
2025-07-11 20:49:39,417 - f1_etl - INFO -   Label shape: (2763,)
2025-07-11 20:49:39,418 - f1_etl - INFO -     green       :  2312 samples ( 83.7%)
2025-07-11 20:49:39,418 - f1_etl - INFO -     safety_car  :   420 samples ( 15.2%)
2025-07-11 20:49:39,419 - f1_etl - INFO -     vsc         :     6 samples (  0.2%)
2025-07-11 20:49:39,419 - f1_etl - INFO -     yellow      :    25 samples (  0.9%)



=== DATA SPLIT SUMMARY ===
Total samples: 2,763
Train: 1,934 (70.0%)
Val:   None (skipped)
Test:  729 (26.4%) - removed 100 samples

Train class distribution:
  Class 0: 1,775 (91.8%)
  Class 2: 140 (7.2%)
  Class 6: 19 (1.0%)

Test class distribution:
  Class 0: 500 (68.6%)
  Class 2: 217 (29.8%)
  Class 4: 6 (0.8%)
  Class 6: 6 (0.8%)

TRAINING WITH TEST: rocket_rf_driver1
Training on train set...

Validation set not available (val_size=0.0)

Running full evaluation on test set...

EVALUATING: ROCKET_RF_DRIVER1
Evaluation ID: rocket_rf_driver1_20250711_204939_rocket_rf_driver1
Training model...
Generating predictions...

📊 OVERALL PERFORMANCE
Accuracy:    0.7599
F1-Macro:    0.3153
F1-Weighted: 0.7058

🎯 TARGET CLASS ANALYSIS: SAFETY_CAR
Precision:       0.9344
Recall:          0.2627
F1-Score:        0.4101
True Positives:    57
False Negatives:  160 (missed safety_car events)
False Positives:    4 (false safety_car alarms)
True Negatives:   508

📈 PER-CLASS PERFORMANCE
green      

If you need to write the model to disk:

In [7]:
import pickle
import joblib
from pathlib import Path

# Save model
model_path = Path("models/my_model.pkl")
model_path.parent.mkdir(exist_ok=True)

# with open(model_path, 'wb') as f:
#     pickle.dump(training_results['model'], f)

# Or using joblib (often better for sklearn-based models)
joblib.dump(training_results['model'], "models/my_model.joblib")

print(f"Model saved to {model_path}")

Model saved to models/my_model.pkl


Load the model back into memory:

In [1]:
import pickle
import joblib
from pathlib import Path

# Load with error handling
def load_model_safely(model_path):
    try:
        if model_path.suffix == '.joblib':
            return joblib.load(model_path)
        else:
            with open(model_path, 'rb') as f:
                return pickle.load(f)
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

# Load your model
model_path = Path("models/my_model.joblib")  # or .pkl
your_model = load_model_safely(model_path)

if your_model is not None:
    print("Model loaded successfully")
    # Verify it works
    print(f"Model type: {type(your_model)}")
else:
    print("Failed to load model")

Model loaded successfully
Model type: <class 'aeon.classification.convolution_based._rocket.RocketClassifier'>


Check the model resource usage:

In [3]:
import sys
import pickle
from pympler import asizeof

# Memory footprint of loaded model
model_memory = asizeof.asizeof(your_model)
print(f"Model memory: {model_memory / 1024**2:.2f} MB")

# Serialized size (storage footprint)
serialized_size = len(pickle.dumps(your_model))
print(f"Serialized size: {serialized_size / 1024**2:.2f} MB")

Model memory: 0.26 MB
Serialized size: 1.01 MB


In [6]:
# 6. Evaluate on external dataset
external_config = DataConfig(
    sessions=[
        SessionConfig(2024, "Canadian Grand Prix", "R"),
        SessionConfig(2024, "Saudi Arabian Grand Prix", "R"),
    ],
    drivers=drivers,  # Use same drivers as training
    include_weather=False,
)

external_dataset = create_safety_car_dataset(
    config=external_config,
    window_size=50,
    prediction_horizon=100,
    normalize=True,
    target_column="TrackStatus",
)

2025-07-11 20:56:57,305 - f1_etl - INFO - Preprocessing configuration:
2025-07-11 20:56:57,306 - f1_etl - INFO -   Missing values: enabled (forward_fill)
2025-07-11 20:56:57,307 - f1_etl - INFO -   Normalization: enabled (standard)
2025-07-11 20:56:57,308 - f1_etl - INFO -   Resampling: disabled
2025-07-11 20:56:57,308 - f1_etl - INFO - Driver configuration:
2025-07-11 20:56:57,309 - f1_etl - INFO -   Global drivers: ['1']
2025-07-11 20:56:57,309 - f1_etl - INFO -   Canadian Grand Prix: ['1']
2025-07-11 20:56:57,309 - f1_etl - INFO -   Saudi Arabian Grand Prix: ['1']
core           INFO 	Loading data for Canadian Grand Prix - Race [v3.5.3]
req            INFO 	Using cached data for session_info
req            INFO 	Using cached data for driver_info


Loading session: 2024 Canadian Grand Prix R


req            INFO 	Using cached data for session_status_data
req            INFO 	Using cached data for lap_count
req            INFO 	Using cached data for track_status_data
req            INFO 	Using cached data for _extended_timing_data
req            INFO 	Using cached data for timing_app_data
core           INFO 	Processing timing data...
req            INFO 	Using cached data for car_data
req            INFO 	Using cached data for position_data
req            INFO 	Using cached data for weather_data
req            INFO 	Using cached data for race_control_messages
core           INFO 	Finished loading data for 20 drivers: ['1', '4', '63', '44', '81', '14', '18', '3', '10', '31', '27', '20', '77', '22', '24', '55', '23', '11', '16', '2']
core           INFO 	Loading data for Saudi Arabian Grand Prix - Race [v3.5.3]
req            INFO 	Using cached data for session_info
req            INFO 	Using cached data for driver_info


Loading session: 2024 Saudi Arabian Grand Prix R


req            INFO 	Using cached data for session_status_data
req            INFO 	Using cached data for lap_count
req            INFO 	Using cached data for track_status_data
req            INFO 	Using cached data for _extended_timing_data
req            INFO 	Using cached data for timing_app_data
core           INFO 	Processing timing data...
req            INFO 	Using cached data for car_data
req            INFO 	Using cached data for position_data
req            INFO 	Using cached data for weather_data
req            INFO 	Using cached data for race_control_messages
core           INFO 	Finished loading data for 20 drivers: ['1', '11', '16', '81', '14', '63', '38', '4', '44', '27', '23', '20', '31', '2', '22', '3', '77', '24', '18', '10']
2025-07-11 20:57:01,478 - f1_etl - INFO - Creating new fixed vocabulary encoder
2025-07-11 20:57:01,509 - f1_etl - INFO - Processing 117259 total telemetry rows
2025-07-11 20:57:01,509 - f1_etl - INFO - Grouping by: ['SessionId', 'Driver']



📊 Track Status Analysis (training_data):
   green       : 101246 samples ( 86.3%)
   safety_car  : 11292 samples (  9.6%)
   yellow      :  4721 samples (  4.0%)
   Missing classes: [np.str_('red'), np.str_('unknown'), np.str_('vsc'), np.str_('vsc_ending')]
✅ FixedVocabTrackStatusEncoder fitted
   Classes seen: ['green', 'safety_car', 'yellow']
   Total classes: 7
   Output mode: integer labels


2025-07-11 20:57:02,406 - f1_etl - INFO - Total sequences generated: 4679
2025-07-11 20:57:02,415 - f1_etl - INFO - Generated 4679 sequences with shape (4679, 50, 9)
2025-07-11 20:57:02,417 - f1_etl - INFO - No missing values detected, skipping imputation
2025-07-11 20:57:02,417 - f1_etl - INFO - Applying normalization with method: standard
2025-07-11 20:57:02,438 - f1_etl - INFO - Final dataset summary:
2025-07-11 20:57:02,439 - f1_etl - INFO -   Sequences: 4679
2025-07-11 20:57:02,439 - f1_etl - INFO -   Features: 9
2025-07-11 20:57:02,439 - f1_etl - INFO -   Classes: 7 (integer)
2025-07-11 20:57:02,440 - f1_etl - INFO -   Label shape: (4679,)
2025-07-11 20:57:02,440 - f1_etl - INFO -     green       :  4046 samples ( 86.5%)
2025-07-11 20:57:02,440 - f1_etl - INFO -     safety_car  :   450 samples (  9.6%)
2025-07-11 20:57:02,440 - f1_etl - INFO -     yellow      :   183 samples (  3.9%)


In [7]:
model = training_results['model']

In [8]:
external_dataset.keys()

dict_keys(['X', 'y', 'y_raw', 'metadata', 'label_encoder', 'feature_engineer', 'raw_telemetry', 'class_distribution', 'all_classes', 'n_classes', 'config'])

In [30]:
meta_0

{'start_time': Timestamp('2024-06-09 17:59:04.557000'),
 'end_time': Timestamp('2024-06-09 17:59:10.652000'),
 'prediction_time': Timestamp('2024-06-09 17:59:23.417000'),
 'sequence_length': 50,
 'prediction_horizon': 100,
 'features_used': ['Speed',
  'RPM',
  'nGear',
  'Throttle',
  'Brake',
  'X',
  'Y',
  'Distance',
  'DifferentialDistance'],
 'target_column': 'TrackStatus',
 'SessionId': '2024_Canadian Grand Prix_R',
 'Driver': '1'}

In [28]:
raw = external_dataset['raw_telemetry']
raw.shape

(117259, 24)

In [29]:
X_t = external_dataset['X'].transpose(0, 2, 1)

X_t.shape

(4679, 9, 50)

In [16]:
X_0.shape

(9, 50)

In [19]:
y_pred = model.predict(X_t)

In [23]:
print(y_pred.shape)
print(type(y_pred))

(4679,)
<class 'numpy.ndarray'>


In [27]:
X_t.shape

(4679, 9, 50)

In [25]:
import pandas as pd

y_pred_df = pd.DataFrame(y_pred)
y_pred_df.value_counts()


0
0    4361
2     318
Name: count, dtype: int64

In [26]:
y = dataset['y']

y_df = pd.DataFrame(y)
y_df.value_counts()

0
0    2312
2     420
6      25
4       6
Name: count, dtype: int64