### 1.Load pretrained VAE model (MNIST)

In [1]:
import torch
from diva_src.diva import DIVA
model_ckpt = torch.load('path/to/your/vae_pretrained.ckpt')
pretrained_vae_model = DIVA(in_channels=1, latent_dim=16, dpmm_param=None)
state_dict = model_ckpt['state_dict']

new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith('model.'):  # Check if the key has the prefix 'model.'
        name = k[6:]  # Remove the prefix
    else:
        name = k
    new_state_dict[name] = v

pretrained_vae_model.load_state_dict(new_state_dict)
pretrained_vae_model.eval()



DIVA(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
  )
  (fc_mu): Linear(in_features=3136, out_features=16, bias=True)
  (fc_log_var): Linear(in_features=3136, out_features=16, bias=True)
  (decoder_input): Linear(in_features=16, out_features=3136, bias=True)
  (decoder): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequent

### 2.Load pretrained DPMM model (MNIST)

In [5]:
import bnpy
import numpy as np
from bnpy.data.XData import XData

dpmm_model = bnpy.ioutil.ModelReader.load_model_at_prefix('path/to/your/dpmm_pretrained/', prefix="Best")

# function for getting the cluster parameters
def calc_cluster_component_params(bnp_model):
        comp_mu = [torch.Tensor(bnp_model.obsModel.get_mean_for_comp(i)) for i in np.arange(0, bnp_model.obsModel.K)]
        comp_var = [torch.Tensor(np.sum(bnp_model.obsModel.get_covar_mat_for_comp(i), axis=0)) for i in np.arange(0, bnp_model.obsModel.K)] 
        return comp_mu, comp_var

comps_mu, comps_var = calc_cluster_component_params(dpmm_model)
print('cluster Gaussian mean values: \n', comps_mu)
print('cluster Gaussian variation values: \n', comps_var)

cluster Gaussian mean values: 
 [tensor([-0.1950, -0.7129,  0.7590,  0.2700,  0.6161, -0.0803,  0.0117, -0.0358,
        -0.0504,  0.6183,  0.1274,  0.1265, -0.4410,  1.6921, -0.0860, -0.9918]), tensor([-0.2000,  0.7289, -0.4580, -0.7908, -0.2773, -0.2145, -0.2050, -0.0329,
        -0.3484,  0.1417,  0.9206,  0.6020,  0.4966, -0.1470, -0.2188, -1.0337]), tensor([-0.1575,  1.8463,  0.0973, -0.8047,  0.3417, -0.3271,  0.7113, -0.5461,
        -0.3040, -0.4803, -0.2435, -0.0698, -0.6653,  0.3123, -0.4079,  0.1448]), tensor([ 0.0930,  0.2475,  0.5648,  0.2028,  0.0319, -0.3070,  0.0076,  0.9255,
        -1.1694, -0.7714,  0.2396,  1.0005, -0.3193,  0.1625,  0.5011,  1.0699]), tensor([-0.4780,  0.0511, -0.0534,  0.6606, -0.3352,  1.2047, -0.5367,  1.1271,
         0.1783, -0.5502, -0.9587, -0.1652, -0.0089, -0.2418,  0.0233, -1.9552]), tensor([-0.7130, -0.2564, -2.0810,  0.3622, -0.0418, -1.6575, -0.8449, -0.4024,
        -0.3185, -0.2138, -0.2747, -0.1087,  0.2748, -0.1157, -0.1617, -2.238