In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
full_data = pd.read_csv("../../data/processed/historical/train/lcl_data.csv")
df_100K = full_data.sample(100000, random_state=0)
df_100K.to_csv("../../data/processed/historical/train/lcl_data_100K.csv", index=False)

# Load Data

In [3]:
import torch
import numpy as np
import random
RANDOM_STATE = 0
torch.manual_seed(RANDOM_STATE)
torch.use_deterministic_algorithms(True)
g = torch.Generator()
g.manual_seed(RANDOM_STATE)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [4]:
from pathlib import Path
from opensynth.data_modules.lcl_data_module import LCLDataModule
import pytorch_lightning as pl

import matplotlib.pyplot as plt

data_path = Path("../../data/processed/historical/train/lcl_data_100K.csv")
stats_path = Path("../../data/processed/historical/train/mean_std.csv")
outlier_path = Path("../../data/processed/historical/train/outliers.csv")

dm = LCLDataModule(data_path=data_path, stats_path=stats_path, batch_size=100000, n_samples=100000)
dm.setup()

In [5]:
import torch
from opensynth.models.faraday import FaradayVAE
vae_model = torch.load("vae_model.pt")
vae_model.eval()

FaradayVAE(
  (encoder): Encoder(
    (encoder_layers): Sequential(
      (0): Linear(in_features=50, out_features=512, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=512, out_features=256, bias=True)
      (3): GELU(approximate='none')
      (4): Linear(in_features=256, out_features=128, bias=True)
      (5): GELU(approximate='none')
      (6): Linear(in_features=128, out_features=64, bias=True)
      (7): GELU(approximate='none')
      (8): Linear(in_features=64, out_features=32, bias=True)
      (9): GELU(approximate='none')
      (10): Linear(in_features=32, out_features=16, bias=True)
    )
  )
  (decoder): Decoder(
    (latent): Linear(in_features=18, out_features=16, bias=True)
    (latent_activations): GELU(approximate='none')
    (decoder_layers): Sequential(
      (0): Linear(in_features=16, out_features=32, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=32, out_features=64, bias=True)
      (3): GELU(approximate='no

In [6]:
from opensynth.models.faraday.gaussian_mixture.prepare_gmm_input import encode_data_for_gmm

next_batch = next(iter(dm.train_dataloader()))
input_tensor = encode_data_for_gmm(data=next_batch, vae_module=vae_model)
input_data = input_tensor.detach().numpy()
n_samples = len(input_tensor)

In [7]:
N_COMPONENTS = 200
REG_COVAR = 1e-4
EPOCHS = 25
IDX = 0
CONVERGENCE_TOL = 1e-2


In [8]:
input_tensor.shape, input_tensor[0][0]

(torch.Size([100000, 18]), tensor(0.4973, grad_fn=<SelectBackward0>))

# Init GMM

In [9]:
from opensynth.models.faraday.new_gmm import gmm_utils

labels_, means_, responsibilities_ = gmm_utils.initialise_centroids(
        X=input_data, n_components=N_COMPONENTS
    )
print(labels_.dtype, responsibilities_.dtype, means_.dtype)

torch.float32 torch.float32 torch.float32


In [10]:
from opensynth.models.faraday.new_gmm.train_gmm import initialise_gmm_params

gmm_init_params = initialise_gmm_params(
    X=input_data,
    n_components = N_COMPONENTS,
    reg_covar=REG_COVAR,
)
print(gmm_init_params["precision_cholesky"][IDX][0][0])
print(gmm_init_params["weights"].sum())

tensor(5.5252)
tensor(1.0000)


# Torch Lightning Batch Learning

In [11]:
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningDataModule
class CustomDataset(Dataset):
    def __init__(self, data_tensor: torch.Tensor):
        self.data = data_tensor
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]
    
class CustomDataModule(LightningDataModule):
    def __init__(self, data_tensor: torch.Tensor, batch_size: int):
        super().__init__()
        self.data_tensor = data_tensor
        self.batch_size = batch_size
    def setup(self, stage=""):
        self.custom_ds = CustomDataset(self.data_tensor)
    def train_dataloader(self):
        return DataLoader(self.custom_ds, batch_size=self.batch_size, shuffle=False, generator=g, worker_init_fn=seed_worker)
    
custom_dm = CustomDataModule(data_tensor=input_tensor, batch_size=25000)
custom_dm.setup(stage="")

In [12]:
from opensynth.models.faraday.new_gmm.new_gmm_model import GaussianMixtureLightningModule, GaussianMixtureModel
gmm_module = GaussianMixtureModel(
    num_components=N_COMPONENTS,
    num_features = input_data.shape[1],
    reg_covar=REG_COVAR,
    print_idx=IDX
)
gmm_module.initialise(gmm_init_params)
print(f"Initial prec chol: {gmm_module.precision_cholesky[IDX][0][0]}. Initial mean: {gmm_module.means[IDX][0]}")

gmm_lightning_module = GaussianMixtureLightningModule(
    gmm_module = gmm_module,
    vae_module = vae_model,
    num_components = gmm_module.num_components,
    num_features = gmm_module.num_features,
    reg_covar = gmm_module.reg_covar,
    convergence_tolerance = CONVERGENCE_TOL,
    compute_on_batch=False
)
trainer = pl.Trainer(max_epochs=EPOCHS, accelerator="cpu", deterministic=True )
trainer.fit(gmm_lightning_module, custom_dm)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/charlotte.avery/.virtualenvs/OpenSynth-BNsxhSIM/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/charlotte.avery/.virtualenvs/OpenSynth-BNsxhSIM/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/Users/charlotte.avery/.virtualenvs/OpenSynth-BNsxhSIM/lib/python3.11/site-packages/pytorch_lightning/core/opt

Initial prec chol: 5.5251617431640625. Initial mean: -0.3756513297557831
Epoch 0: 100%|██████████| 4/4 [00:02<00:00,  1.81it/s, v_num=92]Local weights at rank: 0 - means: 0.0224, -0.3274
Reduced weights, means, covar: 0.0224, -0.3274, 0.0144
NLL:  tensor(3.2374)
Epoch 1: 100%|██████████| 4/4 [00:01<00:00,  2.03it/s, v_num=92]Local weights at rank: 0 - means: 0.0202, -0.3126
Reduced weights, means, covar: 0.0202, -0.3126, 0.0071
NLL:  tensor(2.7450)
Epoch 2: 100%|██████████| 4/4 [00:02<00:00,  1.77it/s, v_num=92]Local weights at rank: 0 - means: 0.0188, -0.3120
Reduced weights, means, covar: 0.0188, -0.3120, 0.0049
NLL:  tensor(2.5370)
Epoch 3: 100%|██████████| 4/4 [00:02<00:00,  1.96it/s, v_num=92]Local weights at rank: 0 - means: 0.0211, -0.3164
Reduced weights, means, covar: 0.0211, -0.3164, 0.0037
NLL:  tensor(2.4421)
Epoch 4: 100%|██████████| 4/4 [00:02<00:00,  1.83it/s, v_num=92]Local weights at rank: 0 - means: 0.0227, -0.3225
Reduced weights, means, covar: 0.0227, -0.3225, 0.003

In [13]:
gmm_lightning_module.gmm_module.means

tensor([[-0.3319, -1.6357,  0.7637,  ..., -1.7925,  7.6466,  3.5776],
        [ 0.0183, -4.0315, -1.4313,  ..., -3.2578,  8.1940,  2.4596],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.3236, -3.3379,  0.6209,  ..., -1.8338,  1.1450,  1.7521],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])

In [14]:
ligthning_sum_components = gmm_lightning_module.gmm_module.means.sum(axis=1)
len(ligthning_sum_components[ligthning_sum_components==0])

89

## SK-learn batch learning

In [15]:
from sklearn.mixture import GaussianMixture
init_weights = gmm_init_params["weights"]
init_means = gmm_init_params["means"]

skgmm = GaussianMixture(n_components=N_COMPONENTS, covariance_type='full', tol=CONVERGENCE_TOL, max_iter=EPOCHS, random_state=0, means_init = init_means, weights_init=init_weights, warm_start=True, verbose=1)

dl = custom_dm.train_dataloader()
next_batch = next(iter(dl))
for batch_num, batch_data in enumerate(dl):
    print("Batch number: ", batch_num)
    input_data = batch_data.detach().numpy()
    n_samples = len(input_tensor)
    skgmm.fit(input_data)
    print("means: ", skgmm.means_)


Batch number:  0
Initialization 0
  Iteration 10
  Iteration 20
Initialization did not converge.




means:  [[ -0.27304786  -1.66365247   0.60691867 ...  -1.69716843   7.96467992
    3.16018055]
 [  0.           0.           0.         ...   0.           0.
    0.        ]
 [  0.           0.           0.         ...   0.           0.
    0.        ]
 ...
 [  0.           0.           0.         ...   0.           0.
    0.        ]
 [  0.27640811 -13.29458469  -1.64329459 ...  -5.77571846   6.95806539
    2.85220053]
 [  0.           0.           0.         ...   0.           0.
    0.        ]]
Batch number:  1
Initialization 0
  Iteration 10
Initialization converged.
means:  [[ -0.30174702  -1.56455869   0.68509828 ...  -1.6584728    7.67447356
    3.17778859]
 [  0.           0.           0.         ...   0.           0.
    0.        ]
 [  0.           0.           0.         ...   0.           0.
    0.        ]
 ...
 [  0.           0.           0.         ...   0.           0.
    0.        ]
 [  0.47426091 -15.22395937  -2.84767364 ...  -7.86735751   6.45543307
    2.6997854

In [16]:
skgmm.means_

array([[ -0.30653206,  -1.47836938,   0.68937353, ...,  -1.60552788,
          7.69479591,   3.40170958],
       [  0.        ,   0.        ,   0.        , ...,   0.        ,
          0.        ,   0.        ],
       [  0.        ,   0.        ,   0.        , ...,   0.        ,
          0.        ,   0.        ],
       ...,
       [  0.        ,   0.        ,   0.        , ...,   0.        ,
          0.        ,   0.        ],
       [  1.23822623, -18.68083721,  -2.66390906, ...,  -9.89375691,
          6.25444043,   2.99957082],
       [  0.        ,   0.        ,   0.        , ...,   0.        ,
          0.        ,   0.        ]])

In [17]:
sklearn_sum_components = skgmm.means_.sum(axis=1)
len(sklearn_sum_components[sklearn_sum_components==0])

145

# Compare

In [18]:
IDX = 0

In [19]:
df_compare_means = pd.DataFrame()
df_compare_means["skgmm"] = skgmm.means_[IDX]
df_compare_means["lightning"] = gmm_lightning_module.gmm_module.means[IDX]
df_compare_means

Unnamed: 0,skgmm,lightning
0,-0.306532,-0.331899
1,-1.478369,-1.635727
2,0.689374,0.76371
3,-0.476471,-0.530579
4,0.229295,0.243099
5,-0.001939,-0.007573
6,0.83609,0.895278
7,0.655178,0.730462
8,0.018107,0.040447
9,-1.286877,-1.583683


In [20]:
df_compare_covar = pd.DataFrame()
df_compare_covar["skgmm"] = skgmm.covariances_[IDX][0]
df_compare_covar["lightning"] = gmm_lightning_module.gmm_module.covariances.detach().numpy()[IDX][0]
df_compare_covar

Unnamed: 0,skgmm,lightning
0,0.003206,0.002145
1,0.004286,0.006281
2,-0.006302,-0.00379
3,0.005536,0.003262
4,-0.000496,0.000255
5,0.000901,0.000547
6,-0.0053,-0.003235
7,-0.00371,-0.004842
8,-0.001353,-0.000872
9,0.014917,0.013158


In [21]:
df_compare_pre_chol = pd.DataFrame()
df_compare_pre_chol["skgmm"] = skgmm.precisions_cholesky_[IDX][0]
df_compare_pre_chol["lightning"] = gmm_lightning_module.gmm_module.precision_cholesky.detach().numpy()[IDX][0]
df_compare_pre_chol

Unnamed: 0,skgmm,lightning
0,17.661761,21.589685
1,-3.481608,-10.049465
2,21.699544,15.095892
3,-2.816636,-6.224689
4,0.844474,3.380112
5,-28.989569,-16.800541
6,4.863907,3.334535
7,17.76335,9.648892
8,2.93221,-0.45522
9,-40.254351,-23.81101


In [22]:
df_compare_weights = pd.DataFrame()
df_compare_weights["skgmm"] = skgmm.weights_[:10]
df_compare_weights["lightning"] = gmm_lightning_module.gmm_module.weights.detach().numpy()[:10]
df_compare_weights

Unnamed: 0,skgmm,lightning
0,0.02498987,0.01638867
1,8.881784e-20,0.01255582
2,8.881784e-20,4.768371e-12
3,8.881784e-20,4.768371e-12
4,0.006050761,0.008841502
5,8.881784e-20,0.01065066
6,8.881784e-20,0.008192629
7,8.881784e-20,4.768371e-12
8,8.881784e-20,4.768371e-12
9,8.881784e-20,0.01158984
