# Defining the general class with init, loading, training, eval methods

In [31]:
#################################################################################
#   MAGIC TRICK FOR HAVING tab, shift+tab COMMANDS!
#################################################################################

%config Completer.use_jedi = False

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import os 
# https://discuss.pytorch.org/t/cuda-launch-blocking-in-jupyter-notebook/163029
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import sys
import gc
from typing import Union
import tqdm

import datetime
import time
import json 

import h5py
import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import pandas as pd

import torch
import torchvision

try:
    from torchsummary import summary
except: 
    %pip install torchsummary
    from torchsummary import summary


print(f"PyTorch v.\t{torch.__version__}")
print(f"TorchVision v.\t{torchvision.__version__}\n")

# in torch/pytorch data and models need to be moved in the specific processing unit
# this code snippet allows to set the variable "device" according to available resoirce (cpu or cuda gpu)
if torch.cuda.is_available():
    print('number of devices: ', torch.cuda.device_count())
    print(torch.cuda.get_device_name(0))

device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Computation device: {device}\n")

import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from utils.memory_utils import free_memory

PyTorch v.	1.10.1+cu102
TorchVision v.	0.11.2+cu102

number of devices:  1
Tesla T4
Computation device: cuda



  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Dataset utils
from utils.memory_utils import free_memory
from utils.custom_dataset import XRFAE1DDataset, AstroSynthDataset
from utils.pretreatment import custom_transform, normalize_hist, rebin_xrf, tanh_norm, smooth_1d
# Model utils
from utils.VAE1D_model import DeepClustering_VAE1D
from utils.save_best_model import SaveBestModel
# Loss utils
from utils.losses import vae_loss_function, silhouette_loss, compute_total_loss
# Clustering utils
from utils.clustering_utils import silhouette_score, IterativeKMeans, KMeans, kpp_init
# VAE utils
from utils.beta_scheduler import BetaScheduler
# Logs utils
from utils.logs_utils import write_line_to_file, store_hyp_dict
# Loader uytils
from utils.open_model import load_model_func

number of devices:  1
Tesla T4
Computation device: cuda

number of devices:  1
Tesla T4
Computation device: cuda

number of devices:  1
Tesla T4
Computation device: cuda

number of devices:  1
Tesla T4
Computation device: cuda

number of devices:  1
Tesla T4
Computation device: cuda

number of devices:  1
Tesla T4
Computation device: cuda



## Define the model Class


In [4]:
class DeepClustering_trainer(nn.Module):
    def __init__(
        self,
        # Verbosity
        verbosity: int = 1, 
        # Dataset
        CUSTOM_TRANSFORM: bool = False ,
        BASE_DATASET_PATH: str = '/jupyter/notebooks/Article/AstroData/Synthetic/1D/' ,
        MAX_TRAIN_SIZE: int  = int(20*10e+4) ,
        MAX_VAL_SIZE  : int  = int( 4*10e+4) ,
        MAX_TEST_SIZE : int  = int( 1*10e+4) ,
        # Dataloader
        BATCH_SIZE : int = 256, 
        NUM_WORKERS: int = 6,
        DO_NORMALIZE: bool = True,
        # Model kwargs
        input_dim: int = 1024,
        n_layers : int = 4, 
        encoding_space_dim: int = 32, 
        latent_space_dim: int = 6, 
        use_latent_space_activation : bool = True, 
        pow_2_decrease: bool = True, 
        pow_2_increase: bool = True, 
        final_activation = nn.Sigmoid(),
        use_SNN: bool = True,         # If true, Self-Normalising Neural networks are used
        # K-Means
        min_n_cluster: int = 4,
        max_n_cluster: int = 12, 
        Niter: int = 10, 
        random_centroid_init: bool = False, 
        # Optimiser
        learning_rate: float = 1e-3, 
        use_lr_scheduler: bool = True, 
        # Name
        BASE_PATH_TO_STORE: str = './model_data',
        base_model_name: str = 'DeepClustering',
        # training
        epochs  : int = 100,
        PATIENCE: int = 10,
        # beta-VAE param
        BETA_MIN: float = 0.0, # NB if BETA_MIN >= BETA_MAX, beta(t) = BETA_MIN forall t
        BETA_MAX: float = 0.0,
        beta_t_number_of_epochs: int = 60,
        beta_t_starting_epoch  : int = 10, 
        beta_t_ending_epoch    : int = 30,
        # Deep Clustering gamma
        γ_init : float = 0.05,
        USE_VARYING_GAMMA: bool = False,
        gamma_t_number_of_epochs: int = 60,
        gamma_t_starting_epoch  : int = 10, 
        gamma_t_ending_epoch    : int = 30,
        # Loss
        USE_SUM : bool  = True,
        # Device
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
        
    ):
        super().__init__()
        
        #
        self.BASE_PATH_TO_STORE = BASE_PATH_TO_STORE
        self.starting_time_str  = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
        self.epochs   = epochs
        self.PATIENCE = PATIENCE
        self.DO_NORMALIZE = DO_NORMALIZE
        # VAE?
        self.BETA_MIN = BETA_MIN 
        self.BETA_MAX = BETA_MAX
        # Deep Clustering
        self.γ_init = γ_init
        self.USE_VARYING_GAMMA = USE_VARYING_GAMMA
        # loss
        self.USE_SUM = USE_SUM
        # model name
        self.base_model_name = base_model_name
        self.model_name  = f'{base_model_name}AE_{self.starting_time_str}' if self.BETA_MIN == self.BETA_MAX == 0 else f'{base_model_name}VAE_{self.starting_time_str}' 
        self.model_name += '_use_sum' if self.USE_SUM else '' 
        # create dir
        self.create_log_dir()
        # other
        self.verbosity = verbosity
        self.device    = device
        self.input_dim = input_dim
        # Apply custom transform or not
        self.CUSTOM_TRANSFORM  = CUSTOM_TRANSFORM
        self.BASE_DATASET_PATH = BASE_DATASET_PATH
        self.MAX_TRAIN_SIZE = MAX_TRAIN_SIZE
        self.MAX_VAL_SIZE   = MAX_VAL_SIZE
        self.MAX_TEST_SIZE  = MAX_TEST_SIZE
        # Dataloader
        self.BATCH_SIZE  = BATCH_SIZE
        self.NUM_WORKERS = NUM_WORKERS
        
        # INIT DATASET
        self.init_dataset()
        # INIT DATALOADE
        self.init_dataloader()
        
        # Model kwargs
        self.input_dim = input_dim
        self.n_layers  = n_layers
        self.encoding_space_dim = encoding_space_dim
        self.latent_space_dim   = latent_space_dim
        self.use_latent_space_activation = use_latent_space_activation
        self.pow_2_decrease = pow_2_decrease
        self.pow_2_increase = pow_2_increase
        self.final_activation = final_activation if self.DO_NORMALIZE else nn.ReLU()
        self.min_n_cluster = min_n_cluster
        self.max_n_cluster = max_n_cluster
        self.Niter = Niter
        self.random_centroid_init = random_centroid_init
        self.use_SNN = use_SNN
        self.model_kwargs = {
            "input_dim" : self.input_dim, 
            "n_layers": self.n_layers,
            "encoding_space_dim": self.encoding_space_dim, 
            "latent_space_dim": self.latent_space_dim,
            "use_latent_space_activation": self.use_latent_space_activation,
            "pow_2_decrease": self.pow_2_decrease ,
            "pow_2_increase": self.pow_2_increase ,
            "final_activation": self.final_activation ,
            "min_n_cluster": self.min_n_cluster, 
            "max_n_cluster": self.max_n_cluster, 
            "Niter": self.Niter, 
            "verbose": True if self.verbosity > 0 else False, 
            "random_centroid_init": self.random_centroid_init,
            "use_SNN" : self.use_SNN , 
        }
        # init model
        self.init_model()
        
        # Optimiser
        self.learning_rate    = learning_rate 
        self.use_lr_scheduler = use_lr_scheduler
        # init optimiser
        self.init_adam_optimizer()
          
        # Training utils
        # Save Model
        self.save_best_model = SaveBestModel(model_name=f"{self.full_path_to_store}/{self.model_name}") #initialize checkpoint function
        # Beta Scheduler
        self.beta_t_number_of_epochs = beta_t_number_of_epochs
        self.beta_t_starting_epoch   = beta_t_starting_epoch
        self.beta_t_ending_epoch = beta_t_ending_epoch
        self.beta_t = BetaScheduler(
            β_min = self.BETA_MIN, β_max = self.BETA_MAX,
            number_of_epochs = self.beta_t_number_of_epochs,
            starting_epoch   = self.beta_t_starting_epoch, 
            ending_epoch     = self.beta_t_ending_epoch
        )
        self.delay_best_model_store = (self.beta_t_starting_epoch > 0 or self.beta_t_ending_epoch > 0) and self.BETA_MAX > 0
        self.Delta_patience = self.beta_t_ending_epoch if self.delay_best_model_store else 0
        
        if self.verbosity > 0: print(f"Delay best model store: {self.delay_best_model_store};\nDelta patience: {self.Delta_patience}")
        
        self._vae_args = {
            'BETA_MIN' : self.BETA_MIN, 
            'BETA_MAX' : self.BETA_MAX, 
            'beta_t_number_of_epochs' : self.beta_t_number_of_epochs ,
            'beta_t_starting_epoch'   : self.beta_t_starting_epoch ,
            'beta_t_ending_epoch'     : self.beta_t_ending_epoch ,
            'delay_best_model_store'  : self.delay_best_model_store ,
        }
        # Gamma scheduler
        self.gamma_t_number_of_epochs = gamma_t_number_of_epochs
        self.gamma_t_starting_epoch   = gamma_t_starting_epoch
        self.gamma_t_ending_epoch  = gamma_t_ending_epoch
        self.gamma_t = BetaScheduler(
            β_min = 0.0 if self.USE_VARYING_GAMMA else self.γ_init, # if USE_VARYING_GAMMA = False, gamma_t = γ_init
            β_max = self.γ_init,
            number_of_epochs = self.gamma_t_number_of_epochs,
            starting_epoch = self.gamma_t_starting_epoch, 
            ending_epoch = self.gamma_t_ending_epoch
        )
        # History
        self.training_loss   = []
        self.validation_loss = []

        self.mmd_losses = []
        self.rec_losses = []
        self.s_losses   = []

        self.learning_rates = []
        
        # === STORE HYPERPARAMETERS =====
        self.hyperparam_kwargs = {
            # Model kwargs
            **self.model_kwargs,
            # Dataset info
            'train_size': len(self.dataloader_train.dataset),
            'val_size'  : len(self.dataloader_val.dataset),
            # Hyperparameters
            'epochs'   : self.epochs,
            'patience' : self.PATIENCE,
            'min_n_clusters': self.min_n_cluster,
            'max_n_clusters': self.max_n_cluster,
            'ItKmeans_n_iter' : self.Niter,
            'γ_init' : self.γ_init,
            'γ variable'  : self.USE_VARYING_GAMMA,
            'loss_use_sum': self.USE_SUM,
            # Other hyperparams
            **self._vae_args , 
        }
        del self.hyperparam_kwargs['final_activation']
        store_hyp_dict(f'{self.full_path_to_store}/{self.model_name}.json', self.hyperparam_kwargs)
        if self.verbosity > 0: print("Stored JSON")
        # === STORE LOGS =====
        self.LOG_FILE = f'{self.full_path_to_store}/{self.model_name}.txt'

        log_text =f"{self.model}\n\nLoss used:\tAECluster\n\nTrain Size: {len(self.dataloader_train.dataset)}\tValidation size: {len(self.dataloader_val.dataset)}\tBatch size: {self.BATCH_SIZE}\nEpochs: {self.epochs}\tPatience: {self.PATIENCE}\nLatent space activation: {self.use_latent_space_activation}\n\n"
        if self.verbosity > 0: print(log_text)
        write_line_to_file(LOG_FILE=self.LOG_FILE, log_line=log_text)
        
    def forward(self, x: torch.Tensor):
        return self.model(x)
    
    def create_log_dir(self):
        """
        Mathod to create the log dir for the model
        """
        path = f'{self.BASE_PATH_TO_STORE}/{self.model_name}'
        if not os.path.exists(path):
            os.makedirs(path)
        
        self.full_path_to_store = path
    
    def custom_transform_realized(self, x: torch.Tensor) -> torch.Tensor:
        """
        Custom pretreatment function.
        
        If self.CUSTOM_TRANSFORM == True, it applies custom_transform method; else, the normalize_hist(1 + log(x)) method. 
        
        Args: 
            x   (torch.Tensor): the tensor to be transformed
            
        """
        if self.CUSTOM_TRANSFORM:
            if x.shape[-1] == self.input_dim:
                _ret = smooth_1d(x)
            else:
                _ret =  rebin_xrf(x, n_bins=self.input_dim) 
            return normalize_hist( _ret )  if self.DO_NORMALIZE else _ret
        else: 
            if x.shape[-1] == self.input_dim:
                return normalize_hist( torch.log(1 + x) )  if self.DO_NORMALIZE else torch.log(1 + x)
            else:
                return normalize_hist( torch.log(1 + rebin_xrf(x, n_bins=self.input_dim) ) ) if self.DO_NORMALIZE else  torch.log(1 + rebin_xrf(x, n_bins=self.input_dim) )
            
    def init_dataset(self):
        if self.verbosity>0: print(f"Preparing dataset.\nUsing custom transform: {self.CUSTOM_TRANSFORM}\n")

        if self.verbosity>0: print('Train set:')
        self.dataset_train = AstroSynthDataset(
            path_to_data  = self.BASE_DATASET_PATH + "train_datacube.pt" ,
            path_to_label = self.BASE_DATASET_PATH + "train_labels.pt" ,
            transform   = self.custom_transform_realized,
            max_size    = self.MAX_TRAIN_SIZE
        )
        if self.verbosity>0: print('Test set:')
        self.dataset_val = AstroSynthDataset(
            path_to_data  = self.BASE_DATASET_PATH + "test_datacube.pt" ,
            path_to_label = self.BASE_DATASET_PATH + "test_labels.pt" ,
            transform = self.custom_transform_realized,
            max_size  = self.MAX_VAL_SIZE
        )
            
        self.dataset_test = AstroSynthDataset(
            path_to_data  = self.BASE_DATASET_PATH + "val_datacube.pt" ,
            path_to_label = self.BASE_DATASET_PATH + "val_labels.pt" ,
            transform = self.custom_transform_realized,
            max_size  = self.MAX_TEST_SIZE
        )
        if self.verbosity>0: print("\nDone.")
        
        if self.verbosity>1:
            print(f"""
            Dataset infos:
            Train: {len(self.dataset_train)}
            Val  : {len(self.dataset_val)}
            Test : {len(self.dataset_test)}
            """)

    def init_dataloader(self):
        # Dataloader
        self.dataloader_train = DataLoader(
            self.dataset_train, batch_size=self.BATCH_SIZE, shuffle=True,
            drop_last=True, pin_memory=True, num_workers=self.NUM_WORKERS
        ) 
        self.dataloader_val   = DataLoader(
            self.dataset_val, batch_size=self.BATCH_SIZE, shuffle=False,
            drop_last=False, pin_memory=False, num_workers=self.NUM_WORKERS
        ) 
        # Test
        self.dataloader_test  = DataLoader(self.dataset_test,  batch_size=self.BATCH_SIZE, shuffle=False) 
        
    def init_model(self):
        self.model = DeepClustering_VAE1D(**self.model_kwargs)
        self.model.to(self.device)
        if self.verbosity > 0: print(self.model)

        if self.verbosity > 1: 
            print(f"\n\nTorchSummary:\n")
            print(summary(self.model, input_size=( self.input_dim,  ), batch_size=self.BATCH_SIZE, device='cpu'))
            
    def load_model(self, for_eval: bool = True, load_from_json: bool = False):
        try:
            if load_from_json:
                base_path = self.full_path_to_store if self.full_path_to_store[:2] != './' else self.full_path_to_store[2:]
                _Name_to_open = f"{base_path}/{self.model_name}" 
                self.model = load_model_func(_Name_to_open, _model_kwargs=self.model_kwargs, model_class=DeepClustering_VAE1D, load_from_json = load_from_json)
            else:
                # load the best model
                RELOAD_MODEL_NAME = f"{self.full_path_to_store}/{self.model_name}.pth"
                checkpoint = torch.load(RELOAD_MODEL_NAME)
                if self.verbosity > 0: print(f'Best model {RELOAD_MODEL_NAME} at epoch: ', checkpoint['epoch'])

                self.model = DeepClustering_VAE1D(**self.model_kwargs)
                self.model.load_state_dict(checkpoint['model_state_dict'])
            
            # To device
            self.model.to(self.device)

            if for_eval:
                self.model.eval() 
        
        except Exception as e: 
            print(f"Error:\n{e}\nModel not loaded\n")
        
            
    def init_adam_optimizer(self):
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr = self.learning_rate,
        )
        if self.use_lr_scheduler:
            self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, 
                mode='min',
                factor   = 0.1, # Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1.
                patience = 5  , #  Number of epochs with no improvement after which learning rate will be reduced. 
            )
            
    def compute_vae_loss(self, x_hat, x, z, β=1, monte_carlo_size:int=256, return_all:bool=False):
        return vae_loss_function(x_hat, x, z, β, monte_carlo_size=monte_carlo_size, return_all=return_all)
    
    def compute_total_loss(self, loss, avg_s_loss, γ, use_sum=False):
        return compute_total_loss(loss, avg_s_loss, γ, use_sum=use_sum)     
    
    def training_step(self):
        self.model.train()
        train_loss = 0

        train_rec_loss = 0
        train_mmd_loss = 0
        train_s_loss   = 0

        counter = 0
        # beta
        try:
            β_t = self.beta_t(self.epoch)
        except: 
            β_t = 0.0
        # Do not apply reparametrisation if Beta is zero!
        if β_t > 0.0:
            self.model._is_vae = True
        else: 
            self.model._is_vae = False
        # gamma    
        try:
            γ   = self.gamma_t(self.epoch)
        except: 
            γ   = self.γ_init
        
        # SET
        self.β_t = β_t
        self.γ   = γ
        
        # Iterate over train set
        for x, y in tqdm.tqdm(self.dataloader_train):
            x = x.to(self.device)
            # =================== forward =====================
            x_hat, z, mu, logvar, best_cl, best_c, best_score, best_K, _  = self.forward(x)
            loss, rec_loss, mmd_loss = self.compute_vae_loss(x_hat, x, z, β=β_t, return_all=True)
            avg_s_loss = (1 - best_score)/2
            loss = self.compute_total_loss(loss, avg_s_loss, γ, use_sum=self.USE_SUM)
            # =================== Losses =====================
            train_loss += loss.item()
            train_rec_loss += rec_loss.item()
            train_mmd_loss += mmd_loss.item()
            train_s_loss += avg_s_loss
            # =================== backward ====================
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            # =================== Free CUDA Memory ====================
            free_memory([x_hat, mu, logvar])
            counter += 1
        # =================== compute LOSS ====================
        train_divisor = counter
        tr_loss = train_loss / train_divisor
        train_rec_loss = train_rec_loss / train_divisor
        train_mmd_loss = train_mmd_loss / train_divisor
        train_s_loss   = train_s_loss   / train_divisor
        # === Append Losses ========
        self.training_loss.append(tr_loss) ## Full trainloss
        
        self.mmd_losses.append(train_mmd_loss) ## KLD train loss
        self.rec_losses.append(train_rec_loss) ## Rec train loss
        self.s_losses.append(train_s_loss)     ## Cluster loss
        
        return tr_loss, train_rec_loss, train_mmd_loss, train_s_loss
        
    def eval_step(self):
        # =====================================================
        # Validation
        counter_val = 0
        with torch.no_grad():
            self.model.eval()
            val_loss = 0
            val_rec_loss = 0
            val_mmd_loss = 0
            val_s_loss   = 0
            for x, y in tqdm.tqdm(self.dataloader_val):
                x = x.to(self.device)
                # =================== forward =====================
                x_hat, z, mu, logvar, best_cl, best_c, best_score, best_K, _  = self.forward(x)
                loss, rec_loss, mmd_loss = self.compute_vae_loss(x_hat, x, z, β=self.β_t, return_all=True)
                avg_s_loss = (1 - best_score)/2
                #loss += γ * avg_s_loss
                loss = self.compute_total_loss(loss, avg_s_loss, self.γ,  use_sum=self.USE_SUM)
                
                val_loss += loss.item()
                val_rec_loss += rec_loss.item()
                val_mmd_loss += mmd_loss.item()
                val_s_loss   += avg_s_loss
                # =================== Free CUDA memory ====================
                free_memory([x_hat, mu, logvar])
                counter_val += 1

        # =================== log ========================        
        val_divisor = counter_val
        val_loss /= val_divisor
        val_rec_loss /= val_divisor
        val_mmd_loss /= val_divisor
        val_s_loss  /= val_divisor

        self.validation_loss.append(val_loss)
        

        #save best model
        if not self.delay_best_model_store or self.epoch > self.beta_t_ending_epoch:
            #save best model
            self.save_best_model(val_loss, self.epoch, self.model, self.optimizer, vae_loss_function) 
        
        return val_loss
    
    def train_model(self):
        t0 = time.time()
        try:
            # Training loop
            for epoch in range(0, self.epochs):
                self.epoch = epoch
                # =====================================================
                # Training
                tr_loss, train_rec_loss, train_mmd_loss, train_s_loss = self.training_step()
                free_memory([])
                # =====================================================
                # Validation
                val_loss = self.eval_step()
                # elapsed time
                elapsed_time = time.time() - t0
                # =====================================================
                # Learning Rate stepper
                self.current_lr = self.optimizer.param_groups[0]['lr']#lr_scheduler.get_last_lr()[0]
                self.learning_rates.append(self.current_lr)
                # update learning rate schedule
                self.lr_scheduler.step(val_loss) ### NB: ONLY FOR ReduceLROnPlateau
                # =================== log ========================
                log_line = f'====> Epoch: {self.epoch}\tTraining loss: {tr_loss:.6f}\t Validation set loss: {val_loss:.6f}\tBeta: {self.β_t:.2f}\tGamma:  {self.γ:.2f}\tlr: {self.current_lr:.2e}\tTime: {datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")}\tTrain Rec_loss: {train_rec_loss:.6f}\tTrain MMD: {train_mmd_loss:.6f}\tTrain clust_loss: {train_s_loss:.6f}\n'
                if self.verbosity > 0: print(log_line)
                write_line_to_file(LOG_FILE=self.LOG_FILE, log_line=log_line)

                # Check patience
                if self.PATIENCE > 0 and len(self.validation_loss) - np.array(self.validation_loss).argmin() > self.PATIENCE + self.Delta_patience:
                    break_log = f"\nPatience treshold = {self.PATIENCE} reached.\nExiting at epoch {self.epoch}.\n"
                    if self.verbosity > 0: print(break_log)
                    write_line_to_file(LOG_FILE=self.LOG_FILE, log_line=break_log)
                    break
                    
            # Store training history in DF
            self.store_df()
            
            if self.verbosity > 0: print(f"\n\nDone")
        # ==== HANDLE ERRORS =================0
        except Exception as e:
            err_line = f"\n\Error at epoch {epoch}:\n{e}\n\n"
            print(err_line)
            write_line_to_file(LOG_FILE=self.LOG_FILE, log_line=err_line)
            
    
    def test_model(self, load_from_json: bool = False):
        # Load model
        self.load_model(load_from_json)
        # Eval
        with torch.no_grad():
            self.model.eval()
            test_loss = 0
            test_rec_loss = 0
            test_mmd_loss = 0
            test_s_loss   = 0
            test_counter = 0
            for x, y in tqdm.tqdm(self.dataloader_test):
                x = x.to(self.device)
                # =================== forward =====================
                x_hat, z, mu, logvar, best_cl, best_c, best_score, best_K, _  = self.forward(x)
                loss, rec_loss, mmd_loss = self.compute_vae_loss(x_hat, x, z, β=self.β_t, return_all=True)
                avg_s_loss = (1 - best_score)/2
                loss = self.compute_total_loss(loss, avg_s_loss, self.γ,  use_sum=self.USE_SUM)

                test_rec_loss += rec_loss.item()
                test_mmd_loss += mmd_loss.item()
                test_s_loss   += avg_s_loss
                test_loss     += loss
                test_counter += 1
                # =================== Free CUDA memory ====================
                free_memory([x_hat, mu, logvar])

        # =================== log ========================
        test_loss /= test_counter

        test_loss_log = f"\n\n\t\t====> Test loss: {test_loss:6f}  <====\n"
        print(test_loss_log)
        write_line_to_file(LOG_FILE=self.LOG_FILE, log_line=test_loss_log)
    
    
    def store_df(self):
        # store as pandas csv
        df_train = pd.DataFrame(
            {
                "epochs"          : [ epoch for epoch in range(len(self.training_loss)) ],
                "training_loss"   : self.training_loss,
                "validation_loss" : self.validation_loss,
                "beta"            : [ self.beta_t(epoch)  for epoch in range(len(self.training_loss)) ],
                "gamma"           : [ self.gamma_t(epoch) for epoch in range(len(self.training_loss)) ],
                'lr'              : self.learning_rates, 
                "mmd_train_losses": self.mmd_losses,
                "rec_train_losses": self.rec_losses,
                "clust_losses"    : self.s_losses
            }
        )
        df_train.to_csv(f'{self.full_path_to_store}/{self.model_name}_history.csv')    

## Define model

In [5]:
BASE_DATASET_PATH = '/jupyter/notebooks/Article/AstroData/Synthetic/1D/' 

#BASE_MODEL_NAME = 'Synth_Trainer_DeepClustering'# AE or VAE are automatically added
USE_SNN  = True
BETA_MIN = 0.000
BETA_MAX = 0.000
GAMMA_INIT = 0.002

beta_t_starting_epoch  : int = 15 
beta_t_ending_epoch    : int = 16

BASE_MODEL_NAME  = 'Synth_Trainer'
BASE_MODEL_NAME += '_SNN' if USE_SNN else '_DNN'
BASE_MODEL_NAME += '_DeepClustering_' if GAMMA_INIT > 0.0 else '_'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = ('cpu')

kwargs = {
    "n_layers" : 3,
    "encoding_space_dim" : 32,
    "latent_space_dim" : 3,
    "use_latent_space_activation" : False,
    "min_n_cluster" : 2,
    "max_n_cluster" : 3,
    "DO_NORMALIZE" : True, 
    "use_SNN" : USE_SNN, 
    # Batch
    "BATCH_SIZE" : 512,
    "PATIENCE" : 25,
    "epochs"   : 100, 
    # Model name
    "BASE_PATH_TO_STORE" : './model_data_synth',
    "base_model_name"    : BASE_MODEL_NAME , 
    #
    "BETA_MIN" : BETA_MIN, 
    "BETA_MAX" : BETA_MAX,
    'beta_t_starting_epoch' : beta_t_starting_epoch,
    'beta_t_ending_epoch'   : beta_t_ending_epoch,
    #
    "γ_init" : GAMMA_INIT,
    "USE_VARYING_GAMMA": False,
    "USE_SUM": True,
    #
    "random_centroid_init" : not True ,
    #
    'CUSTOM_TRANSFORM': not True,  # smooth
    'BASE_DATASET_PATH' : BASE_DATASET_PATH ,
    #
    'device' : device
}

In [6]:
model = DeepClustering_trainer(**kwargs)

Preparing dataset.
Using custom transform: False

Train set:
Test set:

Done.
DeepClustering_VAE1D(
  (final_activation): Sigmoid()
  (IKMeans): DC_IterativeKMeans()
  (encoder): VAEencoder(
    (encoder): Sequential(
      (0): SNN(
        (network): Sequential(
          (fc0): Linear(in_features=1024, out_features=512, bias=False)
          (selu_0): SELU()
          (dropout_0): AlphaDropout(p=0.1, inplace=False)
          (fc1): Linear(in_features=512, out_features=256, bias=False)
          (selu_1): SELU()
          (dropout_1): AlphaDropout(p=0.1, inplace=False)
          (fc_2): Linear(in_features=256, out_features=32, bias=True)
        )
      )
      (1): Tanh()
    )
    (fc_mu): Linear(in_features=32, out_features=3, bias=True)
    (fc_var): Linear(in_features=32, out_features=3, bias=True)
  )
  (decoder): VAEdecoder(
    (final_activation): Sigmoid()
    (decoder): Sequential(
      (Dec_SNN): SNN(
        (network): Sequential(
          (fc0): Linear(in_features=3, o

## Training

In [None]:
model.train_model()

100%|██████████| 314/314 [01:14<00:00,  4.20it/s]
100%|██████████| 90/90 [00:19<00:00,  4.51it/s]



Best validation loss: 0.0005897289905179706

Saving best model for epoch: 0

====> Epoch: 0	Training loss: 0.009235	 Validation set loss: 0.000590	Beta: 0.00	Gamma:  0.00	lr: 1.00e-03	Time: 2023-11-27_15:59:02	Train Rec_loss: 0.009028	Train MMD: 1.096702	Train clust_loss: 0.103395



100%|██████████| 314/314 [01:13<00:00,  4.25it/s]
100%|██████████| 90/90 [00:19<00:00,  4.65it/s]


====> Epoch: 1	Training loss: 0.000639	 Validation set loss: 0.000812	Beta: 0.00	Gamma:  0.00	lr: 1.00e-03	Time: 2023-11-27_16:00:35	Train Rec_loss: 0.000440	Train MMD: 1.143513	Train clust_loss: 0.099629



100%|██████████| 314/314 [01:13<00:00,  4.27it/s]
100%|██████████| 90/90 [00:19<00:00,  4.64it/s]


====> Epoch: 2	Training loss: 0.000643	 Validation set loss: 0.000748	Beta: 0.00	Gamma:  0.00	lr: 1.00e-03	Time: 2023-11-27_16:02:08	Train Rec_loss: 0.000437	Train MMD: 1.149146	Train clust_loss: 0.102783



100%|██████████| 314/314 [01:13<00:00,  4.28it/s]
100%|██████████| 90/90 [00:20<00:00,  4.47it/s]


====> Epoch: 3	Training loss: 0.000630	 Validation set loss: 0.000658	Beta: 0.00	Gamma:  0.00	lr: 1.00e-03	Time: 2023-11-27_16:03:41	Train Rec_loss: 0.000435	Train MMD: 1.156473	Train clust_loss: 0.097387



100%|██████████| 314/314 [01:14<00:00,  4.23it/s]
 40%|████      | 36/90 [00:08<00:11,  4.81it/s]

Error;	K: 2;	uniq: 1
Error;	K: 3;	uniq: 2


 77%|███████▋  | 69/90 [00:15<00:04,  4.40it/s]

Error;	K: 3;	uniq: 2


 80%|████████  | 72/90 [00:15<00:03,  4.72it/s]

Error;	K: 2;	uniq: 1
Error;	K: 3;	uniq: 2


100%|██████████| 90/90 [00:19<00:00,  4.59it/s]



Best validation loss: 0.0004881479360240822

Saving best model for epoch: 4

====> Epoch: 4	Training loss: 0.000573	 Validation set loss: 0.000488	Beta: 0.00	Gamma:  0.00	lr: 1.00e-03	Time: 2023-11-27_16:05:15	Train Rec_loss: 0.000419	Train MMD: 1.154190	Train clust_loss: 0.076611



 44%|████▍     | 139/314 [00:33<00:41,  4.23it/s]

Error;	K: 3;	uniq: 2


100%|██████████| 314/314 [01:14<00:00,  4.22it/s]
  3%|▎         | 3/90 [00:00<00:23,  3.75it/s]

Error;	K: 2;	uniq: 1
Error;	K: 3;	uniq: 2


 26%|██▌       | 23/90 [00:05<00:13,  4.87it/s]

Error;	K: 2;	uniq: 1
Error;	K: 3;	uniq: 2


 31%|███       | 28/90 [00:06<00:12,  4.85it/s]

Error;	K: 2;	uniq: 1
Error;	K: 3;	uniq: 2
Error;	K: 2;	uniq: 1
Error;	K: 3;	uniq: 2


 40%|████      | 36/90 [00:07<00:11,  4.82it/s]

Error;	K: 2;	uniq: 1
Error;	K: 3;	uniq: 2


 58%|█████▊    | 52/90 [00:11<00:07,  4.79it/s]

Error;	K: 3;	uniq: 2


 77%|███████▋  | 69/90 [00:14<00:04,  4.81it/s]

Error;	K: 3;	uniq: 2


 99%|█████████▉| 89/90 [00:19<00:00,  4.65it/s]

Error;	K: 3;	uniq: 2


100%|██████████| 90/90 [00:19<00:00,  4.67it/s]



Best validation loss: 0.00047707894846098496

Saving best model for epoch: 5

====> Epoch: 5	Training loss: 0.000410	 Validation set loss: 0.000477	Beta: 0.00	Gamma:  0.00	lr: 1.00e-03	Time: 2023-11-27_16:06:49	Train Rec_loss: 0.000325	Train MMD: 1.123139	Train clust_loss: 0.042634



100%|██████████| 314/314 [01:14<00:00,  4.23it/s]
100%|██████████| 90/90 [00:21<00:00,  4.23it/s]



Best validation loss: 0.00046301788825076073

Saving best model for epoch: 6

====> Epoch: 6	Training loss: 0.000410	 Validation set loss: 0.000463	Beta: 0.00	Gamma:  0.00	lr: 1.00e-03	Time: 2023-11-27_16:08:25	Train Rec_loss: 0.000289	Train MMD: 1.107616	Train clust_loss: 0.060561



100%|██████████| 314/314 [01:14<00:00,  4.24it/s]
100%|██████████| 90/90 [00:20<00:00,  4.49it/s]


====> Epoch: 7	Training loss: 0.000527	 Validation set loss: 0.000562	Beta: 0.00	Gamma:  0.00	lr: 1.00e-03	Time: 2023-11-27_16:09:59	Train Rec_loss: 0.000259	Train MMD: 1.125538	Train clust_loss: 0.134036



100%|██████████| 314/314 [01:13<00:00,  4.26it/s]
100%|██████████| 90/90 [00:19<00:00,  4.56it/s]


====> Epoch: 8	Training loss: 0.000520	 Validation set loss: 0.000493	Beta: 0.00	Gamma:  0.00	lr: 1.00e-03	Time: 2023-11-27_16:11:32	Train Rec_loss: 0.000202	Train MMD: 1.161597	Train clust_loss: 0.159265



100%|██████████| 314/314 [01:13<00:00,  4.26it/s]
100%|██████████| 90/90 [00:20<00:00,  4.35it/s]


====> Epoch: 9	Training loss: 0.000498	 Validation set loss: 0.000510	Beta: 0.00	Gamma:  0.00	lr: 1.00e-03	Time: 2023-11-27_16:13:07	Train Rec_loss: 0.000163	Train MMD: 1.176325	Train clust_loss: 0.167575



100%|██████████| 314/314 [01:14<00:00,  4.22it/s]
100%|██████████| 90/90 [00:20<00:00,  4.48it/s]


====> Epoch: 10	Training loss: 0.000499	 Validation set loss: 0.000512	Beta: 0.00	Gamma:  0.00	lr: 1.00e-03	Time: 2023-11-27_16:14:42	Train Rec_loss: 0.000146	Train MMD: 1.197503	Train clust_loss: 0.176292



100%|██████████| 314/314 [01:14<00:00,  4.21it/s]
100%|██████████| 90/90 [00:20<00:00,  4.47it/s]


====> Epoch: 11	Training loss: 0.000505	 Validation set loss: 0.000521	Beta: 0.00	Gamma:  0.00	lr: 1.00e-03	Time: 2023-11-27_16:16:16	Train Rec_loss: 0.000138	Train MMD: 1.218319	Train clust_loss: 0.183446



100%|██████████| 314/314 [01:15<00:00,  4.18it/s]
100%|██████████| 90/90 [00:20<00:00,  4.47it/s]


====> Epoch: 12	Training loss: 0.000509	 Validation set loss: 0.000552	Beta: 0.00	Gamma:  0.00	lr: 1.00e-03	Time: 2023-11-27_16:17:52	Train Rec_loss: 0.000131	Train MMD: 1.237073	Train clust_loss: 0.189182



100%|██████████| 314/314 [01:15<00:00,  4.18it/s]
100%|██████████| 90/90 [00:20<00:00,  4.48it/s]


====> Epoch: 13	Training loss: 0.000511	 Validation set loss: 0.000542	Beta: 0.00	Gamma:  0.00	lr: 1.00e-04	Time: 2023-11-27_16:19:27	Train Rec_loss: 0.000126	Train MMD: 1.243731	Train clust_loss: 0.192737



 23%|██▎       | 72/314 [00:17<00:56,  4.31it/s]

In [None]:
model.full_path_to_store

### Training History

In [None]:
fig, ax = plt.subplots(dpi=120)
fig.suptitle(f"{model.model_name} Training history")
ax.plot(range(1,len(model.training_loss)+1)  , model.training_loss, color='green', linestyle='-', label='train loss')
ax.plot(range(1,len(model.validation_loss)+1), model.validation_loss, color='blue', linestyle='-', label='validation loss')

if model.beta_t:
    beta_arr = np.array([model.beta_t(epoch) for epoch in range(len(model.training_loss))])
    if beta_arr.max() > 0:
        ax.plot(range(1,len(model.training_loss)+1)  , model.BETA_MAX * np.array(model.mmd_losses) , color='purple', linestyle=':', label='MMD loss')

try:
    #ax.plot(range(1,len(training_loss)+1), γ*np.array(s_losses), color='orange', linestyle=':', label='train S-loss loss', alpha=0.6)
    # fake for label
    ax.plot(range(1,len(model.training_loss)+1), model.γ * np.array(model.s_losses), color='orange', linestyle='-.', label='train S-loss loss', alpha=0.6)
    ax.plot(range(1,len(model.training_loss)+1), model.rec_losses, color='gold', linestyle='--', label='train Rec loss', alpha=0.6)
    
except Exception as e:
    print(e)
    pass

ax.set_xlabel('Epochs',fontsize=14)
ax.set_ylabel('Loss',fontsize=14)
#ax.set_ylim( np.array(training_loss).min(), np.array(training_loss).max() )
#ax.set_ylim( 0., 0.1125) 
#plt.ylim(ymin = 33, ymax=40.0)


if model.beta_t:
    beta_arr = np.array( [model.beta_t(epoch) for epoch in range(len(model.training_loss))] )
    if beta_arr[0] != beta_arr[-1]:
        # twin object for two different y-axis on the sample plot
        ax2=ax.twinx()
        # make a plot with different y-axis using second axis object
        ax2.plot(range(1,len(model.training_loss)+1), beta_arr, color='red', linestyle='-', label='train loss')
        ax2.set_ylabel("beta(t)", color="red",fontsize=14)

        #ax2.set_ylim(-0.1, beta_arr.max())


elif model.gamma_t:
    beta_arr = [model.gamma_t(epoch) for epoch in range(len(model.training_loss))]
    if beta_arr[0] != beta_arr[-1]:
        # twin object for two different y-axis on the sample plot
        ax2=ax.twinx()
        # make a plot with different y-axis using second axis object
        ax2.plot(range(1,len(model.training_loss)+1), beta_arr, color='red', linestyle='-', label='train loss')
        ax2.set_ylabel("gamma(t)", color="red",fontsize=14)

        ax2.set_ylim(-0.1, 0.5)
else:
    # plot gamma
    # twin object for two different y-axis on the sample plot
    ax2=ax.twinx()
    # make a plot with different y-axis using second axis object
    ax2.plot(range(1,len(model.training_loss)+1), model.γ*np.array(model.s_losses), color='orange', linestyle=':', label='train S-loss loss', alpha=0.6)
    ax2.set_ylabel("Cluster Loss", color="orange",fontsize=14)

    #ax2.set_ylim(-0.1, 3.0)

ax.legend()

        
fig.savefig(f"{model.full_path_to_store}/{model.model_name}.png",  bbox_inches='tight')
fig.show()

### Test model

In [None]:
model.test_model()

## Visualize model

In [None]:
from utils.open_model import load_model_func

In [None]:
""" 
path_to_file = '/jupyter/notebooks/AstroDeepClustering/AstroData/2D/val/ADP:2021-05-17T14:08:08:293.h5'

with h5py.File(path_to_file, 'r') as _h5:
    # Open
    ma_xrf = torch.Tensor( np.array(_h5['img'][()], dtype=float) ).float()
    ma_xrf = ma_xrf[:, :, :3000]
    _shape = ma_xrf.shape
    # Rebin 
    ma_xrf = ma_xrf.cpu()
    ma_xrf = rebin_xrf( ma_xrf.reshape(-1, _shape[-1]), n_bins=1024 )
    ma_xrf = ma_xrf.reshape(*_shape[:2], 1024)
    # Transform
    ma_xrf = model.custom_transform_realized(ma_xrf)
    ma_xrf = ma_xrf.cpu()
    
"""

print()

In [None]:
try:
    del MAXRFVizDataset
except:
    pass

In [None]:
from utils.custom_dataset import MAXRFVizDataset

In [None]:
LOAD_MODEL_FROM_DISK = not True # <============ HERE =======


if LOAD_MODEL_FROM_DISK:
    #LOAD_NAME = 'Synth_Trainer_SNN_VAE_2023-09-08_07:22:33_use_sum'
    #LOAD_NAME = 'Synth_Trainer_SNNDeepClustering_VAE_2023-09-07_15:51:26_use_sum'
    #BASE_MODEL_NAME = 'Synth_Trainer_DeepClustering'# AE or VAE are automatically added
    LOAD_NAME = 'Synth_Trainer_SNN_AE_2023-09-12_18:08:04_use_sum'
    USE_SNN  = True
    BETA_MIN = 1.0
    BETA_MAX = 1.0
    GAMMA_INIT = 0.0

    BASE_MODEL_NAME  = 'Dummy_'
    model_kwargs = {
        "n_layers" : 3,
        "encoding_space_dim" : 32,
        "latent_space_dim" : 3,
        "use_latent_space_activation" : False,
        "min_n_cluster" : 2,
        "max_n_cluster" : 3,
        "use_SNN" : USE_SNN, 
        #
        "random_centroid_init" : True ,
    }
    
    kwargs = {
        **model_kwargs, 
        "DO_NORMALIZE" : True, 
        # Batch
        "BATCH_SIZE" : 512,
        "PATIENCE" : 25,
        "epochs"   : 100, 
        # Model name
        "BASE_PATH_TO_STORE" : './model_data_synth',
        "base_model_name"    : BASE_MODEL_NAME , 
        #
        "BETA_MIN" : BETA_MIN, 
        "BETA_MAX" : BETA_MAX,
        #
        "γ_init" : GAMMA_INIT,
        "USE_VARYING_GAMMA": False,
        "USE_SUM": True,
        #
        'CUSTOM_TRANSFORM': True,  # smooth
    }

    model = DeepClustering_trainer(**kwargs)
    model.model_name = LOAD_NAME
    model.full_path_to_store =  f'./model_data_synth/{LOAD_NAME}'
    
    model.load_model(load_from_json=not True)
    
    model.β_t = BETA_MAX
    model.gamma_t = GAMMA_INIT
    model.γ = GAMMA_INIT
else:
    model.load_model() 

In [None]:
#model.test_model()

In [None]:
from utils.custom_dataset import MAXRFVizDataset
path_to_file = '/jupyter/notebooks/AstroDeepClustering/AstroData/Synthetic/AstroDataCube_Generated.h5'
#path_to_file = '/jupyter/notebooks/AstroDeepClustering/AstroData/Synthetic/AstroDataCube_Generated_noisy.h5'

ma_xrf_dataset = MAXRFVizDataset(
    path_to_datacube = path_to_file,
    data_name = 'img',
    transform = model.custom_transform_realized,
    MAX_BIN = 3000, 
    REBIN_SIZE = 1024,
)

ma_xrf_dataloader =  DataLoader(
    ma_xrf_dataset,  
    batch_size=model.BATCH_SIZE, 
    shuffle=False
)

In [None]:
ma_xrf = torch.zeros(ma_xrf_dataset.final_shape)
transformed_ma_xrf  = torch.zeros_like(ma_xrf)

mu = torch.zeros([ma_xrf_dataset.__len__(), model.latent_space_dim])
best_cl = torch.zeros(ma_xrf_dataset.__len__())
best_c  = []
best_score = 0.0
best_K = 0.0 
s_scores = []

# Datacube shape
_final_shape = [*ma_xrf_dataset.shape[:2], ma_xrf_dataset.REBIN_SIZE]

counter = 0
for x in ma_xrf_dataloader:
    x = x.nan_to_num(nan=0.0)
    # Create Original Datacube
    ma_xrf[counter*ma_xrf_dataloader.batch_size:(counter+1)*ma_xrf_dataloader.batch_size, :] = x
    # === Forward ======
    x = x.to(model.device)
    # Forward
    _transformed_ma_xrf, _, _mu, _, _best_cl, _best_c, _best_score, _best_K, _s_scores = model.model(x)
    # Detach
    _transformed_ma_xrf = _transformed_ma_xrf.detach().cpu()
    transformed_ma_xrf[counter*ma_xrf_dataloader.batch_size:(counter+1)*ma_xrf_dataloader.batch_size, :] = _transformed_ma_xrf
    
    _mu = _mu.detach().cpu()
    mu[counter*ma_xrf_dataloader.batch_size:(counter+1)*ma_xrf_dataloader.batch_size, :] = _mu
    
    _best_cl = _best_cl.detach().cpu()
    best_cl[counter*ma_xrf_dataloader.batch_size:(counter+1)*ma_xrf_dataloader.batch_size] = _best_cl
    
    _best_c  = _best_c.detach().cpu()
    best_c.append( _best_c  )
    
    _best_score = _best_score
    best_score += best_score
    
    _best_K   = _best_K
    best_K += _best_K 
    
    _s_scores = _s_scores
    s_scores.append( _s_scores )
    
    counter += 1
    

In [None]:
"""
loaded_model = model.model.cpu()
_final_shape = ma_xrf.shape
transformed_ma_xrf, _, mu, _, best_cl, best_c, best_score, best_K, s_scores  = loaded_model( ma_xrf.reshape(-1, _final_shape[-1] ) )

fig, ax = plt.subplots(nrows=1, ncols=2, dpi=160)
ax[0].imshow(ma_xrf.reshape(*_final_shape).sum(dim=-1).detach().cpu().numpy() )
ax[0].set_title('Original')
ax[1].imshow(transformed_ma_xrf.reshape(*_final_shape).sum(dim=-1).detach().cpu().numpy()  )
ax[1].set_title('Decoded')
plt.show()
"""

print()

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=3, dpi=160, figsize = (16,8))
ax[0].imshow(ma_xrf.reshape(*_final_shape).sum(dim=-1).detach().cpu().numpy() )
ax[0].set_title('Original')
ax[1].imshow(transformed_ma_xrf.reshape(*_final_shape).sum(dim=-1).detach().cpu().numpy()  )
ax[1].set_title('Decoded')
ax[2].imshow( mu.reshape([*_final_shape[:2], mu.shape[-1]] ).sum(dim=-1).detach().cpu().numpy()  )
ax[2].set_title('Embedded')
plt.savefig(f'./{model.full_path_to_store}/{model.model_name}_integrated_xrf.png')
plt.show()

### Perform clustering

In [None]:
try:
    from utils.explore_latent_space import plot_clustered
except: 
    %pip install plotly
    from utils.explore_latent_space import plot_clustered

In [None]:
# reperform clustering
best_cl, best_c, best_score, best_K, s_scores = IterativeKMeans(
    mu.cpu(), 
    min_n_cluster = 3,
    max_n_cluster = 5, #model.max_n_cluster, 
    Niter = 15, random_centroid_init = model.random_centroid_init
)

In [None]:
fig_1, fig_2, = plot_clustered(
    ma_xrf.cpu(),  transformed_ma_xrf.cpu(), mu.cpu(),
    best_cl.cpu(), best_c.cpu(), best_score, best_K, s_scores.cpu(),
    _final_shape,
    MIN_CLUSTER = model.min_n_cluster, MAX_CLUSTER = 5,
    # add
    tech_name = 'MLM'
)

In [None]:
fig_1.savefig(f'./{model.full_path_to_store}/{model.model_name}_clustered_xrf.pdf')
fig_2.savefig(f'./{model.full_path_to_store}/{model.model_name}_iterative_clustering_silhouette_scores.pdf')

### Visualize Latent Space

In [None]:
from utils.explore_latent_space import plot_latent_space
fig_3 = plot_latent_space(
    ma_xrf.cpu().detach(), ma_xrf.cpu().detach(), mu.cpu().detach(),
    best_cl.cpu().detach(), best_c.cpu().detach(), best_score, best_K, s_scores.cpu().detach(),
    _final_shape,
    MIN_CLUSTER = model.min_n_cluster, MAX_CLUSTER = 5,
    # add
    tech_name = 'MLM',
    N_CONTOUR_LEVELS = 5,
    PLOT_LATENT_SPACE_MA_XRF = True,
    PLOT_HIST_LOG_Y_SCALE = True,
)

In [None]:
fig_3.savefig(f'./{model.full_path_to_store}/{model.model_name}_clustered_latent_space_with_XRF.pdf')

### Explore Model Layers

In [None]:
from utils.explore_trained_model import plot_model

In [None]:
untrained_model = DeepClustering_VAE1D(**model.model_kwargs)

In [None]:
fig, ax = plot_model(
    model_to_plot   = model.model.encoder ,
    untrained_model = untrained_model.encoder , 
    # 
    _suptitle = f"Encoder weights",
    MAGNIFYING_PERCENTAGE = 120, 
)

try:
    fig.savefig(f'./{model.full_path_to_store}/{model.model_name}_Encoder_view.pdf')
except Exception as e:
    print(e)
    pass

plt.show()

In [None]:
fig, ax = plot_model(
    model_to_plot   = model.model.decoder ,
    untrained_model = untrained_model.decoder , 
    # 
    _suptitle = f"Decoder weights",
    MAGNIFYING_PERCENTAGE = 120, 
)

try:
    fig.savefig(f'./{model.full_path_to_store}/{model.model_name}_Decoder_view.pdf')
except Exception as e:
    print(e)
    pass

plt.show()