In [113]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go

from darts import TimeSeries
from darts.dataprocessing.transformers import Scaler
from darts.models import RNNModel
from darts.metrics import mae, mse
from darts.utils.timeseries_generation import datetime_attribute_timeseries

import warnings
warnings.filterwarnings("ignore")

# === 1. Prétraitement des données ===
def preprocess_CPI_data(csv_path="services/data/raw/ConsumerPriceIndexCPI2019AsBaseYearMonthly.csv"):
    # Lecture du fichier CSV d'origine
    df = pd.read_csv(csv_path)
    
    df["DataSeries"] = df["DataSeries"].str.strip()
    
    melted = pd.melt(df, id_vars=["DataSeries"], var_name="year_month", value_name="CPI")
    
    pivot_df = melted.pivot(index="year_month", columns="DataSeries", values="CPI").reset_index()

    pivot_df["year_month"] = pivot_df["year_month"].apply(lambda x: x[:4] + "-" + x[4:])

    pivot_df = pivot_df[['year_month', 'All Items', 'Food', 'Clothing & Footwear', 'Housing & Utilities', 'Household Durables & Services', 'Health Care', 'Transport', 'Communication', 'Recreation & Culture', 'Education', 'Personal Care', 'Alcoholic Drinks & Tobacco', 'Public Transport']]
    return pivot_df


In [114]:
dataframe = preprocess_CPI_data()[['year_month', 'All Items']]

# on converti dataframe["All Items"] en float

dataframe["All Items"] = dataframe["All Items"].astype(float)

# === 2. Conversion en série temporelle Darts ===
series = TimeSeries.from_dataframe(dataframe, time_col="year_month", value_cols="All Items")

In [115]:
# === 3. Normalisation ===
scaler = Scaler()
series_scaled = scaler.fit_transform(series)

In [116]:
# === 4. Split train/val (80/20) ===
train, val = series_scaled.split_after(0.96)

In [117]:
# === 5. Création de covariates temporelles (Année, Mois en sin/cos) ===
year_series = datetime_attribute_timeseries(series, attribute="year", one_hot=False)
year_series = Scaler().fit_transform(year_series)

month_series = datetime_attribute_timeseries(series, attribute="month", one_hot=True)
covariates = year_series.stack(month_series)

cov_train, cov_val = covariates.split_after(0.96)

In [118]:
# === 6. Définition du modèle LSTM avec Darts ===
model = RNNModel(
    model="LSTM",
    hidden_dim=80,
    dropout=0.1,
    batch_size=32,
    n_epochs=100,
    optimizer_kwargs={"lr": 5e-4},
    model_name="CPI_LSTM",
    random_state=42,
    training_length=24,  
    input_chunk_length=24,  # Augmenté
    force_reset=True,
    save_checkpoints=True
)


In [119]:
# === 7. Entraînement ===
model.fit(
    train,
    future_covariates=covariates,
    val_series=val,
    val_future_covariates=covariates,
    verbose=True
)

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

RNNModel(model=LSTM, hidden_dim=80, n_rnn_layers=1, dropout=0.1, training_length=24, batch_size=32, n_epochs=100, optimizer_kwargs={'lr': 0.0005}, model_name=CPI_LSTM, random_state=42, input_chunk_length=24, force_reset=True, save_checkpoints=True)

In [120]:
# === 8. Prédictions ===
pred_series = model.predict(n=len(val), future_covariates=covariates)

# === 9. Calcul des métriques ===
mae_score = mae(val, pred_series)
mse_score = mse(val, pred_series)

print(f"MAE: {mae_score:.4f}, MSE: {mse_score:.4f}")

Predicting: |          | 0/? [00:00<?, ?it/s]

MAE: 0.0261, MSE: 0.0007


In [121]:
dataframe["All Items"]

0       24.187
1       24.517
2       24.487
3       24.565
4       24.542
        ...   
763    115.662
764    116.574
765    116.792
766    116.756
767    117.123
Name: All Items, Length: 768, dtype: float64

In [122]:
import plotly.graph_objects as go

# Inversion correcte des prédictions
pred_dates = dataframe["year_month"][-len(pred_series):]
pred_values = scaler.inverse_transform(pred_series).values().flatten()

# on converti dataframe["All Items"] en float

dataframe["All Items"] = dataframe["All Items"].astype(float)

# Définir une échelle Y continue et uniforme
y_min = min(dataframe["All Items"].min(), pred_values.min()) * 0.95
y_max = max(dataframe["All Items"].max(), pred_values.max()) * 1.05

# Création du graphique
fig = go.Figure()

# Ajout des valeurs réelles
fig.add_trace(go.Scatter(
    x=dataframe["year_month"], 
    y=dataframe["All Items"], 
    mode='lines', 
    name='Réel', 
    line=dict(color='blue', width=2)
))

# Ajout des prédictions
fig.add_trace(go.Scatter(
    x=pred_dates, 
    y=pred_values, 
    mode='lines', 
    name='Prédictions', 
    line=dict(color='red', width=2)
))

# Mise en forme
fig.update_layout(
    title="Prédictions CPI avec LSTM",
    xaxis_title="Date",
    yaxis_title="CPI (Indexation 2018)",
    template="plotly_white",
    xaxis=dict(showgrid=True, tickangle=-45, tickformat="%Y-%b"),
    yaxis=dict(showgrid=True, zeroline=False, range=[y_min, y_max]),  # 🔥 Correction ici !
    legend=dict(x=0, y=1)
)

# Affichage du graphique
fig.show()
