# Load Data

In [15]:
from pathlib import Path
from opensynth.data_modules.lcl_data_module import LCLDataModule
import pytorch_lightning as pl

import matplotlib.pyplot as plt

data_path = Path("../../data/processed/historical/train/lcl_data.csv")
stats_path = Path("../../data/processed/historical/train/mean_std.csv")
outlier_path = Path("../../data/processed/historical/train/outliers.csv")

dm = LCLDataModule(data_path=data_path, stats_path=stats_path, batch_size=25000, n_samples=50000)
dm.setup()

In [8]:
import torch
from opensynth.models.faraday import FaradayVAE
vae_model = torch.load("vae_model.pt")

  vae_model = torch.load("vae_model.pt")


In [127]:
from opensynth.models.faraday.gaussian_mixture.prepare_gmm_input import encode_data_for_gmm

next_batch = next(iter(dm.train_dataloader()))
input_tensor = encode_data_for_gmm(data=next_batch, vae_module=vae_model)
input_data = input_tensor.detach().numpy()
n_samples = len(input_tensor)

# Init GMM

In [122]:
import torch
import numpy as np
from sklearn.cluster import KMeans
from torch.utils.data import DataLoader
from opensynth.models.faraday import FaradayVAE

def initialise_gmm(dataloader: DataLoader, vae_module: FaradayVAE, n_components: int, init_method: str = "kmeans", random_state: int = 0):
    
    # Get input
    next_batch = next(iter(dataloader))
    input_data = encode_data_for_gmm(data=next_batch, vae_module=vae_module).detach().numpy()
    n_samples = len(input_data)
    
    if init_method == "kmeans":
        kmeans_model = KMeans(n_clusters=n_components, random_state=random_state)
        kmeans_model.fit(input_data)
        labels = torch.from_numpy(kmeans_model.labels_)
        means = torch.from_numpy(kmeans_model.cluster_centers_)
        
        responsibilities = np.zeros((n_samples, n_components))
        responsibilities[np.arange(n_samples), labels] = 1
        responsibilities = torch.from_numpy(responsibilities)
        return labels, means, responsibilities
    else:
        raise NotImplementedError("Only kmeans is supported for now")

In [184]:
from scipy import linalg

def sk_estimate_gaussian_parameters(X, means, resp, reg_covar):
    nk = (
        resp.sum(axis=0) + 10 * np.finfo(resp.dtype).eps
    )  # This adds small white noise to avoid division by zero
    means = means + 10 * np.finfo(resp.dtype).eps

    n_components, n_features = means.shape
    covariances = np.empty((n_components, n_features, n_features))
    for k in range(n_components):
        diff = X - means[k]
        covariances[k] = np.dot(resp[:, k] * diff.T, diff) / nk[k]
        covariances[k].flat[:: n_features + 1] += reg_covar

    return nk, means, covariances


def sk_compute_precision_cholesky(covariances):
    estimate_precision_error_message = (
        "Fitting the mixture model failed because some components have "
        "ill-defined empirical covariance (for instance caused by singleton "
        "or collapsed samples). Try to decrease the number of components, "
        "or increase reg_covar."
    )

    n_components, n_features, _ = covariances.shape
    precisions_chol = np.empty((n_components, n_features, n_features))
    for k, covariance in enumerate(covariances):
        try:
            cov_chol = linalg.cholesky(covariance, lower=True)
        except linalg.LinAlgError:
            raise ValueError(estimate_precision_error_message)
        precisions_chol[k] = linalg.solve_triangular(
            cov_chol, np.eye(n_features), lower=True
        ).T

    return precisions_chol

In [185]:
labels_, means_, responsibilities_ = initialise_gmm(dm.train_dataloader(), vae_model, n_components=2)

In [251]:
def torch_estimate_gaussian_parameters(X: torch.Tensor, responsibilities: torch.Tensor, means: torch.Tensor, reg_covar: float):
    """
    Pytorch port of SK Learn's method to estimate gaussian parameters
    link

    Args:
        X (torch.Tensor): Input data
        responsibilities (torch.Tensor): Reponsibilities, i.e. 1-hot encoded tensor of each data and it's cluster label. 
        means (torch.Tensor): Coordinate of centroids
        reg_covar (float): Covariance regularisor

    Returns:
        _type_: _description_
    """
    n_components, n_features = means.shape
    weights = responsibilities.sum(axis=0) + torch.finfo(responsibilities.dtype).eps
    covariances = torch.empty((n_components, n_features, n_features))
    # Avoid division by zero error
    means_eps = means_ + torch.finfo(means.dtype).eps

    for k in range(n_components):
        diff = X - means_eps[k]
        covariances[k] = torch.matmul(responsibilities_[:, k].float() * diff.T, diff) / weights[k]
    
    # Add small regularisation
    covariances += reg_covar
    return weights, covariances


def torch_compute_precision_cholesky(covariances: torch.Tensor):
    estimate_precision_error_message = (
        "Fitting the mixture model failed because some components have "
        "ill-defined empirical covariance (for instance caused by singleton "
        "or collapsed samples). Try to decrease the number of components, "
        "or increase reg_covar."
    )

    n_components, n_features, _ = covariances.shape
    precisions_chol = torch.empty((n_components, n_features, n_features))
    for k, covariance in enumerate(covariances):
        try:
            cov_chol = torch.linalg.cholesky(covariance, upper=False)
        except torch.linalg.LinAlgError:
            raise ValueError(estimate_precision_error_message)
        precisions_chol[k] = torch.linalg.solve_triangular(
            cov_chol, torch.eye(n_features), upper = False
        ).T
    return precisions_chol

In [210]:
sknk, skmeans, skcovars = sk_estimate_gaussian_parameters(input_data, responsibilities_.detach().numpy(), 1e-6)
torch_weights, torch_covars = torch_estimate_gaussian_parameters(input_tensor, responsibilities_, means_, 1e-6)

In [252]:
sk_precision_cholesky = sk_compute_precision_cholesky(skcovars)
torch_precision_cholesky = torch_compute_precision_cholesky(torch_covars)

In [211]:
torch_covars[0][0]

tensor([ 252.1244, -137.3482,  -14.4317,  -95.1547,   71.9191, -146.8477,
         -72.0670, -109.0057,  -56.2617, -236.9324,  155.4083,  -47.3323,
          41.4401,   31.5037,   23.1437, -109.2889,    8.4186,   -6.4330],
       grad_fn=<SelectBackward0>)

In [212]:
skcovars[0][0]

array([ 252.12438633, -137.34862861,  -14.43184656,  -95.15741666,
         71.91888238, -146.84746139,  -72.06651698, -109.00537524,
        -56.26219416, -236.93248741,  155.40830527,  -47.3335113 ,
         41.44068616,   31.50417677,   23.14397379, -109.28911124,
          8.41860832,   -6.43301864])

In [254]:
sk_precision_cholesky[0][0]

array([ 0.06297854,  0.05716094,  0.05703893,  0.09549221,  0.03216006,
        0.25168734, -0.02355435, -0.15893924, -0.60385687, -0.04904175,
       -0.22987771,  0.10108616, -0.83541574,  0.38884346, -0.35199873,
        0.15953552,  0.10748288, -0.13349632])

In [255]:
torch_precision_cholesky[0][0]

tensor([ 0.0630,  0.0572,  0.0570,  0.0952,  0.0320,  0.2501, -0.0226, -0.1563,
        -0.6052, -0.0482, -0.2303,  0.0977, -0.8269,  0.3978, -0.3384,  0.1544,
         0.1112, -0.1292], grad_fn=<SelectBackward0>)