In [None]:
import itertools
import json
import pickle
from pathlib import Path
from pprint import pprint

import numpy as np
import pandas as pd
import flowkit as fk
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler

import ipywidgets as widgets
from tqdm.notebook import tqdm

from utils import ask_directory
from pair_precision import (
    get_logicle_polygon_gates,
    get_pair_precisions
)


SEED = 7
np.random.seed(SEED)

In [None]:
# create "results" folder to keep gate pairs scores and ...
Path("./results").mkdir(exist_ok=True)

In [None]:
# functions to separate alive vs. others

TIFF_NAME_FORMAT = "{name}_{id:08d}.tiff"

def get_id_from_tiff_path(tiff_path: Path):
    return int(tiff_path.stem.split("_")[-1])


def get_tiff_filtered_dataframes(dataframes: dict, sample_paths: dict):
    results = {}
    for sample_name, hdf_path in tqdm(dataframes.items(), desc="separating samples with tiff"):
        df = pd.read_hdf(hdf_path, key="df")
        image_path = sample_paths[sample_name].parent.joinpath("images")
        all_tiffs = list(image_path.glob("**/*.tiff"))
        rows_with_tiff = [get_id_from_tiff_path(img_path) for img_path in all_tiffs]
        df_with_tiff = df.iloc[rows_with_tiff]
        df_without_tiff = df.drop(index=rows_with_tiff)
        df_logicles = pd.read_hdf(hdf_path, key="logicles")
        print(
            f"{sample_name:<20}: Total: {len(df):<7,d} | With tiff: {len(df_with_tiff):<7,d} | "
            f"Without tiff: {len(df_without_tiff):,d}."
        )
        results[sample_name] = [df_without_tiff, df_with_tiff, df_logicles]

    return results

In [None]:
# load config from the previous step, if it's available
config = {}
config_file = Path("./config.json")
if config_file.exists():
    with open(config_file, mode="r") as f:
        config = json.load(f)

config

### Set the raw data directory (*FCS* files):

In [None]:
raw_data_dir = config.get("raw_data_dir", None)

if raw_data_dir is None:
    raw_data_dir = ask_directory("Select your data directory")

raw_data_dir = Path(raw_data_dir)
print(raw_data_dir)

In [None]:
fcs_files = raw_data_dir.glob("**/*.fcs")
sample_paths = {}
for fcs in fcs_files:
    has_images = len(list(fcs.parent.glob("*images"))) > 0
    if has_images:
        sample_paths[fcs.stem] = fcs

print("Sample files with images:")
pprint(sample_paths)

### Select the Logicle transformed directory:

In [None]:
transformed_data_dir = config.get("logicle_data_dir", None)
if transformed_data_dir is None:
    transformed_data_dir = ask_directory("Select directory of the transformed data")

transformed_data_dir = Path(transformed_data_dir)
print(transformed_data_dir)

In [None]:
logicle_samples = {}
h5_files = transformed_data_dir.glob("*.h5")
for file in h5_files:
    logicle_samples[file.stem] = file

print("Logicle transformed data files:")
pprint(logicle_samples)

### Separating alive samples vs. others for each species:

In [None]:
species_dfs = get_tiff_filtered_dataframes(logicle_samples, sample_paths)

### Select the target species:

In [None]:
selector = widgets.Select(
    options=[sp_name for sp_name in species_dfs],
    rows=14,
    description="Target species:"
)

display(selector)

In [None]:
target_species = selector.value
print(f"Your target species: {target_species}")

In [None]:
# get the target alive samples
df_alive = species_dfs[target_species][1]

# get others species samples
df_others = []
for name in species_dfs:
    if name != target_species:
        df_others.append(species_dfs[name][0])
        df_others.append(species_dfs[name][1])

df_others = pd.concat(df_others)


print(f"alive: {df_alive.shape}, others: {df_others.shape}")

In [None]:
# to free some memory up
del species_dfs

### Normalize the data

In [None]:
scaler = StandardScaler()

scaler.fit(np.vstack((
    df_alive.to_numpy(),
    df_others.to_numpy()
)))

In [None]:
alive_normed = scaler.transform(df_alive.to_numpy())
others_normed = scaler.transform(df_others.to_numpy())

### Get the Fisher Discriminant Ratio (FDR) for each channel:
(For two classes alive vs. others)

In [None]:
alive_mu = alive_normed.mean(axis=0)
others_mu = others_normed.mean(axis=0)

alive_vars = alive_normed.var()
others_vars = others_normed.var()

In [None]:
fishers = np.power(alive_mu - others_mu, 2) / (alive_vars + others_vars)
fishers.shape

In [None]:
df_fisher = pd.DataFrame(data={
    "channel": df_alive.columns, "fisher": fishers
}).sort_values("fisher", ascending=False).reset_index(drop=True)

df_fisher

### Get fisher score for each possible pair (average of two channels in a pair)

In [None]:
all_pairs_cols = np.array(
    list(itertools.combinations(df_alive.columns.to_list(), 2))
)

pprint(all_pairs_cols)

In [None]:
pair_fishers = np.zeros(len(all_pairs_cols))

for i, pair in tqdm(
    enumerate(all_pairs_cols), total=len(all_pairs_cols), desc="Getting pairs' fishers"
):
    # print(pair)
    fisher_col1 = df_fisher[df_fisher["channel"] == pair[0]]["fisher"].to_numpy()[0]
    fisher_col2 = df_fisher[df_fisher["channel"] == pair[1]]["fisher"].to_numpy()[0]
    pair_fishers[i] = (fisher_col1 + fisher_col2) / 2

In [None]:
df_pair_fishers = pd.DataFrame(data={
    "channel_1": all_pairs_cols[:, 0],
    "channel_2": all_pairs_cols[:, 1],
    "fisher_avg": pair_fishers
})

df_pair_fishers = df_pair_fishers.round(4).sort_values(
    "fisher_avg", ascending=False).reset_index(drop=True)
df_pair_fishers

### Top Pairs: 

#### 1. Select pairs with fisher score above the average

In [None]:
fisher_mean = df_pair_fishers["fisher_avg"].mean().round(4)
print(fisher_mean)

In [None]:
mask = df_pair_fishers["fisher_avg"] > fisher_mean
df_pair_fisher_above_mean = df_pair_fishers[mask].reset_index(drop=True)

df_pair_fisher_above_mean

#### 2. For each channel selected by the previous step, select pairs with having fisher score above the 50% (median):

In [None]:
def get_top_pairs(channel_list, df_fisher):
    best_matches = []
    for channel in tqdm(channel_list, desc="Proposing best matches"):
        ch_mask = df_fisher["channel_1"] == channel
        if ch_mask.sum() > 0:
            channel_pairs = df_fisher[ch_mask].sort_values("fisher_avg", ascending=False)
            ch_threshold = np.quantile(channel_pairs["fisher_avg"], 0.5)
            top_matches = channel_pairs[channel_pairs["fisher_avg"] >= ch_threshold]
            best_matches.append(top_matches)

    return pd.concat(best_matches).reset_index(drop=True)

In [None]:
df_top_pairs = get_top_pairs(df_alive.columns.to_list(), df_pair_fisher_above_mean)
df_top_pairs

In [None]:
# save the top pairs
df_top_pairs.to_csv(f"./results/{target_species}_top_pairs.csv", index=False)

### Get the average vector of the alive samples (with images) for the target species.

In [None]:
# get the target species data in logicle space
df_positive_logicle = pd.read_hdf(logicle_samples[target_species], key="df")
df_positive_logicle.shape

In [None]:
# get channel's logicle params
df_logicle_params = pd.read_hdf(logicle_samples[target_species], key="logicles")
df_logicle_params.shape

In [None]:
# alive (with tiffs) samples average vector
target_alive_vector = df_alive.mean().to_numpy()
# np.save(f"./results/{target_species}_alive_vector.npy", target_alive_vector)
target_alive_vector.shape

### Calculate the similarity between all samples with the alive average vector.

In [None]:
# calculate the cosine similarity between alive average vector and each row of data
data = df_positive_logicle.to_numpy()
sim_mat = np.dot(data, target_alive_vector)
sim_mat /= (
    np.linalg.norm(data, axis=1) *
    np.linalg.norm(target_alive_vector)
)
# print(sim_mat.min(), sim_mat.max(), sim_mat.mean())

# add weights to make a better distribution of the similarities
weights = (sim_mat - np.mean(sim_mat))
sim_mat = weights * sim_mat

# scale similarity to the range of [0, 1]
sim_mat = (sim_mat - sim_mat.min()) / (sim_mat.max() - sim_mat.min())
# print(sim_mat.min(), sim_mat.max(), sim_mat.mean())

### Set the threshold of the similarity:
##### Higher threshold will produce tighter gate with more pure samples but less in numbers.

In [None]:
# threshold the similarity
threshold_slider = widgets.FloatSlider(
    value=0.85,
    min=0,
    max=1.0,
    step=0.01,
    description="Similarity Threshold:",
    disabled=False,
    continuous_update=False,
    orientation="horizontal",
    readout=True,
    readout_format='.2f',
)
display(threshold_slider)

In [None]:
similarity_threshold = threshold_slider.value
print(similarity_threshold)

high_sim_mask = sim_mat > similarity_threshold
np.save(f"./results/{target_species}_high_sim_mask.npy", high_sim_mask)
print(f"Alive-Similar data rows: {high_sim_mask.sum():,d}")

#### Visualize the threshold effect on a sample pair

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(7, 6.5))

ch1 = df_top_pairs.loc[0, "channel_1"]
ch2 = df_top_pairs.loc[0, "channel_2"]

ax.scatter(
    df_others[ch2].to_numpy(), df_others[ch1].to_numpy(),
    color="salmon", s=7, alpha=0.5
)
ax.scatter(
    df_positive_logicle[ch2].to_numpy()[high_sim_mask], df_positive_logicle[ch1].to_numpy()[high_sim_mask],
    color="limegreen", s=9, alpha=0.5
)

ax.set_xlabel(ch2)
ax.set_ylabel(ch1)

plt.show()

### Create the polygon gates for the selected top pairs:

In [None]:
logicle_polygon_gates = get_logicle_polygon_gates(df_top_pairs, df_positive_logicle, high_sim_mask)

len(logicle_polygon_gates)

In [None]:
# saving gates
pairs = [
    f"{channels[0]}|{channels[1]}" for channels in
    df_top_pairs[["channel_1", "channel_2"]].itertuples(index=False, name=None)
]

# with open(f"./results/{target_species}_top_gates.bin", mode="wb") as f:
#     pickle.dump({
#         pair:gate for pair, gate in zip(pairs, logicle_polygon_gates)
#     }, f)

In [None]:
# to free up some memory

del df_others
del df_positive_logicle
del weights
del sim_mat
del data

### Calculate the precision for each pair/gate

In [None]:
# load raw data to make positive and negative dataframes
df_positive = fk.Sample(sample_paths[target_species]).as_dataframe(source="raw")

df_negative = []
for sample_name in logicle_samples:
    fcs_file = sample_paths[sample_name]
    if sample_name != target_species:
        df_negative.append(
            fk.Sample(fcs_file).as_dataframe(source="raw")
        )
df_negative = pd.concat(df_negative)

# !important: drop second level columns (pns)
df_positive = df_positive.droplevel(level=1, axis=1)
df_negative = df_negative.droplevel(level=1, axis=1)

print(df_positive.shape, df_negative.shape)

In [None]:
# (Takes time!)
df_precisions = get_pair_precisions(
    df_top_pairs, logicle_polygon_gates, df_logicle_params,
    df_positive, df_negative
)

In [None]:
# df_precisions = df_precisions.sort_values("precision", ascending=False).reset_index(drop=True)
df_precisions

In [None]:
df_precisions.to_csv(f"./results/{target_species}_gate_precisions.csv", index=False)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(7.5, 6.5))
ax.scatter(
    df_precisions["precision"], df_precisions["fisher_avg"],
    color="dodgerblue", s=11, lw=0
)
ax.set_title(target_species)
ax.set_xlabel("Precision")
ax.set_ylabel("Fisher Ratio")
ax.grid(alpha=0.3)

plt.show()