# Post processing notebook 

### This notebook import the visualisation functions from the main library and loads the state and meta file and allows you to modify the appearance of the output of Affinity-VAE


In [1]:
import torch 
import numpy as np
import pandas as pd
import umap
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider, fixed, Dropdown
from avae.vis import latent_embed_plot_umap, latent_embed_plot_tsne
from avae.utils import colour_per_class
import warnings
warnings.filterwarnings('ignore')


### The following functions take in the path of the meta and the state file respectively and load them 

In [2]:
def load_meta(meta_fn, device="cpu"):
    meta_df = pd.read_pickle(meta_fn)

    mu = meta_df[
        [col for col in meta_df.columns if col.startswith("lat")]
    ].to_numpy()  # Assuming the column name for latent variables is 'latent'
    
    labels = meta_df["id"]

    pose = meta_df[
        [col for col in meta_df.columns if col.startswith("pos")]
    ].to_numpy()
    
    std = meta_df[
        [col for col in meta_df.columns if col.startswith("std")]
    ].to_numpy()
    
    #### why is there an 0.5 attached to beginning of each std array so that the size is 17?
    std = std[:,1:]
    
    z =  np.random.randn(*std.shape)  * std + mu
    
    return mu, pose, labels, std, z
    

def load_model(model_fn, device="cpu"):
    checkpoint = torch.load(model_fn)
    model = checkpoint["model_class_object"]
    model.load_state_dict(checkpoint["model_state_dict"])
    model.to(device)
    model.eval()
    return model

# Enter the path for the saved model and corresponding meta file 

In [9]:
model_fn = '/Users/mfamili/work/exp_avae/alphanum/states/avae_13_38_11_01_2024_E989_16_1.pt'
meta_fn = '/Users/mfamili/work/exp_avae/alphanum/states/meta_13_38_11_01_2024_E989_16_1.pkl'


# Grab the Information in the meta file and load the model
### From the meta file 
1. `mu`: mean of the latents ($\mu$)
2. `std`: standard deviation of the latents ($\sigma$)
3. `p`: $pose$
4. `labels`: class labels
5. `z`: sampled latent ($z$) 
### model loaded to `model`

In [10]:
mu, p, labels, std, z = load_meta(meta_fn)
model =load_model(model_fn)

# Configure latent embedding:
#### select the following to plot the UMAP embedding of the latent space : 
1. `select_function`: The drop down list allows you to choose between UMAP and TSNE for your dimentionality reduction of the latent embeddings. 
2. `rs` : This slider sets the random state for the UMAP plot.
3. `perplexity`: This slider sets the perplexity for the TSNE plot.
4. `Data Type`: This drop down list allows you to select which variable to plot (`z` : stochastic sample of latent space, `mu`: the mean of the latent space)

In [11]:
# Define the slider widget
random_state_slider = IntSlider(min=0, max=100, step=1, value=42)
perplexity_slider = IntSlider(min=2, max=100, step=1, value=40, description='Perplexity')

# Define the dropdown widget for selecting the function
function_selector = Dropdown(options=['latent_embed_plot_umap', 'latent_embed_plot_tsne'],
                             value='latent_embed_plot_umap', description='Select Function')

# Use the interact function with both widgets
interact(lambda function, data_type, rs, perplexity: (latent_embed_plot_umap(xs=z if data_type == 'z' and function == 'latent_embed_plot_umap' else mu,
                                                                  ys=labels, rs=rs, display=True)
                                          if function == 'latent_embed_plot_umap'
                                          else latent_embed_plot_tsne(xs=z if data_type == 'z' and function == 'latent_embed_plot_tsne' else mu,
                                                                      ys=labels, perplexity=perplexity, display=True)),
         function=function_selector,
         data_type=Dropdown(options=['z', 'mu'], value='z', description='Data Type'),
         rs=random_state_slider,
         perplexity = perplexity_slider)

interactive(children=(Dropdown(description='Select Function', options=('latent_embed_plot_umap', 'latent_embed…

<function __main__.<lambda>(function, data_type, rs, perplexity)>

## Creating new latent interpolation plots

1. choose the size of your input images : `dsize`

2. choose the plot interpolation steps via the slider : `num_int`

note that everytime you drag the `num_int` slider the plot corners change, if you would like to start from a given number of interpolation steps change the value of `init_interpolation_steps` in the code cell below. 

In [12]:
from avae.vis import latent_4enc_interpolate_plot

dsize = [64,64]
init_interpolation_steps = 10

number_interpolation = IntSlider(min=2, max=20, step=1, value=init_interpolation_steps)
interact(lambda num_int: latent_4enc_interpolate_plot(dsize= dsize, xs= torch.from_numpy(z).to(dtype=torch.float), 
                             ys= labels, vae = model ,device = "cpu", plots_config = f"1,{num_int}", 
                             poses = p, display = True),num_int=number_interpolation)

interactive(children=(IntSlider(value=10, description='num_int', max=20, min=2), Output()), _dom_classes=('wid…

<function __main__.<lambda>(num_int)>

## Creating new pose interpolation plots
1. choose the size of your input images : `dsize`
2. choose the plot interpolation steps via the slider : `number_of_samples`
3. choose the classes you would like the interpolation to be generated for via the variable : `pose_vis_class`



In [13]:
from avae.vis import pose_class_disentanglement_plot

dsize= [64,64]
pose_vis_class = "i,e"

num_int = IntSlider(min=2, max=20, step=1, value=10)
interact(lambda number_of_samples:pose_class_disentanglement_plot(dsize= [64,64], x= torch.from_numpy(z).to(dtype=torch.float), 
                                y= labels, pose_vis_class=pose_vis_class, poses = p, vae = model,
                                device = "cpu", number_of_samples = number_of_samples ,display = True), number_of_samples = num_int )


interactive(children=(IntSlider(value=10, description='number_of_samples', max=20, min=2), Output()), _dom_cla…

<function __main__.<lambda>(number_of_samples)>