In [180]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [181]:
import pandas as pd
full_data = pd.read_csv("../../data/processed/historical/train/lcl_data.csv")
df_100K = full_data.sample(100000, random_state=0)
df_100K.to_csv("../../data/processed/historical/train/lcl_data_100K.csv", index=False)

# Load Data

In [182]:
import torch
import numpy as np
import random
RANDOM_STATE = 0
torch.manual_seed(RANDOM_STATE)
torch.use_deterministic_algorithms(True)
g = torch.Generator()
g.manual_seed(RANDOM_STATE)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [183]:
from pathlib import Path
from opensynth.data_modules.lcl_data_module import LCLDataModule
import pytorch_lightning as pl

import matplotlib.pyplot as plt

data_path = Path("../../data/processed/historical/train/lcl_data_25K.csv")
stats_path = Path("../../data/processed/historical/train/mean_std.csv")
outlier_path = Path("../../data/processed/historical/train/outliers.csv")

dm = LCLDataModule(data_path=data_path, stats_path=stats_path, batch_size=25000, n_samples=25000)
dm.setup()

In [184]:
import torch
from opensynth.models.faraday import FaradayVAE
vae_model = torch.load("vae_model.pt")
vae_model.eval()

FaradayVAE(
  (encoder): Encoder(
    (encoder_layers): Sequential(
      (0): Linear(in_features=50, out_features=512, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=512, out_features=256, bias=True)
      (3): GELU(approximate='none')
      (4): Linear(in_features=256, out_features=128, bias=True)
      (5): GELU(approximate='none')
      (6): Linear(in_features=128, out_features=64, bias=True)
      (7): GELU(approximate='none')
      (8): Linear(in_features=64, out_features=32, bias=True)
      (9): GELU(approximate='none')
      (10): Linear(in_features=32, out_features=16, bias=True)
    )
  )
  (decoder): Decoder(
    (latent): Linear(in_features=18, out_features=16, bias=True)
    (latent_activations): GELU(approximate='none')
    (decoder_layers): Sequential(
      (0): Linear(in_features=16, out_features=32, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=32, out_features=64, bias=True)
      (3): GELU(approximate='no

In [185]:
from opensynth.models.faraday.gaussian_mixture.prepare_gmm_input import encode_data_for_gmm

next_batch = next(iter(dm.train_dataloader()))
input_tensor = encode_data_for_gmm(data=next_batch, vae_module=vae_model)
input_data = input_tensor.detach().numpy()
n_samples = len(input_tensor)

In [186]:
N_COMPONENTS = 250
REG_COVAR = 1e-4
EPOCHS = 25
IDX = 0
CONVERGENCE_TOL = 1e-2


In [187]:
input_tensor[0][0]

tensor(0.0195, grad_fn=<SelectBackward0>)

# Init GMM

In [188]:
from opensynth.models.faraday.new_gmm import gmm_utils

labels_, means_, responsibilities_ = gmm_utils.initialise_centroids(
        X=input_data, n_components=N_COMPONENTS
    )
print(labels_.dtype, responsibilities_.dtype, means_.dtype)

torch.float32 torch.float32 torch.float32


In [189]:
from opensynth.models.faraday.new_gmm.train_gmm import initialise_gmm_params

gmm_init_params = initialise_gmm_params(
    X=input_data,
    n_components = N_COMPONENTS,
    reg_covar=REG_COVAR,
)
print(gmm_init_params["precision_cholesky"][IDX][0][0])
print(gmm_init_params["weights"].sum())

Valid covariance: (True, tensor(True))
tensor(4.1672)
tensor(1.)


# Torch GMM

In [190]:
from opensynth.models.faraday.new_gmm.train_gmm import initialise_gmm_params, training_loop
from opensynth.models.faraday.new_gmm.new_gmm_model import GaussianMixtureModel


gmm_init_params = initialise_gmm_params(
    X=input_data,
    n_components = N_COMPONENTS,
    reg_covar=REG_COVAR,
)
torch_gmm = GaussianMixtureModel(
    num_components=N_COMPONENTS,
    num_features = input_data.shape[1],
    reg_covar=REG_COVAR,
    print_idx=IDX
)
torch_gmm.initialise(gmm_init_params)
trained_model = training_loop(model=torch_gmm, data=input_tensor, max_iter=EPOCHS)

Valid covariance: (True, tensor(True))


  4%|▍         | 1/25 [00:00<00:18,  1.30it/s]

Valid covariance: (False, tensor(True))
Change: inf


  8%|▊         | 2/25 [00:01<00:19,  1.18it/s]

Valid covariance: (False, tensor(True))
Change: 0.5455737113952637


 12%|█▏        | 3/25 [00:02<00:15,  1.38it/s]

Valid covariance: (False, tensor(True))
Change: 0.23595058917999268


 16%|█▌        | 4/25 [00:03<00:15,  1.36it/s]

Valid covariance: (False, tensor(True))
Change: 0.1542905569076538


 20%|██        | 5/25 [00:03<00:16,  1.23it/s]

Valid covariance: (False, tensor(True))
Change: 0.11812639236450195


 24%|██▍       | 6/25 [00:04<00:13,  1.37it/s]

Valid covariance: (False, tensor(True))
Change: 0.09741419553756714


 28%|██▊       | 7/25 [00:04<00:11,  1.63it/s]

Valid covariance: (False, tensor(True))
Change: 0.08560162782669067


 32%|███▏      | 8/25 [00:05<00:09,  1.75it/s]

Valid covariance: (False, tensor(True))
Change: 0.07493609189987183


 36%|███▌      | 9/25 [00:05<00:07,  2.01it/s]

Valid covariance: (False, tensor(True))
Change: 0.0600125789642334


 40%|████      | 10/25 [00:06<00:06,  2.24it/s]

Valid covariance: (False, tensor(True))
Change: 0.04927009344100952


 44%|████▍     | 11/25 [00:06<00:05,  2.56it/s]

Valid covariance: (False, tensor(True))
Change: 0.04166579246520996


 48%|████▊     | 12/25 [00:06<00:05,  2.51it/s]

Valid covariance: (False, tensor(True))
Change: 0.03287696838378906


 52%|█████▏    | 13/25 [00:07<00:05,  2.23it/s]

Valid covariance: (False, tensor(True))
Change: 0.026912033557891846


 56%|█████▌    | 14/25 [00:07<00:04,  2.47it/s]

Valid covariance: (False, tensor(True))
Change: 0.02238285541534424


 60%|██████    | 15/25 [00:07<00:03,  2.73it/s]

Valid covariance: (False, tensor(True))
Change: 0.01896977424621582


 64%|██████▍   | 16/25 [00:08<00:03,  2.93it/s]

Valid covariance: (False, tensor(True))
Change: 0.015812456607818604


 68%|██████▊   | 17/25 [00:08<00:02,  3.03it/s]

Valid covariance: (False, tensor(True))
Change: 0.01478499174118042


 72%|███████▏  | 18/25 [00:08<00:02,  3.00it/s]

Valid covariance: (False, tensor(True))
Change: 0.012155234813690186


 76%|███████▌  | 19/25 [00:09<00:02,  2.92it/s]

Valid covariance: (False, tensor(True))
Change: 0.010511398315429688


 76%|███████▌  | 19/25 [00:09<00:03,  1.99it/s]

Valid covariance: (False, tensor(True))
Change: 0.008294880390167236
Converged: True. Number of iterations: 19





# SK Learn GMM Manual

In [191]:
import numpy as np
from scipy.special import logsumexp
from scipy import linalg

def _estimate_gaussian_parameters(X, resp, reg_covar=REG_COVAR):
    nk = resp.sum(axis=0) + 10 * np.finfo(resp.dtype).eps
    means = np.dot(resp.T, X) / nk[:, np.newaxis]
    n_components, n_features = means.shape
    covariances = np.empty((n_components, n_features, n_features))
    for k in range(n_components):
        diff = X - means[k]
        covariances[k] = np.dot(resp[:, k] * diff.T, diff) / nk[k]
        covariances[k].flat[:: n_features + 1] += reg_covar
    return nk, means, covariances

def _compute_precision_cholesky(covariances):
    estimate_precision_error_message = (
        "Fitting the mixture model failed because some components have "
        "ill-defined empirical covariance (for instance caused by singleton "
        "or collapsed samples). Try to decrease the number of components, "
        "or increase reg_covar."
    )

    n_components, n_features, _ = covariances.shape
    precisions_chol = np.empty((n_components, n_features, n_features))
    for k, covariance in enumerate(covariances):
        try:
            cov_chol = linalg.cholesky(covariance, lower=True)
        except linalg.LinAlgError:
            raise ValueError(estimate_precision_error_message)
        precisions_chol[k] = linalg.solve_triangular(
            cov_chol, np.eye(n_features), lower=True
        ).T
    return precisions_chol

def _compute_log_det_cholesky(matrix_chol, n_features):
    n_components, _, _ = matrix_chol.shape
    log_det_chol = np.sum(
        np.log(matrix_chol.reshape(n_components, -1)[:, :: n_features + 1]), 1
    )
    return log_det_chol

def _estimate_log_gaussian_prob(X, means, precisions_chol):
    n_samples, n_features = X.shape
    n_components, _ = means.shape

    log_det = _compute_log_det_cholesky(precisions_chol, n_features)

    log_prob = np.empty((n_samples, n_components))
    for k, (mu, prec_chol) in enumerate(zip(means, precisions_chol)):
        y = np.dot(X, prec_chol) - np.dot(mu, prec_chol)
        log_prob[:, k] = np.sum(np.square(y), axis=1)
    return -0.5 * (n_features * np.log(2 * np.pi) + log_prob) + log_det

def _estimate_log_weights(weights):
        return np.log(weights)

def _estimate_weighted_log_prob(X, means, precisions_chol, weights):
        return _estimate_log_gaussian_prob(X, means, precisions_chol) + _estimate_log_weights(weights)


def _estimate_log_prob_resp(X, means, precisions_chol, weights):
    weighted_log_prob = _estimate_weighted_log_prob(X, means, precisions_chol, weights)
    log_prob_norm = logsumexp(weighted_log_prob, axis=1)
    with np.errstate(under="ignore"):
        log_resp = weighted_log_prob - log_prob_norm[:, np.newaxis]
    return log_prob_norm, log_resp

def _e_step(X,means, precisions_chol, weights):
    log_prob_norm, log_resp = _estimate_log_prob_resp(X, means, precisions_chol, weights)
    return np.mean(log_prob_norm), log_resp

def _m_step(X, log_reponsibilities, reg_covar=REG_COVAR):

    weights_, means_, covariances_ = _estimate_gaussian_parameters(X,np.exp(log_reponsibilities),reg_covar=reg_covar)
    weights_ /= weights_.sum()

    precision_cholesky_ = _compute_precision_cholesky(covariances=covariances_)

    return precision_cholesky_, weights_, means_, covariances_

In [192]:
means = gmm_init_params["means"].detach().numpy()
weights = gmm_init_params["weights"].detach().numpy()
prec_chol = gmm_init_params["precision_cholesky"].detach().numpy()

print(f"Initial prec chol: {prec_chol[IDX][0][0]}. Initial mean: {means[IDX][0]}")

converged = False
lower_bound = -np.inf

for i in range(EPOCHS):
    prev_lower_bound = lower_bound

    print(f"Old Prec Chol: {prec_chol[IDX][0][0]}. Old means: {means[IDX][0]}")
    log_prob, log_resp = _e_step(input_data, means, prec_chol, weights)
    prec_chol, weights, means, covar = _m_step(input_data, log_resp)

    print(f"New prec chol: {prec_chol[IDX][0][0]}. New means: {means[IDX][0]}")

    # Converegence
    lower_bound = log_prob
    change = abs(lower_bound - prev_lower_bound)
    print(f"Change: {change}")
    if change < CONVERGENCE_TOL:
        converged = True
        break

print(f'Converged: {converged}. Number of iterations: {i}')

Initial prec chol: 4.167169570922852. Initial mean: 0.18039406836032867
Old Prec Chol: 4.167169570922852. Old means: 0.18039406836032867
New prec chol: 4.240417114097165. New means: 0.1610282390352588
Change: inf
Old Prec Chol: 4.240417114097165. Old means: 0.1610282390352588
New prec chol: 4.289415021363738. New means: 0.13732364273655587
Change: 0.545201791854862
Old Prec Chol: 4.289415021363738. Old means: 0.13732364273655587
New prec chol: 4.179746626419131. New means: 0.11500928863101198
Change: 0.23580703880669862
Old Prec Chol: 4.179746626419131. Old means: 0.11500928863101198
New prec chol: 3.948317806998173. New means: 0.0940242050240092
Change: 0.15423102791465793
Old Prec Chol: 3.948317806998173. Old means: 0.0940242050240092
New prec chol: 3.825919235870529. New means: 0.0780048305143228
Change: 0.11806440901075788
Old Prec Chol: 3.825919235870529. Old means: 0.0780048305143228
New prec chol: 3.7660081759677153. New means: 0.0639742435193366
Change: 0.09738448263012867
Old 

# SK Learn GMM Epoch

In [193]:
from sklearn.mixture import GaussianMixture

init_weights = gmm_init_params["weights"]
init_means = gmm_init_params["means"]

skgmm = GaussianMixture(n_components=N_COMPONENTS, covariance_type='full', tol=CONVERGENCE_TOL, max_iter=EPOCHS, random_state=0, means_init = init_means, weights_init=init_weights, verbose=2, verbose_interval=1)
skgmm.fit(input_data)
skgmm_pred = skgmm.predict(input_data)

Initialization 0
  Iteration 1	 time lapse 4.49572s	 ll change inf
  Iteration 2	 time lapse 1.99019s	 ll change 0.55425
  Iteration 3	 time lapse 2.10942s	 ll change 0.24330
  Iteration 4	 time lapse 1.90294s	 ll change 0.16184
  Iteration 5	 time lapse 2.04430s	 ll change 0.12116
  Iteration 6	 time lapse 1.91540s	 ll change 0.10078
  Iteration 7	 time lapse 2.20269s	 ll change 0.08830
  Iteration 8	 time lapse 1.84719s	 ll change 0.07739
  Iteration 9	 time lapse 2.01203s	 ll change 0.06126
  Iteration 10	 time lapse 2.27183s	 ll change 0.04845
  Iteration 11	 time lapse 2.16001s	 ll change 0.04266
  Iteration 12	 time lapse 2.42154s	 ll change 0.03660
  Iteration 13	 time lapse 2.35852s	 ll change 0.02900
  Iteration 14	 time lapse 2.49654s	 ll change 0.02518
  Iteration 15	 time lapse 3.23633s	 ll change 0.02246
  Iteration 16	 time lapse 2.54810s	 ll change 0.01872
  Iteration 17	 time lapse 2.47835s	 ll change 0.01665
  Iteration 18	 time lapse 2.25658s	 ll change 0.01431
  Iter

In [194]:
skgmm.converged_, skgmm.n_iter_

(True, 22)

# Torch Lightning

In [195]:
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningDataModule
class CustomDataset(Dataset):
    def __init__(self, data_tensor: torch.Tensor):
        self.data = data_tensor
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]
    
class CustomDataModule(LightningDataModule):
    def __init__(self, data_tensor: torch.Tensor, batch_size: int):
        super().__init__()
        self.data_tensor = data_tensor
        self.batch_size = batch_size
    def setup(self, stage=""):
        self.custom_ds = CustomDataset(self.data_tensor)
    def train_dataloader(self):
        return DataLoader(self.custom_ds, batch_size=self.batch_size, shuffle=False, generator=g, worker_init_fn=seed_worker)
    
custom_dm = CustomDataModule(data_tensor=input_tensor, batch_size=25000)
custom_dm.setup(stage="")

In [196]:

for i in range(5):
    print(next(iter(custom_dm.train_dataloader()))[0][0])

tensor(0.0195, grad_fn=<SelectBackward0>)
tensor(0.0195, grad_fn=<SelectBackward0>)
tensor(0.0195, grad_fn=<SelectBackward0>)
tensor(0.0195, grad_fn=<SelectBackward0>)
tensor(0.0195, grad_fn=<SelectBackward0>)


In [197]:
from opensynth.models.faraday.new_gmm.new_gmm_model import GaussianMixtureLightningModule
gmm_module = GaussianMixtureModel(
    num_components=N_COMPONENTS,
    num_features = input_data.shape[1],
    reg_covar=REG_COVAR,
    print_idx=IDX
)
gmm_module.initialise(gmm_init_params)
print(f"Initial prec chol: {gmm_module.precision_cholesky[IDX][0][0]}. Initial mean: {gmm_module.means[IDX][0]}")

gmm_lightning_module = GaussianMixtureLightningModule(
    gmm_module = gmm_module,
    vae_module = vae_model,
    num_components = gmm_module.num_components,
    num_features = gmm_module.num_features,
    reg_covar = gmm_module.reg_covar,
    convergence_tolerance = CONVERGENCE_TOL
)
trainer = pl.Trainer(max_epochs=EPOCHS, accelerator="cpu", deterministic=True )
trainer.fit(gmm_lightning_module, custom_dm)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/charlotte.avery/.virtualenvs/OpenSynth-BNsxhSIM/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/charlotte.avery/.virtualenvs/OpenSynth-BNsxhSIM/lib/python3.11/site-packages/pytorch_lightning/core/optimizer.py:182: `LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer

  | Name                      | Type                    | Params | Mode 
------------------------------------------------------------------------------
0 | gmm_module                | GaussianMixtureModel    | 0      | train
1 | vae_module                | FaradayVAE              | 402 K  | eval 
2 | weight_metric             | WeightsMetric           | 0      | train
3 | mean_metric               | MeansMetric             | 0      | train
4 | precision_ch

Initial prec chol: 4.167169570922852. Initial mean: 0.18039406836032867
Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s] Valid covariance: (False, tensor(True))
Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  1.32it/s, v_num=27]Local weights at rank: 0 - means: 0.0062, 0.1610
Reduced weights, means: 0.0062, 0.1610
Epoch 1:   0%|          | 0/1 [00:00<?, ?it/s, v_num=27]        Valid covariance: (False, tensor(True))
Epoch 1: 100%|██████████| 1/1 [00:00<00:00,  2.46it/s, v_num=27]Local weights at rank: 0 - means: 0.0058, 0.1373
Reduced weights, means: 0.0058, 0.1373
Epoch 2:   0%|          | 0/1 [00:00<?, ?it/s, v_num=27]        Valid covariance: (False, tensor(True))
Epoch 2: 100%|██████████| 1/1 [00:00<00:00,  2.70it/s, v_num=27]Local weights at rank: 0 - means: 0.0057, 0.1149
Reduced weights, means: 0.0057, 0.1149
Epoch 3:   0%|          | 0/1 [00:00<?, ?it/s, v_num=27]        Valid covariance: (False, tensor(True))
Epoch 3: 100%|██████████| 1/1 [00:00<00:00,  2.99it/s, v_num=27]Local wei

In [198]:
gmm_lightning_module.gmm_module.weights[0], gmm_lightning_module.gmm_module.means[0][0]

(tensor(0.0042), tensor(0.0035))

In [199]:
gmm_lightning_module.weight_metric.compute()[0], gmm_lightning_module.mean_metric.compute()[0][0]

(tensor(0.0042), tensor(0.0035))

# Compare

In [200]:
IDX = 0

In [201]:
df_compare_means = pd.DataFrame()
df_compare_means["skgmm"] = skgmm.means_[IDX]
df_compare_means["numpy"] = means[IDX]
df_compare_means["torch"] = trained_model.means[IDX]
df_compare_means["lightning"] = gmm_lightning_module.gmm_module.means[IDX]
df_compare_means

Unnamed: 0,skgmm,numpy,torch,lightning
0,-0.034104,0.003497,0.003455,0.003455
1,-1.555318,-1.622721,-1.625163,-1.625163
2,-0.805133,-0.963266,-0.962647,-0.962647
3,1.467063,1.691928,1.693601,1.693601
4,0.521913,0.539403,0.543288,0.543288
5,-0.07905,-0.085705,-0.085293,-0.085293
6,-0.479725,-0.688806,-0.688216,-0.688216
7,3.150586,3.29443,3.298344,3.298344
8,-0.261914,-0.301814,-0.301562,-0.301562
9,-1.601044,-1.791389,-1.798772,-1.798772


In [202]:
gmm_init_params["means"][IDX]

tensor([ 0.1804, -2.7538, -1.0911,  1.7346,  0.2788, -0.0517, -0.9778,  3.2770,
        -0.2799, -2.2817, -0.6563, -1.5545, -1.2570,  0.9383,  0.9554, -2.0788,
         5.9050,  3.9497])

In [203]:
df_compare_covar = pd.DataFrame()
df_compare_covar["skgmm"] = skgmm.covariances_[IDX][0]
df_compare_covar["numpy"] = covar[IDX][0]
df_compare_covar["torch"] = trained_model.covariances.detach().numpy()[IDX][0]
df_compare_covar["lightning"] = gmm_lightning_module.gmm_module.covariances.detach().numpy()[IDX][0]
df_compare_covar

Unnamed: 0,skgmm,numpy,torch,lightning
0,0.078326,0.067472,0.067431,0.067431
1,-0.199605,-0.186681,-0.187432,-0.187432
2,-0.11237,-0.087543,-0.087666,-0.087666
3,0.118446,0.097261,0.097659,0.097659
4,0.00264,0.005331,0.005539,0.005539
5,0.017543,0.011311,0.011351,0.011351
6,-0.130956,-0.112912,-0.113341,-0.113341
7,-0.064959,-0.090017,-0.089387,-0.089387
8,-0.017749,-0.016246,-0.016301,-0.016301
9,-0.043763,-0.068598,-0.069978,-0.069978


In [204]:
df_compare_pre_chol = pd.DataFrame()
df_compare_pre_chol["skgmm"] = skgmm.precisions_cholesky_[IDX][0]
df_compare_pre_chol["numpy"] = prec_chol[IDX][0]
df_compare_pre_chol["torch"] = trained_model.precision_cholesky.detach().numpy()[IDX][0]
df_compare_pre_chol["lightning"] = gmm_lightning_module.gmm_module.precision_cholesky.detach().numpy()[IDX][0]
df_compare_pre_chol

Unnamed: 0,skgmm,numpy,torch,lightning
0,3.573121,3.849812,3.850963,3.850963
1,2.461067,2.455015,2.456105,2.456105
2,2.998015,2.292612,2.290046,2.290046
3,-1.307495,-1.795725,-1.789544,-1.789544
4,-2.952195,-3.082288,-3.07105,-3.07105
5,-6.022596,-3.587201,-3.623628,-3.623628
6,2.662257,3.83074,3.820804,3.820804
7,10.953345,12.415186,12.467817,12.467817
8,1.997179,1.477781,1.524926,1.524926
9,-4.058303,-4.109881,-4.275584,-4.275584


In [205]:
df_compare_weights = pd.DataFrame()
df_compare_weights["skgmm"] = skgmm.weights_[:10]
df_compare_weights["numpy"] = weights[:10]
df_compare_weights["torch"] = trained_model.weights[:10]
df_compare_weights["lightning"] = gmm_lightning_module.gmm_module.weights.detach().numpy()[:10]
df_compare_weights

Unnamed: 0,skgmm,numpy,torch,lightning
0,0.004755,0.004216,0.004236,0.004236
1,0.00016,0.00016,0.00016,0.00016
2,0.029276,0.041534,0.041527,0.041527
3,0.021889,0.022639,0.022607,0.022607
4,0.00277,0.002551,0.002554,0.002554
5,0.00016,0.00016,0.00016,0.00016
6,0.006458,0.005502,0.005507,0.005507
7,0.005356,0.005325,0.005326,0.005326
8,0.0006,0.0006,0.0006,0.0006
9,0.00048,0.00048,0.00048,0.00048


# Sampling

In [206]:
def sample(means_, covariances_, weights_, n_samples):
    rng = np.random.RandomState(RANDOM_STATE)
    n_samples_comp = rng.multinomial(n_samples, weights_)

    X = np.vstack(
            [
                rng.multivariate_normal(mean, covariance, int(sample))
                for (mean, covariance, sample) in zip(
                    means_, covariances_, n_samples_comp
                )
            ]
        )
    
    y = np.concatenate(
        [np.full(sample, j, dtype=int) for j, sample in enumerate(n_samples_comp)]
    )
    return (X, y)

In [207]:
N_SAMPLES = 250

In [208]:
skgmm_samples = sample(skgmm.means_, skgmm.covariances_, skgmm.weights_, n_samples = N_SAMPLES)

skgmm_X, skgmm_y = skgmm_samples
skgmm_X[IDX], skgmm_y[IDX]

(array([ 0.01564824, -2.48977325, -1.06677265,  1.03595019, -0.56687919,
        -0.17060043, -0.5776165 ,  3.60625806, -0.09829001,  0.36356929,
         0.21534989, -2.43721739, -2.06559949,  1.85034522, -0.18437695,
        -0.96888118, 11.95730699,  3.74599508]),
 0)

In [209]:
samples = sample(means, covar, weights, n_samples = N_SAMPLES)

X, y = samples
X[IDX], y[IDX]

(array([ 0.18499373, -0.83446685, -0.60144353,  2.06066525,  2.49065299,
         0.10429114, -0.31643926,  1.21047995, -0.52292939, -3.62534916,
        -0.78782121, -1.43205912, -1.679075  , -0.43291366,  2.993503  ,
        -1.81419148,  2.90864241,  3.34585713]),
 0)

In [210]:
train_model_samples = sample(trained_model.means.detach().numpy(), trained_model.covariances.detach().numpy(), trained_model.weights.detach().numpy(), n_samples = N_SAMPLES)
train_model_X, train_model_y = train_model_samples
train_model_X[IDX], train_model_y[IDX]

  rng.multivariate_normal(mean, covariance, int(sample))


(array([ 0.1897325 , -0.88072923, -0.61986741,  2.07609162,  2.47815962,
         0.10266462, -0.34256586,  1.2532853 , -0.52365649, -3.65412251,
        -0.79721218, -1.43767327, -1.66040658, -0.42377924,  2.98937742,
        -1.84510692,  2.90485408,  3.33865499]),
 0)

In [211]:
gmm_lightning_samples = sample(gmm_lightning_module.gmm_module.means.detach().numpy(), gmm_lightning_module.gmm_module.covariances.detach().numpy(), gmm_lightning_module.gmm_module.weights.detach().numpy(), n_samples = N_SAMPLES)
gmm_lightning_X, gmm_lightning_y = train_model_samples
gmm_lightning_X[IDX], gmm_lightning_y[IDX]

  rng.multivariate_normal(mean, covariance, int(sample))


(array([ 0.1897325 , -0.88072923, -0.61986741,  2.07609162,  2.47815962,
         0.10266462, -0.34256586,  1.2532853 , -0.52365649, -3.65412251,
        -0.79721218, -1.43767327, -1.66040658, -0.42377924,  2.98937742,
        -1.84510692,  2.90485408,  3.33865499]),
 0)