In [1]:
import math

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from hy2dl.modelzoo.inputlayer import InputLayer
from hy2dl.utils.config import Config
from hy2dl.utils.distributions import Distribution

path_experiment_settings = "../examples/mdn.yml"

config = Config(path_experiment_settings, True)

In [None]:
class LSTMMDN(nn.Module):
    def __init__(self, cfg: Config):

        super().__init__()

        self.embedding_net = InputLayer(cfg)

        self.lstm = nn.LSTM(input_size=self.embedding_net.output_size, hidden_size=cfg.hidden_size, batch_first=True)

        self.dropout = torch.nn.Dropout(p=cfg.dropout_rate)

        self.distribution = Distribution.from_string(cfg.distribution)
        match self.distribution:
            case Distribution.GAUSSIAN:
                self.num_params = 2
            case Distribution.LAPLACIAN:
                self.num_params = 3

        self.fc_params = nn.Linear(cfg.hidden_size, self.num_params * cfg.num_components * cfg.output_features)

        self.fc_weights = nn.Sequential(
            nn.Linear(cfg.hidden_size, cfg.num_components * cfg.output_features),
            nn.Unflatten(-1, (cfg.num_components, cfg.output_features)),
            nn.Softmax(dim=-2)
        )

        self.num_components = cfg.num_components
        self.predict_last_n = cfg.predict_last_n

        self.output_features = cfg.output_features

        self._reset_parameters(cfg=cfg)

    def _reset_parameters(self, cfg: Config):
        """Special initialization of the bias."""
        if cfg.initial_forget_bias is not None:
            self.lstm.bias_hh_l0.data[cfg.hidden_size : 2 * cfg.hidden_size] = cfg.initial_forget_bias
    
    def forward(self, sample):
        # Pre-process data to be sent to the LSTM
        # processed_sample = self.embedding_net(sample)
        # x_lstm = self.embedding_net.assemble_sample(processed_sample)
        x_lstm = sample

        # Forward pass through the LSTM
        out, _ = self.lstm(x_lstm)
        
        # Extract sequence of interest
        out = out[:, -self.predict_last_n:, :]
        out = self.dropout(out)

        # Probabilistic things
        w = self.fc_weights(out)

        params = self.fc_params(out)
        match self.distribution:
            case Distribution.GAUSSIAN:
                loc, scale = params.chunk(2, dim=-1)
                scale = F.softplus(scale)
                params = {"loc": loc, "scale": scale}
            case Distribution.LAPLACIAN:
                loc, scale, kappa = params.chunk(3, dim=-1)
                scale = F.softplus(scale)
                kappa = F.softplus(kappa)
                params = {"loc": loc, "scale": scale, "kappa": kappa}
        params = {k: v.reshape(v.shape[0], v.shape[1], self.num_components, self.output_features) for k, v in params.items()}
        
        return {"params": params, "weights": w}
    
    def mean(self, x):
        params, w = self(x).values()
        match self.distribution:
            case Distribution.GAUSSIAN:
                mean = params["loc"]
            case Distribution.LAPLACIAN:
                loc, scale, kappa = params.values()
                mean = loc + scale * (1 - kappa.pow(2)) / kappa
        mean = (mean * w).sum(axis=-2)
        return mean
    
    def sample(self, x, num_samples):
        params, w = self(x).values()
        num_batches, sequence_length, num_components, num_targets = next(iter(params.values())).shape
        match self.distribution:
            case Distribution.GAUSSIAN:
                loc, scale = params.values()
                
                samples = torch.randn(num_batches, sequence_length, num_components, num_samples, num_targets).to(x.device)
            case Distribution.LAPLACIAN:
                loc, scale, kappa = params.values()

                u = torch.rand(num_batches, sequence_length, num_components, num_samples, num_targets).to(x.device)

                # Sampling left or right of the mode?
                kappa = kappa.unsqueeze(-2).repeat((1, 1, 1, num_samples, 1))
                p_at_mode = kappa**2 / (1 + kappa**2)

                mask = u < p_at_mode

                samples = torch.zeros_like(u)

                samples[mask] = kappa[mask] * torch.log(u[mask] * (1 + kappa[mask].pow(2)) / kappa[mask].pow(2)) # Left side
                samples[~mask] = -1 * torch.log((1 - u[~mask]) * (1 + kappa[~mask].pow(2))) / kappa[~mask] # Right side

        # Forgive me father for I have sinned.
        
        # samples: [num_batches, sequence_length, num_components, num_samples, output_features]
        # loc, scale: [num_batches, sequence_length, num_components, output_features]
        samples = samples * scale.unsqueeze(-2) + loc.unsqueeze(-2)  # [num_batches, sequence_length, num_components, num_samples, output_features]

        # Select samples according to weights
        # w: [num_batches, sequence_length, num_components, output_features]
        # Reshape w to [num_batches * sequence_length * output_features, num_components] for multinomial
        w_reshaped = w.permute(0, 1, 3, 2).reshape(-1, w.size(2))  # [num_batches * sequence_length * output_features, num_components]
        indices = torch.multinomial(w_reshaped, num_samples, replacement=True)  # [num_batches * sequence_length * output_features, num_samples]

        # Reshape indices back to proper dimensions
        indices = indices.view(num_batches, sequence_length, num_targets, num_samples)  # [num_batches, sequence_length, output_features, num_samples]
        indices = indices.permute(0, 1, 3, 2)  # [num_batches, sequence_length, num_samples, output_features]
        indices = indices.unsqueeze(2)  # [num_batches, sequence_length, 1, num_samples, output_features]

        # Now gather from the num_components dimension (dim=2)
        samples = torch.gather(samples, dim=2, index=indices)  # [num_batches, sequence_length, 1, num_samples, output_features]
        samples = samples.squeeze(2)  # [num_batches, sequence_length, num_samples, output_features]
        
        return samples
   
    def _calc_cdf(self, x, xi):
        """
        Calculate the mixture CDF at points `xi`.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor [B, L, I].
        xi : torch.Tensor
            Evaluation points [B, T, Q, D].

        Returns
        -------
        torch.Tensor
            Mixture CDF values [B, N, Q, D].
        """
        xi = xi.unsqueeze(-2) # [B, N, 1, T]

        params, weights = self(x).values()

        match self.distribution:
            case Distribution.GAUSSIAN:
                loc, scale = params.values() # loc: [B, N, K, T]
                z = (xi - loc) / (scale * math.sqrt(2)) 
                cdf = 0.5 * (1 + torch.erf(z))

            case Distribution.LAPLACIAN:
                loc, scale, kappa = params.values()
                z = (xi - loc) / scale
                mask = (z >= 0)
                cdf = torch.zeros_like(z)
                cdf[mask] = 1 - (1 / (1 + kappa[mask].pow(2))) * torch.exp(-1 * kappa[mask] * z[mask])
                cdf[~mask] = (kappa[~mask].pow(2) / (1 + kappa[~mask].pow(2))) * torch.exp(z[~mask] / kappa[~mask])

        # Mix CDF (weighted mixture over components)
        cdf = (weights * cdf).sum(dim=-2)  # [B, N, T]
        return cdf
    
    def _calc_logpdf(self, x, xi):
        """
        Calculate the density of `xi` in the mixture PDF of `x`.

        Parameters
        ----------
        x : torch.Tensor 
            Tensor of shape [B, L, I].
        xi : torch.Tensor
            The points at which to evaluate the PDF. Tensor of shape [B, N, T].

        Returns
        -------
        torch.Tensor
            The PDF values at `xi`. Tensor of shape [B, N, T].
        """

        xi = xi.unsqueeze(-2) # [B, N, 1, T]

        params, weights = self(x).values() # loc: [B, N, K, T]
        match self.distribution:
            case Distribution.GAUSSIAN:
                loc, scale = params.values()
                scale = torch.clamp(scale, min=1e-6)
                p = (xi - loc) / scale
                log_p = -0.5 * p.pow(2) - torch.log(scale) - 0.5 * torch.log(2 * math.pi)

            case Distribution.LAPLACIAN:
                loc, scale, kappa = params.values()
                scale = torch.clamp(scale, min=1e-6)
                kappa = torch.clamp(kappa, min=1e-6)
                

                p = (xi - loc) / scale
                mask = (p >= 0)

                log_p = torch.zeros_like(p)

                log_p[mask] = -1 * p[mask] * kappa[mask]
                log_p[~mask] = p[~mask] / kappa[~mask]

                log_p = log_p - torch.log(kappa + 1 / kappa) - torch.log(scale)

        log_w = torch.log(torch.clamp(weights, min=1e-10))
        log_p = torch.logsumexp(log_p + log_w, dim=-2) # [B, N, T]
    
        return log_p

    def quantile(self, x, q: list[float], max_iter: int = 50, tol: float = 1e-6):
        out = []

        # Solve one quantile at a time
        for qi in q:
            # Mean as the initial guess
            xi = self.mean(x)  # [B, N, T]
            for _ in range(max_iter):
                pdf = self._calc_logpdf(x, xi).exp()   # [B, N, T]
                cdf = self._calc_cdf(x, xi)            # [B, N, T]

                # Newton step
                delta = (cdf - qi) / (pdf + 1e-12)     # [B, N, T]
                new = xi - delta

                # Convergence check
                if delta.abs().max() < tol:
                    xi = new
                    break

                xi = new

            out.append(xi)

        # Stack quantiles → [B, N, Q, T]
        return torch.stack(out, dim=2)

model = LSTMMDN(config)

In [3]:
x_lstm = torch.randn(256, 365, 33)

out = model(x_lstm)
print(out["params"]["loc"].shape)

print(out["weights"].shape)

torch.Size([256, 15, 3, 2])
torch.Size([256, 15, 3, 2])


In [4]:
mean = model.mean(x_lstm)
print(mean.shape)

torch.Size([256, 15, 2])


In [5]:
xi = torch.randn(256, 15, 2)

logpdf = model._calc_logpdf(x_lstm, xi)
print(logpdf.shape)
cdf = model._calc_cdf(x_lstm, xi)
print(cdf.shape)

torch.Size([256, 15, 2])
torch.Size([256, 15, 2])


In [6]:
q = [0.025, 0.5, 0.975]

quantiles = model.quantile(x_lstm, q)

: 

In [None]:
mean = model.mean(x_lstm)
print(mean.shape)

In [None]:
samples = model.sample(x_lstm, num_samples=667)
print(samples.shape)

In [None]:
def _mask(*tensors):
    masks = []
    for tensor in tensors:
        num_dim = tensor.dim()
        for _ in range(num_dim - 1):
            tensor = tensor.sum(dim=1)
        mask = ~tensor.isnan()
        masks.append(mask)
    mask = torch.stack(masks, dim=1).all(dim=1)

    return tuple(tensor[mask] for tensor in tensors)

def loss_nll(
    params: dict[str, torch.Tensor],
    weights: torch.Tensor,
    dist: Distribution,
    y: torch.Tensor,
) -> torch.Tensor:
    
    y = y.unsqueeze(-2)  # [batch_size, sequence_length, 1, output_features]

    match dist:
        case Distribution.GAUSSIAN:
            loc, scale = params.values()
            scale = torch.clamp(scale, min=1e-6)
            p = (y - loc) / scale
            log_p = -0.5 * p.pow(2) - torch.log(scale) - 0.5 * torch.log(2 * math.pi)

        case Distribution.LAPLACIAN:
            loc, scale, kappa = params.values()
            scale = torch.clamp(scale, min=1e-6)
            kappa = torch.clamp(kappa, min=1e-6)

            p = (y - loc) / scale

            mask = (p >= 0)

            log_p = torch.zeros_like(p)

            log_p[mask] = -1 * p[mask] * kappa[mask]
            log_p[~mask] = p[~mask] / kappa[~mask]

            log_p = log_p - torch.log(kappa + 1 / kappa) - torch.log(scale)

    log_w = torch.log(torch.clamp(weights, min=1e-10))
    loss = -torch.logsumexp(log_p + log_w, dim=1)
    return loss.mean(dim=(0, 1))

y_obs = torch.randn(256, 15, 2)

loss = loss_nll(out["params"], out["weights"], model.distribution, y_obs)
print(loss)