In [23]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
import pandas as pd
full_data = pd.read_csv("../../data/processed/historical/train/lcl_data.csv")
df_100K = full_data.sample(100000, random_state=0)
df_100K.to_csv("../../data/processed/historical/train/lcl_data_100K.csv", index=False)

# Load Data

In [25]:
import torch
import numpy as np
import random
RANDOM_STATE = 0
torch.manual_seed(RANDOM_STATE)
torch.use_deterministic_algorithms(True)
g = torch.Generator()
g.manual_seed(RANDOM_STATE)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [26]:
from pathlib import Path
from opensynth.data_modules.lcl_data_module import LCLDataModule
import pytorch_lightning as pl

import matplotlib.pyplot as plt

data_path = Path("../../data/processed/historical/train/lcl_data_100K.csv")
stats_path = Path("../../data/processed/historical/train/mean_std.csv")
outlier_path = Path("../../data/processed/historical/train/outliers.csv")

dm = LCLDataModule(data_path=data_path, stats_path=stats_path, batch_size=100000, n_samples=100000)
dm.setup()

In [27]:
import torch
from opensynth.models.faraday import FaradayVAE
vae_model = torch.load("vae_model.pt")
vae_model.eval()

FaradayVAE(
  (encoder): Encoder(
    (encoder_layers): Sequential(
      (0): Linear(in_features=50, out_features=512, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=512, out_features=256, bias=True)
      (3): GELU(approximate='none')
      (4): Linear(in_features=256, out_features=128, bias=True)
      (5): GELU(approximate='none')
      (6): Linear(in_features=128, out_features=64, bias=True)
      (7): GELU(approximate='none')
      (8): Linear(in_features=64, out_features=32, bias=True)
      (9): GELU(approximate='none')
      (10): Linear(in_features=32, out_features=16, bias=True)
    )
  )
  (decoder): Decoder(
    (latent): Linear(in_features=18, out_features=16, bias=True)
    (latent_activations): GELU(approximate='none')
    (decoder_layers): Sequential(
      (0): Linear(in_features=16, out_features=32, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=32, out_features=64, bias=True)
      (3): GELU(approximate='no

In [28]:
from opensynth.models.faraday.gaussian_mixture.prepare_gmm_input import encode_data_for_gmm

next_batch = next(iter(dm.train_dataloader()))
input_tensor = encode_data_for_gmm(data=next_batch, vae_module=vae_model)
input_data = input_tensor.detach().numpy()
n_samples = len(input_tensor)

In [29]:
N_COMPONENTS = 200
REG_COVAR = 1e-4
EPOCHS = 25
IDX = 0
CONVERGENCE_TOL = 1e-2


In [30]:
input_tensor.shape, input_tensor[0][0]

(torch.Size([100000, 18]), tensor(0.4973, grad_fn=<SelectBackward0>))

# Init GMM

In [31]:
from opensynth.models.faraday.new_gmm import gmm_utils

labels_, means_, responsibilities_ = gmm_utils.initialise_centroids(
        X=input_data, n_components=N_COMPONENTS
    )
print(labels_.dtype, responsibilities_.dtype, means_.dtype)

torch.float32 torch.float32 torch.float32


In [32]:
from opensynth.models.faraday.new_gmm.train_gmm import initialise_gmm_params

gmm_init_params = initialise_gmm_params(
    X=input_data,
    n_components = N_COMPONENTS,
    reg_covar=REG_COVAR,
)
print(gmm_init_params["precision_cholesky"][IDX][0][0])
print(gmm_init_params["weights"].sum())

tensor(5.5259)
tensor(1.)


# Torch Lightning Batch Learning

In [33]:
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningDataModule
class CustomDataset(Dataset):
    def __init__(self, data_tensor: torch.Tensor):
        self.data = data_tensor
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]
    
class CustomDataModule(LightningDataModule):
    def __init__(self, data_tensor: torch.Tensor, batch_size: int):
        super().__init__()
        self.data_tensor = data_tensor
        self.batch_size = batch_size
    def setup(self, stage=""):
        self.custom_ds = CustomDataset(self.data_tensor)
    def train_dataloader(self):
        return DataLoader(self.custom_ds, batch_size=self.batch_size, shuffle=False, generator=g, worker_init_fn=seed_worker)
    
custom_dm = CustomDataModule(data_tensor=input_tensor, batch_size=25000)
custom_dm.setup(stage="")

In [34]:
from opensynth.models.faraday.new_gmm.new_gmm_model import GaussianMixtureLightningModule, GaussianMixtureModel
gmm_module = GaussianMixtureModel(
    num_components=N_COMPONENTS,
    num_features = input_data.shape[1],
    reg_covar=REG_COVAR,
    print_idx=IDX
)
gmm_module.initialise(gmm_init_params)
print(f"Initial prec chol: {gmm_module.precision_cholesky[IDX][0][0]}. Initial mean: {gmm_module.means[IDX][0]}")

gmm_lightning_module = GaussianMixtureLightningModule(
    gmm_module = gmm_module,
    vae_module = vae_model,
    num_components = gmm_module.num_components,
    num_features = gmm_module.num_features,
    reg_covar = gmm_module.reg_covar,
    convergence_tolerance = CONVERGENCE_TOL,
    compute_on_batch=False
)
trainer = pl.Trainer(max_epochs=EPOCHS, accelerator="cpu", deterministic=True )
trainer.fit(gmm_lightning_module, custom_dm)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/charlotte.avery/.virtualenvs/OpenSynth-BNsxhSIM/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/charlotte.avery/.virtualenvs/OpenSynth-BNsxhSIM/lib/python3.11/site-packages/pytorch_lightning/core/optimizer.py:182: `LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer

  | Name                      | Type                    | Params | Mode 
------------------------------------------------------------------------------
0 | gmm_module                | GaussianMixtureModel    | 0      | train
1 | vae_module                | FaradayVAE              | 402 K  | eval 
2 | weight_metric             | WeightsMetric           | 0      | train
3 | mean_metric               | MeansMetric             | 0      | train
4 | precision_ch

Initial prec chol: 5.525859355926514. Initial mean: -0.3752790093421936
Epoch 0: 100%|██████████| 4/4 [00:02<00:00,  1.78it/s, v_num=91]Local weights at rank: 0 - means: 0.0224, -0.3269
Reduced weights, means, covar: 0.0224, -0.3269, 0.0144
log prob:  tensor(3.2367)
Epoch 1: 100%|██████████| 4/4 [00:01<00:00,  2.12it/s, v_num=91]Local weights at rank: 0 - means: 0.0202, -0.3121
Reduced weights, means, covar: 0.0202, -0.3121, 0.0071
log prob:  tensor(2.7456)
Epoch 2: 100%|██████████| 4/4 [00:01<00:00,  2.16it/s, v_num=91]Local weights at rank: 0 - means: 0.0188, -0.3117
Reduced weights, means, covar: 0.0188, -0.3117, 0.0049
log prob:  tensor(2.5395)
Epoch 3: 100%|██████████| 4/4 [00:01<00:00,  2.29it/s, v_num=91]Local weights at rank: 0 - means: 0.0211, -0.3161
Reduced weights, means, covar: 0.0211, -0.3161, 0.0037
log prob:  tensor(2.4432)
Epoch 4: 100%|██████████| 4/4 [00:01<00:00,  2.21it/s, v_num=91]Local weights at rank: 0 - means: 0.0227, -0.3224
Reduced weights, means, covar: 0.0

In [36]:
gmm_lightning_module.gmm_module.means

tensor([[-0.3320, -1.6353,  0.7637,  ..., -1.7928,  7.7219,  3.6024],
        [ 0.0227, -4.0463, -1.4244,  ..., -3.2790,  8.0274,  2.4504],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])

In [37]:
ligthning_sum_components = gmm_lightning_module.gmm_module.means.sum(axis=1)
len(ligthning_sum_components[ligthning_sum_components==0])

90

## SK-learn batch learning

In [38]:
from sklearn.mixture import GaussianMixture
init_weights = gmm_init_params["weights"]
init_means = gmm_init_params["means"]

skgmm = GaussianMixture(n_components=N_COMPONENTS, covariance_type='full', tol=CONVERGENCE_TOL, max_iter=EPOCHS, random_state=0, means_init = init_means, weights_init=init_weights, warm_start=True, verbose=1)

dl = custom_dm.train_dataloader()
next_batch = next(iter(dl))
for batch_num, batch_data in enumerate(dl):
    print("Batch number: ", batch_num)
    input_data = batch_data.detach().numpy()
    n_samples = len(input_tensor)
    skgmm.fit(input_data)
    print("means: ", skgmm.means_)


Batch number:  0
Initialization 0
  Iteration 10
  Iteration 20
Initialization did not converge.




means:  [[ -0.27297892  -1.66418081   0.60645039 ...  -1.69709651   7.9664832
    3.15889911]
 [  0.           0.           0.         ...   0.           0.
    0.        ]
 [  0.           0.           0.         ...   0.           0.
    0.        ]
 ...
 [  0.           0.           0.         ...   0.           0.
    0.        ]
 [  0.39421762 -13.38144488  -1.88406379 ...  -5.81491683   7.04163387
    2.75308323]
 [  0.           0.           0.         ...   0.           0.
    0.        ]]
Batch number:  1
Initialization 0
Initialization converged.
means:  [[ -0.30060178  -1.56674916   0.68211978 ...  -1.6561499    7.66094338
    3.16973349]
 [  0.           0.           0.         ...   0.           0.
    0.        ]
 [  0.           0.           0.         ...   0.           0.
    0.        ]
 ...
 [  0.           0.           0.         ...   0.           0.
    0.        ]
 [  0.51273646 -14.29359097  -2.78757385 ...  -7.26379998   6.79046349
    2.56269853]
 [  0.       

In [39]:
skgmm.means_

array([[ -0.30425054,  -1.47796971,   0.68436253, ...,  -1.60052826,
          7.71406474,   3.38502124],
       [  0.        ,   0.        ,   0.        , ...,   0.        ,
          0.        ,   0.        ],
       [  0.        ,   0.        ,   0.        , ...,   0.        ,
          0.        ,   0.        ],
       ...,
       [  0.        ,   0.        ,   0.        , ...,   0.        ,
          0.        ,   0.        ],
       [  1.26299495, -17.71371465,  -3.04977976, ...,  -8.88094948,
          6.44890154,   2.84957527],
       [  0.        ,   0.        ,   0.        , ...,   0.        ,
          0.        ,   0.        ]])

In [40]:
sklearn_sum_components = skgmm.means_.sum(axis=1)
len(sklearn_sum_components[sklearn_sum_components==0])

146

# Compare

In [44]:
IDX = 0

In [45]:
df_compare_means = pd.DataFrame()
df_compare_means["skgmm"] = skgmm.means_[IDX]
df_compare_means["lightning"] = gmm_lightning_module.gmm_module.means[IDX]
df_compare_means

Unnamed: 0,skgmm,lightning
0,-0.304251,-0.332041
1,-1.47797,-1.635299
2,0.684363,0.763727
3,-0.472924,-0.531022
4,0.226218,0.241841
5,-0.00107,-0.007984
6,0.832385,0.895011
7,0.658312,0.730633
8,0.017196,0.040649
9,-1.277301,-1.584936


In [46]:
df_compare_covar = pd.DataFrame()
df_compare_covar["skgmm"] = skgmm.covariances_[IDX][0]
df_compare_covar["lightning"] = gmm_lightning_module.gmm_module.covariances.detach().numpy()[IDX][0]
df_compare_covar

Unnamed: 0,skgmm,lightning
0,0.003373,0.002025
1,0.00336,0.006338
2,-0.006471,-0.003585
3,0.005735,0.003117
4,-0.000684,0.000247
5,0.000959,0.000489
6,-0.005469,-0.003131
7,-0.003141,-0.00476
8,-0.001353,-0.000873
9,0.014247,0.012821


In [47]:
df_compare_pre_chol = pd.DataFrame()
df_compare_pre_chol["skgmm"] = skgmm.precisions_cholesky_[IDX][0]
df_compare_pre_chol["lightning"] = gmm_lightning_module.gmm_module.precision_cholesky.detach().numpy()[IDX][0]
df_compare_pre_chol

Unnamed: 0,skgmm,lightning
0,17.218255,22.224047
1,-2.470484,-11.175132
2,21.809571,14.351311
3,-3.038004,-6.554366
4,0.339362,3.524893
5,-29.150468,-16.216255
6,5.500214,3.720157
7,18.770017,9.581682
8,2.760689,-0.528105
9,-39.930301,-23.916332


In [48]:
df_compare_weights = pd.DataFrame()
df_compare_weights["skgmm"] = skgmm.weights_[:10]
df_compare_weights["lightning"] = gmm_lightning_module.gmm_module.weights.detach().numpy()[:10]
df_compare_weights

Unnamed: 0,skgmm,lightning
0,0.02604451,0.01542987
1,8.881784e-20,0.0128762
2,8.881784e-20,4.768371e-12
3,8.881784e-20,4.768371e-12
4,0.006282215,0.009035708
5,8.881784e-20,0.01045682
6,8.881784e-20,0.008056618
7,8.881784e-20,4.768371e-12
8,8.881784e-20,4.768371e-12
9,8.881784e-20,0.01180198
