# `TCoClust`: basic usage and examples

This notebook showcases the main functionalities of `TCoClust`, a Python package for Trimmed Co-Clustering. The notebook itself is not intended to provide a detailed account of the package's content: hopefully, proper documentation will be coming soon. Note, however, that every function and class in the package is documented via a docstring.

## Getting started

Let's start by importing some common Python packages...

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

...and `TCoClust`

In [None]:
import TCoClust as tcc

In its current state, the package consists of five main modules:
* `cell_tricc`: contains the function `cell_tbsem`, for cellwise-trimmed co-clustering via a cellwise-trimmed block Stochastic EM algorithm, along with ancillary functions (not of direct interest to the end-user);
* `roco_tricc`: contains the function `roco_tbcem`, which implements row and column-wise trimmed co-clustering via a block Classification EM algorithm, plus ancillary functions as above;
* `plots`: contains functions for plotting co-clustering results, visualising outlier maps etc.;
* `model_selection`: provides functions to compute model selection criteria, such as tICL, and perform model selection leveraging multiprocessing capabilities;
* `metrics`: contains common metrics for the assessment of co-clustering results, such as the CARI (Co-clustering Adj. Rand Index);
* `poisson_utils`: contains utility functions specific to the Poisson case, including functions for simulating data from a Poisson LBM (e.g.: `simulate_poisson_lbm`);
* `normal_utils`: similar to the above, but for normal data;
* `utils`: contains general utility functions, not specific to a particular LBM.

In addition to these main modules, the following files are also included: `class_defs.py` (containing definitions of classes), a `config.py` and, of course, an `__init__.py`.

### Data generation

Synthetic data can be generated by some functions in our package, as shown below.

**Generating synthetic data**
<br>
We will show how to generate synthetic data from a Poisson LBM and add cellwise contamination, including both outlying and missing cells.
<br>
The first step consists in generating a clean data matrix from a Poisson LBM, and to do so the easiest way is to use the function `generate_poisson_lbm` from the `tcc.poisson_utils` module:

In [None]:
n = 100    # number of rows
p = 100    # number of columns
g = 4      # number of row groups
m = 3      # number of column groups

# simulate Poisson LBM
X, rowp, colp = tcc.poisson_utils.simulate_poisson_lbm(n, p, g, m, seed=0)  # returns data matrix and row and column partitions

Now we can add cellwise contamination using the function `tcc.utilities.generate_m`:

In [None]:
alpha_o = 0.02   # fraction of outlying cells
alpha_m = 0.02   # fraction of missing cells

# M0 encodes the contamination and missingness patterns:
#   * M0[i, j] = 1  --> cell X[i, j] chosen as outlier
#   * M0[i, j] = -1 --> cell X[i, j] chosen as missing
#   * M0[i, j] = 0  --> cell X[i, j] stays the same
M0 = tcc.utils.generate_m(n,
                          p,
                          alpha_o - 50/(100*100),
                          alpha_m,
                          seed=0
                         )

# contaminate X:
rng = np.random.default_rng(seed=0)
X[M0 == 1] = rng.poisson(12, np.sum(M0 == 1))   # cellwise contamination
X[[11, 12], 0:100:4] = rng.poisson(40, 50).reshape(2, 25)   # let's also add some more "structured" contamination to two rows
M0[[11, 12], 0:100:4] = 1
X[M0 == -1] = np.nan   # missing cells

Let's take a look:

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

vmax = 18
im = ax1.imshow(X,
                vmax=vmax,
                interpolation="none",
               )

# add a colorbar
cbar = plt.colorbar(im, ticks=np.linspace(0, vmax, 5))
cbar.ax.set_yticklabels([f"{tick:g}" for tick in np.linspace(0, vmax, 5)][:-1] + [f">{vmax}"])
ax1.set(xticks=[], yticks=[], title="Unsorted data matrix simulated from Poisson LBM")

im2 = ax2.imshow(M0, 
                 cmap="coolwarm",
                 interpolation="none"
                )
cbar2 = plt.colorbar(im2, ticks=[-1, 0, 1])
cbar2.ax.set_yticklabels(["missing", "regular", "outlying"])
ax2.set(xticks=[], yticks=[], title="Contamination and missingness patterns")

plt.show()

## Performing cellwise trimmed co-clustering

Now that we have some data, we can go on and apply our cell-TRICC method.

### Cellwise trimmed co-clustering with `cell_tbsem`

Let's start with the real data set and the cellwise-trimmed method. The function we will use is `cell_tbsem` from module `cell_tricc`.
<br>
To get acquainted with it, for instance we can read the function's documentation from its docstring:

In [None]:
help(tcc.cell_tricc.cell_tbsem)

In the text above, the first lines are the function's signature, which includes type hints. A basic description of the function as well as more detailed list of the input parameters and of the output are printed after the signature.
<br>
We call this function on the data stored in numpy ndarray `X`, with the following parameter choices:
* number or row groups $g=4$
* number of column groups $m=3$
* the density function to be used in the LBM is Poisson (at the moment this is the only option supported by this method but extensions are on their way)
* trimming level (`alpha`): $\alpha=\alpha_o$ (matches the one used to simulate the data)
* number of initialisations: 10
* we require that the method be restarded until a valid solution is found by setting `until_converged = True`
* we set the seed for reproducibility

In [None]:
result = tcc.cell_tricc.cell_tbsem(X, 
                                   g, 
                                   m, 
                                   density="Poisson", 
                                   alpha=alpha_o, 
                                   n_init=10, 
                                   until_converged=True,
                                   seed=0,
                                  )

In [None]:
type(result)

The result is stored in the variable `result`, which is an object of type `TccResult`, i.e., it is an instance of the class `TccResult`, defined in module `TCoClust.class_defs`.

What's inside this object? Let's take a look:

In [None]:
help(result)

So, for instance, if we want to get the partitions recovered by the method (in the form of binary matrices) we will write:

In [None]:
Z, W = result.Partitions.values()

# let's take a look at the first 4 rows of the row partition matrix:
Z[:4, :]

Recall that the binary matrices $Z$ and $W$ define respectively a row and column partition in the following way:

$$ Z_{ik} = \begin{cases} 1 & \quad \textrm{if row $i$ is in row class $k$} \\ 0 & \quad \textrm{else} \end{cases} $$

and analogously for $W$ and the columns.

If instead of these binary matrices we want partitions to be represented by labels, we can use the function `part_matrix_to_dict`, found in `tcc.utils`:

In [None]:
print(tcc.utils.part_matrix_to_dict(Z))

For instance, in this case, the dictionary printed above tells us that the first row (i.e., row `0`) is assigned to group 0, the second row (1) to group 1, and so on.

The method `summary` can be used to print a summary of the solution found by the procedure:

In [None]:
result.summary()

Before moving to the other case, one important attribute of the output object of `cellTBSEM` is `M`, which contains two mask matrices: a first one accounting for outlying cells, and a second one for outlying (or 'flagged') cells. If, like in this case, there are no missing cells, we are interested in the second mask matrix only:

In [None]:
_, M = result.M
M

Recall that

$$ M_{ij} = \begin{cases} 0 & \quad \textrm{if cell $x_{ij}$ is flagged} \\ 1 & \quad \textrm{else}\end{cases} $$

## Visualisation

### The `plot` method

Once the trimmed model is estimated, the corresponding co-clustering can be plotted using the `plot` method of the result object. The co-clustered matrix is shown by sorting the rows and columns of the original data matrix according to the estimated partitions, which are represented by vertical and horizontal lines. Flagged cells, if any, are highlighted as well:

In [None]:
result.plot()

The `plot` method accepts a variety of keyword arguments, many of which are inherited by pyplot's function `imshow`, which is at the heart of `TccResult.plot()`. An example:

In [None]:
_, ax, _ = result.plot(cmap="Blues",
                       vmin=0,
                       vmax=15,
                       colorbar=True,
                       figsize=(10, 5),
                       return_graphics_objects=True,
                      )

ax.set(xlabel="columns", ylabel="rows")
plt.show()

### Diagnostic plots

Diagnostic plots can be used to visualise different aspects of the flagged cells. We proposed diagnostic plots based either on cell posteriors or on cell residuals. The functions `cell_posterior_plots` and `cell_residual_plots` from module `plots` produce these types of plots. The function `outlier_plots` automatically combines these different types of plots in a grid of subplots.
<br>
To build some of these plots, we need a non-robust fit of the LBM. This can be computed automatically inside the plotting functions, or can be computed once before plotting and passed to the plotting functions as an optional parameter, as we do in the following code blocks.

In [None]:
result0 = tcc.cell_tricc.cell_tbsem(X, 
                                    4, 
                                    3, 
                                    density="Poisson", 
                                    alpha=0, 
                                    n_init=10, 
                                    until_converged=True,
                                    seed=0,
                                   )

In [None]:
fig = tcc.plots.outlier_plots(result, result0)
fig.show()

### Interactive treemap

We can explore the co-clustering results (including the derived block ranking) through an interactive treemap. The function `coclust_treemap` from the `plots` module simply takes as input a `TCCResult` object and launches a Dash app. Row and column labels can be passed as additional parameters.

In [None]:
tcc.plots.coclust_treemap(result)

## Model selection and validation

### Model selection

For cell_TRICC, we have developed two criteria for selecting the number of groups and trimming level $\alpha$: a _trimmed Integrated Completed Likelihood_ criterion and its BIC-like approximation. Both criteria can also be used when no trimming is performed and reduce to the ICL and ICL-BIC criteria for the Poisson LBM.

These criteria can be computed on a grid to select $g$, $m$ and $\alpha$. This can be done automatically and in parallel thanks to the function `select_model` of the `model_selection` module:

In [None]:
alpha_o

In [None]:
# define a grid
G = [3, 4, 5]
M = [2, 3, 4]
Alpha = [0, 0.01, 0.02, 0.03]

# keyword arguments for cell_tbsem (our fitting function)
kwargs = {
    "density": "Poisson", 
    "n_init": 30, 
    "until_converged": True,
    "seed": 0
}

# keyword arguments for poisson_ticl (tICL criterion)
# tau, a and b are the hyperparameters of the prior distributions used to compute the exact tICL criterion
kwargs_icl = {
    "tau": 4.,
    "a": 0.01,
    "b": 0.01,
    "beta": np.log(1 / 0.001 - 1)
}

best_model = tcc.model_selection.select_model(X,
                                              row_grid=G,
                                              column_grid=M,
                                              alpha_grid=Alpha,
                                              kwargs=kwargs,  # for the main fitting function
                                              kwargs_icl=kwargs_icl,  # for the tICL criterion
                                              n_jobs=4,  # number of parallel jobs
                                              verbose=True,  # to show progress
                                             )

Both criteria (tICL and tICL-BIC) selected $(g,m)=(4,3)$ and $\alpha = 2\%$, corresponding to the original partition sizes and contamination rate used to simulate the data.

### Comparing co-clusterings witht the Co-clustering ARI (CARI)

The Co-clustering Adjusted Rand Index (CARI) can be used to compare two co-clusterings. Its interpretation is the same as the familiar ARI for clustering. 
In our package, it can be computed using the function `cari`from module `metrics`.
For now, `cari` expectes the four input partitions to be expressed as lists of labels.

In [None]:
# transform the partitions' representations to lists
z_true, w_true = list(rowp.values()), list(colp.values())
z_est, w_est = list(tcc.utils.part_matrix_to_dict(Z).values()), list(tcc.utils.part_matrix_to_dict(W).values())

tcc.metrics.cari(z_true, w_true, z_est, w_est)