# GMM over z

In this notebook, we **fit a generative model over latent variables “z”** provided by the pre-trained encoder network in Task2, in order to **define a probability distribution over “z”, and then to retrieve a complete generative mechanism and sample new sounds**.

We use the package [sklearn.mixture](https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html) to help us build a GMM generative model

<p><b> GMM methods </b></p>

<img src="pictures/GMM_methods.jpg" alt="GMM methods in sklearn" width="800">

In [None]:
import numpy as np
import sklearn
from sklearn.mixture import GaussianMixture

print(sklearn.__version__)

## Import and reshape the z dataset

[encoder.py](https://github.com/magenta/ddsp/blob/master/ddsp/training/encoders.py)

[decoder.py](https://github.com/magenta/ddsp/blob/master/ddsp/training/decoders.py)

[np.reshape](https://numpy.org/doc/stable/reference/generated/numpy.reshape.html)

In [None]:
load_path = 'z_datasets/z_piano_ae.npy'
X = np.load(load_path)
print(np.shape(X))

#### Reshape the dataset as two-dimensional

The original dataset is three-dimensional. We reduce its dimensionality to 2 in order to fit the GMM.

In [None]:
# reduce dimensionality 
N, T_step, Z_dim = np.shape(X)[0], np.shape(X)[2], np.shape(X)[3]
X_new = np.zeros(shape=(N,T_step, Z_dim))

for i in np.arange(N):
    X_new[i] = X[i,0]
    
print(np.shape(X_new))
#print(X_new)

X_re = np.reshape(X_new,(N*T_step,16))
print('\n',np.shape(X_re),'\n',X_re)

In [None]:
# D = np.reshape(X_re,(N,T_step, Z_dim))
# print(X == D)

## Fit the dataset z

In [None]:
# fit Z and build a GMM model
gm = GaussianMixture(n_components=10).fit(X_re)
# attributes of the model
"""
weights_: array-like of shape (n_components,)
The weights of each mixture components.

means_: array-like of shape (n_components, n_features)
The mean of each mixture component.
"""

means = gm.means_
weights = gm.weights_

print('means:',means,'\n','weights of each component:',weights)

In [None]:
predictions = gm.sample(1000)

print('\n', np.shape(predictions[0]))
print('\n', predictions[0])