In [14]:
import xgboost
import shap 

import numpy as np
from cell_paint_seg.utils import get_id_to_path, check_valid_labels, threat_score
from cell_paint_seg.image_io import read_ims, read_seg
from skimage import io, exposure, measure, transform
from skimage.measure import label, regionprops
from pathlib import Path
import napari
from tqdm.notebook import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import umap
import umap.plot
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score 
from PIL import Image
import shutil

# Prepare dataset

In [None]:
dir = "/Users/thomasathey/Documents/shavit-lab/fraenkel/papers/cvpr/data/all/processed"
id_to_path = get_id_to_path(dir, tag ='.tif', id_from_name=id_from_name_dataset)

for id, paths in id_to_path.items():
    ims = read_ims(paths)
    im = np.stack(ims, axis=0).astype(np.float32)

    # if 'e4' in id:
    #     np.save(f'/Users/thomasathey/Documents/shavit-lab/fraenkel/papers/cvpr/data/all/als_v_control_2/test/{id}.npy', im)
    # else:
    #     np.save(f'/Users/thomasathey/Documents/shavit-lab/fraenkel/papers/cvpr/data/all/als_v_control_2/train/{id}.npy', im)


## Single neuron

In [3]:
path_patch_dataset = Path("/Users/thomasathey/Documents/shavit-lab/fraenkel/papers/cvpr/data/all/patches/")

In [None]:
def id_from_name_single(name):
    id = "_".join(name.split("_")[:3])
    return id

seg_dir = "/Users/thomasathey/Documents/shavit-lab/fraenkel/papers/cvpr/data/all/tifs_3channel"
id_to_path_seg = get_id_to_path(seg_dir, tag ='_masks.png', id_from_name=id_from_name_single)

im_dir = "/Users/thomasathey/Documents/shavit-lab/fraenkel/papers/cvpr/data/all/processed"
id_to_path_im = get_id_to_path(im_dir, tag ='.tif', id_from_name=id_from_name_single)

print(len(list(id_to_path_seg.keys())))

for id in tqdm(id_to_path_seg.keys()):
    path_seg = id_to_path_seg[id]
    paths_im = id_to_path_im[id]
    seg = read_seg(path_seg)
    ims = read_ims(paths_im)
    im = np.stack(ims, axis=0).astype(np.float32)
    
    regions = regionprops(seg)
    for region in tqdm(regions, leave=False):
        centroid = region.centroid
        x = int(centroid[0])
        y = int(centroid[1])

        x0 = x - 32
        y0 = y - 32
        x1 = x + 32
        y1 = y + 32

        if x0 < 0 or y0 < 0 or x1 > seg.shape[0] or y1 > seg.shape[1]:
            continue
        else:
            im_crop = im[:, x0:x1, y0:y1]
            if not np.any(im_crop == -1):
                # np.save(path_patch_dataset / f'{id}_{x}_{y}.npy', im_crop)
                im[:, x0:x1, y0:y1] = -1
            

In [None]:
# read all npy files in path_patch_dataset
files = list(path_patch_dataset.glob('*.npy'))
for file in tqdm(files):
    im = np.load(file)
    assert im.shape == (6, 64, 64)
    assert not np.any(im == -1)
    
ids = ["_".join(f.stem.split("_")[:3])for f in files]
unq, cts = np.unique(ids, return_counts=True)
for unq, ct in tqdm(zip(unq, cts)):
    area = ct*64*64
    assert area < 1024*1024

  0%|          | 0/59191 [00:00<?, ?it/s]

0it [00:00, ?it/s]

### Balanced dataset

In [34]:
path_patch_balanced = Path("/Users/thomasathey/Documents/shavit-lab/fraenkel/papers/cvpr/data/all/patches_balanced")

def id_from_name_patch(name):
    id = "_".join(name.split("_")[:2])
    condition = condition_from_id(id)
    line = line_from_id(id)
    id = f"{id[:2]}_{condition}_{line}"
    return id

id_to_path_patch = get_id_to_path(path_patch_dataset, tag ='.npy', id_from_name=id_from_name_patch)
num_cells = [len(val) for key, val in id_to_path_patch.items()]

min = np.amin(num_cells)

for id, paths in id_to_path_patch.items():
    np.random.shuffle(paths)
    paths = paths[:min]

    for path in paths:
        path_balanced = path_patch_balanced / path.name
        shutil.copyfile(path, path_balanced)


In [53]:
indir = "/Users/thomasathey/Documents/shavit-lab/fraenkel/papers/cvpr/data/all/patches_balanced"
outdir = Path("/Users/thomasathey/Documents/shavit-lab/fraenkel/papers/cvpr/data/all/patches_balanced_clahe")

files = list(Path(indir).glob('*.npy'))
for file in tqdm(files):
    im = np.load(file)
    im /= 255
    for c in range(im.shape[0]):
        im[c] = exposure.equalize_adapthist(im[c])
    
    np.save(outdir / file.name, im)


  0%|          | 0/20900 [00:00<?, ?it/s]

In [54]:
print((np.amin(im), np.amax(im)))

(0.0, 1.0)


In [55]:
indir = "/Users/thomasathey/Documents/shavit-lab/fraenkel/papers/cvpr/data/all/patches_balanced_clahe"
outdir = Path("/Users/thomasathey/Documents/shavit-lab/fraenkel/papers/cvpr/data/all/patches_balanced_tif")

files = list(Path(indir).glob('*.npy'))
for file in tqdm(files):
    im = np.load(file)
    im *= 255
    im = im.astype(np.uint8)
    #im = np.moveaxis(im, 0, -1)
    io.imsave(outdir / f'{file.stem}.tif', im)

  0%|          | 0/20900 [00:00<?, ?it/s]

  io.imsave(outdir / f'{file.stem}.tif', im)


# Universal

In [25]:
def id_from_name_dataset(name):
    id = name[:-7]
    return id 

def condition_from_id(id):
    row_to_roid = {"B":0, "C":1, "D":2, "E":3, "F":4, "G":5}

    well = id.split("_")[1]
    row = well[1]
    row = row_to_roid[row]
    col = int(well[2:])
    col -= 2

    conditions = [[1,1,2,2,5,1,1,4,5,5],
                  [3,3,4,4,5,2,2,4,3,3],
                  [5,2,2,1,1,2,2,3,3,5],
                  [3,3,4,4,5,5,4,4,1,1],
                  [4,4,1,1,3,2,2,1,3,3],
                  [5,5,2,2,3,5,5,1,4,4]]
    
    return conditions[row][col] - 1

def line_from_id(id):
    row_to_roid = {"B":0, "C":1, "D":2, "E":3, "F":4, "G":5}

    e = int(id[1])
    well = id.split("_")[1]
    row = well[1]
    row = row_to_roid[row]
    col = int(well[2:])

    if row < 2:
        if col < 7:
            e_to_line = {1:"RFTiALS", 2:"AE8iCTR", 3:"ADKiCTR", 4:"EGMiALS"}
        else:
            e_to_line = {1:"AE8iCTR", 2:"BFUiALS", 3:"ZLMiALS", 4:"XH7iCTR"}
    elif row < 4:
        if col < 7:
            e_to_line = {1:"ZKZiCTR", 2:"KRCiALS", 3:"BFUiALS", 4:"ADKiCTR"}
        else:
            e_to_line = {1:"TJViALS", 2:"XH7iCTR", 3:"NK3iCTR", 4:"LJXiALS"}
    else:
        if col < 7:
            e_to_line = {1:"DG9iALS", 2:"ZKZiCTR", 3:"ZKZiCTR", 4:"NK3iCTR"}
        else:
            e_to_line = {1:"XH7iCTR", 2:"RJViALS", 3:"AFGiALS", 4:"AE8iCTR"}
    
    
    return e_to_line[e]

path_dataset = "/Users/thomasathey/Documents/shavit-lab/fraenkel/papers/cvpr/data/all/processed"


cond_to_cond = {1: "KPT", 2:"H2O2", 3:"Tunicamycin", 4:"Autophagy", 5:"DMSO"}
channels = ["ER", "DNA", "Mito", "Actin", "RNA", "Golgi/membrane"]
res = 6.9e-7

id_to_path = get_id_to_path(path_dataset, tag=".tif", id_from_name=id_from_name_dataset)

print(f"{len(id_to_path.keys())} samples found")

720 samples found


# Make Dataset

In [None]:
tile_test = True
ds_size = 1024//8
tile_sz = 64//2

data = []
y = []
for id, paths in tqdm(id_to_path.items()):
    condition = condition_from_id(id)
    line = line_from_id(id)
    if "ALS" in line:
        disease = 1
    else:
        disease = 0
    exp = int(id[1])
    ims = read_ims(paths)

    if exp == 4:
        ims = [transform.resize(im, (ds_size,ds_size), anti_aliasing=True) for im in ims]
        if tile_test:
            for x0 in range(0, ds_size, tile_sz):
                for y0 in range(0, ds_size, tile_sz):
                    ims_flat = [im[x0:x0+tile_sz,y0:y0+tile_sz].flatten() for im in ims]
                    x = np.concatenate(ims_flat + [[disease, condition, exp]], axis=0)
                    data.append(x)
        else:
            ims_flat = [im[8:24,8:24].flatten() for im in ims]
            x = np.concatenate(ims_flat + [[disease, condition, exp]], axis=0)
            data.append(x)
    else:
        ims = [transform.resize(im, (ds_size,ds_size), anti_aliasing=True) for im in ims]
        for x0 in range(0, ds_size, tile_sz):
            for y0 in range(0, ds_size, tile_sz):
                ims_flat = [im[x0:x0+tile_sz,y0:y0+tile_sz].flatten() for im in ims]
                x = np.concatenate(ims_flat + [[disease, condition, exp]], axis=0)
                data.append(x)



data = np.stack(data, axis=0)
columns = [f"{channel}_{i}" for channel in channels for i in range(ims_flat[0].size)] + ["Disease", "Condition", "Experiment"]

df = pd.DataFrame(data, columns=columns)
df.shape  

In [None]:
#df.to_csv("/Users/thomasathey/Documents/shavit-lab/fraenkel/papers/cvpr/data/all/dataframes/df_128_32.csv")

In [None]:
df = df[df["Condition"] == 0]

In [None]:
separate_plates = True 

dep_var = "Condition"

for exclude in range(1,5):
    if separate_plates:
        df_train = df[df["Experiment"] != exclude]
        df_test = df[df["Experiment"] == exclude]

        x_train = df_train.drop(columns=["Condition", "Experiment", "Disease"])
        y_train = df_train[dep_var].values
        x_test = df_test.drop(columns=["Condition", "Experiment", "Disease"])
        y_test = df_test[dep_var].values
    else:
        raise NotImplementedError()
        x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=15, stratify=y)

    model = xgboost.XGBClassifier().fit(x_train, y_train)
    y_pred = model.predict(x_test)
    acc = accuracy_score(y_test, y_pred)

    confusion_mat = confusion_matrix(y_test,y_pred)

    print("Accuracy is",acc)
    print("Confusion Matrix")
    print(confusion_mat)

In [None]:
ims_crop = [ims[16:32,16:32] for ims in ims]

f, ax = plt.subplots(2, 3, figsize=(10, 8))
for im, ax, channel in zip(ims_crop, ax.flatten(), channels):
    ax.imshow(im, cmap='gray')
    # turn off axes for ax
    ax.set_title(channel)
    ax.axis('off')


ims[0].shape

# Predict

In [None]:
model = xgboost.XGBClassifier().fit(x_train, y_train)
y_pred = model.predict(x_test)
acc = accuracy_score(y_test, y_pred)

confusion_mat = confusion_matrix(y_test,y_pred)

print("Accuracy is",acc)
print("Confusion Matrix")
print(confusion_mat)

In [None]:
# explain the model's predictions using SHAP
# (same syntax works for LightGBM, CatBoost, scikit-learn, transformers, Spark, etc.)
explainer = shap.Explainer(model)
shap_values = explainer(x_test)


In [None]:
print(f"values: {shap_values.values.shape}, base_values: {shap_values.base_values.shape}, data: {shap_values.data.shape}")
# data is just data
# base_values is the expected value of the model (one for each class - prior distribution?
# values is the shapley values for each feature nxdxk where each entry is contribution of that sample's feature to that class

In [None]:
shap_values.values.shape

In [None]:
print(np.sum(shap_values.values, axis=1)[0,:] + shap_values.base_values[0,:])
model.predict_proba(x_test)[0,:]

In [None]:
# # visualize the first prediction's explanation
shap.plots.waterfall(shap_values[0,:,0])
print(shap_values.values[0,5*256+60,0])
print(shap_values.data[0,5*256+60])
print(shap_values.base_values[0,0])
print(np.sum(shap_values.values[0,:,0])+shap_values.base_values[0,0])

In [None]:
for condition in range(5):
    channel_sums = np.zeros((shap_values.shape[0], len(channels)))

    for sample in range(shap_values.shape[0]):
        for channel in range(len(channels)):
            channel_sums[sample, channel] = np.abs(np.sum(shap_values.values[sample,channel*256:(channel+1)*256,condition]))

    df = pd.DataFrame(channel_sums, columns=channels)

    sns.stripplot(df)
    plt.title(f"{cond_to_cond[condition+1]}")
    plt.show()

In [None]:
channel_sums = np.zeros((shap_values.shape[0], len(channels)))

for sample in range(shap_values.shape[0]):
    for channel in range(len(channels)):
        channel_sums[sample, channel] = np.abs(np.sum(shap_values.values[sample,channel*256:(channel+1)*256]))

df = pd.DataFrame(channel_sums, columns=channels)

sns.stripplot(df)
plt.title(f"ALS Prediction")
plt.show()