In [None]:
from pytorch_forecasting import TemporalFusionTransformer
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_forecasting.metrics import RMSE
import pandas as pd 
import numpy as np
from pytorch_forecasting import TimeSeriesDataSet
import holidays


In [None]:
# 1. Charger et nettoyer
df = pd.read_csv("/home/jathur/Bureau/Projects/TeleFinance/data/raw/History.csv")
df["Date"] = pd.to_datetime(df["Date"])
# 2. Trier proprement
df = df.sort_values(["Date", "ticket"]).reset_index(drop=True)
# 3. Créer un time index global
df["time_idx"] = df["Date"].rank(method="dense").astype(int)
# 4. (optionnel) vérifier la couverture par ticket
print(df.groupby("ticket")["time_idx"].agg(["min", "max", "count"]))

In [None]:

# Générer les jours fériés français
french_holidays = holidays.France(years=df["Date"].dt.year.unique())
# Si tu as aussi des actions US
us_holidays = holidays.US(years=df["Date"].dt.year.unique())
# Ajouter une colonne pays (ex: en te basant sur le suffixe du ticket)
df["market"] = df["ticket"].apply(lambda x: "FR" if x.endswith(".PA") else "US")

# Ajouter colonne is_holiday en fonction du pays
def is_market_holiday(row):
    date = row["Date"].date()
    if row["market"] == "FR":
        return date in french_holidays
    elif row["market"] == "US":
        return date in us_holidays
    return False

df["is_holiday"] = df.apply(is_market_holiday, axis=1)


In [None]:
# Assurer un time_idx propre
full_df = df.copy()
full_df["time_idx"] = full_df.groupby("ticket").cumcount()

# Filtrer les groupes suffisamment longs
min_len = 90
valid_tickets = full_df.groupby("ticket").filter(lambda x: len(x) >= min_len)["ticket"].unique()

filtered_df = full_df[full_df["ticket"].isin(valid_tickets)]

# Split : les 90 derniers points pour chaque série
val_df = filtered_df.groupby("ticket").tail(min_len)
train_df = filtered_df.drop(val_df.index)


In [None]:
dataset_params = dict(
    time_idx="time_idx",
    target="Close",
    group_ids=["ticket"],
    max_encoder_length=max_encoder_length,
    max_prediction_length=max_prediction_length,
    time_varying_known_reals=["time_idx"],
    time_varying_unknown_reals=["Close", "Volume"],
    allow_missing_timesteps=True
)

In [None]:
train_dataset = TimeSeriesDataSet(train_df, **dataset_params)
val_dataset = TimeSeriesDataSet(val_df, **dataset_params)


In [None]:
print("Train dataset:", len(train_dataset))
print("Val dataset:", len(val_dataset))


In [None]:
train_loader = train_dataset.to_dataloader(train=True, batch_size=64)
val_loader = val_dataset.to_dataloader(train=False, batch_size=64)


In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor

trainer = Trainer(
    max_epochs=20,
    accelerator="auto",  # ou "cpu" si tu ne veux pas laisser gérer automatiquement
    callbacks=[
        EarlyStopping(monitor="val_loss", patience=3),
        LearningRateMonitor(logging_interval="epoch")
    ]
)


In [None]:
trainer.fit(
    model=tft,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader
)
