In [1]:
%load_ext autoreload
%autoreload 2

# Load Data

In [2]:
from pathlib import Path
from opensynth.data_modules.lcl_data_module import LCLDataModule
import pytorch_lightning as pl

import matplotlib.pyplot as plt

data_path = Path("../../data/processed/historical/train/lcl_data.csv")
stats_path = Path("../../data/processed/historical/train/mean_std.csv")
outlier_path = Path("../../data/processed/historical/train/outliers.csv")

dm = LCLDataModule(data_path=data_path, stats_path=stats_path, batch_size=25000, n_samples=50000)
dm.setup()

In [3]:
import torch
from opensynth.models.faraday import FaradayVAE
vae_model = torch.load("vae_model.pt")

  vae_model = torch.load("vae_model.pt")


In [4]:
from opensynth.models.faraday.gaussian_mixture.prepare_gmm_input import encode_data_for_gmm

next_batch = next(iter(dm.train_dataloader()))
input_tensor = encode_data_for_gmm(data=next_batch, vae_module=vae_model)
input_data = input_tensor.detach().numpy()
n_samples = len(input_tensor)

In [24]:
N_COMPONENTS = 50

# Init GMM

In [177]:
from opensynth.models.faraday.new_gmm import gmm_utils

labels_, means_, responsibilities_ = gmm_utils.initialise_centroids(
        X=input_data, n_components=N_COMPONENTS
    )
print(labels_.dtype, responsibilities_.dtype, means_.dtype)
responsibilities_ = responsibilities_.double()
means_ = means_.double()

torch.float64 torch.float64 torch.float64


In [194]:
from opensynth.models.faraday.new_gmm.train_gmm import initialise_gmm_params
from opensynth.models.faraday.new_gmm.new_gmm_model import GaussianMixtureModel

gmm_init_params = initialise_gmm_params(
    X=input_data,
    n_components = N_COMPONENTS,
    reg_covar=1e-4,
)
torch_gmm = GaussianMixtureModel(
    num_components=N_COMPONENTS,
    num_features = input_data.shape[1],
    reg_covar=1e-4
)
torch_gmm.initialise(gmm_init_params)

In [195]:
torch_log_prob_norm, torch_log_resp = torch_gmm.e_step(input_tensor.double())

In [191]:
torch_log_resp[0]

tensor([-3.6227e+05, -1.2615e+01, -1.9258e+01, -2.4954e+01, -1.1200e+01,
        -8.3701e+00, -7.1378e+00, -3.1867e+01, -5.2818e+05, -1.3116e+04,
        -5.9181e+01, -2.5521e+01, -3.1732e+01, -3.1955e+01, -7.7710e+03,
        -4.8142e+05, -4.2992e+01, -3.7528e+01, -3.3823e-01, -3.3509e+00,
        -2.1167e+01, -3.5678e+01, -3.2415e+01, -1.4766e+00, -3.0800e+01,
        -6.5013e+01, -6.1166e+01, -3.1126e+01, -1.7480e+01, -3.3293e+01,
        -2.3036e+01, -1.7120e+04, -2.4828e+01, -5.6232e+07, -3.7960e+00,
        -2.3908e+01, -2.7681e+01, -2.3572e+01, -6.5275e+01, -2.0750e+01,
        -1.1995e+01, -1.1659e+07, -1.7925e+03, -2.1349e+01, -1.8273e+01,
        -3.5001e+01, -6.0848e+01, -1.3041e+02, -4.1237e+01, -2.7694e+01],
       dtype=torch.float64, grad_fn=<SelectBackward0>)

In [192]:
torch_log_prob_norm

tensor(0.2317, dtype=torch.float64, grad_fn=<MeanBackward0>)

In [196]:
torch_prec_chol, torch_weights, torch_means = torch_gmm.m_step(input_tensor.double(), log_reponsibilities=torch_log_resp)

In [217]:
torch_means[0]

tensor([-44.6996,   9.7845,  27.4608,  25.7868,   3.7500,  23.0785,   6.0063,
         62.9924, -15.2910,  57.9771, -48.0307,  32.3790, -26.4540,  -9.3495,
         -3.1272,  30.9442,   3.6154,   4.0000], dtype=torch.float64,
       grad_fn=<SelectBackward0>)

In [230]:
torch_log_prob_norm_step2, torch_log_resp_step2 = torch_gmm.e_step(input_tensor.double())

# SK Learn GMM 1 Epoch

In [241]:
from sklearn.mixture import GaussianMixture
skgmm = GaussianMixture(n_components=N_COMPONENTS, covariance_type='full', max_iter=1, random_state=0)
skgmm_pred = skgmm.fit_predict(input_data)



In [242]:
np.round(skgmm.means_[0],4)

array([-44.6996,   9.7845,  27.4608,  25.7868,   3.75  ,  23.0785,
         6.0063,  62.9924, -15.291 ,  57.9771, -48.0307,  32.379 ,
       -26.454 ,  -9.3495,  -3.1272,  30.9442,   3.6154,   4.    ])

In [243]:
torch_means[0]

tensor([-44.6996,   9.7845,  27.4608,  25.7868,   3.7500,  23.0785,   6.0063,
         62.9924, -15.2910,  57.9771, -48.0307,  32.3790, -26.4540,  -9.3495,
         -3.1272,  30.9442,   3.6154,   4.0000], dtype=torch.float64,
       grad_fn=<SelectBackward0>)

In [244]:
np.round(skgmm.weights_,5)

array([5.2000e-04, 3.9480e-02, 1.5670e-02, 2.3580e-02, 2.7170e-02,
       4.1030e-02, 1.2215e-01, 1.7200e-03, 2.0000e-04, 5.2000e-04,
       2.1700e-03, 2.5570e-02, 6.7000e-03, 7.4000e-03, 4.0000e-04,
       4.0000e-04, 1.9600e-03, 3.3300e-03, 7.0530e-02, 4.1900e-02,
       1.1660e-02, 4.2800e-03, 9.0300e-03, 1.2375e-01, 8.8600e-03,
       1.3000e-03, 1.1400e-03, 1.2020e-02, 8.3110e-02, 5.3500e-03,
       1.5110e-02, 6.0000e-04, 1.6110e-02, 8.0000e-05, 6.1140e-02,
       1.9990e-02, 7.1000e-03, 3.9560e-02, 1.6200e-03, 2.3030e-02,
       5.1410e-02, 2.4000e-04, 7.2000e-04, 1.5970e-02, 2.0150e-02,
       3.3200e-03, 9.4000e-04, 1.1200e-03, 3.1100e-03, 2.5770e-02])

In [245]:
torch_weights

tensor([5.2000e-04, 3.9001e-02, 1.5583e-02, 2.3427e-02, 2.7020e-02, 4.0506e-02,
        1.2368e-01, 1.7113e-03, 2.0000e-04, 5.2000e-04, 2.1690e-03, 2.5135e-02,
        6.6539e-03, 7.2543e-03, 4.0000e-04, 4.0000e-04, 1.9556e-03, 3.3038e-03,
        7.1844e-02, 4.1670e-02, 1.1400e-02, 4.2469e-03, 8.9280e-03, 1.2502e-01,
        8.6914e-03, 1.2980e-03, 1.1331e-03, 1.1831e-02, 8.4619e-02, 5.3282e-03,
        1.4921e-02, 6.0000e-04, 1.5852e-02, 8.0000e-05, 6.1452e-02, 1.9749e-02,
        7.0153e-03, 3.9111e-02, 1.6215e-03, 2.1725e-02, 5.2204e-02, 2.4000e-04,
        7.2000e-04, 1.5692e-02, 1.9778e-02, 3.3048e-03, 9.4403e-04, 1.1201e-03,
        3.0589e-03, 2.5361e-02], dtype=torch.float64, grad_fn=<DivBackward0>)

In [246]:
skgmm.precisions_cholesky_[0][0]

array([ 1.04936879e-01,  1.78828131e-01,  3.05499820e-01,  2.02909722e-01,
        4.12734793e-01,  2.30530027e-01,  2.52401349e-01,  5.15028124e-01,
       -2.32220625e+00,  3.02152732e+00, -8.87033752e+00,  1.18858961e+00,
        6.68453105e+02,  3.59762976e+01, -5.72866077e+01,  7.41760253e+01,
       -1.15800415e+02, -3.56074488e+02])

In [248]:
torch_prec_chol[0][0]

tensor([  0.1049,   0.1788,   0.3055,   0.2029,   0.4125,   0.2305,   0.2521,
          0.5148,  -2.3180,   3.0113,  -8.7476,   1.1410,  63.6677, -13.2915,
         -4.4993,   2.4293,  -8.2627, -15.2198], dtype=torch.float64,
       grad_fn=<SelectBackward0>)

In [250]:
skgmm_pred

array([18, 21, 40, ..., 32,  3, 22])

In [252]:
torch_log_resp_step2.argmax(dim=1)

tensor([18, 21, 40,  ..., 32,  3, 22])