In [1]:
import torch
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader, Dataset

# Note - you must have torchvision installed for this example
from torchvision import transforms

import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import os
from torchvision.io import read_image
import torchvision
import json
import matplotlib.animation as animation
from matplotlib.patches import Ellipse
# import matplotlib.transforms as transforms
import shutil

from data import KaggleDataset, KaggleDataModule

In [2]:
import torch.nn as nn

class Encoder(nn.Module):

    def __init__(self, encoded_space_dim=2, fc2_input_dim=128, num_channels=[8, 16, 32], in_chan = 3):
        super().__init__()
        self.encoded_space_dim = encoded_space_dim
        self.num_channels = num_channels

        ### Convolutional section
        self.compress_layer = nn.Conv2d(3, 3, 3, stride=2, padding=0)

        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, self.num_channels[0], 3, stride=2, padding=0),
            nn.BatchNorm2d(self.num_channels[0]),
            nn.ReLU(),
            nn.Conv2d(self.num_channels[0], self.num_channels[1], 3, stride=2, padding=0),
            nn.BatchNorm2d(self.num_channels[1]),
            nn.ReLU(),
            nn.Conv2d(self.num_channels[1], self.num_channels[2], 3, stride=2, padding=0),
            nn.ReLU()
        )

        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)
        ### Linear section
        self.encoder_lin = nn.Sequential(
            nn.Linear(9 * 9 * self.num_channels[2], fc2_input_dim),
            nn.ReLU(),
            nn.Linear(fc2_input_dim, encoded_space_dim * 2)
        )

    def forward(self, x):
        for _ in range(3):
            x = self.compress_layer(x)
        x = self.encoder_cnn(x)
        x = self.flatten(x)
        x = self.encoder_lin(x)
        mu, logsigmasq = x[:, :self.encoded_space_dim], x[:, self.encoded_space_dim:]
        return mu, logsigmasq


class Decoder(nn.Module):

    def __init__(self, encoded_space_dim=2, fc2_input_dim=128, num_channels=[32, 16, 8], in_chan = 3):
        super().__init__()
        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, fc2_input_dim),
            nn.ReLU(),
            nn.Linear(fc2_input_dim, 9 * 9 * num_channels[0]),
            nn.ReLU()
        )

        self.unflatten = nn.Unflatten(dim=1,
                                      unflattened_size=(num_channels[0], 9, 9))

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(num_channels[0], num_channels[1], 3,
                               stride=2, output_padding=(0, 0)),
            nn.BatchNorm2d(num_channels[1]),
            nn.ReLU(),
            nn.ConvTranspose2d(num_channels[1], num_channels[2], 3, stride=2,
                               padding=0, output_padding=(0, 0)),
            nn.BatchNorm2d(num_channels[2]),
            nn.ReLU(),
            nn.ConvTranspose2d(num_channels[2], 2*in_chan, 3, stride=2,
                               padding=0, output_padding=(0, 0))
        )


        self.decompress_layer_0 = nn.ConvTranspose2d(2*in_chan, 2*in_chan, 3, stride=2, padding=0, output_padding=(0,0))
        self.decompress_layer_1 = nn.ConvTranspose2d(2*in_chan, 2*in_chan, 3, stride=2, padding=0, output_padding=(1,1))

    def forward(self, x):

        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
#         x = self.decompress_layer_0(x)
        for _ in range(2):
            x = self.decompress_layer_0(x)
        x = self.decompress_layer_1(x)
        mu, logsigmasq = x[:, :3, :, :], x[:, 3:, :, :]
        return mu, logsigmasq

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
import os


r = 202211010
np.random.seed(r)
torch.manual_seed(r)

D = 1 # number of modalities
dim = 64 # dimension of each modality (assume to be same)
K = 2
Z = 2
latent_dim = Z

lr = 0.5*1e-3
num_epochs = 50
batch_size = 8
num_workers = 8
em_reg = 1e-6
logsigmasq_reg = em_reg

means_hist = []
mu_c_hist = []
logsigmasq_c_hist = []
gamma_c_train_hist = []
gamma_c_val_hist = []


w_rec = 1.0
w_reg = 1.0
w_entr = 1.0


sim_name = F'gmvae_ld{latent_dim}_nc{K}_rec{w_rec}_reg{w_reg}_entr{w_entr}'
data_dir = './data/kaggle_cats_and_dogs/PetImages'

my_datamodule = KaggleDataModule(data_dir, 640, n_train = 5000, n_valid = 1000, n_test = 100, batch_size = 8)
my_datamodule.setup('fit')

save_every = 1

# N_train, W_img, H_img = train_dataset.data.shape  # 60000, 28, 28
# N_test, _, _ = test_dataset.data.shape  # 10000, 28, 28

train_loader = my_datamodule.train_dataloader()
valid_loader = my_datamodule.valid_dataloader()

train_labels = torch.Tensor([int(train_loader.dataset[i][1]) for i in range(len(train_loader.dataset))])
valid_labels = torch.Tensor([int(valid_loader.dataset[i][1]) for i in range(len(valid_loader.dataset))])

File corrupted - dog_9931.jpg - Train. Ignoring file...
File corrupted - dog_9955.jpg - Train. Ignoring file...
File corrupted - dog_997.jpg - Train. Ignoring file...
File corrupted - dog_9940.jpg - Train. Ignoring file...
File corrupted - dog_9987.jpg - Train. Ignoring file...
File corrupted - dog_9995.jpg - Train. Ignoring file...
File corrupted - dog_9972.jpg - Train. Ignoring file...
File corrupted - dog_9997.jpg - Train. Ignoring file...
File corrupted - dog_9968.jpg - Train. Ignoring file...
File corrupted - dog_9988.jpg - Train. Ignoring file...
File corrupted - dog_9956.jpg - Train. Ignoring file...
File corrupted - dog_995.jpg - Train. Ignoring file...
File corrupted - dog_9994.jpg - Train. Ignoring file...
File corrupted - dog_9948.jpg - Train. Ignoring file...
File corrupted - dog_9950.jpg - Train. Ignoring file...
File corrupted - dog_9954.jpg - Train. Ignoring file...
File corrupted - dog_9976.jpg - Train. Ignoring file...
File corrupted - dog_9967.jpg - Train. Ignoring fi

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
# from kmeans_pytorch import kmeans

import matplotlib.pyplot as plt
import numpy as np

def encoder_step(x, encoder, decoder):
    """ Computes a stochastic estimate of the rescaled evidence lower bound

    Args:
        x_list: length-D list of (N, data_dim) torch.tensor
        encoder_list: length-D list of Encoder
        decoder_list: length-D list of Decoder
        params: dictionary of other parameters
    Returns:
        elbo: a (,) torch.tensor containing the estimate of the ELBO
    """
    mu, logsigmasq = encoder.forward(x)
    return mu, logsigmasq + logsigmasq_reg

def em_step(z, mu, params, update_by_batch=False):

    mu_c = params['mu_c'].to(device)  # (K, Z)
    logsigmasq_c = params['logsigmasq_c'].to(device)  # (K, Z)
    sigma_c = torch.exp(0.5 * logsigmasq_c)
    pi_c = params['pi_c'].to(device)

    log_prob_zc = Normal(mu_c, sigma_c).log_prob(z.unsqueeze(dim=1)).sum(dim=2) + torch.log(pi_c)  #[N, K]
    log_prob_zc -= log_prob_zc.logsumexp(dim=1, keepdims=True)
    gamma_c = torch.exp(log_prob_zc) + em_reg

    denominator = torch.sum(gamma_c, dim=0).unsqueeze(1)
    mu_c = torch.einsum('nc,nz->cz', gamma_c, mu) / denominator
    logsigmasq_c = torch.log(torch.einsum('nc,ncz->cz', gamma_c, (mu.unsqueeze(dim=1) - mu_c) ** 2)) - torch.log(denominator)

    if not update_by_batch:
        return gamma_c, mu_c, logsigmasq_c

    else:
        hist_weights = params['hist_weights'].to(device)
        hist_mu_c = params['hist_mu_c'].to(device)
        hist_logsigmasq_c = params['hist_logsigmasq_c'].to(device)

        curr_weights = denominator
        new_weights = hist_weights + curr_weights
        new_mu_c = (hist_weights * hist_mu_c + curr_weights * mu_c) / new_weights
        new_logsigmasq_c = torch.log(hist_weights * torch.exp(hist_logsigmasq_c) + curr_weights * torch.exp(logsigmasq_c)) - torch.log(new_weights)
        # new_logsigmasq_c = torch.log(torch.exp(torch.log(hist_weights) + hist_logsigmasq_c) +
        #                              torch.exp(torch.log(curr_weights) + logsigmasq_c)) - torch.log(new_weights)

        params['hist_weights'] = new_weights
        params['hist_mu_c'] = new_mu_c
        params['hist_logsigmasq_c'] = new_logsigmasq_c
        return gamma_c, new_mu_c, new_logsigmasq_c


def decoder_step(x, z, encoder, decoder, params, mu, logsigmasq, gamma_c):
    """
    Computes a stochastic estimate of the ELBO.
    :param x_list: length-D list of (N, data_dim) torch.tensor
    :param z: MC samples of the encoded distributions
    :param encoder_list: length-D list of Encoder
    :param decoder_list: length-D list of Decoder
    :param params: dictionary of non-DNN parameters
    :return:
        elbo: (,) tensor containing the elbo estimation
    """
    sigma = torch.exp(0.5 * logsigmasq)
    mu_c = params['mu_c']
    logsigmasq_c = params['logsigmasq_c']
    pi_c = params['pi_c']

    elbo = 0
    
    reconstruction = 0
    regularization = 0
    entropy = 0
    
    mu_, logsigmasq_ = decoder.forward(z)
    reconstruction += Normal(mu_, torch.exp(0.5 * logsigmasq_)).log_prob(x).sum()
        
    regularization = - 0.5 * torch.sum(gamma_c * (logsigmasq_c + (sigma.unsqueeze(1) ** 2 + (mu.unsqueeze(1) - mu_c) ** 2) /
                                         torch.exp(logsigmasq_c)).sum(dim=2))
    
    entropy = torch.sum(gamma_c * (torch.log(pi_c) - torch.log(gamma_c))) + 0.5 * torch.sum(1 + logsigmasq)

    elbo = w_rec*reconstruction + w_reg*regularization + w_entr*entropy
    
    return elbo, reconstruction, regularization, entropy

In [None]:
# initialize latent GMM model parameters
params = {}
device = 0

pi_variables = torch.zeros(K, requires_grad = True, device = device)
params['pi_c'] = torch.ones(K) / K
torch.manual_seed(r)
params['mu_c'] = torch.rand((K, Z)) * 2.0 - 1.0
params['mu_c'] = params['mu_c']
params['logsigmasq_c'] = torch.zeros((K, Z))

# initialize neural networks
encoder_list = []
decoder_list = []
trainable_parameters = []
trainable_parameters.append(pi_variables)

torch.manual_seed(r)
encoder = Encoder(encoded_space_dim = latent_dim).to(device)
decoder = Decoder(encoded_space_dim = latent_dim).to(device)

trainable_parameters += list(encoder.parameters()) + list(decoder.parameters())

optimizer = optim.Adam(trainable_parameters, lr=lr)


# training
train_loss = torch.zeros(num_epochs)
rec_loss = torch.zeros(num_epochs)
reg_loss = torch.zeros(num_epochs)
entr_loss = torch.zeros(num_epochs)

valid_loss = torch.zeros(num_epochs)
pi_history = torch.zeros((num_epochs, K))
min_valid_loss = torch.inf
epoch_list = []
train_loss_list = []
valid_loss_list = []

rec_loss_list = []
reg_loss_list = []
entr_loss_list = []

for epoch in range(num_epochs):
    encoder.train()
    decoder.train()

    train_elbo = 0
    rec_elbo = 0
    reg_elbo = 0
    entr_elbo = 0
    gamma_c_epoch = []
    params['hist_weights'] = torch.zeros((K, 1)).clone().detach()
    params['hist_mu_c'] = torch.zeros((K, Z)).clone().detach()
    params['hist_logsigmasq_c'] = torch.zeros((K, Z)).clone().detach()

    for (batch_idx, batch_x) in enumerate(train_loader):
        
        x = batch_x[0].to(device)

        optimizer.zero_grad()
        pi_c = torch.exp(pi_variables) / torch.sum(torch.exp(pi_variables))
#         print(f'pi variables: {pi_variables}')
#         print(f'pi_c: {pi_c}')

        params['pi_c'] = pi_c

        mu, logsigmasq = encoder_step(x, encoder, decoder)
#         print(f'mu: {mu}')
#         print(f'logsigmasq: {logsigmasq}')
        sigma = torch.exp(0.5 * logsigmasq)
        torch.manual_seed(r)
        eps = Normal(0, 1).sample(mu.shape).to(device)
        z = mu + eps * sigma
#         print(f'z: {z}')

        with torch.no_grad():
            gamma_c, mu_c, logsigmasq_c = em_step(z, mu, params, update_by_batch=True)
            
#         print(f'gamma_c: {gamma_c}')
#         print(f'mu_c: {mu_c}')
#         print(f'logsigmasq_c: {logsigmasq_c}')
        params['mu_c'] = mu_c
        params['logsigmasq_c'] = logsigmasq_c
        gamma_c_epoch.append(gamma_c)

        elbo, rec, reg, entr = decoder_step(x, z, encoder, decoder, params, mu, logsigmasq, gamma_c)
        
        train_elbo += elbo.item()
        rec_elbo += rec.item()
        reg_elbo += reg.item()
        entr_elbo += entr.item()
        loss = - elbo
        
#         print(f'step: {batch_idx} | train_loss: {-train_elbo}')
        loss.backward()
        optimizer.step()
        
    gamma_c_train_hist.append(torch.vstack(gamma_c_epoch))

    encoder.eval()
    decoder.eval()

    valid_elbo = 0
    gamma_c_epoch = []
    with torch.no_grad():
        for (batch_idx, batch_x) in enumerate(valid_loader):
#             print('VALIDATION')
            x = batch_x[0].to(device)
            mu, logsigmasq = encoder_step(x, encoder, decoder)
#             print(f'mu: {mu}')
#             print(f'logsigmasq: {logsigmasq}')
            sigma = torch.exp(0.5 * logsigmasq)
            torch.manual_seed(r)
            eps = Normal(0, 1).sample(mu.shape).to(device)
            z = mu + eps * sigma
#             print(f'z: {z}')
            with torch.no_grad():
                gamma_c, _, _ = em_step(z, mu, params)
#             print(f'gamma_c: {gamma_c}')
            gamma_c_epoch.append(gamma_c)
            elbo, rec, reg, entr  = decoder_step(x, z, encoder, decoder, params, mu, logsigmasq, gamma_c)
            valid_elbo += elbo.item()
            
    gamma_c_val_hist.append(torch.vstack(gamma_c_epoch))

    train_elbo /= len(train_loader.dataset)
    valid_elbo /= len(valid_loader.dataset)
    # print('====> Epoch: {} Train ELBO: {:.4f} '.format(epoch, train_elbo))
    print('====> Epoch: {} Train ELBO: {:.4f} Val ELBO: {:.4f}'.format(epoch, train_elbo, valid_elbo))

    train_loss[epoch] = - train_elbo
    rec_loss[epoch] = - rec_elbo
    reg_loss[epoch] = - reg_elbo
    entr_loss[epoch] = - entr_elbo
    valid_loss[epoch] = - valid_elbo
    pi_history[epoch] = params['pi_c']

    if epoch % save_every == 0:
        epoch_list.append(epoch)
        train_loss_list.append(train_loss[epoch].item())
        valid_loss_list.append(valid_loss[epoch].item())
        rec_loss_list.append(rec_loss[epoch].item())
        reg_loss_list.append(reg_loss[epoch].item())
        entr_loss_list.append(entr_loss[epoch].item())
        # Plot the first two dimensions of the latents
        with torch.no_grad():
            means = []
            # labels = []
            for batch_x in train_loader:
                x = batch_x[0].to(device)
                mean, _ = encoder_step(x, encoder, decoder)
                means.append(mean)
                
                # labels.append(batch_label)

        means = torch.vstack(means).cpu()
        # labels = torch.hstack(labels)
        means_hist.append(means)
        mu_c_hist.append(params['mu_c'].cpu())
        logsigmasq_c_hist.append(params['logsigmasq_c'].cpu())
        fig, ax = plt.subplots(figsize=(6, 5))
        for i in range(K):
            means_i = means[train_labels == i]
            ax.scatter(means_i[:, 0], means_i[:, 1], alpha=0.25, label=str(i))
        for i in range(K):
            ax.plot(params['mu_c'].cpu()[i, 0], params['mu_c'].cpu()[i, 1], 'x', markersize=12) #, label='$\mu$' + str(i + 1))
        
        ax.set_xlabel('$z_1$')
        ax.set_ylabel('$z_2$')
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
        fig.tight_layout()
        plt.close()



In [None]:
path = f'./results/{sim_name}/'
#plt.plot(train_e)
if os.path.exists(path):
    shutil.rmtree(path)
os.makedirs(path)

# Compute the mean of the latents given the data
encoder.eval()

with torch.no_grad():
    means = []
    # labels = []
    for batch_x in train_loader:
        x = batch_x[0].to(device)
        mean, _ = encoder_step(x, encoder, decoder)
        means.append(mean)
        # labels.append(batch_label)

means = torch.vstack(means)
# labels = torch.hstack(labels)

with torch.no_grad():
    gamma_c, mu_c, logsigmasq_c = em_step(means, means, params, update_by_batch=False)

my_datamodule.setup('test')
test_loader = my_datamodule.test_dataloader()
with torch.no_grad():
    means_test = []
    for batch_x in test_loader:
        x = batch_x[0].to(device)
        mean, _ = encoder_step(x, encoder, decoder)
        means_test.append(mean)
        # labels.append(batch_label)
means_test = torch.cat(means_test)

# my_datamodule.setup('outliers')
# outliers_loader = my_datamodule.outliers_dataloader()
# with torch.no_grad():
#     means_outliers = []
#     for batch_x in outliers_loader:
#         x = batch_x[0].to(device)
#         mean, _ = encoder_step(x, encoder, decoder)
#         means_outliers.append(mean)
#         # labels.append(batch_label)
# means_outliers = torch.cat(means_outliers)

# Plot the first two dimensions of the latents
# plot_params()
fig, ax = plt.subplots(figsize=(6, 5))
cluster_means = torch.zeros((K, Z))
for i in range(K):
    means_i = means[train_labels == i].cpu()
    ax.scatter(means_i[:, 0], means_i[:, 1], alpha=0.25, label=str(i))
    cluster_means[i] = torch.mean(means_i, dim=0)
for i in range(K):
    ax.plot(params['mu_c'].cpu()[i, 0], params['mu_c'].cpu()[i, 1], 'x', markersize=12, color = 'k')  # , label='$\mu$' + str(i + 1))

ax.scatter(means_test[:, 0].cpu(), means_test[:, 1].cpu(), alpha=1, color = 'g')
# ax.scatter(means_outliers[:, 0].cpu(), means_outliers[:, 1].cpu(), alpha=1, color = 'r')
ax.set_xlabel('$z_1$ mean')
ax.set_ylabel('$z_2$ mean')
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
fig.tight_layout()
plt.savefig(path +'latent.png', dpi=600)


# Plot training loss vs. epoch number
plt.figure(figsize=(4.5, 4))
const = min(train_loss)
train_loss_adjusted = train_loss - const + 10.
# val_loss_adjusted = val_loss - const + 10.
plt.semilogy(train_loss_adjusted, label='train')
# plt.semilogy(val_loss_adjusted, label='val')
plt.xlabel("number of epochs")
# plt.legend()
plt.tight_layout()
plt.savefig(path +'train_loss.png', dpi=600)
plt.close()

# Plot the training and validation loss vs. epoch number
plt.figure(figsize=(4.5, 4))
const = min(min(train_loss), min(valid_loss))
train_loss_adjusted = train_loss - const + 10.
valid_loss_adjusted = valid_loss - const + 10.
plt.semilogy(train_loss_adjusted, label='train')
plt.semilogy(valid_loss_adjusted, label='val')
plt.xlabel("number of epochs")
plt.legend()
plt.tight_layout()
plt.savefig(path +'loss.png', dpi=600)
plt.close()

# Plot the history of pi
plt.figure(figsize=(4.5, 4))
for i in range(K):
    plt.plot(pi_history[:, i].detach().numpy(), label='$\pi$' + str(i+1))
plt.xlabel("number of epochs")
plt.legend()
plt.tight_layout()
plt.close()

In [None]:
# plt.plot(train_loss_list)
# plt.plot(val_loss_list)
plt.plot(rec_loss_list)
plt.plot(reg_loss_list)
plt.plot(entr_loss_list)


In [None]:
results_dic = {}
results_dic['epochs'] = epoch_list
results_dic['train_loss_epoch'] = train_loss_list
results_dic['valid_loss_epoch'] = valid_loss_list
results_dic['reconstruction_loss_epoch'] = rec_loss_list
results_dic['regularization_loss_epoch'] = reg_loss_list
results_dic['entropy_loss_epoch'] = entr_loss_list
results_dic['mu_c'] = [mu_c.cpu().numpy().tolist() for mu in mu_c_hist]
results_dic['logsigmasq_c'] = [logsig.cpu().numpy().tolist() for logsig in logsigmasq_c_hist]
results_dic['means'] = [mu.cpu().numpy().tolist() for mu in means_hist]
results_dic['pi'] = [pi.cpu().detach().numpy().tolist() for pi in pi_history]
results_dic['gamma_c_train'] = [gamma.cpu().detach().numpy().tolist() for gamma in gamma_c_train_hist]
results_dic['train_labels'] = train_labels.tolist()
results_dic['gamma_c_val'] = [gamma.cpu().detach().numpy().tolist() for gamma in gamma_c_val_hist]
results_dic['val_labels'] = valid_labels.tolist()

with open(path + f'GMVAE_rocks_K1_Z4.json', 'w') as outfile:
    json.dump(results_dic, outfile)
    
%xdel results_dic

In [None]:
xmax = np.max(means_hist[0].numpy()[:,0])
ymax = np.max(means_hist[0].numpy()[:,1])
xmin = np.min(means_hist[0].numpy()[:,0])
ymin = np.min(means_hist[0].numpy()[:,1])
max_epoch = np.max(np.array(epoch_list))

for i in range(len(means_hist)):
    means = means_hist[i].numpy()
    if np.max(means[:,0]) > xmax:
        xmax = np.max(means[:,0])
    if np.max(means[:,1]) > ymax:
        ymax = np.max(means[:,1])
    if np.min(means[:,0]) < xmin:
        xmin = np.min(means[:,0])
    if np.min(means[:,1]) < ymin:
        ymin = np.min(means[:,1])
    

In [None]:
Writer = animation.writers['ffmpeg']
writer = Writer(fps=20, metadata=dict(artist='Me'))

fig, axs = plt.subplots(1,2,figsize=(13, 5))
train_loss_list = np.array(train_loss_list)
valid_loss_list = np.array(valid_loss_list)
rec_loss_list = np.array(rec_loss_list)
reg_loss_list = np.array(reg_loss_list)
entr_loss_list = np.array(entr_loss_list)
epoch_list = np.array(epoch_list)

fig.suptitle(f'GMVAE | latent dim. = {latent_dim}, # clusters = {K}', fontsize=16)
def animate(j):
    print(j)
    #Axis 1
    means = means_hist[j]
    mu_c = mu_c_hist[j]
    sigmaq_c = np.exp(logsigmasq_c_hist[j].numpy())
    axs[0].clear()
    axs[1].clear()
    for i in range(2):
        means_i = means[train_labels == i]
        axs[0].scatter(means_i[:, 0], means_i[:, 1], alpha=0.25, label=str(i))
    for i in range(K):
        axs[0].plot(mu_c[i, 0], mu_c[i, 1], 'x', markersize=12) #, label='$\mu$' + str(i + 1))
        c = axs[0].get_lines()[0].get_color()
        ellipse = Ellipse((mu_c[i, 0], mu_c[i, 1]),
        width=2*3*sigmaq_c[i, 0],
        height=2*3*sigmaq_c[i, 1],
        facecolor='None',
        edgecolor = c,
        linestyle = '--')
        axs[0].add_patch(ellipse)
        
    axs[0].set_xlabel('$z_1$')
    axs[0].set_ylabel('$z_2$')
    axs[0].set_xlim([xmin, xmax])
    axs[0].set_ylim([ymin, ymax])
    axs[0].legend(loc='center left', bbox_to_anchor=(1, 0.5))

    #Axis 2
    axs[1].plot(epoch_list[0:j], train_loss_list[0:j] / train_loss[0],'k')
    axs[1].plot(epoch_list[0:j], valid_loss_list[0:j] / valid_loss[0],'m')
    axs[1].plot(epoch_list[0:j], rec_loss_list[0:j] / rec_loss[0],'r--')
    axs[1].plot(epoch_list[0:j], reg_loss_list[0:j] / reg_loss[0],'g--')
    axs[1].plot(epoch_list[0:j], entr_loss_list[0:j] / entr_loss[0],'b--')
    axs[1].set_xlabel('$epoch #$')
    axs[1].set_ylabel('$loss$')
    axs[1].set_xlim([0, max_epoch])
#     ax2.set_ylim([ymin, ymax])
    axs[1].legend(['train loss', 'val loss', 'rec. loss', 'reg. loss', 'entropy loss'], loc='center left', bbox_to_anchor=(1, 0.5))
    fig.tight_layout()

In [None]:
ani = matplotlib.animation.FuncAnimation(fig, animate, frames=len(epoch_list), repeat=False)
ani.save(path + 'training_video.mp4', writer=writer, dpi=600)

In [None]:
path