In [1]:
from dataclasses import dataclass
from typing import Tuple, Sequence

import pandas as pd
from sklearn.utils.class_weight import compute_class_weight

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch import Tensor

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
from torch.utils.data import Dataset, DataLoader, Subset
from coral_pytorch.losses import corn_loss
from coral_pytorch.dataset import corn_label_from_logits
from torchmetrics.functional import concordance_corrcoef


@dataclass(slots=True)
class MultiCategoricalLoss:
    n_values: int
    n_classes: int
    is_categorical: bool
    is_ordinal: bool
    weight: torch.Tensor = None

    @property
    def last_layer_size(self) -> int:
        if self.is_categorical and self.is_ordinal:
            raise ValueError("Cannot be both categorical and ordinal.")

        if self.is_categorical:
            return self.n_values * self.n_classes
        elif self.is_ordinal:
            return self.n_values * (self.n_classes - 1)
        else:
            return self.n_values

    def calculate_reconstruction(self, y_raw: torch.Tensor) -> torch.LongTensor:
        if self.is_categorical:
            return y_raw.reshape((-1, self.n_values, self.n_classes)).argmax(dim=-1)
        elif self.is_ordinal:
            return corn_label_from_logits(
                y_raw.reshape(-1, self.n_classes - 1)
            ).reshape((-1, self.n_values))
        else:
            return (F.sigmoid(y_raw) * (self.n_classes - 1)).round().long()

    def calculate_reconstruction_loss(self, y_true: torch.Tensor, y_pred: torch.Tensor):
        if self.is_categorical:
            y_pred = y_pred.reshape((-1, self.n_values, self.n_classes))
            return F.cross_entropy(
                input=y_pred.reshape(-1, self.n_classes),
                target=y_true.reshape(-1).long(),
                weight=self.weight,
                reduction="mean",
            )
        elif self.is_ordinal:
            y_pred = y_pred.reshape((-1, self.n_values, self.n_classes - 1))
            return corn_loss(
                logits=y_pred.reshape(-1, self.n_classes - 1),
                y_train=y_true.reshape(-1).long(),
                num_classes=self.n_classes,
            )
        else:
            return (
                F.mse_loss(
                    input=F.sigmoid(y_pred),
                    target=y_true / (self.n_classes - 1),
                    reduction="none",
                )
                .sum(axis=-1)
                .mean()
            )


@dataclass(slots=True)
class VariationalAutoencoderOutput:
    z_dist: torch.distributions.Distribution
    z_sample: torch.Tensor
    x_recon: torch.LongTensor

    loss: torch.Tensor
    loss_recon: torch.Tensor
    loss_generative: torch.Tensor

    def log(self, logger, batch, prefix: str):
        logger.log(f"{prefix}_loss", self.loss)
        logger.log(f"{prefix}_loss_recon", self.loss_recon)
        logger.log(f"{prefix}_loss_generative", self.loss_generative)

        # Log evaluation metrices
        logger.log(
            f"{prefix}_rel_wrong_items",
            ((batch != self.x_recon).sum(axis=-1) / batch.shape[-1]).median(),
        )
        real_sum = batch.sum(axis=-1)
        reconstructed_sum = self.x_recon.sum(axis=-1)
        logger.log(
            f"{prefix}_concordance",
            concordance_corrcoef(
                target=real_sum.to(torch.float), preds=reconstructed_sum.to(torch.float)
            ),
        )

    def calculate_total_error(self, batch):
        real_sum = batch.sum(axis=-1)
        reconstructed_sum = self.x_recon.sum(axis=-1)
        return F.l1_loss(input=reconstructed_sum, target=real_sum, reduction="mean")


class GenerativeLoss:
    def __call__(self, dist: torch.distributions.Distribution, z: Tensor):
        raise NotImplementedError()


class KullbackLeiblerLoss(GenerativeLoss):
    def __init__(self, beta: float = 1.0):
        self.beta = beta

    def __call__(self, dist: torch.distributions.Distribution, z: Tensor):
        std_normal = torch.distributions.MultivariateNormal(
            torch.zeros_like(dist.mean, device=dist.mean.device),
            scale_tril=torch.eye(dist.mean.shape[-1], device=dist.mean.device)
            .unsqueeze(0)
            .expand(dist.mean.shape[0], -1, -1),
        )
        return self.beta * torch.distributions.kl.kl_divergence(dist, std_normal).mean()


class WassersteinLoss(GenerativeLoss):
    def __init__(self, reg_weight: float, kernel_type: str, z_var: float):
        self.reg_weight = reg_weight
        self.kernel_type = kernel_type
        self.z_var = z_var
        
    def __call__(self, dist: torch.distributions.Distribution, z: Tensor):
        # Calculate the corrected reg_weight
        batch_size = z.size(0)
        bias_corr = batch_size * (batch_size - 1)
        reg_weight = self.reg_weight / bias_corr

        # Sample from prior (Gaussian) distribution
        prior_z = torch.randn_like(z)

        prior_z__kernel = self.compute_kernel(prior_z, prior_z)
        z__kernel = self.compute_kernel(z, z)
        priorz_z__kernel = self.compute_kernel(prior_z, z)

        mmd = (
            reg_weight * prior_z__kernel.mean()
            + reg_weight * z__kernel.mean()
            - 2 * reg_weight * priorz_z__kernel.mean()
        )
        return mmd

    def compute_kernel(self, x1: Tensor, x2: Tensor) -> Tensor:
        # Convert the tensors into row and column vectors
        D = x1.size(1)
        N = x1.size(0)

        x1 = x1.unsqueeze(-2)  # Make it into a column tensor
        x2 = x2.unsqueeze(-3)  # Make it into a row tensor

        x1 = x1.expand(N, N, D)
        x2 = x2.expand(N, N, D)

        if self.kernel_type == "rbf":
            result = self.compute_rbf(x1, x2)
        elif self.kernel_type == "imq":
            result = self.compute_inv_mult_quad(x1, x2)
        else:
            raise ValueError("Undefined kernel type.")

        return result

    def compute_rbf(self, x1: Tensor, x2: Tensor, eps: float = 1e-7) -> Tensor:
        z_dim = x2.size(-1)
        sigma = 2.0 * z_dim * self.z_var

        result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma))
        return result

    def compute_inv_mult_quad(
        self, x1: Tensor, x2: Tensor, eps: float = 1e-7
    ) -> Tensor:
        z_dim = x2.size(-1)
        C = 2 * z_dim * self.z_var
        kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim=-1))

        # Exclude diagonal elements
        result = kernel.sum() - kernel.diag().sum()

        return result


class Encoder(nn.Module):
    def __init__(self, layers: Sequence[int], dropout: float = 0.1):
        super().__init__()
        self.encoder = nn.Sequential()
        for i, (i_size, out_size) in enumerate(zip(layers, layers[1:])):
            if i < len(layers) - 2:
                self.encoder.append(nn.Linear(i_size, out_size))
                if dropout > 0.0:
                    self.encoder.append(nn.Dropout(dropout))
                self.encoder.append(nn.SiLU())
            else:
                self.encoder.append(
                    nn.Linear(i_size, 2 * out_size)
                )  # 2 for mean and variance.
        self.softplus = nn.Softplus()

    def forward(self, x, eps: float = 1e-8) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self.encoder(x)
        mu, logvar = torch.chunk(x, 2, dim=-1)
        scale = self.softplus(logvar) + eps
        return mu, scale

    def forward_and_reparameterize(
        self, x
    ) -> Tuple[torch.distributions.Distribution, torch.Tensor]:
        mu, scale = self(x)
        scale_tril = torch.diag_embed(scale)
        dist = torch.distributions.MultivariateNormal(mu, scale_tril=scale_tril)
        return dist, dist.rsample()


class Decoder(nn.Module):
    def __init__(self, layers: Sequence[int], dropout: float = 0.1):
        super().__init__()
        self.decoder = nn.Sequential()
        for i, (i_size, out_size) in enumerate(zip(layers, layers[1:])):
            self.decoder.append(nn.Linear(i_size, out_size))
            if i < len(layers) - 2:
                if dropout > 0.0:
                    self.decoder.append(nn.Dropout(dropout))
                self.decoder.append(nn.SiLU())

    def forward(self, z):
        return self.decoder(z)


class VariationalAutoencoder(nn.Module):
    def __init__(
        self,
        reconstruction_loss: MultiCategoricalLoss,
        generative_loss: GenerativeLoss,
        intermediate_layers: Sequence[int],
        dropout: float
    ):
        super().__init__()
        self.reconstruction_loss = reconstruction_loss
        self.generative_loss = generative_loss
        self.encoder = Encoder([reconstruction_loss.n_values] + intermediate_layers, dropout=dropout)
        self.decoder = Decoder(
            list(reversed(intermediate_layers)) + [reconstruction_loss.last_layer_size], dropout=dropout
        )

    def forward(self, x, compute_loss: bool = True):
        dist, z = self.encoder.forward_and_reparameterize(x)
        recon_x = self.decoder(z)

        if not compute_loss:
            return VariationalAutoencoderOutput(
                z_dist=dist,
                z_sample=z,
                x_recon=self.reconstruction_loss.calculate_reconstruction(recon_x),
                loss=None,
                loss_recon=None,
                loss_generative=None,
            )

        loss_recon = self.reconstruction_loss.calculate_reconstruction_loss(
            y_true=x, y_pred=recon_x
        )
        loss_generative = self.generative_loss(dist, z)

        loss = loss_recon + loss_generative
        return VariationalAutoencoderOutput(
            z_dist=dist,
            z_sample=z,
            x_recon=self.reconstruction_loss.calculate_reconstruction(recon_x),
            loss=loss,
            loss_recon=loss_recon,
            loss_generative=loss_generative,
        )


class VariationAutoencoderModule(L.LightningModule):
    def __init__(
        self,
        reconstruction_loss: MultiCategoricalLoss,
        generative_loss: GenerativeLoss,
        intermediate_layers: Sequence[int],
        learning_rate: float,
        patience: int,
        dropout: float
    ):
        super().__init__()
        self.model = VariationalAutoencoder(
            reconstruction_loss, generative_loss, intermediate_layers, dropout=dropout
        )
        self.learning_rate = learning_rate
        self.patience = patience
        self.dropout = dropout
        self.save_hyperparameters()

    def training_step(self, batch, batch_idx):
        output = self.model(batch, compute_loss=True)
        output.log(self, batch, "train")
        return output.loss

    def validation_step(self, batch, batch_idx):
        output = self.model(batch, compute_loss=True)
        output.log(self, batch, "val")
        return output.loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        scheduler = {
            "scheduler": ReduceLROnPlateau(
                optimizer, mode="max", patience=self.patience, min_lr=1e-6
            ),
            "monitor": "val_concordance",
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

In [None]:
class UpdrsData(Dataset):
    def __init__(self, path):
        super().__init__()
        data = pd.read_csv(path, sep="\t").sort_values(["PatientID", "Age"])
        measurements = (
            data[[column for column in data.columns if column.startswith("3.")]]
            .dropna()
            .astype(int)
        )
        self.covariates = data.loc[
            measurements.index,
            [
                "PatientID",
                "Age",
                "Deep brain stimulation available",
                "Deep brain stimulation",
                "Medication",
            ],
        ].reset_index(drop=True)
        self.measurements = torch.tensor(measurements.to_numpy(), dtype=torch.float32)

    def __getitem__(self, index):
        return self.measurements[index]

    def __len__(self):
        return len(self.measurements)

    @property
    def participant_covariate(self) -> str:
        return "PatientID"


class UpdrsDataQoL(Dataset):
    COLUMNS = [f"UPDRS 1.{i}" for i in range(1, 14)] + [
        f"UPDRS 2.{i}" for i in range(1, 14)
    ]

    def __init__(self, path):
        super().__init__()
        data = pd.read_csv(path, sep=",").sort_values(["Participant", "Age"])
        measurements = data[UpdrsDataQoL.COLUMNS].dropna().astype(int)
        self.covariates = data.loc[
            measurements.index,
            ["Participant", "Age"],
        ].reset_index(drop=True)
        self.measurements = torch.tensor(measurements.to_numpy(), dtype=torch.float32)

    def __getitem__(self, index):
        return self.measurements[index]

    def __len__(self):
        return len(self.measurements)

    @property
    def participant_covariate(self) -> str:
        return "Participant"


class UpdrsDataModule(L.LightningDataModule):
    def __init__(
        self,
        dataset: Dataset,
        percentage_subjects_in_valid_dataset: float,
        batch_size: int,
    ):
        super().__init__()
        assert 0 < percentage_subjects_in_valid_dataset <= 1

        self.data = dataset
        self.batch_size = batch_size

        if percentage_subjects_in_valid_dataset < 1:
            patients = self.data.covariates[dataset.participant_covariate].unique()
            num_patients_valid = int(
                len(patients) * percentage_subjects_in_valid_dataset
            )
            first_patient_valid = len(patients) - num_patients_valid
            # Find the index of the first patient in valid set
            self.val_start = self.data.covariates[
                self.data.covariates[dataset.participant_covariate]
                == patients[first_patient_valid]
            ].index[0]
        else:
            self.val_start = 0

    def calculate_class_weights(self):
        return torch.tensor(
            compute_class_weight(
                "balanced",
                classes=range(5),
                y=self.data.measurements[: self.val_start].flatten().long().numpy(),
            )
        ).float()

    def train_dataloader(self):
        if self.val_start == 0:
            raise ValueError("Only the validation set is used.")
        return DataLoader(
            Subset(self.data, range(0, self.val_start)),
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=False,
            num_workers=8,
        )

    def val_dataloader(self):
        return DataLoader(
            Subset(self.data, range(self.val_start, len(self.data))),
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=4,
        )


data = UpdrsDataQoL(
    "/workspaces/de.uke.iam.parkinson.vae_longitudinal/data/pdq_amp.csv"
)
data_module = UpdrsDataModule(
    data,
    percentage_subjects_in_valid_dataset=0.2,
    batch_size=512,
)
print(len(data_module.train_dataloader()))
print(len(data_module.val_dataloader()))

## Fit the model

In [None]:
NAME = "updrs_qol_vae_new"

reconstruction_loss = MultiCategoricalLoss(
    n_values=len(UpdrsDataQoL.COLUMNS),
    n_classes=5,
    is_categorical=False,
    is_ordinal=True,
    weight=data_module.calculate_class_weights().to("cuda"),
)
#generative_loss = KullbackLeiblerLoss(beta=1.0)
generative_loss = WassersteinLoss(reg_weight=100, kernel_type="imq", z_var=2.0)
model = VariationAutoencoderModule(
    reconstruction_loss, generative_loss, [64, 48, 32, 16], patience=80, learning_rate=1e-3, dropout=0.05
)
checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints/",
    filename=NAME,
    save_top_k=1,
    verbose=True,
    monitor="val_concordance",
    mode="max",
)
early_stopping = EarlyStopping(monitor="val_concordance", patience=120, mode="max")
logger = TensorBoardLogger("logs", name=NAME)

# Initialize the PyTorch Lightning trainer
trainer = L.Trainer(
    max_epochs=1000, callbacks=[early_stopping], logger=logger, log_every_n_steps=6
)

trainer.fit(model, data_module)

## Test the model

In [None]:
model = VariationAutoencoderModule.load_from_checkpoint(
    "/workspaces/de.uke.iam.parkinson.vae_longitudinal/src/logs/updrs_qol_vae/version_1/checkpoints/epoch=431-step=2592.ckpt"
).model.eval()

In [5]:
import torchmetrics


def load_testdata(path: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
    data = pd.read_csv(
        path,
        sep=",",
        na_values=[
            "Keine_Angabe",
            "Nicht_durchgef√ºhrt",
            "Keine_Angaben",
            "Keine_angabe",
        ],
    )
    pdq_columns = [column for column in data.columns if column.startswith("PDQ39 ")]
    updrs_columns = [
        column
        for column in data.columns
        if column.startswith("UPDRS 1.") or column.startswith("UPDRS 2.")
    ]
    data = data[pdq_columns + updrs_columns].dropna().reset_index(drop=True)
    return data[updrs_columns], data[pdq_columns]


test_updrs, test_pdq = load_testdata(
    "/workspaces/de.uke.iam.parkinson.vae_longitudinal/data/pdq_uke_new.csv"
)

ground_truth = next(iter(data_module.val_dataloader()))
ground_truth_sum = ground_truth.sum(axis=-1).to("cuda")
reconstruction_sum = model(ground_truth.to("cuda")).x_recon.sum(axis=-1)

print(
    F.l1_loss(
        input=ground_truth.to("cuda"),
        target=model(ground_truth.to("cuda")).x_recon,
        reduction="none",
    )
    .sum(axis=-1)
    .mean()
    / (ground_truth.shape[-1])
)
print(
    (
        (ground_truth.to("cuda") != model(ground_truth.to("cuda")).x_recon).sum(axis=-1)
        / ground_truth.shape[-1]
    ).median()
)
print(
    torchmetrics.functional.concordance_corrcoef(
        target=ground_truth_sum.to(torch.float),
        preds=reconstruction_sum.to(torch.float),
    )
)