In [None]:
import numpy as np
import joblib
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from embedded_tabular_data import prepare_tabular_dataset

# RF est invariant à la permutation des features : Il ne multiplie jamais ou ne regroupe jamais plusieurs 
# features voisines pour détecter une relation locale (comme le ferait un CNN ou une fenêtre glissante paramétrique)

In [19]:
def train_rf(X_train, y_train, X_val, y_val, 
             n_estimators=100, max_depth=30, max_features='sqrt', random_state=42,
             save_model_path=None):
    """
    Fit un Random Forest sur X_train/y_train, évalue sur X_val/y_val.
    Calcule accuracy et Q3, et sauvegarde le modèle si demandé.
    """
    # Création du modèle
    rf = RandomForestClassifier(
        n_estimators=n_estimators,
        max_depth=max_depth,
        max_features=max_features,
        n_jobs=-1,
        random_state=random_state,
        verbose=2
    )
    
    print("Fitting Random Forest...")
    rf.fit(X_train, y_train)
    
    # Prédiction sur le set de validation
    y_pred = rf.predict(X_val)
    
    # Accuracy
    acc = accuracy_score(y_val, y_pred)
    
    # Q3 pour classification secondaire (3 classes typiques H/E/L)
    classes = ['H', 'E', 'L']
    mask = np.isin(y_val, classes)
    q3 = np.mean(y_val[mask] == y_pred[mask])
    
    print(f"Accuracy: {acc:.4f}")
    print(f"Q3: {q3:.4f}")
    
    # Matrice de confusion
    cm = confusion_matrix(y_val, y_pred, labels=classes)
    print("Confusion matrix (H, E, L):")
    print(cm)
    
    # Sauvegarde du modèle
    if save_model_path is not None:
        joblib.dump(rf, save_model_path)
        print(f"Modèle sauvegardé dans {save_model_path}")
    
    return rf, acc, q3
 

In [8]:
X_train,y_train = prepare_tabular_dataset("../matches_subset_dssp.json",10)

In [9]:
X_val,y_val = prepare_tabular_dataset("../validation_matches_subset_dssp.json",10)

In [20]:
model, acc, q3 = train_rf(X_train, y_train, X_val, y_val, save_model_path="rf_model.joblib")

Fitting Random Forest...


[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.


building tree 4 of 100building tree 2 of 100
building tree 5 of 100
building tree 6 of 100
building tree 1 of 100
building tree 8 of 100

building tree 3 of 100
building tree 7 of 100
building tree 9 of 100
building tree 10 of 100
building tree 11 of 100building tree 12 of 100

building tree 13 of 100
building tree 14 of 100
building tree 15 of 100
building tree 16 of 100
building tree 17 of 100
building tree 18 of 100building tree 19 of 100

building tree 20 of 100
building tree 21 of 100
building tree 22 of 100
building tree 23 of 100
building tree 24 of 100
building tree 25 of 100
building tree 26 of 100
building tree 27 of 100
building tree 28 of 100
building tree 29 of 100
building tree 30 of 100
building tree 31 of 100
building tree 32 of 100
building tree 33 of 100


[Parallel(n_jobs=-1)]: Done  25 tasks      | elapsed:  2.5min


building tree 34 of 100
building tree 35 of 100
building tree 36 of 100
building tree 37 of 100
building tree 38 of 100
building tree 39 of 100
building tree 40 of 100
building tree 41 of 100
building tree 42 of 100
building tree 43 of 100
building tree 44 of 100
building tree 45 of 100
building tree 46 of 100
building tree 47 of 100
building tree 48 of 100
building tree 49 of 100
building tree 50 of 100
building tree 51 of 100
building tree 52 of 100
building tree 53 of 100
building tree 54 of 100
building tree 55 of 100
building tree 56 of 100
building tree 57 of 100
building tree 58 of 100
building tree 59 of 100
building tree 60 of 100
building tree 61 of 100
building tree 62 of 100
building tree 63 of 100
building tree 64 of 100
building tree 65 of 100building tree 66 of 100

building tree 67 of 100
building tree 68 of 100
building tree 69 of 100
building tree 70 of 100
building tree 71 of 100
building tree 72 of 100
building tree 73 of 100
building tree 74 of 100
building tree 75

[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:  8.1min finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed:    1.5s
[Parallel(n_jobs=8)]: Done 100 out of 100 | elapsed:    5.8s finished


Accuracy: 0.5302
Q3: 0.6517
Confusion matrix (H, E, L):
[[11923  1362  1439]
 [ 2501  4836  1625]
 [ 2610  1605  6874]]
Modèle sauvegardé dans rf_model.joblib


In [15]:
len(X_train)

2270581