In [92]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [93]:
import pandas as pd
full_data = pd.read_csv("../../data/processed/historical/train/lcl_data.csv")
df_100K = full_data.sample(100000, random_state=0)
df_100K.to_csv("../../data/processed/historical/train/lcl_data_100K.csv", index=False)

# Load Data

In [94]:
import torch
import numpy as np
import random
RANDOM_STATE = 0
torch.manual_seed(RANDOM_STATE)
torch.use_deterministic_algorithms(True)
g = torch.Generator()
g.manual_seed(RANDOM_STATE)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [95]:
from pathlib import Path
from opensynth.data_modules.lcl_data_module import LCLDataModule
import pytorch_lightning as pl

import matplotlib.pyplot as plt

data_path = Path("../../data/processed/historical/train/lcl_data_25K.csv")
stats_path = Path("../../data/processed/historical/train/mean_std.csv")
outlier_path = Path("../../data/processed/historical/train/outliers.csv")

dm = LCLDataModule(data_path=data_path, stats_path=stats_path, batch_size=25000, n_samples=25000)
dm.setup()

In [96]:
import torch
from opensynth.models.faraday import FaradayVAE
vae_model = torch.load("vae_model.pt")
vae_model.eval()

FaradayVAE(
  (encoder): Encoder(
    (encoder_layers): Sequential(
      (0): Linear(in_features=50, out_features=512, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=512, out_features=256, bias=True)
      (3): GELU(approximate='none')
      (4): Linear(in_features=256, out_features=128, bias=True)
      (5): GELU(approximate='none')
      (6): Linear(in_features=128, out_features=64, bias=True)
      (7): GELU(approximate='none')
      (8): Linear(in_features=64, out_features=32, bias=True)
      (9): GELU(approximate='none')
      (10): Linear(in_features=32, out_features=16, bias=True)
    )
  )
  (decoder): Decoder(
    (latent): Linear(in_features=18, out_features=16, bias=True)
    (latent_activations): GELU(approximate='none')
    (decoder_layers): Sequential(
      (0): Linear(in_features=16, out_features=32, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=32, out_features=64, bias=True)
      (3): GELU(approximate='no

In [97]:
from opensynth.models.faraday.gaussian_mixture.prepare_gmm_input import encode_data_for_gmm

next_batch = next(iter(dm.train_dataloader()))
input_tensor = encode_data_for_gmm(data=next_batch, vae_module=vae_model)
input_data = input_tensor.detach().numpy()
n_samples = len(input_tensor)

In [98]:
N_COMPONENTS = 250
REG_COVAR = 1e-4
EPOCHS = 25
IDX = 0
CONVERGENCE_TOL = 1e-2


In [99]:
input_tensor[0][0]

tensor(0.0195, grad_fn=<SelectBackward0>)

# Init GMM

In [100]:
from opensynth.models.faraday.new_gmm import gmm_utils

labels_, means_, responsibilities_ = gmm_utils.initialise_centroids(
        X=input_data, n_components=N_COMPONENTS
    )
print(labels_.dtype, responsibilities_.dtype, means_.dtype)

torch.float32 torch.float32 torch.float32


In [101]:
from opensynth.models.faraday.new_gmm.train_gmm import initialise_gmm_params

gmm_init_params = initialise_gmm_params(
    X=input_data,
    n_components = N_COMPONENTS,
    reg_covar=REG_COVAR,
)
print(gmm_init_params["precision_cholesky"][IDX][0][0])
print(gmm_init_params["weights"].sum())

tensor(4.1549)
tensor(1.)


# Torch GMM

In [102]:
from opensynth.models.faraday.new_gmm.train_gmm import initialise_gmm_params, training_loop
from opensynth.models.faraday.new_gmm.new_gmm_model import GaussianMixtureModel


gmm_init_params = initialise_gmm_params(
    X=input_data,
    n_components = N_COMPONENTS,
    reg_covar=REG_COVAR,
)

means = gmm_init_params["means"].detach().numpy()
weights = gmm_init_params["weights"].detach().numpy()
prec_chol = gmm_init_params["precision_cholesky"].detach().numpy()
print(f"Initial prec chol: {prec_chol[IDX][0][0]}. Initial mean: {means[IDX][0]}")

torch_gmm = GaussianMixtureModel(
    num_components=N_COMPONENTS,
    num_features = input_data.shape[1],
    reg_covar=REG_COVAR,
    print_idx=IDX
)
torch_gmm.initialise(gmm_init_params)
trained_model = training_loop(model=torch_gmm, data=input_tensor, max_iter=EPOCHS)

Initial prec chol: 4.1548590660095215. Initial mean: 0.17847225069999695




Change: inf




Change: 0.5466245412826538




Change: 0.23698663711547852




Change: 0.15521585941314697




Change: 0.11739683151245117




Change: 0.09774953126907349




Change: 0.08635258674621582




Change: 0.07510387897491455




Change: 0.058798372745513916




Change: 0.04888087511062622




Change: 0.04147988557815552




Change: 0.032204270362854004




Change: 0.026453495025634766




Change: 0.022479653358459473




Change: 0.017785489559173584




Change: 0.014730274677276611




Change: 0.013055741786956787




Change: 0.012066006660461426




Change: 0.010769903659820557


 76%|███████▌  | 19/25 [00:10<00:03,  1.74it/s]

Change: 0.009308397769927979
Converged: True. Number of iterations: 19





# SK Learn GMM Manual

In [103]:
import numpy as np
from scipy.special import logsumexp
from scipy import linalg

def is_symmetric_positive_definite(covariance):
    is_symmetric = np.all([np.allclose(covariance[i], covariance[i].T) for i in range(covariance.shape[0])])
    is_positive_definite = np.all([np.all(np.linalg.eigvalsh(covariance[i]) > 0.0) for i in range(covariance.shape[0])])
    return is_symmetric and is_positive_definite

def _estimate_gaussian_parameters(X, resp, reg_covar=REG_COVAR):
    nk = resp.sum(axis=0) + 10 * np.finfo(resp.dtype).eps
    means = np.dot(resp.T, X) / nk[:, np.newaxis]
    n_components, n_features = means.shape
    covariances = np.empty((n_components, n_features, n_features))
    for k in range(n_components):
        diff = X - means[k]
        covariances[k] = np.dot(resp[:, k] * diff.T, diff) / nk[k]
        covariances[k].flat[:: n_features + 1] += reg_covar

    check_covariances = is_symmetric_positive_definite(covariances)
    if not check_covariances:
        raise ValueError("Covariance matrix is not positive definite.")
    return nk, means, covariances

def _compute_precision_cholesky(covariances):
    estimate_precision_error_message = (
        "Fitting the mixture model failed because some components have "
        "ill-defined empirical covariance (for instance caused by singleton "
        "or collapsed samples). Try to decrease the number of components, "
        "or increase reg_covar."
    )

    n_components, n_features, _ = covariances.shape
    precisions_chol = np.empty((n_components, n_features, n_features))
    for k, covariance in enumerate(covariances):
        try:
            cov_chol = linalg.cholesky(covariance, lower=True)
        except linalg.LinAlgError:
            raise ValueError(estimate_precision_error_message)
        precisions_chol[k] = linalg.solve_triangular(
            cov_chol, np.eye(n_features), lower=True
        ).T
    return precisions_chol

def _compute_log_det_cholesky(matrix_chol, n_features):
    n_components, _, _ = matrix_chol.shape
    log_det_chol = np.sum(
        np.log(matrix_chol.reshape(n_components, -1)[:, :: n_features + 1]), 1
    )
    return log_det_chol

def _estimate_log_gaussian_prob(X, means, precisions_chol):
    n_samples, n_features = X.shape
    n_components, _ = means.shape

    log_det = _compute_log_det_cholesky(precisions_chol, n_features)

    log_prob = np.empty((n_samples, n_components))
    for k, (mu, prec_chol) in enumerate(zip(means, precisions_chol)):
        y = np.dot(X, prec_chol) - np.dot(mu, prec_chol)
        log_prob[:, k] = np.sum(np.square(y), axis=1)
    return -0.5 * (n_features * np.log(2 * np.pi) + log_prob) + log_det

def _estimate_log_weights(weights):
        return np.log(weights)

def _estimate_weighted_log_prob(X, means, precisions_chol, weights):
        return _estimate_log_gaussian_prob(X, means, precisions_chol) + _estimate_log_weights(weights)


def _estimate_log_prob_resp(X, means, precisions_chol, weights):
    weighted_log_prob = _estimate_weighted_log_prob(X, means, precisions_chol, weights)
    log_prob_norm = logsumexp(weighted_log_prob, axis=1)
    with np.errstate(under="ignore"):
        log_resp = weighted_log_prob - log_prob_norm[:, np.newaxis]
    return log_prob_norm, log_resp

def _e_step(X,means, precisions_chol, weights):
    log_prob_norm, log_resp = _estimate_log_prob_resp(X, means, precisions_chol, weights)
    return np.mean(log_prob_norm), log_resp

def _m_step(X, log_reponsibilities, reg_covar=REG_COVAR):

    weights_, means_, covariances_ = _estimate_gaussian_parameters(X,np.exp(log_reponsibilities),reg_covar=reg_covar)
    weights_ /= weights_.sum()

    precision_cholesky_ = _compute_precision_cholesky(covariances=covariances_)

    return precision_cholesky_, weights_, means_, covariances_

In [104]:
means = gmm_init_params["means"].detach().numpy()
weights = gmm_init_params["weights"].detach().numpy()
prec_chol = gmm_init_params["precision_cholesky"].detach().numpy()

print(f"Initial prec chol: {prec_chol[IDX][0][0]}. Initial mean: {means[IDX][0]}")

converged = False
lower_bound = -np.inf

for i in range(EPOCHS):
    prev_lower_bound = lower_bound

    print(f"Old Prec Chol: {prec_chol[IDX][0][0]}. Old means: {means[IDX][0]}")
    log_prob, log_resp = _e_step(input_data, means, prec_chol, weights)
    prec_chol, weights, means, covar = _m_step(input_data, log_resp)

    print(f"New prec chol: {prec_chol[IDX][0][0]}. New means: {means[IDX][0]}")

    # Converegence
    lower_bound = log_prob
    change = abs(lower_bound - prev_lower_bound)
    print(f"Change: {change}")
    if change < CONVERGENCE_TOL:
        converged = True
        break

print(f'Converged: {converged}. Number of iterations: {i}')

Initial prec chol: 4.1548590660095215. Initial mean: 0.17847225069999695
Old Prec Chol: 4.1548590660095215. Old means: 0.17847225069999695
New prec chol: 4.264129574723797. New means: 0.15994199961644728
Change: inf
Old Prec Chol: 4.264129574723797. Old means: 0.15994199961644728
New prec chol: 4.316532703063464. New means: 0.13492338340899412
Change: 0.5462640519048083
Old Prec Chol: 4.316532703063464. Old means: 0.13492338340899412
New prec chol: 4.175878808822229. New means: 0.11099307058643207
Change: 0.2368447243335936
Old Prec Chol: 4.175878808822229. Old means: 0.11099307058643207
New prec chol: 3.9374693321175926. New means: 0.09106721682231624
Change: 0.15513712903174603
Old Prec Chol: 3.9374693321175926. Old means: 0.09106721682231624
New prec chol: 3.8129173430607217. New means: 0.07953399423854603
Change: 0.11734342660665598
Old Prec Chol: 3.8129173430607217. Old means: 0.07953399423854603
New prec chol: 3.761243859543583. New means: 0.06862377937016459
Change: 0.0977096888

# SK Learn GMM Epoch

In [105]:
from sklearn.mixture import GaussianMixture

init_weights = gmm_init_params["weights"]
init_means = gmm_init_params["means"]

skgmm = GaussianMixture(n_components=N_COMPONENTS, covariance_type='full', tol=CONVERGENCE_TOL, max_iter=EPOCHS, random_state=0, means_init = init_means, weights_init=init_weights, verbose=2, verbose_interval=1)
skgmm.fit(input_data)
skgmm_pred = skgmm.predict(input_data)

Initialization 0
  Iteration 1	 time lapse 4.53057s	 ll change inf
  Iteration 2	 time lapse 2.58428s	 ll change 0.55539
  Iteration 3	 time lapse 2.58921s	 ll change 0.24410
  Iteration 4	 time lapse 2.05588s	 ll change 0.16252
  Iteration 5	 time lapse 3.81376s	 ll change 0.12118
  Iteration 6	 time lapse 2.24481s	 ll change 0.10090
  Iteration 7	 time lapse 2.11445s	 ll change 0.08873
  Iteration 8	 time lapse 2.25663s	 ll change 0.07771
  Iteration 9	 time lapse 2.43798s	 ll change 0.05981
  Iteration 10	 time lapse 2.23437s	 ll change 0.04934
  Iteration 11	 time lapse 2.00195s	 ll change 0.04429
  Iteration 12	 time lapse 2.22959s	 ll change 0.03445
  Iteration 13	 time lapse 2.11412s	 ll change 0.02836
  Iteration 14	 time lapse 2.09645s	 ll change 0.02424
  Iteration 15	 time lapse 2.15970s	 ll change 0.02104
  Iteration 16	 time lapse 2.14948s	 ll change 0.01738
  Iteration 17	 time lapse 2.01346s	 ll change 0.01459
  Iteration 18	 time lapse 2.40509s	 ll change 0.01319
  Iter

In [106]:
skgmm.converged_, skgmm.n_iter_

(True, 21)

# Torch Lightning

In [107]:
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningDataModule
class CustomDataset(Dataset):
    def __init__(self, data_tensor: torch.Tensor):
        self.data = data_tensor
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]
    
class CustomDataModule(LightningDataModule):
    def __init__(self, data_tensor: torch.Tensor, batch_size: int):
        super().__init__()
        self.data_tensor = data_tensor
        self.batch_size = batch_size
    def setup(self, stage=""):
        self.custom_ds = CustomDataset(self.data_tensor)
    def train_dataloader(self):
        return DataLoader(self.custom_ds, batch_size=self.batch_size, shuffle=False, generator=g, worker_init_fn=seed_worker)
    
custom_dm = CustomDataModule(data_tensor=input_tensor, batch_size=25000)
custom_dm.setup(stage="")

In [108]:

for i in range(5):
    print(next(iter(custom_dm.train_dataloader()))[0][0])

tensor(0.0195, grad_fn=<SelectBackward0>)
tensor(0.0195, grad_fn=<SelectBackward0>)
tensor(0.0195, grad_fn=<SelectBackward0>)
tensor(0.0195, grad_fn=<SelectBackward0>)
tensor(0.0195, grad_fn=<SelectBackward0>)


In [109]:
from opensynth.models.faraday.new_gmm.new_gmm_model import GaussianMixtureLightningModule
gmm_module = GaussianMixtureModel(
    num_components=N_COMPONENTS,
    num_features = input_data.shape[1],
    reg_covar=REG_COVAR,
    print_idx=IDX
)
gmm_module.initialise(gmm_init_params)
print(f"Initial prec chol: {gmm_module.precision_cholesky[IDX][0][0]}. Initial mean: {gmm_module.means[IDX][0]}")

gmm_lightning_module = GaussianMixtureLightningModule(
    gmm_module = gmm_module,
    vae_module = vae_model,
    num_components = gmm_module.num_components,
    num_features = gmm_module.num_features,
    reg_covar = gmm_module.reg_covar,
    convergence_tolerance = CONVERGENCE_TOL,
    compute_on_batch=True
)
trainer = pl.Trainer(max_epochs=EPOCHS, accelerator="cpu", deterministic=True )
trainer.fit(gmm_lightning_module, custom_dm)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/charlotte.avery/.virtualenvs/OpenSynth-BNsxhSIM/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/charlotte.avery/.virtualenvs/OpenSynth-BNsxhSIM/lib/python3.11/site-packages/pytorch_lightning/core/optimizer.py:182: `LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer

  | Name                      | Type                    | Params | Mode 
------------------------------------------------------------------------------
0 | gmm_module                | GaussianMixtureModel    | 0      | train
1 | vae_module                | FaradayVAE              | 402 K  | eval 
2 | weight_metric             | WeightsMetric           | 0      | train
3 | mean_metric               | MeansMetric             | 0      | train
4 | precision_ch

Initial prec chol: 4.1548590660095215. Initial mean: 0.17847225069999695
Epoch 0: 100%|██████████| 1/1 [00:00<00:00,  1.30it/s, v_num=20]Local weights at rank: 0 - means: 0.0063, 0.1599
Reduced weights, means: 0.0063, 0.1599
Epoch 1: 100%|██████████| 1/1 [00:00<00:00,  1.49it/s, v_num=20]Local weights at rank: 0 - means: 0.0059, 0.1349
Reduced weights, means: 0.0059, 0.1349
Epoch 2: 100%|██████████| 1/1 [00:00<00:00,  2.27it/s, v_num=20]Local weights at rank: 0 - means: 0.0058, 0.1109
Reduced weights, means: 0.0058, 0.1109
Epoch 3: 100%|██████████| 1/1 [00:00<00:00,  2.59it/s, v_num=20]Local weights at rank: 0 - means: 0.0058, 0.0909
Reduced weights, means: 0.0058, 0.0909
Epoch 4: 100%|██████████| 1/1 [00:00<00:00,  2.26it/s, v_num=20]Local weights at rank: 0 - means: 0.0056, 0.0793
Reduced weights, means: 0.0056, 0.0793
Epoch 5: 100%|██████████| 1/1 [00:00<00:00,  2.92it/s, v_num=20]Local weights at rank: 0 - means: 0.0053, 0.0683
Reduced weights, means: 0.0053, 0.0683
Epoch 6: 100%|█

In [110]:
gmm_lightning_module.gmm_module.weights[0], gmm_lightning_module.gmm_module.means[0][0]

(tensor(0.0044), tensor(-0.0017))

In [111]:
gmm_lightning_module.weight_metric.compute()[0], gmm_lightning_module.mean_metric.compute()[0][0]

(tensor(0.0044), tensor(-0.0017))

# Compare

In [112]:
IDX = 0

In [113]:
df_compare_means = pd.DataFrame()
df_compare_means["skgmm"] = skgmm.means_[IDX]
df_compare_means["numpy"] = means[IDX]
df_compare_means["torch"] = trained_model.means[IDX]
df_compare_means["lightning"] = gmm_lightning_module.gmm_module.means[IDX]
df_compare_means

Unnamed: 0,skgmm,numpy,torch,lightning
0,-0.035201,-0.001878,-0.001677,-0.001677
1,-1.627329,-1.776944,-1.777669,-1.777669
2,-0.842405,-0.960595,-0.959451,-0.959451
3,1.525421,1.739976,1.740235,1.740235
4,0.530031,0.615125,0.616592,0.616592
5,-0.081067,-0.079712,-0.079611,-0.079611
6,-0.521916,-0.67363,-0.674024,-0.674024
7,3.303917,3.522102,3.519754,3.519754
8,-0.257689,-0.288638,-0.288372,-0.288372
9,-1.700934,-2.087236,-2.089393,-2.089393


In [114]:
gmm_init_params["means"][IDX]

tensor([ 0.1785, -2.7549, -1.0892,  1.7319,  0.2749, -0.0532, -0.9788,  3.2775,
        -0.2790, -2.2910, -0.6581, -1.5413, -1.2498,  0.9345,  0.9561, -2.0828,
         5.9111,  3.9389])

In [115]:
df_compare_covar = pd.DataFrame()
df_compare_covar["skgmm"] = skgmm.covariances_[IDX][0]
df_compare_covar["numpy"] = covar[IDX][0]
df_compare_covar["torch"] = trained_model.covariances.detach().numpy()[IDX][0]
df_compare_covar["lightning"] = gmm_lightning_module.gmm_module.covariances.detach().numpy()[IDX][0]
df_compare_covar

Unnamed: 0,skgmm,numpy,torch,lightning
0,0.075459,0.071019,0.071599,0.071599
1,-0.202008,-0.20133,-0.204387,-0.204387
2,-0.111972,-0.098755,-0.099446,-0.099446
3,0.111245,0.102897,0.10397,0.10397
4,-0.023585,-0.013192,-0.012076,-0.012076
5,0.01376,0.009636,0.009897,0.009897
6,-0.139312,-0.135811,-0.136264,-0.136264
7,-0.074386,-0.095631,-0.094559,-0.094559
8,-0.016463,-0.016869,-0.016866,-0.016866
9,-0.039057,-0.069567,-0.071,-0.071


In [116]:
df_compare_pre_chol = pd.DataFrame()
df_compare_pre_chol["skgmm"] = skgmm.precisions_cholesky_[IDX][0]
df_compare_pre_chol["numpy"] = prec_chol[IDX][0]
df_compare_pre_chol["torch"] = trained_model.precision_cholesky.detach().numpy()[IDX][0]
df_compare_pre_chol["lightning"] = gmm_lightning_module.gmm_module.precision_cholesky.detach().numpy()[IDX][0]
df_compare_pre_chol

Unnamed: 0,skgmm,numpy,torch,lightning
0,3.640359,3.752435,3.737195,3.737195
1,2.480077,2.400315,2.415535,2.415535
2,3.012257,2.680574,2.668036,2.668036
3,-0.922032,-1.395591,-1.409341,-1.409341
4,-3.007662,-3.242618,-3.27351,-3.27351
5,-5.049216,-3.70316,-3.714443,-3.714443
6,3.224421,4.641292,4.642696,4.642696
7,12.447998,13.521707,13.529406,13.529406
8,0.293808,1.508207,1.512385,1.512385
9,-7.302855,-5.594008,-5.637901,-5.637901


In [117]:
df_compare_weights = pd.DataFrame()
df_compare_weights["skgmm"] = skgmm.weights_[:10]
df_compare_weights["numpy"] = weights[:10]
df_compare_weights["torch"] = trained_model.weights[:10]
df_compare_weights["lightning"] = gmm_lightning_module.gmm_module.weights.detach().numpy()[:10]
df_compare_weights

Unnamed: 0,skgmm,numpy,torch,lightning
0,0.004774,0.004385,0.004385,0.004385
1,0.00016,0.00016,0.00016,0.00016
2,0.029848,0.041667,0.041665,0.041665
3,0.02173,0.022938,0.022917,0.022917
4,0.002641,0.002473,0.002478,0.002478
5,0.00016,0.00016,0.00016,0.00016
6,0.00642,0.005413,0.005415,0.005415
7,0.004994,0.005059,0.005059,0.005059
8,0.0006,0.0006,0.0006,0.0006
9,0.00048,0.00048,0.00048,0.00048


# Sampling

In [118]:
def sample(means_, covariances_, weights_, n_samples):
    rng = np.random.RandomState(RANDOM_STATE)
    n_samples_comp = rng.multinomial(n_samples, weights_)
    
    X = np.vstack(
            [
                rng.multivariate_normal(mean, covariance, int(sample))
                for (mean, covariance, sample) in zip(
                    means_, covariances_, n_samples_comp
                )
            ]
        )
    
    y = np.concatenate(
        [np.full(sample, j, dtype=int) for j, sample in enumerate(n_samples_comp)]
    )
    return (X, y)

In [119]:
def torch_sample(means_, covariances_, weights_, n_samples):
    # Set up the random generator with a specified seed
    generator = torch.Generator().manual_seed(RANDOM_STATE)
    
    # Sample component counts from the multinomial distribution
    n_samples_comp = torch.multinomial(weights_, n_samples, replacement=True, generator=generator).bincount(minlength=len(weights_))
    
    # Initialize lists to collect samples and labels
    X = []
    y = []
    
    # Sample from each component based on the number of samples
    for j, (mean, covariance, sample_count) in enumerate(zip(means_, covariances_, n_samples_comp)):
        if sample_count > 0:  # Only sample if we need samples from this component
            dist = torch.distributions.MultivariateNormal(
                mean, covariance
            )
            samples = dist.sample((sample_count,))
            X.append(samples)
            y.append(torch.full((sample_count,), j, dtype=torch.int64))
    
    # Concatenate all samples and labels into single tensors
    X = torch.vstack(X)
    y = torch.cat(y)
    
    return X, y

In [120]:
N_SAMPLES = 250

In [121]:
skgmm_samples = sample(skgmm.means_, skgmm.covariances_, skgmm.weights_, n_samples = N_SAMPLES)

skgmm_X, skgmm_y = skgmm_samples
skgmm_X[IDX], skgmm_y[IDX]

(array([ 0.07240839, -1.71829723, -0.60803314,  1.73979589,  2.01262296,
         0.11960121, -0.18303798,  2.26989082, -0.18918716, -3.64823518,
        -0.51437723, -1.0090173 , -1.70681869,  0.34972181,  2.84867092,
        -2.27103036,  2.52582455,  3.44432533]),
 0)

In [122]:
samples = sample(means, covar, weights, n_samples = N_SAMPLES)

X, y = samples
X[IDX], y[IDX]

(array([ 0.4047585 , -2.01683898, -1.32529041,  2.77408286,  1.82210771,
         0.06704174, -1.10510541,  2.33170647, -0.58834022, -3.6683143 ,
        -0.97818583, -2.3354214 , -1.63901546,  0.05558079,  2.33957331,
        -2.19347763,  5.54049101,  1.08385274]),
 0)

In [123]:
train_model_samples = torch_sample(trained_model.means, trained_model.covariances, trained_model.weights, n_samples = N_SAMPLES)
train_model_X, train_model_y = train_model_samples
train_model_X[IDX], train_model_y[IDX]

(tensor([ -5.0416, -30.0005,  24.1288, -27.3287,   5.9395,   1.1542,  22.8423,
         -22.7208,   1.2230, -27.2587,  25.5793,  40.1886,  45.4400,  -4.1804,
          34.8146, -36.0102,   7.2540,   0.3268]),
 tensor(1))

In [124]:
gmm_lightning_samples = torch_sample(gmm_lightning_module.gmm_module.means, gmm_lightning_module.gmm_module.covariances, gmm_lightning_module.gmm_module.weights, n_samples = N_SAMPLES)
gmm_lightning_X, gmm_lightning_y = train_model_samples
gmm_lightning_X[IDX], gmm_lightning_y[IDX]

(tensor([ -5.0416, -30.0005,  24.1288, -27.3287,   5.9395,   1.1542,  22.8423,
         -22.7208,   1.2230, -27.2587,  25.5793,  40.1886,  45.4400,  -4.1804,
          34.8146, -36.0102,   7.2540,   0.3268]),
 tensor(1))