In [1]:
import torch.nn as nn
import torch
import sys

sys.path.append('../../../')

from configs.data_configs.rosbank import data_configs
from configs.model_configs.gen.rosbank import model_configs
from src.data_load.dataloader import create_data_loaders, create_test_loader

import src.models.preprocessors as prp

from src.models.gen_models import NumericalFeatureProjector, EmbeddingPredictor
#from src.models.timevae import TimeVAE

In [11]:
# due to flatten constant len of input needed

def sample_z(mean, logstd, k_iwae):
    epsilon = torch.randn(k_iwae, mean.shape[0], mean.shape[1]).to(
        logstd.device
    )
    z = epsilon * torch.exp(0.5 * logstd) + mean  # modified
    z = z.view(-1, mean.shape[1])
    return z

def get_normal_kl(mean_1, log_std_1, mean_2=None, log_std_2=None):
    """
    This function should return the value of KL(p1 || p2),
    where p1 = Normal(mean_1, exp(log_std_1)), p2 = Normal(mean_2, exp(log_std_2) ** 2).
    If mean_2 and log_std_2 are None values, we will use standard normal distribution.
    Note that we consider the case of diagonal covariance matrix.
    """
    if mean_2 is None:
        mean_2 = torch.zeros_like(mean_1).to(mean_1.device)
    if log_std_2 is None:
        log_std_2 = torch.zeros_like(log_std_1).to(mean_1.device)
    # ====
    # https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
    # https://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians

    sigma_1 = torch.exp(log_std_1)
    sigma_2 = torch.exp(log_std_2)

    out = torch.log(sigma_2 / sigma_1)
    out += (sigma_1**2 + (mean_1 - mean_2) ** 2) / (2 * (sigma_2**2))
    out -= 1 / 2

    return out



class TimeVAE(nn.Module):

    def __init__(self, model_conf, data_conf):
        super().__init__()
        self.model_conf = model_conf
        self.data_conf = data_conf

        self.processor = prp.FeatureProcessor(
            model_conf=model_conf, data_conf=data_conf
        )
        self.time_encoder = prp.TimeEncoder(
            model_conf=self.model_conf, data_conf=self.data_conf
        )

        ### INPUT SIZE ###
        all_emb_size = self.model_conf.features_emb_dim * len(
            self.data_conf.features.embeddings
        )

        self.all_numeric_size = (
            len(self.data_conf.features.numeric_values)
            * self.model_conf.numeric_emb_size
        )

        self.input_dim = (
            all_emb_size + self.all_numeric_size + self.model_conf.use_deltas
        )
        self.out_dim = self.input_dim
        if self.model_conf.time_embedding:
            self.input_dim += self.model_conf.time_embedding * 2 - 1

        ### INIT ENCODER
        ls = []
        prev_dim = self.input_dim
        for i, dim in enumerate(self.model_conf.timevae.hiddens):
            ls.append(nn.Conv1d(in_channels=prev_dim, out_channels=dim, kernel_size=3, stride=2, padding=3))
            ls.append(nn.ReLU())
            prev_dim = dim
        
        ls.append(nn.Flatten())
        self.encoder_conv = nn.Sequential(*ls)
        self.flat_size = self.detect_size(self.input_dim, self.data_conf.train.max_seq_len)
        
        self.mu_head = nn.Linear(self.flat_size, self.model_conf.timevae.latent_dim)
        self.log_std_head = nn.Linear(self.flat_size, self.model_conf.timevae.latent_dim)

        ### Decoder Init ###
        self.dec_proj = nn.Linear(self.model_conf.timevae.latent_dim, self.flat_size)

        ls = []
        prev_dim = self.model_conf.timevae.hiddens[-1]
        for i, dim in enumerate(reversed(self.model_conf.timevae.hiddens[:-1])):
            ls.append(nn.ConvTranspose1d(in_channels=prev_dim, out_channels=dim, kernel_size=3, stride=2, padding=2))
            ls.append(nn.ReLU())
            prev_dim = dim
        ls.append(nn.ConvTranspose1d(in_channels=prev_dim, out_channels=self.input_dim, kernel_size=3, stride=2, padding=2))        
        ls.append(nn.ReLU())
        ls.append(nn.Flatten())
        self.decoder = nn.Sequential(*ls)
        out_proj_size = self.detect_dec_size()
        self.decoder_out_proj = nn.Linear(out_proj_size, self.data_conf.train.max_seq_len * self.out_dim)


        ### LOSS ###
        self.embedding_predictor = EmbeddingPredictor(
            model_conf=self.model_conf, data_conf=self.data_conf
        )
        self.numeric_projector = NumericalFeatureProjector(
            model_conf=self.model_conf, data_conf=self.data_conf
        )
        self.mse_fn = torch.nn.MSELoss(reduction="none")
        self.ce_fn = torch.nn.CrossEntropyLoss(
            reduction="mean", ignore_index=0, label_smoothing=0.05
        )

    def forward(self, padded_batch):
        x, time_steps = self.processor(padded_batch)
        x = self.time_encoder(x, time_steps)
        
        mu, log_std = self.encode(x.transpose(1,2))
        z = sample_z(mu, log_std, 1)
        out = self.decode(z)

        


        pred = self.embedding_predictor(out)
        pred.update(self.numeric_projector(out))

        if self.model_conf.use_deltas:
            pred["delta"] = out[:, :, -1].squeeze(-1)

        gt = {'input_batch': padded_batch, 'time_steps': time_steps}

        res_dict = {'gt': gt,
                    'pred': pred,
                    'latent': z,
                    'mu': mu, 
                    'log_std': log_std}
        return res_dict

    
    def decode(self, z):
        projected = self.dec_proj(z).view(z.size(0), self.model_conf.timevae.hiddens[-1], -1)
        #print(projected)
        decoded = self.decoder(projected)
        print(decoded.size())
        out = self.decoder_out_proj(decoded)
        return out.view(z.size(0), self.data_conf.train.max_seq_len, self.out_dim)

    def encode(self, x):
        features = self.encoder_conv(x)
        mu = self.mu_head(features)
        log_std = self.log_std_head(features)

        return mu, log_std
    
    def detect_size(self, in_dim, seq_len):
        test_value = torch.rand(1, seq_len, in_dim)
        
        with torch.no_grad():
            out = self.encoder_conv(test_value.transpose(1,2))

        return out.size(1)

    def detect_dec_size(self):
        test_value = torch.rand(1, self.model_conf.timevae.latent_dim)
        with torch.no_grad():
            projected = self.dec_proj(test_value).view(1, self.model_conf.timevae.hiddens[-1], -1)
            decoded = self.decoder(projected)

        return decoded.size(1)
    
    
    def loss(self, output, ground_truth):
        """
        output: Dict that is outputed from forward method
        """
        ### MSE ###
        total_mse_loss = self.numerical_loss(output)
        delta_mse_loss = self.delta_mse_loss(output)

        ### CROSS ENTROPY ###
        cross_entropy_losses = self.embedding_predictor.loss(
            output["pred"], output["gt"]["input_batch"]
        )
        total_ce_loss = torch.sum(
            torch.cat([value.unsqueeze(0) for _, value in cross_entropy_losses.items()])
        )

        kl_loss = get_normal_kl(mean_1=output['mu'], log_std_1=output['log_std']).sum(dim=1).mean()

        losses_dict = {
            "total_mse_loss": total_mse_loss,
            "total_CE_loss": total_ce_loss,
            "total_KL_loss": kl_loss, 
            "delta_loss": delta_mse_loss
        }
        losses_dict.update(cross_entropy_losses)

        total_loss = (
            (self.model_conf.mse_weight * losses_dict["total_mse_loss"]
            + self.model_conf.CE_weight * total_ce_loss + self.model_conf.delta_weight * delta_mse_loss) * self.model_conf.timevae.recon_weight
            + kl_loss
        )
        losses_dict["total_loss"] = total_loss

        return losses_dict

    def generate(self, padded_batch, lens):
        Z = torch.randn(lens, self.latent_dim)
        samples = self.decode(Z)
        return samples
    

    def numerical_loss(self, output):
        # MSE
        total_mse_loss = 0
        for key, values in output["gt"]["input_batch"].payload.items():
            if key in self.processor.numeric_names:
                gt_val = values.float()
                gt_val = values.float()
                pred_val = output["pred"][key].squeeze(-1)

                mse_loss = self.mse_fn(
                    gt_val,
                    pred_val,
                )
                mask = gt_val != 0
                masked_mse = mse_loss * mask
                total_mse_loss += (
                    masked_mse.sum(dim=1)  # / (mask != 0).sum(dim=1)
                ).mean()

        return total_mse_loss

    def delta_mse_loss(self, output):
        # DELTA MSE
        if self.model_conf.use_deltas:
            gt_delta = output["gt"]["time_steps"].diff(1)
            if self.model_conf.use_log_delta:
                gt_delta = torch.log(gt_delta + 1e-15)
            delta_mse = self.mse_fn(gt_delta, output["pred"]["delta"][:, :-1])
            # print(delta_mse, gt_delta[0], output["gt"]["time_steps"].diff(1)[0], output["gt"]["time_steps"][0])
            mask = output["gt"]["time_steps"] != -1

            delta_masked = delta_mse * mask[:, :-1]
            delta_mse = delta_masked.sum() / (mask != 0).sum()
        else:
            delta_mse = torch.tensor(0)

        return delta_mse


In [12]:
model_conf = model_configs()
data_conf = data_configs()

In [13]:
tv = TimeVAE(model_conf=model_conf, data_conf=data_conf)

In [14]:
train_loader, val_loader = create_data_loaders(data_conf, supervised=False)

Data shapes: train 8467, val 946, test 0


In [15]:
for batch in train_loader:
    break

In [16]:
tv

TimeVAE(
  (processor): FeatureProcessor(
    (embed_layers): ModuleDict(
      (channel_type): Embedding(5, 12)
      (currency): Embedding(5, 12)
      (mcc): Embedding(100, 12)
      (trx_category): Embedding(12, 12)
    )
    (numeric_processor): ModuleDict(
      (amount): Linear(in_features=1, out_features=12, bias=True)
    )
    (numeric_norms): ModuleDict(
      (amount): RBatchNormWithLens(
        (bn): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (time_encoder): TimeEncoder()
  (encoder_conv): Sequential(
    (0): Conv1d(64, 64, kernel_size=(3,), stride=(2,), padding=(3,))
    (1): ReLU()
    (2): Conv1d(64, 64, kernel_size=(3,), stride=(2,), padding=(3,))
    (3): ReLU()
    (4): Flatten(start_dim=1, end_dim=-1)
  )
  (mu_head): Linear(in_features=3392, out_features=32, bias=True)
  (log_std_head): Linear(in_features=3392, out_features=32, bias=True)
  (dec_proj): Linear(in_features=32, out_features=3392, bias=True)
  (

In [17]:
out = tv(batch[0])

torch.Size([512, 12992])


In [18]:
losses = tv.loss(out, batch[1])

In [19]:
losses

{'total_mse_loss': tensor(4505.9702, grad_fn=<AddBackward0>),
 'total_CE_loss': tensor(934.3182, grad_fn=<SumBackward0>),
 'total_KL_loss': tensor(0.3102, grad_fn=<MeanBackward0>),
 'delta_loss': tensor(0.0124, grad_fn=<DivBackward0>),
 'channel_type': tensor(147.4019, grad_fn=<MeanBackward0>),
 'currency': tensor(154.8810, grad_fn=<MeanBackward0>),
 'mcc': tensor(423.8500, grad_fn=<MeanBackward0>),
 'trx_category': tensor(208.1854, grad_fn=<MeanBackward0>),
 'total_loss': tensor(16321.5498, grad_fn=<AddBackward0>)}

In [27]:
losses['total_loss'].backward()

In [99]:
z = torch.rand(32, 4)
dec = tv.decode(z)

In [103]:
dec.size()

torch.Size([32, 7, 5])

In [53]:

out_proj = nn.Linear(4, 125)
projected = out_proj(z)

In [61]:
projected.reshape(32, -1, 25)[0, 0, :]

tensor([ 0.0246,  0.9201, -0.1976,  0.6290, -0.2905, -0.0487, -0.6348, -0.5932,
         0.1185,  1.0665,  0.1459,  0.0549,  0.7347, -0.1536, -0.5866, -0.4701,
         0.0191, -0.6995, -0.2983, -0.0667,  0.2349, -0.2173, -0.4106, -0.8101,
         0.5722], grad_fn=<SliceBackward0>)

In [65]:
p = projected.reshape(32, 25, -1)

In [66]:
c = nn.ConvTranspose1d(in_channels=25, out_channels=25, kernel_size=3, padding=2, stride=2)

In [67]:
c(p).size()

torch.Size([32, 25, 7])

In [68]:
p.size()

torch.Size([32, 25, 5])

In [45]:
x = torch.rand(32, 200, 64)
c = nn.ConvTranspose1d(in_channels=64, out_channels=60, kernel_size=3, padding=0, stride=2)

In [46]:
c(x.transpose(1,2)).size()

torch.Size([32, 60, 401])