# Tutorial for Gromov-Wassserstein unsupervised alignment 

In [1]:
import os, sys
sys.path.append(os.path.join(os.getcwd(), '../../'))

import numpy as np

from src.align_representations import Representation, AlignRepresentations, OptimizationConfig, VisualizationConfig

# Step1: Prepare dissimilarity matrices or embeddings from the data
First, you need to prepare dissimilarity matrices or embeddings from your data.  
To store dissimilarity matrices or embeddings, an instance of the class `Representation` is used.   
Please put your dissimilarity matrices or embeddings into the variables `sim_mat` or `embedding` in this instance.   

## Load data
`AllenBrain`: Neuropixel data recorded in the visual areas of mice from the Allen Brain Observatory

In [2]:
# list of representations where the instances of "Representation" class are included
representations = list()
data_select = "AllenBrain"

### Dataset `AllenBrain`
We treat the average spike count for `natural movie one` in the VISal and VISam of two independent pseudo-mice as an embedding.

In [3]:
representations = []
for name in ["pseudo_a_VISal", "pseudo_b_VISam"]:
    emb = np.load(f"../../data/AllenBrain/{name}.npy")
    representation = Representation(
        name=name,
        embedding=emb,  # the dissimilarity matrix will be computed with this embedding.
        metric="cosine",
        get_embedding=False, # If there is the embeddings, plese set this variable "False".
        object_labels=np.arange(emb.shape[0]) 
    )
    representations.append(representation)

# Step 2: Set the parameters for the optimazation of GWOT
Second, you need to set the parameters for the optimization of GWOT.    
For most of the parameters, you can start with the default values.   
However, there are some essential parameters that you need to check for your original applications.  

## Optimization Config  

#### Most important parameters to check for your application:
`eps_list`: The range of the values of epsilon for entropic GWOT.   
If epsilon is not in appropriate ranges (if it is too low), the optimization may not work properly.   
Although the algorithm will find good epsilon values after many trials, it is a good practice to narrow down the range beforehand.   

`num_trial`: The number of trials to test epsilon values from the specified range.   
This number directly determines the quality of the unsupervised alignment.   
You should set this number high enough to find good local minima. 

In [4]:
eps_list_tutorial = [1e-4, 1e-1]
device = 'cpu'
to_types = 'numpy'
multi_gpu = False

eps_log = True
num_trial = 100
# init_mat_plan = 'random'
# sampler_name = 'tpe'

init_mat_plan = 'uniform'
sampler_name = 'grid'

# init_mat_plan = 'random'
# sampler_name = 'grid'

In [5]:
config = OptimizationConfig(    
    eps_list = eps_list_tutorial,
    eps_log = eps_log,
    num_trial = num_trial,
    sinkhorn_method='sinkhorn_log',
    
    ### Set the device ('cuda' or 'cpu') and variable type ('torch' or 'numpy')
    to_types = to_types,
    device = device,
    data_type = "double", 
    
    n_jobs = 1,
    multi_gpu = multi_gpu, 
    db_params={"drivername": "sqlite"},
    
    ### Set the parameters for optimization
    # 'uniform': uniform matrix, 'diag': diagonal matrix', random': random matrix
    init_mat_plan = init_mat_plan,
    n_iter = 1,
    max_iter = 200,
    
    sampler_name = sampler_name,
)

## Step 3 : Gromov-Wasserstein Optimal Transport (GWOT) between Representations
Third, you perform GWOT between the instanses of "Representation", by using the class `AlignRepresentations`.  
This class has methods for the optimization of entropic Gromov-Wasserstein distance, and the evaluation of the GWOT (Step 4).  
This class also has a method to perform conventional Representation Similarity Analysis (RSA).   

In [None]:
# Create an "AlignRepresentations" instance
align_representation = AlignRepresentations(
    config=config,
    representations_list=representations,   

    # histogram matching : this will adjust the histogram of target to that of source.
    histogram_matching=False,

    # main_results_dir : folder or file name when saving the result
    main_results_dir = "../../results/" + data_select,
    # main_results_dir = "../../results/random+grid/" + data_select,
   
    # data_name : Please rewrite this name if users want to use their own data.
    data_name = data_select,
)

## Show dissimilarity matrices

In [None]:
sim_mat_format = "default"

visualize_config = VisualizationConfig(
    show_figure = True,
    fig_ext='svg',
    figsize=(12, 12), 
    title_size=60, 
    cmap='rocket_r',
    font='Arial',
    
    xlabel='90 short movies',
    ylabel='90 short movies',
    cbar_label='cosine distance',
    cbar_label_size=60,
    
    xlabel_size=60,
    ylabel_size=60,
    
    cbar_ticks_size=50,
    
    ticks=None,
    ot_object_tick=False,
    ot_category_tick=False,
    
    # Note that please set ot_category_tick = True when drawing the category line.
    draw_category_line=False,
    category_line_color='black',
    category_line_alpha=0.5,
    category_line_style='dashed',
    plot_eps_log = eps_log,
)

visualize_hist = VisualizationConfig(figsize=(8, 6), color='C0')

sim_mat = align_representation.show_sim_mat(
    sim_mat_format = sim_mat_format, 
    visualization_config = visualize_config,
    visualization_config_hist = visualize_hist,
    show_distribution=False,
)

## Reperesentation Similarity Aanalysis (RSA)
This performs a conventional representation similarity analysis.

In [None]:
### parameters for computing RSA
# metric = "pearson" or "spearman" by scipy.stats
# The result of RSA for each pair will be stored in align_representation.RSA_corr
align_representation.RSA_get_corr(metric = "pearson")

# print(align_representation.RSA_corr)

GWOT is performed by appling the method `gw_alignment` to the instance of `AlignRepresentations` class.

We show all the parameters to run GWOT computation as an example with THINGS or DNN dataset because these dataset have category information label.

For the dataset of color, AllenBrain, and simulation (these doesn’t have the category information), we show how to do this in next cell. 

Here is the example to compute the GWOT for each pair for color, AllenBrain, and simulation datasets below.

In [9]:
# If the computation has been completed and there is no need to recompute, set "compute_OT" to False. In this case, the previously calculated OT plans will be loaded.
# compute_OT = True
compute_OT = False

### If the previous optimization data exists, you can delete it.
# Setting delete_results=True will delete both the database and the directory where the results of the previous optimization are stored.
# The code will prompt for confirmation before deleting all the results.
delete_results = False

In [None]:
sim_mat_format = "default"

visualize_config.set_params(
    figsize=(18, 18), 
    title_size = 50,
    xlabel='90 short movies of VISam (pseudo mouse A)',
    ylabel='90 short movies of VISal (pseudo mouse B)',
    font='Arial',
    cbar_label='Probability',
    cbar_label_size=40,
    cbar_ticks_size=30,
    xlabel_size=50,
    ylabel_size=50,
)

align_representation.gw_alignment(
    compute_OT = compute_OT,
    delete_results = delete_results,
    return_data = False,
    return_figure = True,
    OT_format = sim_mat_format,
    visualization_config = visualize_config,
)

# Step 4: Evaluation and Visualization
Finally, you can evaluate and visualize the unsupervise alignment of GWOT.   

## Show how the GWD was optimized
`show_optimization_log` will make two figures to show both the relationships between epsilons (x-axis) and GWD (y-axis), and between accuracy (x-axis) and GWD (y-axis).



In [None]:
### Show how the GWD was optimized (evaluation figure)
# show both the relationships between epsilons and GWD, and between accuracy and GWD

visualize_config.set_params(
    show_figure = True,
    figsize=(10, 8),
    cmap='viridis',
    font='Arial',
    xticks_rotation=0,
    cbar_label_size=30,
    cbar_ticks_size=30,
    title_size=25,
    xticks_size=30,
    yticks_size=30,
    marker_size=60,
    plot_eps_log = eps_log,
    lim_acc = [0, 100],
    lim_gwd = [0, 0.1],
    lim_eps = [1e-4, 1e-1],
)


align_representation.show_optimization_log(fig_dir=None, visualization_config=visualize_config) 

## Evaluation of the accuracy of the unsupervised alignment
There are two ways to evaluate the accuracy.  
1. Calculate the accuracy based on the OT plan. 
- For using this method, please set the parameter `eval_type = "ot_plan"` in "calc_accuracy()".
  
2. Calculate the matching rate based on the k-nearest neighbors of the embeddings.
-  For using this method, please set the parameter `eval_type = "k_nearest"` in "calc_accuracy()".

For both cases, the accuracy evaluation criterion can be adjusted by considering "top k".  
By setting "top_k_list", you can observe how the accuracy increases as the criterion is relaxed.

In [None]:
## Calculate the accuracy based on the OT plan. 
align_representation.calc_accuracy(top_k_list = [1, 5, 10], eval_type = "ot_plan")
align_representation.plot_accuracy(eval_type = "ot_plan", scatter = True)

top_k_accuracy = align_representation.top_k_accuracy # you can get the dataframe directly 

'''
random + tpe
Top k accuracy : 
        pseudo_a_VISal_vs_pseudo_b_VISam
top_n                                  
1                             91.111111
5                             95.555556
10                           100.000000


uniform + grid
        pseudo_a_VISal_vs_pseudo_b_VISam
top_n                                  
1                             63.333333
5                             80.000000
10                            91.111111

random + grid
Top k accuracy : 
        pseudo_a_VISal_vs_pseudo_b_VISam
top_n                                  
1                             92.222222
5                             95.555556
10                           100.000000
'''

In [None]:
## Calculate the matching rate based on the k-nearest neighbors of the embeddings.
align_representation.calc_accuracy(top_k_list = [1, 5, 10], eval_type = "k_nearest", metric="cosine")
align_representation.plot_accuracy(eval_type = "k_nearest", scatter = True)

k_nearest_matching_rate = align_representation.k_nearest_matching_rate # you can get the dataframe directly 

## Procrustes Analysis
Using optimized transportation plans, you can align the embeddings of each representation to a shared space in an unsupervised manner.  
The `"pivot"` refers to the target embeddings space to which the other embeddings will be aligned.   
You have the option to designate the `"pivot"` as one of the representations or the barycenter.  
Please ensure that 'pair_number_list' includes all pairs between the pivot and the other Representations.  

If you wish to utilize the barycenter, please make use of the method `AlignRepresentation.barycenter_alignment()`.  
You can use it in the same manner as you did with `AlignRepresentation.gw_alignment()`.

In [None]:
emb_name = "PCA" #"TSNE", "PCA", "MDS"

visualization_embedding = VisualizationConfig(
    fig_ext="svg",
    figsize=(10, 10),
    marker_size=80,
    xlabel="PC1",
    ylabel="PC2",
    zlabel="PC3",
    legend_size=20,
    font="Arial",
    color_hue="cool",
    cmap="cool",
    colorbar_label="short movies",
    colorbar_range=[0, len(emb)],
    colorbar_shrink=0.8,
    xlabel_size=20,
    ylabel_size=20,
    zlabel_size=20,
    alpha=0.7,
)

align_representation.visualize_embedding(
    dim=3, # the dimensionality of the space the points are embedded in. You can choose either 2 or 3.
    pivot=0, # the number of one of the representations or the "barycenter".
    method="PCA",
    visualization_config=visualization_embedding,
)