# Tutorial 5: Unbalanced Gromov-Wasserstein Distances and Fused Gromov-Wasserstein distances
The Gromov-Wasserstein distance is highly useful for quantifying differences in cell morphology, and a number of variants of Gromov-Wasserstein distance have been proposed in the literature. Here we introduce two such variants, "Unbalanced Gromov-Wasserstein" and "Fused Gromov-Wasserstein", and discuss their applications to neuron taxonomy. We will see that their technical advantages lead to better ability to recapitulate known labels such as the RNA family. Additional background can be found in the "Variants of Gromov-Wasserstein" page.  We will use the same data set that was studied in Tutorial 4, and the same sampled points. All data can be downloaded from [Dropbox](https://www.dropbox.com/scl/fo/a5b2t4rkek0j5xvjrt5un/ALrhJWIU0zYWuk2QShiGjLs?rlkey=qt79k4qzy2oeo5rnvik7jimu1&st=bu6yzcuw&dl=0).

## Intuition for Unbalanced Gromov-Wasserstein
The big-picture idea behind unbalanced Gromov-Wasserstein is that it is less sensitive to small changes in morphology than ordinary GW. If GW answers the question "How well can we align these two cell morphologies?" then UGW answers the question "How well can we align a large chunk of cell 1 with a large chunk of cell 2, where we want to maximize a weighted sum of the size of the pieces being matched and the goodness-of-fit of the match". In situations where it is safe to discard small pieces of a cell without this substantially changing the morphology, then unbalanced GW might be more robust than GW.

The definition of Gromov-Wasserstein distance involves searching through all possible "couplings" between two cells. The notion of "coupling" employed here is rather rigid and inflexible - cells are regarded as having unit mass, and the couplings are required to satisfy a "conservation of mass" law, that is, all mass in the first cell must be paired with corresponding mass in the second cell. If two neurons are modelled as point clouds with 100 points, then each point will be modelled as having mass 0.01 units, and a valid coupling must satisfy the property that each point in one cell should have 0.01 units worth of mass associated to it from the other cell.

Suppose we have two neurons, which are absolutely identical except that an additional dendrite is present in one which is not present in the other. This would be biologically interesting, and it is plausible that considering such embeddings of one neuron into another would help us to capture important biological similarities. But Gromov-Wasserstein does not recognize such embeddings as valid cell couplings, because it violates "conservation of mass" - all the mass from the first neuron is paired with a fraction of the mass of the second neuron, and the extra dendrite of the other neuron is not paired with anything. The optimal GW transport plan would likely bear no trace of the structural equivalence between the first neuron and a fragment of the second. 

The Unbalanced Gromov-Wasserstein distance allows for such embeddings - transport plans which are permitted to create or destroy mass, at the expense of paying an additional penalty cost. The size of the penalty is contolled by a user-supplied parameter $\rho$. When $\rho$ is large, very little deviation from a "perfect coupling" is allowed, but as $\rho$ is allowed to grow smaller, the algorithm will become more tolerant of deviations and allow looser fits. The [Unbalanced Gromov-Wasserstein paper](https://arxiv.org/abs/2009.04266) by Séjourné, Vialard, and Peyré provides some useful examples of situations where the extra flexibility of unbalanced Gromov-Wasserstein makes it more tolerant of small differences between objects.

Choosing a specific numerical value for $\rho$ can be challenging because it is not a priori clear what order of magnitude $\rho$ should be in order to get sensible results. We expose a more intuitive and interpretable control parameter: a lower bound on the amount of mass that will be discarded while aligning cells. That is, the user can supply a parameter `mass_kept=0.90` to guarantee that when two neurons are aligned, at least 90% of the points in both neurons are properly aligned, and at most 10% discarded to improve the fit.

Let us demonstrate how to use the implementation.

In [1]:
from os.path import join
from cajal.ugw import _multicore, UGW # Substitute _single_core for single-threaded usage, useful if you want to parallelize at the level of Python processes
UGW_multicore = UGW(_multicore) # For GPU backends, the constructor has to negotiate a connection to the GPU, so it may take a long time to initialize.

bd = "/home/jovyan/tutorial5" # Base directory

The appropriate parameters are sensitive to the absolute scales of your data.
To choose appropriate coefficients you can run the ordinary GW computation first and use this to estimate the appropriate scales, see the "Variants of Gromov-Wasserstein" page.

The algorithm for UGW is much more time-intensive than the algorithm for classical Gromov-Wasserstein and we don't recommend running this cell during the tutorial.

In [2]:
eps = 100.0
UGW_dmat = UGW_multicore.ugw_armijo_pairwise(
    mass_kept = 0.80,
    eps=eps,
    dmats=join(bd,"geodesic_100_icdm.csv")
)

(645, 100, 100)
10865.144685592724
Done first pass.
CPU times: user 2d 1h 5min 41s, sys: 50.3 s, total: 2d 1h 6min 32s
Wall time: 2h 28min 12s


In [3]:
import numpy as np
np.save("UGW_dmat_mass_80_eps_100.npy", UGW_dmat)

Then one can simply use the associated dissimilarity matrix as we have shown in other tutorials, for example, drawing a nearest-neighbors graph through the space and using community detection algorithms to cluster the cells.
Here we test the ability of UGW to recover genetic information about the cell from its neighbors in the morphology space.

In [5]:
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import LeaveOneOut, cross_val_score
from cajal.utilities import cell_iterator_csv

cells, idcms = zip(*cell_iterator_csv(intracell_csv_loc=join(bd,"geodesic_100_icdm.csv")))
metadata = pd.read_csv(join(bd,'m1_patchseq_meta_data.csv'),sep='\t',index_col='Cell').loc[pd.Series(cells)]
RNA_family = metadata['RNA family']
hq = RNA_family != 'low quality' # Filter down to the cells that have a well-defined RNA family.
clf = KNeighborsClassifier(metric="precomputed", n_neighbors=10, weights="distance")
cv = LeaveOneOut()

results = cross_val_score(clf, X=UGW_dmat[hq,:][:,hq], y=RNA_family.loc[hq],cv=cv)
print("Accuracy:", results.sum()/results.shape[0])

from sklearn.model_selection import cross_val_predict
from sklearn.metrics import matthews_corrcoef
cvp = cross_val_predict(clf, X= UGW_dmat[hq,:][:,hq], y=RNA_family.loc[hq], cv=cv)
print("MCC: ", matthews_corrcoef(cvp, RNA_family.loc[hq]))

Accuracy: 0.5510204081632653
MCC:  0.4520340022810735


A scan across values of $\rho$ ranging from 400 to 40000 gives results for the MCC ranging from .444 to .482 with a median of .468, so this is representative.

We benchmark against classical Gromov-Wasserstein in the same way for comparison:

In [6]:
import cajal.utilities

_, classical_gw_dists = cajal.utilities.read_gw_dists(join(bd, 'swc_bdad_100pts_geodesic_gw.csv'), header=True)
classical_gw_dmat = cajal.utilities.dist_mat_of_dict(classical_gw_dists, metadata.index[hq].to_list())

gw_results = cross_val_score(clf, X=classical_gw_dmat, y=RNA_family[hq],cv=cv)
print("Accuracy:", gw_results.sum()/gw_results.shape[0])
cvp = cross_val_predict(clf, X=classical_gw_dmat, y=RNA_family[hq], cv=cv)
print("MCC: ", matthews_corrcoef(cvp, RNA_family[hq]))

Accuracy: 0.5196232339089482
MCC:  0.41437665858208905


The MCC of 0.476 is about 15% better than its classical counterparts. Of course, these statistics themselves are affected by the sampling distribution that the neurons are drawn from, and a different set of neurons would give different numbers.

## Using Fused Gromov-Wasserstein

Fused Gromov-Wasserstein is discussed in detail on the "Variants of Gromov-Wasserstein" page. Here we discuss the application to SWC morphology reconstructions. The interface for the Fused Gromov-Wasserstein is similar to that for classical Gromov-Wasserstein. It requires the following additional pieces of information:
- a file path to the location of the SWC node type identifiers for the sampled points
- fields `soma_dendrite_penalty` and `basal_apical_penalty`. Setting `soma_dendrite_penalty` to a high value means that the algorithm will try to avoid pairing soma nodes from one neuron to dendrite nodes of the other neuron, `basal_apical_penalty` indicates the penalty of pairing a basal dendrite node from one neuron with an apical dendrite node from the other.
- `penalty_dictionary`, an optional argument which overrides the `soma_dendrite_penalty` and `basal_apical_penalty` fields. The user can directly specify the penalty for each pair of node types. This will be most useful for users whose SWC files contain structure ids outside the most commonly used ones (0-4)
- `worst_case_gw_increase`, an optional argument which changes the way the node penalties are interpreted. By default, node penalties are "absolute" - if you specify that the soma-to-dendrite penalty is 5.0, then the FGW cost will increase by 5.0 whenever a soma node is paired with a dendrite node. Setting this argument makes the node penalties relative (so that only the ratios between penalties are important), and all penalties will be rescaled by a constant chosen so that the median increase in GW cost due to the node type penalties will be at most `worst_case_gw_increase`.

In [None]:
from cajal.fused_gw_swc import fused_gromov_wasserstein_parallel

fused_gw_dmat = fused_gromov_wasserstein_parallel(
    intracell_csv_loc=join(bd,'geodesic_100_icdm.csv'),
    swc_node_types=join(bd,"geodesic_100_node_types.npy"),
    fgw_dist_csv_loc=join(bd,"geodesic_100_fgw.csv"),
    num_processes=14,
    soma_dendrite_penalty= 1., # The cost of aligning a soma node to a dendrite node is initialized to be 1.0, but it will be rescaled based on the value of `worst_case_gw_increase`
    basal_apical_penalty=0., # The data set we are using doesn't distinguish basal and apical dendrites, so this parameter has no effect
    # penalty_dictionary: Optional[dict[tuple[int, int], float]] = None,
    chunksize = 100,
    worst_case_gw_increase= 0.50, # We want the GW cost to go up at most 50% (at most double) for the median pair of cells in the data set.
)
fused_gw_dmat = fused_gw_dmat[hq, :][:, hq]

In [None]:
fgw_results = cross_val_score(clf, X=fused_gw_dmat, y=RNA_family[hq],cv=cv)
print("Accuracy:", fgw_results.sum()/fgw_results.shape[0])
cvp = cross_val_predict(clf, X=fused_gw_dmat, y=RNA_family[hq], cv=cv)
print("MCC: ", matthews_corrcoef(cvp, RNA_family[hq]))

Accuracy: 0.5698587127158555
MCC:  0.47469628255037005


So, incorporating this additional data to distinguish between basal dendrites and apical dendrites, Fused Gromov-Wasserstein outperforms classical Gromov-Wasserstein by a similar margin as Unbalanced Gromov-Wasserstein.