In [None]:
from dataclasses import dataclass

import pandas as pd
from sklearn.utils.class_weight import compute_class_weight

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
from torch.utils.data import Dataset, DataLoader, Subset
from coral_pytorch.losses import corn_loss
from coral_pytorch.dataset import corn_label_from_logits

@dataclass(slots=True)
class MultiCategoricalLoss:
    n_values: int
    n_classes: int
    is_categorical: bool
    is_ordinal: bool
    weight: torch.Tensor = None

    @property
    def last_layer_size(self) -> int:
        if self.is_categorical and self.is_ordinal:
            raise ValueError("Cannot be both categorical and ordinal.")
        
        if self.is_categorical:
            return self.n_values * self.n_classes
        elif self.is_ordinal:
            return self.n_values * (self.n_classes - 1)
        else:
            return self.n_values

    def calculate_reconstruction(self, y_raw: torch.Tensor) -> torch.LongTensor:
        if self.is_categorical:
            return y_raw.view((-1, self.n_values, self.n_classes)).argmax(dim=-1)
        elif self.is_ordinal:
            return corn_label_from_logits(y_raw.view(-1, self.n_classes - 1)).view((-1, self.n_values))
        else:
            return (F.sigmoid(y_raw) * (self.n_classes - 1)).round().long()

    def calculate_reconstruction_loss(self, y_true: torch.Tensor, y_pred: torch.Tensor):
        if self.is_categorical:
            y_pred = y_pred.view((-1, self.n_values, self.n_classes))
            return F.cross_entropy(
                input=y_pred.view(-1, self.n_classes),
                target=y_true.view(-1).long(),
                weight=self.weight,
                reduction="mean",
            )
        elif self.is_ordinal:
            y_pred = y_pred.view((-1, self.n_values, self.n_classes - 1))
            return corn_loss(
                logits=y_pred.view(-1, self.n_classes - 1),
                y_train=y_true.view(-1).long(),
                num_classes=self.n_classes,
            )
        else:
            return F.mse_loss(
                input=F.sigmoid(y_pred),
                target=y_true / (self.n_classes - 1),
                reduction="none",
            ).sum(axis=-1).mean()

@dataclass(slots=True)
class VariationalAutoencoderOutput:
    z_dist: torch.distributions.Distribution
    z_sample: torch.Tensor
    x_recon: torch.LongTensor

    loss: torch.Tensor
    loss_recon: torch.Tensor
    loss_kl: torch.Tensor

    def log(self, logger, prefix: str):
        logger.log(f"{prefix}_loss", self.loss)
        logger.log(f"{prefix}_loss_recon", self.loss_recon)
        logger.log(f"{prefix}_loss_kl", self.loss_kl)

    def calculate_total_error(self, batch):
        real_sum = batch.sum(axis=-1)
        reconstructed_sum = self.x_recon.sum(axis=-1)
        return F.l1_loss(input=reconstructed_sum, target=real_sum, reduction="mean")


class VariationalAutoencoder(nn.Module):
    def __init__(self, loss: MultiCategoricalLoss, hidden_dim: int, latent_dim: int):
        super().__init__()
        self.loss = loss
        self.encoder = nn.Sequential(
            nn.Linear(loss.n_values, loss.n_values),
            nn.SiLU(),
            nn.Linear(loss.n_values, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.SiLU(),
            nn.Linear(hidden_dim // 2, 2 * latent_dim),  # 2 for mean and variance.
        )
        self.softplus = nn.Softplus()
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.SiLU(),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, loss.last_layer_size),
            nn.SiLU(),
            nn.Linear(loss.last_layer_size, loss.last_layer_size),
        )

    def encode(self, x, eps: float = 1e-8):
        x = self.encoder(x)
        mu, logvar = torch.chunk(x, 2, dim=-1)
        scale = self.softplus(logvar) + eps
        scale_tril = torch.diag_embed(scale)
        return torch.distributions.MultivariateNormal(mu, scale_tril=scale_tril)

    def reparameterize(self, dist):
        return dist.rsample()

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x, compute_loss: bool = True, weight: torch.Tensor = None):
        dist = self.encode(x)
        z = self.reparameterize(dist)
        recon_x = self.decode(z)

        if not compute_loss:
            return VariationalAutoencoderOutput(
                z_dist=dist,
                z_sample=z,
                x_recon=self.loss.calculate_reconstruction(recon_x),
                loss=None,
                loss_recon=None,
                loss_kl=None,
            )

        loss_recon = self.loss.calculate_reconstruction_loss(y_true=x, y_pred=recon_x)
        std_normal = torch.distributions.MultivariateNormal(
            torch.zeros_like(z, device=z.device),
            scale_tril=torch.eye(z.shape[-1], device=z.device)
            .unsqueeze(0)
            .expand(z.shape[0], -1, -1),
        )
        loss_kl = torch.distributions.kl.kl_divergence(dist, std_normal).mean()

        loss = loss_recon + loss_kl
        return VariationalAutoencoderOutput(
            z_dist=dist,
            z_sample=z,
            x_recon=self.loss.calculate_reconstruction(recon_x),
            loss=loss,
            loss_recon=loss_recon,
            loss_kl=loss_kl,
        )


class VariationAutoencoderModule(L.LightningModule):
    def __init__(
        self,
        loss: MultiCategoricalLoss,
        learning_rate: float = 1e-4,
        patience: int = 10,
        latent_dim: int = 8,
    ):
        super().__init__()
        self.model = VariationalAutoencoder(loss, 24, latent_dim)
        self.learning_rate = learning_rate
        self.patience = patience
        self.save_hyperparameters()

    def training_step(self, batch, batch_idx):
        output = self.model(batch, compute_loss=True)
        output.log(self, "train")
        self.log("train_l1_error", output.calculate_total_error(batch))
        return output.loss

    def validation_step(self, batch, batch_idx):
        output = self.model(batch, compute_loss=True)
        output.log(self, "val")
        self.log("val_l1_error", output.calculate_total_error(batch))
        return output.loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        scheduler = {
            "scheduler": ReduceLROnPlateau(
                optimizer, mode="min", patience=self.patience
            ),
            "monitor": "val_loss",
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

In [None]:
class UpdrsData(Dataset):
    def __init__(self, path):
        super().__init__()
        data = pd.read_csv(path, sep="\t").sort_values(["PatientID", "Age"])
        measurements = (
            data[[column for column in data.columns if column.startswith("3.")]]
            .dropna()
            .astype(int)
        )
        self.covariates = data.loc[
            measurements.index,
            [
                "PatientID",
                "Age",
                "Deep brain stimulation available",
                "Deep brain stimulation",
                "Medication",
            ],
        ].reset_index(drop=True)
        self.measurements = torch.tensor(measurements.to_numpy(), dtype=torch.float32)

    def __getitem__(self, index):
        return self.measurements[index]

    def __len__(self):
        return len(self.measurements)


class UpdrsDataModule(L.LightningDataModule):
    def __init__(
        self, path: str, percentage_subjects_in_valid_dataset: float, batch_size: int
    ):
        super().__init__()
        assert 0 < percentage_subjects_in_valid_dataset <= 1

        self.data = UpdrsData(path)
        self.batch_size = batch_size

        if percentage_subjects_in_valid_dataset < 1:
            patients = self.data.covariates["PatientID"].unique()
            num_patients_valid = int(len(patients) * percentage_subjects_in_valid_dataset)
            first_patient_valid = len(patients) - num_patients_valid
            # Find the index of the first patient in valid set
            self.val_start = self.data.covariates[
                self.data.covariates["PatientID"] == patients[first_patient_valid]
            ].index[0]
        else:
            self.val_start = 0

    def calculate_class_weights(self):
        return torch.tensor(
            compute_class_weight(
                "balanced",
                classes=range(5),
                y=self.data.measurements[: self.val_start].flatten().long().numpy(),
            )
        ).float()

    def train_dataloader(self):
        if self.val_start == 0:
            raise ValueError("Only the validation set is used.")
        return DataLoader(
            Subset(self.data, range(0, self.val_start)),
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=False,
            num_workers=8,
        )

    def val_dataloader(self):
        return DataLoader(
            Subset(self.data, range(self.val_start, len(self.data))),
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=4,
        )


data_module = UpdrsDataModule(
    "../data/ppmi_with_meta.csv",
    percentage_subjects_in_valid_dataset=0.3,
    batch_size=512,
)
print(len(data_module.train_dataloader()))
print(len(data_module.val_dataloader()))

## Fit the model

In [None]:
NAME = "vanilla_vae"

loss = MultiCategoricalLoss(n_values=33, n_classes=5, is_categorical=False, is_ordinal=True)
model = VariationAutoencoderModule(loss, latent_dim=6)
checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints/",
    filename=NAME,
    save_top_k=1,
    verbose=True,
    monitor="val_loss",
    mode="min",
)
early_stopping = EarlyStopping(monitor="val_loss", patience=20, mode="min")
logger = TensorBoardLogger("logs", name=NAME)

# Initialize the PyTorch Lightning trainer
trainer = L.Trainer(
    max_epochs=1000, callbacks=[early_stopping], logger=logger, log_every_n_steps=25
)

trainer.fit(model, data_module)

## Test the model

In [None]:
#model = VariationAutoencoderModule.load_from_checkpoint("/workspaces/de.uke.iam.parkinson.vae_longitudinal/src/logs/vanilla_vae/version_4/checkpoints/epoch=324-step=8125.ckpt").model
model = model.model

model.eval()
data_module = UpdrsDataModule(
    "../data/uke_with_meta.csv",
    percentage_subjects_in_valid_dataset=1.0,
    batch_size=32,
)

batch = next(iter(data_module.val_dataloader()))
ground_truth = batch[0].numpy().astype(int)
prediction = model.model(batch).x_recon[0].detach().numpy()

print(ground_truth)
print(prediction)