Aims:
- Label some BY4741 time series as oscillating vs non-oscillating.  50 for now, can increase when needed.
- Use this to train an SVM on PCAs, then compute MI.
- See if it can tell apart oscillating and non-oscillating time series.

Specify file name and sampling period

In [None]:
%matplotlib inline

# Import data

In [None]:
import numpy as np
import pandas as pd
import csv

# PARAMETERS
filename_prefix = './data/arin/Omero19979_'
#filename_prefix = './data/arin/Omero20016_'
#

# Import flavin signals
signal = pd.read_csv(filename_prefix+'flavin.csv')
signal.replace(0, np.nan, inplace=True) # because the CSV is constructed like that :/

# Import look-up table for strains (would prefer to directly CSV -> dict)
strainlookup_df = pd.read_csv(filename_prefix+'strains.csv')
strainlookup_dict = dict(zip(strainlookup_df.position, strainlookup_df.strain))

# Positions -> Strain (more informative)
signal = signal.replace({'position': strainlookup_dict})
signal.rename(columns = {"position": "strain"}, inplace = True)
signal = signal.drop(['distfromcentre'], axis = 1)

# Convert to multi-index dataframe
signal_temp = signal.iloc[:,2:]
multiindex = pd.MultiIndex.from_frame(signal[['strain', 'cellID']])
signal = pd.DataFrame(signal_temp.to_numpy(),
                      index = multiindex)

signal

# Choose a list of cells as working data

List strains

In [None]:
signal.index.get_level_values(0).unique().to_list()

Define `signal_wd` as working data

In [None]:
signal_wd = signal.loc[['swe1_Del', 'tsa1_Del_tsa2_Del']]

signal_wd

# Processing time series

## Range

Chop up time series according to `interval_start` and `interval_end`, then remove cells that have NaNs.

In [None]:
# PARAMETERS
interval_start = 25
interval_end = 168
#

signal_processed = signal_wd.iloc[:, interval_start:interval_end].dropna()

signal_processed

## Detrend

Using sliding window

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# PARAMETERS
window = 45
#

fig, ax = plt.subplots()
sns.heatmap(signal_processed)
plt.title('Before detrending')
plt.show()

def moving_average(input_timeseries,
                  window = 3):
    processed_timeseries = np.cumsum(input_timeseries, dtype=float)
    processed_timeseries[window:] = processed_timeseries[window:] - processed_timeseries[:-window]
    return processed_timeseries[window - 1 :] /  window

signal_processed = signal_processed.div(signal_processed.mean(axis = 1), axis = 0)
signal_movavg = signal_processed.apply(lambda x: pd.Series(moving_average(x.values, window)), axis = 1)
signal_norm = signal_processed.iloc(axis = 1)[window//2: -window//2] / signal_movavg.iloc[:,0:signal_movavg.shape[1]-1].values

fig, ax = plt.subplots()
sns.heatmap(signal_norm)
plt.title('After detrending')
plt.show()

signal_processed = signal_norm

signal_processed

# Featurisation

Use `catch22`

In [None]:
from postprocessor.core.processes.catch22 import catch22Parameters, catch22

catch22_processor = catch22(catch22Parameters.default())
features = catch22_processor.run(signal_processed)

sns.heatmap(features)

# Labels

Import oscillatory/non-oscillatory labels

In [None]:
filename_targets = 'categories_19979_detrend.csv'
targets_df = pd.read_csv(filename_targets, header = None, index_col = 0)
targets_df.index.names = ['cellID']
targets = targets_df.loc[features.index.get_level_values('cellID')].to_numpy().flatten()

targets

In [None]:
import pandas as pd

osc_list = [cell.flavin.category for cell in Wlist]
freq_table = pd.Series(osc_list).value_counts()
print(freq_table)

Put oscillations into array input for MI

In [None]:
categories = ['0', '1']

mi_data = [
    np.vstack(tuple(
        [cell.flavin.reading_processed
             for cell in Wlist
             if cell.flavin.category == category]))
    for category in categories
]

Or, by strain:

In [None]:
strains = list(set([cell.strain for cell in Wlist]))
#strains = ['FY4', 'CEN_PK_Mat_A']

mi_data = [
    np.vstack(tuple(
        [cell.flavin.reading_processed
             for cell in Wlist
             if cell.strain == strain]))
    for strain in strains
]

Viewing and visualising the data array

In [None]:
mi_data

In [None]:
import seaborn as sns

mi_data_min = min([np.min(group_array) for group_array in mi_data])
mi_data_max = max([np.max(group_array) for group_array in mi_data])
for group_index, group_array in enumerate(mi_data):
    fig, ax = plt.subplots()
    sns.heatmap(group_array, vmin = mi_data_min, vmax = mi_data_max, center = 0, cmap = 'vlag')
    plt.title(strains[group_index])
    plt.show()
    
fig, ax = plt.subplots()
for group_index, group_array in enumerate(mi_data):
    plt.plot(np.mean(group_array, axis=0), label=strains[group_index])
plt.legend()
plt.show()

Train classifier via bootstrapping and compute MI

In [None]:
from MIdecoding import estimateMI, plotMIovertime

results = estimateMI(mi_data, verbose=True, overtime=False, n_bootstraps=100)

In [None]:
results

Random assignment of traces to groups

In [None]:
#strains = 
#strains = list(set([cell.strain for cell in Wlist]))
#strains = ['swe1_Del', 'tsa1_Del_tsa2_Del']

stack_all = np.vstack(tuple(mi_data))
random_categories = np.random.randint(len(strains), size = stack_all.shape[0])

mi_random = [
    np.vstack(tuple(
        [time_series
             for index, time_series in enumerate(stack_all)
             if random_categories[index] == category]))
    for category in range(len(strains))
]

mi_random_min = min([np.min(group_array) for group_array in mi_random])
mi_random_max = max([np.max(group_array) for group_array in mi_random])
for group_index, group_array in enumerate(mi_random):
    fig, ax = plt.subplots()
    sns.heatmap(group_array, vmin = mi_random_min, vmax = mi_random_max, center = 0, cmap = 'vlag')
    plt.title(group_index)
    plt.show()
    
fig, ax = plt.subplots()
for group_index, group_array in enumerate(mi_random):
    plt.plot(np.mean(group_array, axis=0), label=group_index)
plt.legend()
plt.show()

results = estimateMI(mi_random, verbose=True, overtime=False, n_bootstraps=1000)