In [None]:
from typing import Tuple

import pandas as pd
from sklearn.utils.class_weight import compute_class_weight

import torch
import torch.nn.functional as F

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
from torch.utils.data import Dataset, DataLoader, Subset

from vae import VariationAutoencoderModule, WassersteinLoss, MultiCategoricalLoss

In [None]:
class UpdrsData(Dataset):
    def __init__(self, path):
        super().__init__()
        data = pd.read_csv(path, sep="\t").sort_values(["PatientID", "Age"])
        measurements = (
            data[[column for column in data.columns if column.startswith("3.")]]
            .dropna()
            .astype(int)
        )
        self.covariates = data.loc[
            measurements.index,
            [
                "PatientID",
                "Age",
                "Deep brain stimulation available",
                "Deep brain stimulation",
                "Medication",
            ],
        ].reset_index(drop=True)
        self.measurements = torch.tensor(measurements.to_numpy(), dtype=torch.float32)

    def __getitem__(self, index):
        return self.measurements[index]

    def __len__(self):
        return len(self.measurements)

    @property
    def participant_covariate(self) -> str:
        return "PatientID"


class UpdrsDataQoL(Dataset):
    COLUMNS = [f"UPDRS 1.{i}" for i in range(1, 14)] + [
        f"UPDRS 2.{i}" for i in range(1, 14)
    ]

    def __init__(self, path):
        super().__init__()
        data = pd.read_csv(path, sep=",").sort_values(["Participant", "Age"])
        measurements = data[UpdrsDataQoL.COLUMNS].dropna().astype(int)
        self.covariates = data.loc[
            measurements.index,
            ["Participant", "Age"],
        ].reset_index(drop=True)
        self.measurements = torch.tensor(measurements.to_numpy(), dtype=torch.float32)

    def __getitem__(self, index):
        return self.measurements[index]

    def __len__(self):
        return len(self.measurements)

    @property
    def participant_covariate(self) -> str:
        return "Participant"


class UpdrsDataModule(L.LightningDataModule):
    def __init__(
        self,
        dataset: Dataset,
        percentage_subjects_in_valid_dataset: float,
        batch_size: int,
    ):
        super().__init__()
        assert 0 < percentage_subjects_in_valid_dataset <= 1

        self.data = dataset
        self.batch_size = batch_size

        if percentage_subjects_in_valid_dataset < 1:
            patients = self.data.covariates[dataset.participant_covariate].unique()
            num_patients_valid = int(
                len(patients) * percentage_subjects_in_valid_dataset
            )
            first_patient_valid = len(patients) - num_patients_valid
            # Find the index of the first patient in valid set
            self.val_start = self.data.covariates[
                self.data.covariates[dataset.participant_covariate]
                == patients[first_patient_valid]
            ].index[0]
        else:
            self.val_start = 0

    def calculate_class_weights(self):
        return torch.tensor(
            compute_class_weight(
                "balanced",
                classes=range(5),
                y=self.data.measurements[: self.val_start].flatten().long().numpy(),
            )
        ).float()

    def train_dataloader(self):
        if self.val_start == 0:
            raise ValueError("Only the validation set is used.")
        return DataLoader(
            Subset(self.data, range(0, self.val_start)),
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=False,
            num_workers=8,
        )

    def val_dataloader(self):
        return DataLoader(
            Subset(self.data, range(self.val_start, len(self.data))),
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=4,
        )


data = UpdrsDataQoL(
    "/workspaces/de.uke.iam.parkinson.vae_longitudinal/data/updrs_amp.csv"
)
data_module = UpdrsDataModule(
    data,
    percentage_subjects_in_valid_dataset=0.2,
    batch_size=512,
)
print(len(data_module.train_dataloader()))
print(len(data_module.val_dataloader()))

## Fit the model

In [None]:
NAME = "updrs_qol_vae_new"

reconstruction_loss = MultiCategoricalLoss(
    n_values=len(UpdrsDataQoL.COLUMNS),
    n_classes=5,
    is_categorical=False,
    is_ordinal=True,
    weight=data_module.calculate_class_weights().to("cuda"),
)
#generative_loss = KullbackLeiblerLoss(beta=1.0)
generative_loss = WassersteinLoss(reg_weight=100, kernel_type="imq", z_var=2.0)
model = VariationAutoencoderModule(
    reconstruction_loss, generative_loss, [64, 48, 32, 16], patience=80, learning_rate=1e-3, dropout=0.05
)
checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints/",
    filename=NAME,
    save_top_k=1,
    verbose=True,
    monitor="val_concordance",
    mode="max",
)
early_stopping = EarlyStopping(monitor="val_concordance", patience=120, mode="max")
logger = TensorBoardLogger("logs", name=NAME)

# Initialize the PyTorch Lightning trainer
trainer = L.Trainer(
    max_epochs=1000, callbacks=[early_stopping], logger=logger, log_every_n_steps=6
)

trainer.fit(model, data_module)

## Test the model

In [None]:
model = VariationAutoencoderModule.load_from_checkpoint(
    "/workspaces/de.uke.iam.parkinson.vae_longitudinal/src/logs/updrs_qol_vae/version_1/checkpoints/epoch=431-step=2592.ckpt"
).model.eval()

In [5]:
import torchmetrics


def load_testdata(path: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
    data = pd.read_csv(
        path,
        sep=",",
        na_values=[
            "Keine_Angabe",
            "Nicht_durchgeführt",
            "Keine_Angaben",
            "Keine_angabe",
        ],
    )
    pdq_columns = [column for column in data.columns if column.startswith("PDQ39 ")]
    updrs_columns = [
        column
        for column in data.columns
        if column.startswith("UPDRS 1.") or column.startswith("UPDRS 2.")
    ]
    data = data[pdq_columns + updrs_columns].dropna().reset_index(drop=True)
    return data[updrs_columns], data[pdq_columns]


test_updrs, test_pdq = load_testdata(
    "/workspaces/de.uke.iam.parkinson.vae_longitudinal/data/pdq_uke_new.csv"
)

ground_truth = next(iter(data_module.val_dataloader()))
ground_truth_sum = ground_truth.sum(axis=-1).to("cuda")
reconstruction_sum = model(ground_truth.to("cuda")).x_recon.sum(axis=-1)

print(
    F.l1_loss(
        input=ground_truth.to("cuda"),
        target=model(ground_truth.to("cuda")).x_recon,
        reduction="none",
    )
    .sum(axis=-1)
    .mean()
    / (ground_truth.shape[-1])
)
print(
    (
        (ground_truth.to("cuda") != model(ground_truth.to("cuda")).x_recon).sum(axis=-1)
        / ground_truth.shape[-1]
    ).median()
)
print(
    torchmetrics.functional.concordance_corrcoef(
        target=ground_truth_sum.to(torch.float),
        preds=reconstruction_sum.to(torch.float),
    )
)