In [1]:
pip install pytorch-lightning pytorch-forecasting pandas scikit-learn


Collecting pytorch-lightning
  Downloading pytorch_lightning-2.5.2-py3-none-any.whl.metadata (21 kB)
Collecting pytorch-forecasting
  Downloading pytorch_forecasting-1.4.0-py3-none-any.whl.metadata (14 kB)
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.7.3-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting lightning<3.0.0,>=2.0.0 (from pytorch-forecasting)
  Downloading lightning-2.5.2-py3-none-any.whl.metadata (38 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.1.0->pytorch-lightning)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.1.0->pytorch-lightning)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4

In [6]:
# Convert all categorical columns to string type
categorical_columns = [
    "COHORT", "subgroup", "SEX", "PD v/s NON-PD", "month_year", "sequence_id"
]

for col in categorical_columns:
    df[col] = df[col].astype(str)


In [17]:
# 📚 Imports
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, Baseline
from pytorch_forecasting.data import NaNLabelEncoder, GroupNormalizer
from pytorch_forecasting.metrics import CrossEntropy, RMSE, MultiLoss
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import torch

# 📥 Load data
df = pd.read_csv("tft_input_prepared.csv")

# 🧽 Optional: drop duplicates or sort if needed
df = df.sort_values(by=["sequence_id", "time_idx"]).reset_index(drop=True)

# ✅ Cast categorical columns as category dtype
df["participant_id"] = df["participant_id"].astype(str)
df["sequence_id"] = df["sequence_id"].astype(str)
df["target_final"] = df["target_final"].astype(int)

# ✅ Target encoding fix (if 1,2,3 → make it 0,1,2)
df["target_final"] = df["target_final"] - 1

# 🧠 Define columns for TFT
static_categoricals = ["participant_id", "COHORT", "SEX", "subgroup"]
static_reals = []
time_varying_known_categoricals = ["is_future"]
time_varying_known_reals = ["time_idx"]
time_varying_unknown_reals = [
    "age_at_visit", "EDUCYRS", "BMI", "moca", "MSEADLG",
    "orthostasis", "td_pigd", "NP1COG", "APOE_e4", "NHY",
    "duration_yrs", "Years_with_PD", "fampd_bin", "diabetes_flag"
]
# Convert all categorical columns to string type
categorical_columns = [
    "COHORT", "subgroup", "SEX", "PD v/s NON-PD", "month_year", "sequence_id","is_future"
]

for col in categorical_columns:
    df[col] = df[col].astype(str)

# 📐 Scaling (optional but recommended)
scaler = StandardScaler()
df[time_varying_unknown_reals] = scaler.fit_transform(df[time_varying_unknown_reals])

# 🧱 Define parameters
max_encoder_length = 3
max_prediction_length = 3

# 🧠 Composite loss for classification (cross entropy) + regression (RMSE)
multi_loss = MultiLoss([CrossEntropy(), RMSE()])

from pytorch_forecasting.data import MultiNormalizer, GroupNormalizer, NaNLabelEncoder

# 🧪 Create dataset
tft_dataset = TimeSeriesDataSet(
    df,
    time_idx="time_idx",
    target=["target_final", "moca"],  # classification + regression
    group_ids=["sequence_id"],
    static_categoricals=static_categoricals,
    static_reals=static_reals,
    time_varying_known_categoricals=time_varying_known_categoricals,
    time_varying_known_reals=time_varying_known_reals,
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=time_varying_unknown_reals,
    target_normalizer=MultiNormalizer([
        NaNLabelEncoder(),                         # classification
        GroupNormalizer(groups=["sequence_id"])    # regression
    ]),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    max_encoder_length=max_encoder_length,
    max_prediction_length=max_prediction_length,
    allow_missing_timesteps=True,
)
# Determine split index (80% of sequence IDs for training)
train_ids = df["sequence_id"].unique()
train_cutoff = int(len(train_ids) * 0.8)

train_sequence_ids = train_ids[:train_cutoff]
val_sequence_ids = train_ids[train_cutoff:]

# Subset data for train and val
train_df = df[df["sequence_id"].isin(train_sequence_ids)].copy()
val_df = df[df["sequence_id"].isin(val_sequence_ids)].copy()

# Recreate datasets from splits
training = TimeSeriesDataSet.from_dataset(tft_dataset, train_df, stop_randomization=True)
validation = TimeSeriesDataSet.from_dataset(tft_dataset, val_df, stop_randomization=True)


# 🔁 Dataloaders
batch_size = 64
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size)


# 📈 Logger
logger = TensorBoardLogger("lightning_logs")
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
from pytorch_forecasting.metrics import CrossEntropy, RMSE, MultiLoss

# Define composite loss function
loss_fn = MultiLoss([CrossEntropy(), RMSE()])

# Model definition (💡 output_size is now a list: [3, 1])
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=1e-3,
    hidden_size=16,
    attention_head_size=1,
    dropout=0.1,
    loss=loss_fn,
    output_size=[3, 1],  # 🧠 [classification, regression]
    reduce_on_plateau_patience=4,
)


# ⚠️ Ensure you define logger earlier (optional)
from lightning.pytorch.loggers import TensorBoardLogger
logger = TensorBoardLogger("lightning_logs", name="tft_run")

# 🔧 Define trainer
import lightning.pytorch as pl

trainer = pl.Trainer(
    max_epochs=30,
    gradient_clip_val=0.1,
    limit_train_batches=1.0,
    logger=logger
)

# ✅ Corrected training call
trainer.fit(
    model=tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader
)


/usr/local/lib/python3.11/dist-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
/usr/local/lib/python3.11/dist-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.
INFO:pytorch_lightning.utilities.rank_zero:💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU avai

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (44) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.


In [22]:
raw_output = tft.predict(val_dataloader, mode="raw", return_x=True)


In [23]:
print(type(raw_output))


<class 'pytorch_forecasting.models.base._base_model.Prediction'>


In [31]:
print(dir(pred_tensor))


['__add__', '__class__', '__class_getitem__', '__contains__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__getnewargs__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__match_args__', '__module__', '__mul__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__rmul__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_asdict', '_field_defaults', '_fields', '_make', '_replace', 'count', 'decoder_attention', 'decoder_lengths', 'decoder_variables', 'encoder_attention', 'encoder_lengths', 'encoder_variables', 'get', 'iget', 'index', 'items', 'keys', 'prediction', 'static_variables']


In [32]:
# Extract actual prediction tensor from Output object
pred_probs = pred_tensor.prediction  # shape: [samples, prediction_length, n_classes]


In [35]:
import torch

valid_preds = []
for i, p in enumerate(pred_probs):
    if isinstance(p, torch.Tensor) and p.ndim == 3 and p.shape[-1] == 3:
        valid_preds.append(p)
    else:
        print(f"⚠️ Skipped prediction at index {i} due to shape: {p.shape}")

# Now safely stack
pred_tensor = torch.stack(valid_preds)

⚠️ Skipped prediction at index 1 due to shape: torch.Size([657, 3, 1])


In [39]:
# ✅ Fetch predictions with targets and input
raw_output = tft.predict(val_dataloader, mode="raw", return_x=True, return_y=True)

# ✅ Extract outputs
pred_probs = raw_output.output         # list of tensors [samples, pred_len, 3]
true_labels = raw_output.y             # list of tensors [samples, pred_len]

In [43]:
valid_preds = []
valid_targets = []

for i, (pred, label) in enumerate(zip(pred_probs, true_labels)):

    # Skip if label is None
    if label is None or pred is None:
        print(f"⚠️ Skipped index {i} due to None value")
        continue

    # Unwrap if in list or tuple
    if isinstance(pred, list):
        pred = pred[0]
    if isinstance(label, list):
        label = label[0]
    if isinstance(label, tuple):
        label = label[0]

    # Final skip if still None
    if label is None or pred is None:
        print(f"⚠️ Still None after unpacking at index {i}")
        continue

    # Try to squeeze safely
    try:
        pred = pred.squeeze()
        label = label.squeeze()
    except Exception as e:
        print(f"⚠️ Error squeezing index {i}: {e}")
        continue

    # Only accept valid classification shapes (e.g., [657, 3])
    if pred.ndim == 3 and pred.shape[-1] == 3:
        valid_preds.append(pred)
        valid_targets.append(label)
    else:
        print(f"⚠️ Skipped index {i} due to invalid shape: {pred.shape}")

# Then continue:
if valid_preds and valid_targets:
    pred_tensor = torch.stack(valid_preds)
    true_tensor = torch.stack(valid_targets)

    y_pred = pred_tensor.argmax(dim=-1).cpu().numpy()
    y_true = true_tensor.cpu().numpy()

    y_pred_flat = y_pred.flatten()
    y_true_flat = y_true.flatten()
else:
    print("❌ No valid predictions found.")


⚠️ Skipped index 1 due to None value


In [44]:
from sklearn.metrics import f1_score
from collections import Counter
import numpy as np

def ipw_f1_score(y_true, y_pred, labels=[0, 1, 2]):
    """
    Compute IPW-F1 Score based on inverse class frequencies.
    """
    counts = Counter(y_true)
    weights = {cls: 1.0 / counts[cls] for cls in labels}
    f1s = f1_score(y_true, y_pred, labels=labels, average=None)
    weighted_f1 = sum(f1s[i] * weights[cls] for i, cls in enumerate(labels)) / sum(weights.values())
    return weighted_f1, f1s

# Flatten if not already
y_true_flat = y_true.flatten()
y_pred_flat = y_pred.flatten()

# Compute scores
ipw_f1, f1_per_class = ipw_f1_score(y_true_flat, y_pred_flat)

# Show results
print(f"\n📊 IPW-F1 Score: {ipw_f1:.4f}")
print("🔍 F1 per class:")
for i, label in enumerate(['NC', 'MCI', 'Dementia']):
    print(f"  {label}: F1 = {f1_per_class[i]:.4f}")



📊 IPW-F1 Score: 0.3413
🔍 F1 per class:
  NC: F1 = 0.9132
  MCI: F1 = 0.3977
  Dementia: F1 = 0.3256
