In [None]:
from config import settings
from hannover_pylon.data import datamodules as dm
from pathlib import Path
from hannover_pylon.modelling.backbone.utils import FromBuffer , CutPSD, NormLayer
import matplotlib.pyplot as plt
from torch import nn 
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
from pathlib import Path
from torch import nn
freq_axis = np.linspace(0, 825.5, 8197)
db_path = Path(settings.path.processed,'Welch(n_fft=16392, fs=1651, max_freq=825.5).db')
columns= ['psd','level','direction']
transform_func = [nn.Sequential(FromBuffer(),CutPSD(freq_axis=freq_axis,freq_range=(0,150)),NormLayer(min_val=-5.46,max_val=4.96))] + [nn.Identity()]*2
query_key = f'''
    SELECT id FROM data
    WHERE date BETWEEN "{settings.state.healthy_train.start}" AND "{settings.state.healthy_train.end}"
    AND corrupted = 0
    AND sensor = "accel"
'''
data_loader = dm.PSDDataModule(db_path= db_path,table_name='data', columns=columns,transform_func=transform_func, query_key=query_key, batch_size=128, return_dict=True, cached=True,num_workers=16)


In [None]:
data_loader.setup()

In [None]:
for batch in data_loader.train_dataloader():
    break

In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from hannover_pylon.modelling.backbone import utils as ut

def gaussian_nll_loss(mu, logvar, target):
    """
    Compute the Gaussian negative log-likelihood loss for 1D regression.
    For each sample, if the target is y, predicted mean is mu, and predicted log variance is s,
    then the loss is:
    
        NLL = 0.5 * exp(-s) * (y - mu)^2 + 0.5 * s.
    
    Args:
        mu (Tensor): Predicted location mean of shape (B, 1)
        logvar (Tensor): Predicted log variance of shape (B, 1)
        target (Tensor): Ground truth location of shape (B, 1)
        
    Returns:
        Tensor: Averaged NLL loss.
    """
    sq_error = (target - mu) ** 2  # shape: (B, 1)
    loss = 0.5 * torch.exp(-logvar) * sq_error + 0.5 * logvar
    return loss.mean()


class OneToOneAutoEncoderWithRegressorNLL(nn.Module):
    def __init__(self, psd_length=1490, encoder_dims=[512, 128, 64], latent_dim=32):
        """
        A one-to-one autoencoder that reconstructs a sensor's PSD (from key "psd") and also
        predicts its 1D location with uncertainty.
        
        The encoder processes the PSD, and the decoder is symmetric to the encoder (i.e. the decoder's
        hidden dimensions are the reverse of the encoder's). The location regressor outputs 2 values:
          - The first is the predicted 1D location mean.
          - The second is the predicted log variance.
        """
        super().__init__()
        # Encoder: build layers using the provided encoder dimensions.
        self.encoder = ut.build_layers(
            hidden_dims=[psd_length] + encoder_dims,
            activation_list='relu',
            batch_norm=True,
        )
        self.latent_layer = nn.Linear(encoder_dims[-1], latent_dim)
        
        # Decoder: symmetric to the encoder.
        decoder_dims = encoder_dims[::-1]
        self.decoder = ut.build_layers(
            hidden_dims=[latent_dim] + decoder_dims + [psd_length],
            activation_list=['relu'] * (len(decoder_dims) + 1) + [None],
            batch_norm=True,
        )
        
        # Regressor: predicts the 1D sensor location and log variance.
        self.location_regressor = nn.Sequential(
            nn.Linear(latent_dim, latent_dim // 2),
            nn.ReLU(),
            nn.Linear(latent_dim // 2, 2)  # 2 outputs: [location_mean, location_logvar]
        )
        
    def forward(self, x):
        """
        Args:
            x (dict): Must contain a key "psd" with tensor of shape (B, psd_length).
        
        Returns:
            dict: Contains:
              - "reconstruction": reconstructed PSD (B, psd_length)
              - "latent": latent embedding (B, latent_dim)
              - "location_mean": predicted 1D location (B, 1)
              - "location_logvar": predicted log variance (B, 1)
        """
        psd = x["psd"]  # (B, psd_length)
        encoded = self.encoder(psd)
        latent = self.latent_layer(encoded)
        reconstruction = self.decoder(latent)
        loc_out = self.location_regressor(latent)  # shape: (B, 2)
        location_mean = loc_out[:, :1]              # (B, 1)
        location_logvar = loc_out[:, 1:].unsqueeze(1) # (B, 1)
        return {
            "reconstruction": reconstruction,
            "latent": latent,
            "location_mean": location_mean,
            "location_logvar": location_logvar
        }


class OneToOneTrainingModule(pl.LightningModule):
    def __init__(self, psd_length=1490, encoder_dims=[512, 128, 64], latent_dim=32, lr=1e-3, location_loss_weight=1.0):
        """
        Args:
            location_loss_weight (float): Weight for the location regression (NLL) loss.
        """
        super().__init__()
        self.model = OneToOneAutoEncoderWithRegressorNLL(psd_length=psd_length, encoder_dims=encoder_dims, latent_dim=latent_dim)
        self.recon_loss_fn = nn.MSELoss()
        self.lr = lr
        self.location_loss_weight = location_loss_weight
        
    def forward(self, x):
        return self.model(x)
    
    def _common_step(self, batch, batch_idx):
        output = self(batch)
        recon_loss = self.recon_loss_fn(output["reconstruction"], batch["psd"])
        loc_loss = gaussian_nll_loss(output["location_mean"], output["location_logvar"], batch["location"])
        loss = recon_loss + self.location_loss_weight * loc_loss
        return loss
    
    def training_step(self, batch, batch_idx):
        loss = self._common_step(batch, batch_idx)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss = self._common_step(batch, batch_idx)
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)
