**Purpose:** See how YMCs in mutants differ from each other and wild-type.

**Aims:**
- Import flavin signals from multiple strains in the same experiment (and thus same nutrient conditions).
   - Obvious dataset: Causton strains, because there are five strains.
- Process data: cut time series to duration of interest, detrend flavin signals.
- Featurise data: use `catch22`
- Use UMAP to visualise the relationship between the data.
   - Adjust hyperparameters as appropriate to help with visualisation.
   - Potentially use the labels themselves to perform supervised UMAP.  This will hopefully separate the classes while retaining some local and global structure.

**Paradigms:**
- Use `aliby` data structures, i.e. `pandas` `DataFrames` with multi-indexing.
- Use `postprocessor` processes for featurisation
- Use `scikit-learn` and `umap` routines.
- Ultimate goal to put all the cells together in a script to put in `skeletons` (especially if `svm_sandbox.ipynb` and `cycle_alignment_sandbox.ipynb` share *many* cells with this one).

In [None]:
import PyQt5
%matplotlib qt

# Import data

In [None]:
import numpy as np
import pandas as pd
import csv

# PARAMETERS
#filename_prefix = './data/arin/Omero19979_'
filename_prefix = './data/arin/Omero20016_'
#

# Import flavin signals
signal_flavin = pd.read_csv(filename_prefix+'flavin.csv')
signal_flavin.replace(0, np.nan, inplace=True) # because the CSV is constructed like that :/

def convert_df_to_aliby(
    signal,
    strainlookup_df,
):
    # Import look-up table for strains (would prefer to directly CSV -> dict)
    strainlookup_dict = dict(zip(strainlookup_df.position, strainlookup_df.strain))
    
    # Positions -> Strain (more informative)
    signal = signal.replace({'position': strainlookup_dict})
    signal.rename(columns = {"position": "strain"}, inplace = True)
    signal = signal.drop(['distfromcentre'], axis = 1)

    # Convert to multi-index dataframe
    signal_temp = signal.iloc[:,2:]
    multiindex = pd.MultiIndex.from_frame(signal[['strain', 'cellID']])
    signal = pd.DataFrame(signal_temp.to_numpy(),
                          index = multiindex)
    
    return signal

strainlookup_df = pd.read_csv(filename_prefix+'strains.csv')
signal_flavin = convert_df_to_aliby(signal_flavin, strainlookup_df)

# Processing time series

## Range

Chop up time series according to `interval_start` and `interval_end`, then remove cells that have NaNs.  Print number of cells of each strain.

In [None]:
# PARAMETERS
interval_start = 25
interval_end = 168
#

signal_flavin_processed = signal_flavin.iloc[:, interval_start:interval_end].dropna()

signal_flavin_processed.index.get_level_values(0).value_counts()

## Detrend

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# PARAMETERS
window = 45
#

fig, ax = plt.subplots()
sns.heatmap(signal_flavin_processed)
plt.title('Before detrending')
plt.show()

def moving_average(input_timeseries,
                  window = 3):
    processed_timeseries = np.cumsum(input_timeseries, dtype=float)
    processed_timeseries[window:] = processed_timeseries[window:] - processed_timeseries[:-window]
    return processed_timeseries[window - 1 :] /  window

signal_flavin_processed = signal_flavin_processed.div(signal_flavin_processed.mean(axis = 1), axis = 0)
signal_flavin_movavg = signal_flavin_processed.apply(lambda x: pd.Series(moving_average(x.values, window)), axis = 1)
signal_flavin_norm = signal_flavin_processed.iloc(axis = 1)[window//2: -window//2] / signal_flavin_movavg.iloc[:,0:signal_flavin_movavg.shape[1]-1].values

fig, ax = plt.subplots()
sns.heatmap(signal_flavin_norm)
plt.title('After detrending')
plt.show()

signal_flavin_processed = signal_flavin_norm

# Featurisation

Featurisation, using `catch22`

In [None]:
from postprocessor.core.processes.catch22 import catch22Parameters, catch22

catch22_processor = catch22(catch22Parameters.default())
features = catch22_processor.run(signal_flavin_processed)

sns.heatmap(features)

Alternatively, use time points

In [None]:
features = signal_flavin_processed

sns.heatmap(features)

Normalise features

In [None]:
from sklearn.preprocessing import StandardScaler

scaled_features = StandardScaler().fit_transform(features)

sns.heatmap(scaled_features)

Scatterplot matrix of the first 10 features (there is probably space for `train.importance` around here).

In [None]:
df = pd.DataFrame(scaled_features[:, 2:5])
df['strain'] = pd.Series(signal_flavin_processed.index.get_level_values(0))
sns.pairplot(df, hue='strain')

# UMAP

Label by strain

In [None]:
strain_labels = signal_flavin_processed.index.get_level_values('strain')
strain_unique = strain_labels.unique().to_list()
strain_map = dict(zip(strain_unique, list(range(len(strain_unique)))))
strain_labels_numerical = [strain_map.get(item, item) for item in strain_labels]

Load custom labels (e.g. oscillation categories)

In [None]:
# PARAMETERS
filename_targets = 'categories_20016_detrend.csv'
#

targets = pd.read_csv(filename_targets, header = None, index_col = 0)
targets.index.names = ['cellID']
targets.columns = ['category']

customcat_labels = np.array([
    targets.loc[cellID].item()
    for cellID in signal_flavin_processed.index.get_level_values('cellID')
])
customcat_labels_numerical = customcat_labels

Combine strain and oscillation categories to produce colour keys:
- If there are n strains, those strains will have numerical labels of 1 to n and will correspond to n colours of the palette.
- Non-oscillating nodes from _any_ strain will have a numerical label of 0 and will correspond to grey.
- Defining it this way because `matplotlib` conveniently has a couple of qualitative colour maps that has grey as the last colour.  I reverse it so that grey is the first; intuitively it's easier to work with if 0 consistently corresponds to grey.

In [None]:
from matplotlib import cm
from matplotlib import colors

# Strain names or 'non-oscillatory'
combined_labels = [
    strain_labels[index] if customcat_labels[index] == 1 else 'non-oscillatory'
    for index, _ in enumerate(customcat_labels)
]
# Numbers, as described above
combined_labels_numerical = [
    strain_labels_numerical[index]+1 if customcat_labels_numerical[index] == 1 else 0
    for index, _ in enumerate(customcat_labels_numerical)
]
# Create a palette out of cm
palette_cm = cm.get_cmap('Set1_r', len(strain_unique)+1)
combined_labels_numerical_unique = np.unique(combined_labels_numerical)
palette_rgb = [
    colors.rgb2hex(palette_cm(index/len(combined_labels_numerical_unique))[:3])
    for index, _ in enumerate(combined_labels_numerical_unique)
]
# Dict to map label to colour
palette_map = dict(zip(
    np.concatenate((['non-oscillatory'], strain_unique)).tolist(),
    palette_rgb
))

Optional: make non-oscillatory white and transparent

In [None]:
palette_map['non-oscillatory'] = '#ffffff00'

Fit and plot

In [None]:
import umap
import umap.plot

# Fit
reducer = umap.UMAP(
    random_state = 0,
    n_neighbors = 4,
    min_dist = 0.02,
    n_components = 2,
    metric = 'cosine',
)
mapper = reducer.fit(scaled_features)

# Plot
umap.plot.points(
    mapper,
    labels = np.array(combined_labels),
    color_key = palette_map,
)

To do: add way to mouse over points and see what the time series looks like

Vary hyperparameters

In [None]:
import itertools

# Wrap UMAP fitting and plotting into one function,
# taking matplotlib axis as an argument
def generate_umap(
    scaled_features,
    n_neighbors,
    min_dist,
    combined_labels,
    palette_map,
    ax = None,
):
    reducer = umap.UMAP(
        random_state = 0,
        n_neighbors = n_neighbors,
        min_dist = min_dist,
        n_components = 2,
        metric = 'cosine',
    )
    mapper = reducer.fit(scaled_features)
    
    if ax is None:
        ax = plt.gca()
    ax = umap.plot.points(
        mapper,
        labels = np.array(combined_labels),
        color_key = palette_map,
        ax = ax,
    )
    return ax

# Define values of hyperparameters to iterate over here:
hyperparams_to_iterate = {
    'n_neighbors' : [5,10],
    'min_dist' : [0.05, 0.10],
}

# Plot UMAPs in a grid
fig, axs = plt.subplots(
    len(hyperparams_to_iterate['n_neighbors']),
    len(hyperparams_to_iterate['min_dist'])
)
for n_neighbors_index, n_neighbors in enumerate(hyperparams_to_iterate['n_neighbors']):
    for min_dist_index, min_dist in enumerate(hyperparams_to_iterate['min_dist']):
        axs[n_neighbors_index, min_dist_index] = generate_umap(
            scaled_features,
            n_neighbors,
            min_dist,
            combined_labels,
            palette_map,
            ax = axs[n_neighbors_index, min_dist_index],
        )
plt.show()