In [None]:
import numpy as np
import pandas as pd

import os
import random
import shutil
import sys

from collections import defaultdict
from sklearn.preprocessing import LabelEncoder

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchmetrics
from torchsummary import summary

import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
from utilities.data_downloader import train_val_test_downloader
from utilities.upsampling import upsampling
from utilities.plots import plt, COLORMAP, visualize_latent

In [None]:
from warnings import simplefilter
simplefilter("ignore", category=RuntimeWarning)

In [None]:
def set_random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

set_random_seed(42)
L.seed_everything(42)

# Description

bla bla

# Dataset
bla bla

In [None]:
train, val, test, labels = train_val_test_downloader('interp')

In [None]:
train_upsampled = upsampling(train)

In [None]:
class LightCurveDataset(Dataset):
    def __init__(self, dataframe:pd.DataFrame,
                 data_col:str='lgRate',
                 weight_col:str='weight'):
        
        data = np.array(dataframe.loc[:, data_col].tolist(),
                        dtype=np.float32)
        weight = np.array(dataframe.loc[:, weight_col].tolist(),
                          dtype=np.float32)
        
        self.data = torch.from_numpy(
            data).unsqueeze(dim=1)   # value
        self.weight = torch.from_numpy(
            weight).unsqueeze(dim=1) # weight

        # using dataframe index = event names 
        # as labels
        labels = dataframe.index
        self.label_enc = LabelEncoder()
        self.labels = torch.as_tensor(self.label_enc.fit_transform(labels))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx], self.weight[idx]

In [None]:
train_dataset, val_dataset, test_dataset = (
    LightCurveDataset(train_upsampled),
    LightCurveDataset(val),
    LightCurveDataset(test)
)

In [None]:
train_loader = DataLoader(train_dataset,
                          batch_size=256,
                          shuffle=True,
                          num_workers=2)
val_loader = DataLoader(val_dataset,
                        batch_size=256,
                        shuffle=False,
                        num_workers=2)
test_loader = DataLoader(test_dataset,
                         batch_size=256,
                         shuffle=False,
                         num_workers=2)

# for predictions on non-augmented train:
train_loader_ = DataLoader(LightCurveDataset(train),
                           batch_size=256,
                           shuffle=False,
                           num_workers=2)

# Model
bla bla

## Torch Models

In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_dim:int,
                 architecture:tuple=(32, 4),
                 tseries_length:int=64):
        super().__init__()

        self.hidden_dims = [
            architecture[0]* 2**pow for pow in range(architecture[1])
            ]                                       # num of filters in layers
        self.tseries_length = tseries_length

        modules = []
        in_channels = 1                             # initial num of channels
        for h_dim in self.hidden_dims:              # conv layers
            modules.append(
                nn.Sequential(
                    nn.Conv1d(
                        in_channels=in_channels,    # num of input channels
                        out_channels=h_dim,         # num of output channels
                        kernel_size=3,
                        stride=2,                   # convolution kernel step
                        padding=1,                  # save shape
                    ),
                    nn.BatchNorm1d(h_dim),
                    nn.LeakyReLU(),
                )
            )
            in_channels = h_dim                     # changing num of 
                                                    # input channels for 
                                                    # next iteration

        modules.append(nn.Flatten())                # to vector
        intermediate_dim = (
            self.hidden_dims[-1] * 
            self.tseries_length // (2**len(self.hidden_dims))
        )
        modules.append(nn.Linear(in_features=intermediate_dim,
                                 out_features=latent_dim))

        self.encoder = nn.Sequential(*modules)

    def forward(self, x):
        x = self.encoder(x)
        return x


class Decoder(nn.Module):
    def __init__(self, latent_dim:int,
                 architecture:tuple=(32, 4),
                 tseries_length:int=64):
        super().__init__()
        self.hidden_dims = [
            architecture[0]* 2**pow for pow in range(architecture[1]-1, 0, -1)
            ]                                       # num of filters in layers
        self.tseries_length = tseries_length

        intermediate_dim = (
            self.hidden_dims[0] * 
            self.tseries_length // (2**len(self.hidden_dims))
        )
        self.linear = nn.Linear(in_features=latent_dim,
                                out_features=intermediate_dim)

        modules = []
        for i in range(len(self.hidden_dims) - 1):  # define upsample layers
            modules.append(
                nn.Sequential(
                    nn.Upsample(scale_factor=2),
                    nn.Conv1d(
                        in_channels=self.hidden_dims[i],
                        out_channels=self.hidden_dims[i + 1],
                        kernel_size=3,
                        padding=1,
                    ),
                    nn.BatchNorm1d(self.hidden_dims[i + 1]),
                    nn.LeakyReLU(),
                )
            )

        modules.append(
            nn.Sequential(
                nn.Upsample(scale_factor=2),
                nn.Conv1d(in_channels=self.hidden_dims[-1],
                          out_channels=1,
                          kernel_size=3, padding=1)
            )
        )

        self.decoder = nn.Sequential(*modules)

    def forward(self, x):
        x = self.linear(x)        # from latents space to Linear
        x = x.view(
            -1, self.hidden_dims[0],
            self.tseries_length // (2**len(self.hidden_dims))
            )                     # reshape
        x = self.decoder(x)       # reconstruction
        return x

class VAEncoder(Encoder):
    def __init__(self, latent_dim):
        if latent_dim % 2 != 0:   # check for the parity of the latent space
            raise Exception('Latent size for VAEncoder must be even')

        super().__init__(latent_dim)

## Lightning wrapper

In [None]:
class LitAE(L.LightningModule):
    def __init__(self, encoder, decoder, derivative_weight=1.0):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.derivative_weight = derivative_weight

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

    def forward_handler(self, data,
                        *args, **kwargs):
        # here is the logic how data is moved through AE
        latent = self.encoder(data)
        recon = self.decoder(latent)
        return latent, recon

    def loss_handler(self, recon, data, weight, latent,
                     *args, **kwargs):
        # here is the loss function computing
        recon_loss = torch.masked_select(
            input = F.mse_loss(
                recon, data, reduction='none'
            ) * weight,
            mask = weight.ge(0.0)
        )
        recon_loss = recon_loss.mean()

        # derivative penalty = 
        # L1-regularization of the output timeseries
        derivative_loss = torch.abs(
            torch.diff(recon, dim=-1)
        ).mean()

        # total loss
        loss = recon_loss + self.derivative_weight * derivative_loss

        return loss

    def training_step(self, batch, batch_idx):
        data, labels, weight = batch

        latent, recon = self.forward_handler(data, labels)
        loss = self.loss_handler(recon, data, weight, latent)

        self.log('train_loss', loss, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        data, labels, weight = batch

        latent, recon = self.forward_handler(data, labels)
        loss = self.loss_handler(recon, data, weight, latent)

        self.log('val_loss', loss, on_step=False, on_epoch=True)
        return loss

    def on_test_epoch_start(self):
        # create dict with empty tensors for further accumulating over batches
        self.test_result = defaultdict(torch.Tensor)

    def test_step(self, batch, batch_idx):
        data, labels, weight = batch

        latent, recon = self.forward_handler(data, labels)
        self.update_test_result(data, weight, recon, latent, labels)

    def update_test_result(self, data, weight, recon, latent, labels):
        # accumulating results every batch
        self.test_result['real'] = torch.cat(
            [self.test_result['real'], data.cpu()]
        )
        self.test_result['weight'] = torch.cat(
            [self.test_result['weight'], weight.cpu()]
        )
        self.test_result['recon'] = torch.cat(
            [self.test_result['recon'], recon.cpu()]
        )
        self.test_result['latent'] = torch.cat(
            [self.test_result['latent'], latent.cpu()]
        )
        self.test_result['labels'] = torch.cat(
            [self.test_result['labels'], labels.cpu()]
        )

    def on_test_epoch_end(self):
        # simply change type from torch tensor to numpy array
        # for every item in test_result dictionary
        for key in self.test_result:
            self.test_result[key] = self.test_result[key].numpy()

In [None]:
class LitVAE(LitAE):
    def __init__(self, encoder, decoder,
                 derivative_weight=1.0,
                 kld_weight=0.005,
                 ):
        super().__init__(encoder, decoder, derivative_weight)
        self.kld_weight = kld_weight

    def vae_split(self, latent):
        size = (
            latent.shape[1] // 2
        )  # divide the latent representation into mu and log_var
        mu = latent[:, :size]
        log_var = latent[:, size:]
        return mu, log_var

    def vae_reparametrize(self, mu, log_var):
        sigma = torch.exp(0.5 * log_var)
        eps = torch.randn(mu.shape[0], mu.shape[1]).to(self.device)
        return eps * sigma + mu

    def kld_loss(self, mu, log_var):
        var = log_var.exp()
        kl_loss = torch.mean(
            -0.5 * torch.sum(log_var - var - mu**2 + 1, dim=1), dim=0
        )
        return kl_loss

    def forward_handler(self, data, *args, **kwargs):
        # here is the logic how data is moved through AE
        latent = self.encoder(data)

        mu, log_var = self.vae_split(latent)
        sample = self.vae_reparametrize(mu, log_var)

        recon = self.decoder(sample)
        return latent, recon

    def loss_handler(self, recon, data, weight, latent, *args, **kwargs):
        mu, log_var = self.vae_split(latent)
        # here is the loss function computing
        loss = torch.masked_select(
            input = F.mse_loss(recon, data, reduction='none') * weight,
            mask = weight.ge(0.0)).mean() + self.derivative_weight * torch.abs(
            torch.diff(recon, dim=-1)
            ).mean() + self.kld_weight * self.kld_loss(mu, log_var)
        return loss

## Utilities

In [None]:
def reparametrize_latent(vae, latent):
    mu, log_var = vae.vae_split(latent)
    var = np.exp(log_var)

    mu, log_var = torch.tensor(mu), torch.tensor(log_var)
    sample = vae.vae_reparametrize(mu, log_var).numpy()
    return sample

In [None]:
def get_dict_result(trainer, model, dataloader, ckpt_path):
    
    with torch.no_grad():
        trainer.test(model, dataloader, ckpt_path=ckpt_path)
    model.test_result[
        'labels'
    ]=dataloader.dataset.label_enc.inverse_transform(
    model.test_result[
        'labels'
        ].astype(int)
    )

    real = model.test_result['real'].squeeze()
    recon = model.test_result['recon'].squeeze()
    weight = model.test_result['weight'].squeeze()


    weightedMSE = (real-recon)**2 * weight
    pred_errors = (weightedMSE ** 0.5).tolist()

    weightedMSE = np.ma.masked_array(data=weightedMSE,
                                     mask=~(weight.astype(bool))
    )
    weightedMSE = weightedMSE.mean(axis=1, keepdims=True)

    latent = model.test_result['latent'].copy()

    if hasattr(model, 'vae_reparametrize') and callable(model.vae_reparametrize):
        # for VAE, we must reparametrize latent first
        latent = reparametrize_latent(model, latent)

    latentdim = latent.shape[-1]

    latent = pd.DataFrame(
        data=np.concatenate((latent, weightedMSE), axis=1),
        index=model.test_result['labels'],
        columns=['feature_'+str(dim) for dim in range(latentdim)]+['wMSE'])

    latent.insert(loc=latent_dim+1, column='pred_error', value=pred_errors)

    return latent, real, recon, weight

# Training

In [None]:
latent_dim = 3

encoder, decoder = Encoder(latent_dim), Decoder(latent_dim)

print(">>> Encoder")
print(summary(encoder, (1, 64), device="cpu"))

print(">>> Decoder")
print(summary(decoder, (1, latent_dim), device="cpu"))