# Tutorial 5: Unbalanced Gromov-Wasserstein Distances and Fused Gromov-Wasserstein distances
This notebook demonstrates the practical usage of unbalanced GW and fused GW. The theory behind these concepts was introduced in the "Variants of Gromov-Wasserstein" page. We will use the same data set that was studied in Tutorial 4, and the same sampled points. All data can be downloaded from Dropbox

## Using Unbalanced Gromov-Wasserstein
We have implemented a Python module that allows the user to compute the unbalanced Gromov-Wasserstein distance between cells. By default, CAJAL ships with a single core version of the algorithm and a multicore version, and the user can uncomment the appropriate line in the package's `setup.py` build script to get a version of the algorithm for a GPU using either CUDA or OpenCL. These are disabled by default as the end user must configure their machine so that the CUDA (respectively, OpenCL) header files can be found and all necessary libraries are available. A few other backends can be made available upon request. Our experience shows that the GPU backends are only likely to be useful when the individual UGW problems are very large (i.e., the metric spaces are large)

The user should only import one of the backend modules at a time due to technical limitations of C (C has no namespacing, so there will be symbol conflicts from identically named functions in the two backend modules). Restart the Python interpreter if you want to load a different backend module.

Let us demonstrate how to use the implementation.

We import the module we want to use, in this case the multicore implementation, and the UGW class. The constructor for the UGW class takes the backend module as its argument, establishes a connection with the library, and returns an object that maintains the internal state of the computation. The wrapper functions for the C backend are then accessible as *methods* of this object. If the user intends to parallelize at the level of Python processes, each process should instantiate the class. As usual one can call `help(UGW_multicore)`, `help(UGW_multicore.ugw_armijo)`, and so on for documentation of the functions.

In [None]:
import cajal.sample_swc
import cajal.swc
from os.path import join
bd = "/home/jovyan/tutorial5" # Base directory

In [3]:
cajal.sample_swc.compute_icdm_all_geodesic(
    infolder=join(bd, 'swc'),
    out_csv=join(bd,'geodesic_100_icdm.csv'),
    out_node_types=join(bd,"geodesic_100_node_types.npy"),
    num_processes=8,
    preprocess=cajal.swc.preprocessor_geo(
        structure_ids=[1,3,4]),
    n_sample=100
    )

100%|█████████▉| 644/645 [00:28<00:00, 22.67it/s]


[]

In [4]:
from cajal.ugw import _multicore, UGW # _single_core for single-threaded usage, useful if you want to parallelize at the level of Python processes
UGW_multicore = UGW(_multicore) # For GPU backends, the constructor has to negotiate a connection to the GPU, so it may take a long time to initialize.

In [None]:
from cajal.run_gw import cell_iterator_csv
import numpy as np

cells, icdms = zip(*cell_iterator_csv(join(bd, "geodesic_100_icdm.csv")))
icdm_block = np.stack(
    icdms, axis=0
) 

For efficient memory usage and effective parallelization, the parallel function requires an array of cells of uniform length.

The appropriate parameters are sensitive to the absolute scales of your data.
To choose appropriate coefficients you can run the ordinary GW computation first and use this to estimate the appropriate scales, see the "Variants of Gromov-Wasserstein" page.

In [None]:

rho1 = 1600.0
rho2 = 1600.0
eps = 100.0

The algorithm for UGW is much more time-intensive than the algorithm for classical Gromov-Wasserstein and we don't recommend running this cell during the tutorial.

In [None]:

UGW_results = UGW_multicore.ugw_armijo_pairwise_unif(
    rho1=rho1, rho2=rho2, eps=eps, dmats=icdm_block
)


The ".from_futhark()" method converts the library's internal representation of the output to a Numpy array.

In [None]:
UGW_array = UGW_multicore.from_futhark(UGW_results)
np.save(
    join(bd, "unbalanced_gw_bdad_100pts_geodesic_rho1_1600_rho2_1600_eps_100.npy"),
    UGW_array,
)

Unfortunately, the UGW algorithm is not numerically stable and on large data sets it's likely that some inputs will return NaN.
The following wrapper function will make a second pass through the data looking for NaN values, and for each pair of cells where the computation diverged, it will rerun the computation with exponentially increasing values of the regularization parameter until the algorithm converges.

In [None]:
u = np.ones(shape=(645,100),dtype=np.float64) / 100 # The uniform distribution on each cell
UGW_fix_nans=UGW_multicore.ugw_armijo_pairwise_increasing(
        ugw_dmat = UGW_array,
        increasing_ratio=1.1,
        rho1 = 1600.0,
        rho2 = 1600.0,
        eps = 100.0,
        dmats = icdm_block,
        distrs=u
        )
UGW_array = UGW_multicore.from_futhark(UGW_fix_nans)
np.save(join(bd, "unbalanced_gw_bdad_100pts_geodesic_rho1_1600_rho2_1600_eps_100_after_cleaning.npy"), UGW_array)

The returned array has five columns, corresponding to $\mathcal{G}(T)$, the first and second marginal costs $KL(\pi_X(T)\otimes\pi_X(T)\mid \mu\otimes\mu)$ and $KL(\pi_Y(T)\otimes\pi_Y(T)\mid \nu\otimes\nu)$, and the entropy regularization term $KL(T\otimes T\mid (\mu\otimes\nu)\otimes(\mu\otimes \nu))$, and the weighted linear combination $\mathcal{L}_\varepsilon(T)=UGW_\varepsilon$, where $T$ was the optimal coupling found by the search. In our analysis, we choose to use $\mathcal{L}$ rather than $\mathcal{L}_\varepsilon$ as the measure of "distance", because the regularization term is only present for computational reasons and it doesn't inform us about morphological distinctions.

In [None]:
from scipy.spatial.distance import squareform
UGW_dmat = squareform(UGW_array[:,0] + rho1 * UGW_array[:,1] + rho2 * UGW_array[:,2])

Then one can simply use the associated dissimilarity matrix as we have shown in other tutorials, for example, drawing a nearest-neighbors graph through the space and using community detection algorithms to cluster the cells.
Here we test the ability of UGW to recover genetic information about the cell from its neighbors in the morphology space.

In [8]:
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import LeaveOneOut, cross_val_score

metadata = pd.read_csv(join(bd,'m1_patchseq_meta_data.csv'),sep='\t',index_col='Cell').loc[pd.Series(cells)]
RNA_family = metadata['RNA family']
hq = RNA_family != 'low quality' # Filter down to the cells that have a well-defined RNA family.
clf = KNeighborsClassifier(metric="precomputed", n_neighbors=10, weights="distance")
cv = LeaveOneOut()

results = cross_val_score(clf, X=UGW_dmat[hq,:][:,hq], y=RNA_family.loc[hq],cv=cv)
print("Accuracy:", results.sum()/results.shape[0])

from sklearn.model_selection import cross_val_predict
from sklearn.metrics import matthews_corrcoef
cvp = cross_val_predict(clf, X= UGW_dmat[hq,:][:,hq], y=RNA_family.loc[hq], cv=cv)
print("MCC: ", matthews_corrcoef(cvp, RNA_family.loc[hq]))

Accuracy: 0.5682888540031397
MCC:  0.4756706769071058


A scan across values of $\rho$ ranging from 400 to 40000 gives results for the MCC ranging from .444 to .482 with a median of .468, so this is representative.

We benchmark against classical Gromov-Wasserstein in the same way for comparison:

In [9]:
import cajal.utilities

_, classical_gw_dists = cajal.utilities.read_gw_dists(join(bd, 'swc_bdad_100pts_geodesic_gw.csv'), header=True)
classical_gw_dmat = cajal.utilities.dist_mat_of_dict(classical_gw_dists, metadata.index[hq].to_list())

gw_results = cross_val_score(clf, X=classical_gw_dmat, y=RNA_family[hq],cv=cv)
print("Accuracy:", gw_results.sum()/gw_results.shape[0])
cvp = cross_val_predict(clf, X=classical_gw_dmat, y=RNA_family[hq], cv=cv)
print("MCC: ", matthews_corrcoef(cvp, RNA_family[hq]))

Accuracy: 0.5196232339089482
MCC:  0.41437665858208905


The MCC of 0.476 is about 15% better than its classical counterparts. Of course, these statistics themselves are affected by the sampling distribution that the neurons are drawn from, and a different set of neurons would give different numbers.

## Using Fused Gromov-Wasserstein

Fused Gromov-Wasserstein is discussed in detail on the "Variants of Gromov-Wasserstein" page. Here we discuss the application to SWC morphology reconstructions. The interface for the Fused Gromov-Wasserstein is similar to that for classical Gromov-Wasserstein. It requires the following additional pieces of information:
- a file path to the location of the SWC node type identifiers for the sampled points
- fields `soma_dendrite_penalty` and `basal_apical_penalty`. Setting `soma_dendrite_penalty` to a high value means that the algorithm will try to avoid pairing soma nodes from one neuron to dendrite nodes of the other neuron, `basal_apical_penalty` indicates the penalty of pairing a basal dendrite node from one neuron with an apical dendrite node from the other.
- `penalty_dictionary`, an optional argument which overrides the `soma_dendrite_penalty` and `basal_apical_penalty` fields. The user can directly specify the penalty for each pair of node types. This will be most useful for users whose SWC files contain structure ids outside the most commonly used ones (0-4)
- `worst_case_gw_increase`, an optional argument which changes the way the node penalties are interpreted. By default, node penalties are "absolute" - if you specify that the soma-to-dendrite penalty is 5.0, then the FGW cost will increase by 5.0 whenever a soma node is paired with a dendrite node. Setting this argument makes the node penalties relative (so that only the ratios between penalties are important), and all penalties will be rescaled by a constant chosen so that the median increase in GW cost due to the node type penalties will be at most `worst_case_gw_increase`.

In [None]:
from cajal.fused_gw_swc import fused_gromov_wasserstein_parallel

fused_gw_dmat = fused_gromov_wasserstein_parallel(
    intracell_csv_loc=join(bd,'geodesic_100_icdm.csv'),
    swc_node_types=join(bd,"geodesic_100_node_types.npy"),
    fgw_dist_csv_loc=join(bd,"geodesic_100_fgw.csv"),
    num_processes=14,
    soma_dendrite_penalty= 1., # The cost of aligning a soma node to a dendrite node is initialized to be 1.0, but it will be rescaled based on the value of `worst_case_gw_increase`
    basal_apical_penalty=0., # The data set we are using doesn't distinguish basal and apical dendrites, so this parameter has no effect
    # penalty_dictionary: Optional[dict[tuple[int, int], float]] = None,
    chunksize = 100,
    worst_case_gw_increase= 0.50, # We want the GW cost to go up at most 50% (at most double) for the median pair of cells in the data set.
)
fused_gw_dmat = fused_gw_dmat[hq, :][:, hq]

In [None]:
fgw_results = cross_val_score(clf, X=fused_gw_dmat, y=RNA_family[hq],cv=cv)
print("Accuracy:", fgw_results.sum()/fgw_results.shape[0])
cvp = cross_val_predict(clf, X=fused_gw_dmat, y=RNA_family[hq], cv=cv)
print("MCC: ", matthews_corrcoef(cvp, RNA_family[hq]))

Accuracy: 0.5698587127158555
MCC:  0.47469628255037005


So, incorporating this additional data to distinguish between basal dendrites and apical dendrites, Fused Gromov-Wasserstein outperforms classical Gromov-Wasserstein by a similar margin as Unbalanced Gromov-Wasserstein.