In [None]:
import pandas as pd 
import numpy as np
from scipy import stats

import matplotlib.pyplot as plt
%matplotlib inline

import mahalanobis as maha

In [None]:
df = pd.read_parquet('data/covid_screen.parquet')
df.rename(columns={"cbkid": "batch_id"}, inplace=True)

# Reorder the columns: first 7 are metadata, others are morphology
columns = df.columns.tolist()
new_columns = columns[:2] + columns[-5:] + columns[2:-5] 
df = df[new_columns]

In [None]:
df_healthy = df[df['batch_id'] == 'Uninfected']
x_healthy = df_healthy.iloc[:, 7:].to_numpy()

sigma, rank = maha.get_singular_values(x_healthy)
plt.scatter(range(rank), sigma[:rank])

n_components = 5
plt.axhline(sigma[n_components - 1], c='brown', label=n_components)

n_components = np.count_nonzero(sigma > 5)
plt.axhline(sigma[n_components - 1], c='chocolate', label=n_components)

# optimal threshold for PCA of Healthy cells
tau = maha.svht(x_healthy)
plt.axhline(sigma[int(tau)], c='k', label=f'{int(tau)} (svht)')

plt.grid()
plt.yscale('log')
plt.legend(fontsize=12)
plt.xlabel('number of components', fontsize=14)
plt.ylabel('singular value', fontsize=14)
plt.title("Threshold for PCA on Healthy cells")
plt.show()

In [None]:
x_all = df.iloc[:, 7:].copy()
clip = 50 # if we want to clip eventual extreme values; otherwise float('inf')
x_all.clip(lower=-clip, upper=clip, inplace=True)

In [None]:
# Predictions based on whitening transform
n_components = 5

# Distance to Healthy (Mahalanobis)
df_healthy = df[df['batch_id'] == 'Uninfected']
x_healthy = df_healthy.iloc[:, 7:].to_numpy()
dist_pred = maha.get_distance(x_healthy, x_all, n_components=n_components)
dist_pred = pd.DataFrame(dist_pred, columns=['distance'])

# Probability of Healthy
proba_pred = maha.get_proba(x_healthy, x_all, n_components=n_components)
proba_pred = pd.DataFrame(proba_pred, columns=['proba'])

In [None]:
predictions = pd.concat([df.iloc[:, :7], dist_pred, proba_pred], axis=1)

In [None]:
def quantile(n):
    def quantile_(x):
        return x.quantile(n)
    quantile_.__name__ = f'q_{n*100:.0f}'
    return quantile_

selected_columns = predictions[[
    'batch_id',
    'Count_nuclei',
    'distance',
    # 'proba',
    'name',
]].copy()

# Add concetration column (dummy due to only single dose measurements in screening)
selected_columns['conc'] = 10.

# Aggregate replicates
average_predictions = selected_columns.groupby(by=['batch_id', 'name', 'conc'], as_index=False).agg({
        'Count_nuclei': 'median',
        'distance': [quantile(.25), quantile(.75), 'median'],
        # 'proba': 'median',
})

# Rename columns
custom_columns = [''.join(col) for col in average_predictions.columns.to_flat_index() if not col[1]]
custom_columns += ['_'.join(col) for col in average_predictions.columns.to_flat_index() if col[1]]
average_predictions.columns = custom_columns
average_predictions.rename(columns={'Count_nuclei_median': 'count_nuclei'}, inplace=True)
average_predictions.rename(columns={'distance_median': 'distance'}, inplace=True)

# Sort by distance to non-infected
average_predictions.sort_values(by=['distance'], ascending=True, inplace=True)

# Quantiles are written as plus/minus to the median
average_predictions['distance_q_75'] = average_predictions['distance_q_75'] - average_predictions['distance']
average_predictions['distance_q_25'] =  average_predictions['distance'] - average_predictions['distance_q_25']

average_predictions.reset_index(inplace=True, drop=True)
average_predictions.to_csv("output/Covid_screen_ranked.csv", sep=';')
average_predictions.head(20)

In [None]:
import plotly.express as px
import plotly.graph_objects as go


def plot_compounds_interactive(compounds, dmso=None):
    drug_names = {
        id_: name_ for (id_, name_) in zip(compounds['batch_id'], compounds['name'])
    }

    drugs = compounds[~compounds['batch_id'].isin(['Uninfected', 'DMSO'])]
    fig = px.line(
        drugs.sort_values(by="conc"), x="conc", y="distance",
        color='name', markers=True,
        width=1200, height=400,
        error_y='distance_q_75', error_y_minus='distance_q_75',
        category_orders={"name": drug_names.values()},
        hover_name="name", 
        hover_data={"name": False, "count_nuclei": ':.1f', "conc": True, "distance": ':.3f'},
    )

    # Healthy cells - confidence intervals
    for i, conf in enumerate([0.5, 0.75, 0.95]):
        label = dict(
            text="Non-infected", font=dict(size=15, color="black"), textposition="middle left"
        ) if i == 0 else None

        left, right = stats.chi2.interval(confidence=conf, df=n_components)
        fig.add_hrect(
            type="rect",
            y0=np.sqrt(left), y1=np.sqrt(right),
            fillcolor="royalblue", opacity=0.8*(1-conf)**0.5,
            layer="below", line_width=0, label=label
        )

    # DMSO
    if dmso is not None:
        conc = range(0, 31, 1)
        median = dmso['distance'].iloc[0]
        perc_25 = dmso['distance_q_25'].iloc[0]
        perc_75 = dmso['distance_q_75'].iloc[0]
        fig.add_hrect(
            type="rect",
            y0=median-perc_25, y1=median+perc_75,
            fillcolor="grey", opacity=0.25,
            layer="below", line_width=0,
            label=dict(
                text="DMSO", font=dict(size=15, color="black"), textposition="middle left"
            )
        )

    conc = np.unique(compounds['conc'])
    fig.update_xaxes(title_text="concentration", gridcolor='lightgrey', type="log", tickvals=conc)
    fig.update_yaxes(title_text="distance", gridcolor='lightgrey', type="log", tickvals=[1, 5, 10, 40, 80, 100])
    fig.update_layout(
        title_text="Distance to non-infected cell distribution",
        plot_bgcolor='white'
    )

    fig.show()


In [None]:
plt.hist(average_predictions['distance'])
plt.title("Histogram of distances")
plt.show()

In [None]:
# compounds = average_predictions.copy()
compounds = average_predictions.copy()
dmso = compounds[compounds['batch_id'] == 'DMSO']

plot_compounds_interactive(compounds, dmso)

### Primary hits

In [None]:
# Read the type of drug: primary hit, secondary hit, or analog

subset = pd.read_csv("data/metadata_annotations.csv", usecols=["CBCSbatch_id", "arm"])
subset.rename(columns={"CBCSbatch_id": "batch_id", "arm": "type"}, inplace=True)

with open("data/id_to_name.txt", 'r') as file:
    table = file.readlines()
    id_to_name = {}
    for i, line in enumerate(table):
        id_, name = line.rstrip().split('\t')
        id_to_name[id_] = name

subset['name'] = subset['batch_id'].map(id_to_name)

In [None]:
# Extract primary hit
primary_hit_name = subset[subset['type'] == 'primary_hit']['name'].to_list()
primary_hits = compounds[compounds['name'].isin(primary_hit_name)].copy()

plot_compounds_interactive(primary_hits, dmso)

### Secondary hits

In [None]:
# Extract secondary hit
secondary_hit_name = subset[subset['type'] == 'secondary_cp_hit']['name'].to_list()
secondary_hits = compounds[compounds['name'].isin(secondary_hit_name)].copy()

plot_compounds_interactive(secondary_hits, dmso)

Output `secondary_hits` and `DMSO` as a reference

In [None]:
secondary_hits_to_save = pd.concat([secondary_hits, dmso], ignore_index=True)
secondary_hits_to_save.rename(columns={'batch_id': 'cbkid'}, inplace=True)
secondary_hits_to_save.to_csv("output/secondary_hits_screen.csv", sep=';', index=False)