In [None]:
###########################################
##                  TFT                  ##
###########################################
## Chargement des donn√©es
## Conversion des types
## S√©paration Train / Test
## Pr√©paration des TimeSeriesDataSet
## D√©finition des DataLoaders
## Initialisation du mod√®le TFT
## Entra√Ænement du mod√®le
## Sauvegarde et chargement du mod√®le
## Pr√©diction et affichage des r√©sultats

In [None]:
import pandas as pd
import time
import torch
from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting.metrics import QuantileLoss
from pytorch_forecasting.data.encoders import NaNLabelEncoder
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import EarlyStopping

# Charger le dataset
df = pd.read_csv(r"..\..\..\..\Datasources\MetroPT3_new_imputed_final.csv", delimiter=",", decimal=".", index_col=0)
df.reset_index(drop=True, inplace=True)
display(df.head(2))

In [None]:
# Dataset commence le 2020-04-12 11:20:00 et se termine le 2020-07-17 06:00:00
pannes = [
    {'id': 'Panne1',  'start': '2020-04-12 11:50:00', 'end': '2020-04-12 23:30:00'},
    {'id': 'Panne2',  'start': '2020-04-18 00:00:00', 'end': '2020-04-18 23:59:00'},
    {'id': 'Panne3',  'start': '2020-04-19 00:00:00', 'end': '2020-04-19 01:30:00'},
    {'id': 'Panne4',  'start': '2020-04-29 03:20:00', 'end': '2020-04-29 04:00:00'},
    {'id': 'Panne5',  'start': '2020-04-29 22:00:00', 'end': '2020-04-29 22:20:00'},
    {'id': 'Panne6',  'start': '2020-05-13 14:00:00', 'end': '2020-05-13 23:59:00'},
    {'id': 'Panne7',  'start': '2020-05-18 05:00:00', 'end': '2020-05-18 05:30:00'},
    {'id': 'Panne8',  'start': '2020-05-19 10:10:00', 'end': '2020-05-19 11:00:00'},
    {'id': 'Panne9',  'start': '2020-05-19 22:10:00', 'end': '2020-05-19 23:59:00'},
    {'id': 'Panne10', 'start': '2020-05-20 00:00:00', 'end': '2020-05-20 20:00:00'},
    {'id': 'Panne11', 'start': '2020-05-23 09:50:00', 'end': '2020-05-23 10:10:00'},
    {'id': 'Panne12', 'start': '2020-05-29 23:30:00', 'end': '2020-05-29 23:59:00'},
    {'id': 'Panne13', 'start': '2020-05-30 00:00:00', 'end': '2020-05-30 06:00:00'},
    {'id': 'Panne14', 'start': '2020-06-01 15:00:00', 'end': '2020-06-01 15:40:00'},
    {'id': 'Panne15', 'start': '2020-06-03 10:00:00', 'end': '2020-06-03 11:00:00'},
    {'id': 'Panne16', 'start': '2020-06-05 10:00:00', 'end': '2020-06-05 23:59:00'},
    {'id': 'Panne17', 'start': '2020-06-06 00:00:00', 'end': '2020-06-06 23:59:00'},
    {'id': 'Panne18', 'start': '2020-06-07 00:00:00', 'end': '2020-06-07 14:30:00'},
    {'id': 'Panne19', 'start': '2020-07-08 17:30:00', 'end': '2020-07-08 19:00:00'},
    {'id': 'Panne20', 'start': '2020-07-15 14:30:00', 'end': '2020-07-15 19:00:00'},
    {'id': 'Panne21', 'start': '2020-07-17 04:30:00', 'end': '2020-07-17 05:30:00'}
         ]

In [None]:
# Convertir timestamp en datetime
df['timestamp'] = pd.to_datetime(df['timestamp'])

# D√©finir la date de s√©paration entre Train et Test
split_date = "2020-06-05 23:59:10"  # S√©paration temporelle (Fin de la panne 16)

# S√©parer les donn√©es en train/test
df_train = df[df["timestamp"] < split_date].copy()
df_test = df[df["timestamp"] >= split_date].copy()

# Recalculer time_idx en divisant par 10 pour √©viter les grands nombres
df_train["time_idx"] = ((df_train["timestamp"] - df_train["timestamp"].min()).dt.total_seconds() // 10).astype(int)
df_test["time_idx"] = ((df_test["timestamp"] - df_test["timestamp"].min()).dt.total_seconds() // 10).astype(int)

print(f"Min time_idx in df_train: {df_train['time_idx'].min()}, Max: {df_train['time_idx'].max()}")
print(f"Min time_idx in df_test: {df_test['time_idx'].min()}, Max: {df_test['time_idx'].max()}")

In [None]:
# D√©finir les colonnes cat√©goriques et les forcer en str avant de les convertir en category
for col in ["COMP", "DV_eletric", "Towers"]:
    df_train[col] = df_train[col].astype(int).astype(str).astype("category")
    df_test[col] = df_test[col].astype(int).astype(str).astype("category")

# Conversion explicite de panne
df_train["panne"] = df_train["panne"].astype(int).astype(str).astype("category")
df_test["panne"] = df_test["panne"].astype(int).astype(str).astype("category")

print("Conversion des colonnes cat√©goriques et de panne termin√©e.")

In [None]:
# Ajouter un group_id unique pour tout le dataset
df_train["group_id"] = "compresseur"
df_test["group_id"]  = "compresseur"

print("group_id ajout√©.")

In [None]:
# D√©finir les encodeurs pour les variables cat√©goriques
categorical_encoders = {
    "COMP": NaNLabelEncoder(),
    "DV_eletric": NaNLabelEncoder(),
    "Towers": NaNLabelEncoder(),
}

print("Encodeurs cat√©goriques d√©finis.")

In [None]:
# D√©finir l'horizon de pr√©vision et la fen√™tre d'observation
max_encoder_length = 30  # Fen√™tre plus longue pour √©viter la suppression
max_prediction_length = 3  # Une seule pr√©diction pour √©viter les erreurs

train_dataset = TimeSeriesDataSet(
    df_train,
    time_idx="time_idx",
    target="panne",
    group_ids=["group_id"],  # Utiliser un ID unique par s√©rie
    time_varying_known_reals=["TP2", "H1", "DV_pressure", "Oil_temperature", "Motor_current"],
    time_varying_known_categoricals=["DV_eletric", "Towers"],
    max_encoder_length=30,  
    max_prediction_length=3,  
    min_encoder_length=10,  # Augmenter pour √©viter des s√©quences trop courtes
    min_prediction_length=3,
    categorical_encoders=categorical_encoders,
    allow_missing_timesteps=True
)


print(f"Dataset train cr√©√© avec {len(train_dataset)} s√©quences.")

In [None]:
test_dataset = TimeSeriesDataSet(
    df_test,
    time_idx="time_idx",
    target="panne",
    group_ids=["COMP"],
    time_varying_known_reals=["TP2", "H1", "DV_pressure", "Oil_temperature", "Motor_current"],
    time_varying_known_categoricals=["DV_eletric", "Towers"],
    max_encoder_length=10,  # Doit √™tre identique √† train_dataset
    max_prediction_length=3,  # Doit √™tre identique √† train_dataset
    min_encoder_length=5,
    min_prediction_length=2,
    categorical_encoders=categorical_encoders,
    allow_missing_timesteps=True
)

print(f"Dataset test cr√©√© avec {len(test_dataset)} s√©quences.")

In [None]:
print(f"Nombre de groupes dans test_dataset: {len(test_dataset.index) if hasattr(test_dataset, 'index') else 'N/A'}")

In [None]:
from pytorch_forecasting.models import TemporalFusionTransformer
from pytorch_forecasting.metrics import QuantileLoss

# D√©finition du mod√®le TFT (Pas besoin de `LightningModule`)
tft = TemporalFusionTransformer.from_dataset(
    train_dataset,
    learning_rate=1e-3,
    hidden_size=16,
    attention_head_size=4,
    dropout=0.1,
    hidden_continuous_size=8,
    loss=QuantileLoss(),
    log_interval=10,
    reduce_on_plateau_patience=4,
)

# Pas besoin d'assertion sur `LightningModule`
print(f"Mod√®le TFT initialis√© avec {tft.size()} param√®tres.")

In [None]:
from pytorch_lightning import LightningModule

# Cr√©ation du wrapper LightningModule
class TFTLightningModule(LightningModule):
    def __init__(self, tft):
        super().__init__()
        self.model = tft

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.model(x)
        loss = self.model.loss(y_pred, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        if batch is None or "x" not in batch or "y" not in batch:
            print(f"Batch {batch_idx} mal structur√© : {batch}")
            return None  # √âvite une erreur bloquante

        x = batch["x"]
        y = batch["y"]

        try:
            y_pred = self.model(x)
            loss = self.model.loss(y_pred, y)
        except KeyError as e:
            print(f"Erreur de cl√© dans `validation_step`: {e}")
            return None

        return {"val_loss": loss}


    def configure_optimizers(self):
        return self.model.configure_optimizers()

# Encapsulation du mod√®le TFT dans `LightningModule`
tft_lightning = TFTLightningModule(tft)

In [None]:
import torch

def tft_collate_fn(batch):
    """ Collate function pour le mod√®le TFT, en s'assurant que les tailles et types sont coh√©rents. """
    
    # Filtrer les √©chantillons invalides (None)
    batch = [b for b in batch if b is not None]  
    if not batch:
        return None  # Retourne None si le batch est vide pour √©viter une erreur

    try:
        # Extraire les dictionnaires de features et les labels
        feature_dicts = [b[0] for b in batch]
        labels = [b[1][0] for b in batch]  # On extrait le premier √©l√©ment du tuple des labels

        # V√©rifier que les cl√©s essentielles existent dans les features
        required_keys = ["encoder_length", "decoder_length", "x_cat", "x_cont"]  # Changer "encoder_lengths" en "encoder_length"
        missing_keys = [k for k in required_keys if k not in feature_dicts[0]]

        if missing_keys:
            raise KeyError(f"üö® Cl√©s manquantes dans le batch : {missing_keys}")

        # Cr√©ation d'un dictionnaire o√π chaque cl√© est associ√©e √† un tenseur
        batch_data = {k: torch.stack([d[k].clone().detach() if isinstance(d[k], torch.Tensor) else torch.tensor(d[k]) for d in feature_dicts]) for k in feature_dicts[0].keys()}

        # Conversion des labels en tenseur
        batch_labels = torch.stack([label.clone().detach() if isinstance(label, torch.Tensor) else torch.tensor(label) for label in labels])

        return batch_data, batch_labels

    except Exception as e:
        print(f"üö® Erreur dans `tft_collate_fn`: {e}")
        return None  # Retourne None en cas d'erreur

sample_batch = [train_dataset[i] for i in range(10)]
collate_result = tft_collate_fn(sample_batch)

if collate_result:
    print("‚úÖ `collate_fn` a bien fonctionn√©. Voici les dimensions des donn√©es :", {k: v.shape for k, v in collate_result[0].items()})
else:
    print("üö® `collate_fn` a √©chou√©.")
  

In [None]:
import torch
from pytorch_lightning.callbacks import EarlyStopping, TQDMProgressBar
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader

# ‚úÖ Fonction collate corrig√©e
def tft_collate_fn(batch):
    """ Collate function pour le mod√®le TFT, en s'assurant que les tailles et types sont coh√©rents. """
    
    batch = [b for b in batch if b is not None]  
    if not batch:
        return None  # √âviter une erreur si le batch est vide

    try:
        feature_dicts = [b[0] for b in batch]
        labels = [b[1][0] for b in batch]  

        # ‚úÖ Correction des noms des cl√©s
        required_keys = ["encoder_length", "decoder_length", "x_cat", "x_cont"]
        missing_keys = [k for k in required_keys if k not in feature_dicts[0]]
        if missing_keys:
            raise KeyError(f"üö® Cl√©s manquantes : {missing_keys}")

        batch_data = {
            k: torch.stack([
                torch.tensor(d[k]) if not isinstance(d[k], torch.Tensor) else d[k].clone().detach()
                for d in feature_dicts
            ]) 
            for k in required_keys  # On ne prend que les cl√©s essentielles
        }

        batch_labels = torch.stack([
            torch.tensor(label) if not isinstance(label, torch.Tensor) else label.clone().detach()
            for label in labels
        ])

        return batch_data, batch_labels

    except Exception as e:
        print(f"üö® Erreur `tft_collate_fn`: {e}")
        return None  

# ‚úÖ V√©rification du dataset
try:
    sample = train_dataset[0]
    print(f"‚úÖ √âchantillon dataset : {list(sample[0].keys())}")
except Exception as e:
    print(f"üö® Erreur lors du chargement du dataset : {e}")

# ‚úÖ Configuration des DataLoaders
batch_size = 64
num_workers = 0  # ‚ùå √âviter multiprocessing sous Windows
pin_memory = torch.cuda.is_available()

train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=False, 
    collate_fn=tft_collate_fn, pin_memory=pin_memory
)

test_dataloader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=False, 
    collate_fn=tft_collate_fn, pin_memory=pin_memory
)

# ‚úÖ V√©rification du premier batch
try:
    test_batch = next(iter(train_dataloader))
    if isinstance(test_batch, tuple) or test_batch is None:
        print(f"üö® Erreur : Batch mal structur√© ! Type : {type(test_batch)}")
    else:
        print(f"‚úÖ Batch structur√© : Cl√©s = {list(test_batch.keys())}")
except Exception as e:
    print(f"üö® Erreur lors du chargement du batch : {e}")

# ‚úÖ Callbacks pour l'entra√Ænement
early_stop_callback = EarlyStopping(monitor="val_loss", patience=5, verbose=True, mode="min")
progress_bar = TQDMProgressBar(refresh_rate=10)

# ‚úÖ D√©finition de l'entra√Æneur
trainer = Trainer(
    max_epochs=30,
    accelerator="auto",
    enable_progress_bar=True,
    enable_checkpointing=False,  
    callbacks=[early_stop_callback, progress_bar],  
    gradient_clip_val=0.1
)

# ‚úÖ Lancement de l'entra√Ænement
trainer.fit(tft_lightning, train_dataloader, test_dataloader)

print("‚úÖ Entra√Ænement termin√© avec succ√®s !")


In [None]:
##############################################
## 9Ô∏è‚É£ Sauvegarde et chargement du mod√®le   ##
#############################################

# Sauvegarde du mod√®le
model_path = r"..\..\Generated_Files\TFT\tft_model.ckpt"
trainer.save_checkpoint(model_path)
print("Mod√®le sauvegard√© !")

# Chargement du mod√®le
tft_loaded = TemporalFusionTransformer.load_from_checkpoint(model_path)
print("Mod√®le charg√© avec succ√®s !")


In [None]:
##############################################
## üîü Pr√©dictions avec le mod√®le           ##
#############################################

# Faire une pr√©diction sur le test set
predictions = tft.predict(test_dataloader, mode="prediction")

# Afficher les 10 premi√®res pr√©dictions
print("Pr√©dictions :", predictions[:10])

In [None]:
*********************************************

In [None]:
*********************************************

In [None]:
import pandas as pd
import torch
from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting.data.encoders import NaNLabelEncoder
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import EarlyStopping

# ‚úÖ Charger le dataset
df = pd.read_csv(r"..\..\..\..\Datasources\MetroPT3_new_imputed_final.csv", delimiter=",", decimal=".", index_col=0)
df.reset_index(drop=True, inplace=True)

# ‚úÖ Convertir timestamp en datetime
df['timestamp'] = pd.to_datetime(df['timestamp'])

# ‚úÖ D√©finir la date de s√©paration entre Train et Test
split_date = "2020-06-05 23:59:10"

# ‚úÖ S√©parer les donn√©es en train/test
df_train = df[df["timestamp"] < split_date].copy()
df_test = df[df["timestamp"] >= split_date].copy()

# ‚úÖ Recalculer time_idx pour normaliser les valeurs temporelles
df_train["time_idx"] = ((df_train["timestamp"] - df_train["timestamp"].min()).dt.total_seconds() // 10).astype(int)
df_test["time_idx"] = ((df_test["timestamp"] - df_test["timestamp"].min()).dt.total_seconds() // 10).astype(int)

# ‚úÖ D√©finition des colonnes cat√©goriques
for col in ["COMP", "DV_eletric", "Towers"]:
    df_train[col] = df_train[col].astype(int).astype(str).astype("category")
    df_test[col] = df_test[col].astype(int).astype(str).astype("category")


print("‚úÖ Conversion des colonnes cat√©goriques et de panne termin√©e.")

# ‚úÖ G√©n√©rer un group_id plus distinctif
df_train["group_id"] = df_train["COMP"].astype(str) + "_" + (df_train["time_idx"] // 1000).astype(str)
df_test["group_id"] = df_test["COMP"].astype(str) + "_" + (df_test["time_idx"] // 1000).astype(str)

print("‚úÖ group_id ajout√©.")

#  ‚úÖ V√©rifier la correction
print(df_test["group_id"].value_counts())  # V√©rifier combien de groupes il y a maintenant
print(f"Nombre de groupes uniques dans df_test : {df_test['group_id'].nunique()}")

In [None]:
# ‚úÖ D√©finir l'horizon de pr√©vision et la fen√™tre d'observation
max_encoder_length = 30  
max_prediction_length = 3  

# ‚úÖ Cr√©ation du `TimeSeriesDataSet` avec **les bonnes cl√©s**
train_dataset = TimeSeriesDataSet(
    df_train,
    time_idx="time_idx",
    target="panne",
    group_ids=["group_id"],  
    max_encoder_length=max_encoder_length,
    max_prediction_length=max_prediction_length,
    time_varying_known_reals=["TP2", "H1", "DV_pressure", "Oil_temperature", "Motor_current"],
    time_varying_known_categoricals=["DV_eletric", "Towers"],
    categorical_encoders={"panne": NaNLabelEncoder(add_nan=True)},
    allow_missing_timesteps=True,
    target_normalizer=None  # ‚úÖ **√âvite les transformations incorrectes**
)

In [None]:

# ‚úÖ Correction **du probl√®me des cl√©s** en renommant les champs
for i in range(len(train_dataset)):
    train_dataset[i][0]["encoder_lengths"] = train_dataset[i][0].pop("encoder_length")
    train_dataset[i][0]["decoder_lengths"] = train_dataset[i][0].pop("decoder_length")

print(f"‚úÖ Dataset train corrig√© avec {len(train_dataset)} s√©quences.")


In [None]:
# ‚úÖ Correction appliqu√©e aussi au dataset test
test_dataset = TimeSeriesDataSet.from_dataset(train_dataset, df_test, predict=True)

for i in range(len(test_dataset)):
    test_dataset[i][0]["encoder_lengths"] = test_dataset[i][0].pop("encoder_length")
    test_dataset[i][0]["decoder_lengths"] = test_dataset[i][0].pop("decoder_length")

print(f"‚úÖ Dataset test corrig√© avec {len(test_dataset)} s√©quences.")

# ‚úÖ V√©rification de la structure
print(f"‚úÖ Nombre de groupes dans test_dataset: {len(test_dataset.index) if hasattr(test_dataset, 'index') else 'N/A'}")

In [None]:

# ‚úÖ Cr√©ation des `DataLoader`
batch_size = 64
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# ‚úÖ V√©rification du premier batch
test_batch = next(iter(train_dataloader))
print(f"‚úÖ Batch structur√© : {list(test_batch.keys())}")


In [None]:

# ‚úÖ Entra√Ænement du mod√®le
trainer = Trainer(
    max_epochs=30,
    accelerator="auto",
    enable_checkpointing=False,
    callbacks=[EarlyStopping(monitor="val_loss", patience=5)],
    gradient_clip_val=0.1
)

trainer.fit(tft_lightning, train_dataloader, test_dataloader)

print("‚úÖ Entra√Ænement termin√© avec succ√®s !")