## Ne fonctionne pas avec: 
- 'MTGNN'
- 'DCGRU'
- 'LSTM'
- 'RNN'
- 'CNN'
- 'GRU'

In [None]:
import os
import sys
import torch
import numpy as np
import random
import pandas as pd  
import warnings
warnings.simplefilter("error", category=RuntimeWarning)

# === FIXE LA SEED POUR REPRODUCTIBILITÉ ===
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# === CHEMIN ET IMPORTS ===
current_path = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_path, '..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

from examples.train_and_visu_non_recurrent import evaluate_config,train_the_config,get_ds
from examples.benchmark import local_get_args
from high_level_DL_method import load_optimizer_and_scheduler
from dl_models.full_model import full_model
from trainer import Trainer

# === PARAMÈTRES GÉNÉRAUX ===
SEED = 42
EPOCHS = 1  # une seule epoch
station = ['BEL','PAR','AMP','SAN','FLA']
training_mode_to_visualise = ['test']

model_names = ['STGCN','ASTGCN']   # 'DCGRU','MTGNN','LSTM','RNN','CNN','GRU'
dataset_for_coverage = ['subway_in', 'netmob_image_per_station']
dataset_names= ["subway_in","subway_out"] #     ["subway_in", "netmob_POIs"] #    ["subway_in"]

def get_modification(dataset_names):
    # Définir la base de la modification
    modification = {
        'epochs': EPOCHS,
        'lr': 5e-5,
        'weight_decay': 0.05,
        'dropout': 0.15,
        'scheduler': None,
        'adj_type': 'corr',
        'threshold': 0.7,
        'stacked_contextual': True,
        'target_data': 'subway_in',
        'compute_node_attr_with_attn': False,
        'use_target_as_context': False,
    }

    if "netmob_POIs" in dataset_names:
        modification.update({
            'NetMob_only_epsilon': True,
            'NetMob_selected_apps': ['Google_Maps'],
            'NetMob_transfer_mode': ['DL'],
            'NetMob_selected_tags': ['station_epsilon100'],
            'NetMob_expanded': ''
        })
    return modification

def load_inputs(model_name,dataset_names):
    # Init args
    modification = get_modification(dataset_names=dataset_names)
    args_init = local_get_args(model_name,args_init=None,dataset_names=dataset_names,dataset_for_coverage=dataset_for_coverage,modification=modification)

    # Load ds
    ds,args,trial_id,save_folder,df_loss = get_ds(modification=modification,args_init=args_init)
    return ds,args,trial_id,save_folder

# === FONCTION POUR UNE CONFIG SPÉCIFIQUE ===
def run_test(model_names, dataset_names):
    set_seed(SEED)
    df = pd.DataFrame()
    ds,args,trial_id,save_folder = load_inputs(model_names[0],dataset_names)
    for model_name in model_names:
        print(f"\n=== TESTING {model_name} on {dataset_names} ===")
        args.model_name = model_name
        model = full_model(ds, args).to(args.device)
        optimizer,scheduler,loss_function = load_optimizer_and_scheduler(model,args)
        trainer = Trainer(ds,model,args,optimizer,loss_function,scheduler = scheduler,show_figure = False,trial_id = trial_id, fold=0,save_folder = save_folder)
        trainer.train_and_valid(normalizer = ds.normalizer, mod = 1000,mod_plot = None) 

        mse_test = trainer.performance['test_metrics']['mse']
        mse_valid = trainer.performance['valid_metrics']['mse']
        df = pd.concat([df,pd.DataFrame({'mse_test': [mse_test], 'mse_valid': [mse_valid]}, index=[model_name])], axis=0)

    print("=== TEST COMPLETED ===")
    return ds,trainer,df

# === LANCEMENT GLOBAL ===
if __name__ == "__main__":
    ds,last_trainer,df = run_test(model_names, dataset_names)
    
    # Afficher le DataFrame des résultats
    print("\n=== RÉSULTATS ===")
    display(df)
    
    # Créer un DataFrame de référence pour vérification
    checking = pd.DataFrame({
        'mse_test': [5988.421875, 48209.429688],
        'mse_valid': [6884.651855, 71904.921875]
    }, index=['STGCN', 'ASTGCN'])

    print("\n=== HAS TO BE EQUAL TO: ===")
    display(checking)

    print("\n=== ABSOLUTE DIFFERENCE BETWEEN BOTH DF: ===")
    display(abs(checking-df))

Training and Hyper-parameter tuning with Ray is not possible
----------------------------------------
Loading the initial dataset for K-fold splitting
Coverage Period: 7392 elts between 2019-03-16 00:00:00 and 2019-05-31 23:45:00
Invalid dates within this fold: 776

>>>Tackle Target dataset: subway_in
   Load data from: /home/rrochas/prediction-validation/../../../../data/rrochas/prediction_validation/subway_in/subway_in.csv
   Init Dataset: 'torch.Size([7392, 40]). 0 Nan values
   TRAIN contextual_ds: torch.Size([2821, 40, 7])
   VALID contextual_ds: torch.Size([940, 40, 7])
   TEST contextual_ds: torch.Size([940, 40, 7])

>>>Tackle Contextual dataset:  subway_out
   Load data from: /home/rrochas/prediction-validation/../../../../data/rrochas/prediction_validation/subway_out/subway_out.csv
T_subway_out:  torch.Size([7392, 40])
   Init Dataset: '[torch.Size([7392, 40])]. [tensor(0)] Nan values
   TRAIN contextual_ds: [torch.Size([2821, 40, 7])]
   VALID contextual_ds: [torch.Size([940,

Unnamed: 0,mse_test,mse_valid
STGCN,3018.133789,5507.885254
ASTGCN,31862.908203,65793.15625



=== HAS TO BE EQUAL TO: ===


Unnamed: 0,mse_test,mse_valid
STGCN,5988.421875,6884.651855
ASTGCN,48209.429688,71904.921875



=== ABSOLUTE DIFFERENCE BETWEEN BOTH DF: ===


Unnamed: 0,mse_test,mse_valid
STGCN,2970.288086,1376.766601
ASTGCN,16346.521485,6111.765625


In [11]:
import os 
import pandas as pd
csv_path = '/home/rrochas/prediction-validation/../../../../data/rrochas/prediction_validation/agg_data/validation_individuelle/subway_indiv_15min/subway_indiv_15min.csv'
print(f"   Load data from: {csv_path}.csv")

DATE_COL = 'VAL_DATE'

df = pd.read_csv(csv_path)
df[DATE_COL] = pd.to_datetime(df[DATE_COL])
df.set_index(DATE_COL, inplace=True)

display(df)
START = '2019-11-01'
END = '2020-04-30 23:30:00'
freq = '15min'
reindex = pd.date_range(start=START, end=END, freq=freq)[:-1]
df = df.reindex(reindex).fillna(0)
df

   Load data from: /home/rrochas/prediction-validation/../../../../data/rrochas/prediction_validation/agg_data/validation_individuelle/subway_indiv_15min/subway_indiv_15min.csv.csv


Unnamed: 0_level_0,AMP,BEL,BRO,CHA,COR,CPA,CRO,CUI,CUS,DEB,...,PER,GUI,JAU,REP,SAN,SAX,GER,VMY,SOI,JEA
VAL_DATE,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2019-11-01 00:00:00,20.0,164.0,10.0,59.0,88.0,5.0,20.0,1.0,5.0,21.0,...,89.0,13.0,16.0,12.0,49.0,68.0,8.0,23.0,3.0,59.0
2019-11-01 00:15:00,6.0,82.0,11.0,30.0,43.0,3.0,8.0,0.0,2.0,3.0,...,19.0,7.0,7.0,3.0,11.0,33.0,2.0,16.0,0.0,38.0
2019-11-01 00:30:00,0.0,4.0,0.0,0.0,2.0,2.0,6.0,0.0,1.0,0.0,...,0.0,0.0,0.0,3.0,5.0,2.0,0.0,3.0,1.0,0.0
2019-11-01 00:45:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2019-11-01 04:15:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2020-03-30 22:15:00,0.0,5.0,2.0,8.0,4.0,0.0,2.0,2.0,4.0,1.0,...,4.0,2.0,0.0,1.0,0.0,9.0,0.0,3.0,8.0,4.0
2020-03-30 22:30:00,1.0,5.0,0.0,13.0,0.0,0.0,2.0,1.0,2.0,0.0,...,3.0,2.0,1.0,1.0,2.0,1.0,5.0,5.0,7.0,0.0
2020-03-30 22:45:00,1.0,4.0,0.0,11.0,1.0,0.0,2.0,0.0,4.0,4.0,...,3.0,1.0,1.0,1.0,0.0,2.0,2.0,2.0,3.0,1.0
2020-03-30 23:00:00,1.0,5.0,0.0,2.0,3.0,0.0,0.0,0.0,2.0,0.0,...,0.0,2.0,2.0,0.0,0.0,6.0,0.0,0.0,0.0,0.0


Unnamed: 0,AMP,BEL,BRO,CHA,COR,CPA,CRO,CUI,CUS,DEB,...,PER,GUI,JAU,REP,SAN,SAX,GER,VMY,SOI,JEA
2019-11-01 00:00:00,20.0,164.0,10.0,59.0,88.0,5.0,20.0,1.0,5.0,21.0,...,89.0,13.0,16.0,12.0,49.0,68.0,8.0,23.0,3.0,59.0
2019-11-01 00:15:00,6.0,82.0,11.0,30.0,43.0,3.0,8.0,0.0,2.0,3.0,...,19.0,7.0,7.0,3.0,11.0,33.0,2.0,16.0,0.0,38.0
2019-11-01 00:30:00,0.0,4.0,0.0,0.0,2.0,2.0,6.0,0.0,1.0,0.0,...,0.0,0.0,0.0,3.0,5.0,2.0,0.0,3.0,1.0,0.0
2019-11-01 00:45:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2019-11-01 01:00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2020-04-30 22:15:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2020-04-30 22:30:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2020-04-30 22:45:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2020-04-30 23:00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


  df = df.reindex(reindex).fillna(0)


Unnamed: 0,VAL_DATE,AMP,BEL,BRO,CHA,COR,CPA,CRO,CUI,CUS,...,PER,GUI,JAU,REP,SAN,SAX,GER,VMY,SOI,JEA
2019-11-01 00:00:00,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2019-11-01 00:15:00,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2019-11-01 00:30:00,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2019-11-01 00:45:00,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2019-11-01 01:00:00,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2020-04-30 22:15:00,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2020-04-30 22:30:00,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2020-04-30 22:45:00,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2020-04-30 23:00:00,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [None]:
    df = df.set_index(DATE_COL)
    df_reindexed = df[df.index.isin(coverage_period)].copy()