In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt

In [2]:
from typing import List, Set, Dict, Tuple, Optional, Any
from collections import defaultdict

import pandas as pd
import seaborn as sns
import numpy as np

import math 
import torch
from torch import nn, Tensor
from torch.nn.functional import softplus, relu
from torch.distributions import Distribution, Normal
from torch.utils.data import DataLoader, Dataset

from gmfpp.utils.data_preparation import *
from gmfpp.utils.data_transformers import *
from gmfpp.utils.plotting import *
from gmfpp.utils.training import *

from gmfpp.models.ReparameterizedDiagonalGaussian import *
from gmfpp.models.CytoVariationalAutoencoder import *
from gmfpp.models.VariationalAutoencoder import *
from gmfpp.models.ConvVariationalAutoencoder import *
from gmfpp.models.VariationalInference import *
from gmfpp.models.LoadModels import *

In [3]:
constant_seed()

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


# Load Data

In [5]:
# here Andrea had to modify the "all" folder with "small", because there is no "all" folder in the repository
#metadata_all = read_metadata("./data/all/metadata.csv")
#metadata_all = read_metadata("./data/small/metadata.csv")
#metadata_all = read_metadata("./data/two_from_each_well/metadata.csv")
metadata_all = read_metadata("./data/mix_from_all/metadata.csv")
metadata_all = shuffle_metadata(metadata_all)

In [6]:
metadata = metadata_all[:2000]

In [7]:
metadata = shuffle_metadata(metadata)
metadata_train, metadata_validation = split_metadata(metadata, split_fraction = .90)

In [8]:
# same problem as with the dataset as above -- change this back to "all" iff needed

relative_path = get_relative_image_paths(metadata)
#image_paths = ["./data/all/" + path for path in relative_path]
#image_paths = ["./data/small/" + path for path in relative_path]
#image_paths = ["./data/two_from_each_well/" + path for path in relative_path]
image_paths = ["./data/mix_from_all/" + path for path in relative_path]

In [9]:
images = load_images(image_paths, verbose=True, log_every=10000)

14:37:52 | loaded 0/1150 images (0.00%).
14:38:18 | loaded 1150/1150 images (100.00%).


# Normalize Data

In [10]:
normalize_every_image_channels_seperately_inplace(images)
#normalize_channels_inplace(images)

In [11]:
channel_first = view_channel_dim_first(images)
for i in range(channel_first.shape[0]):
    channel = channel_first[i]
    print("channel {} interval: [{:.2f}; {:.2f}]".format(i, torch.min(channel), torch.max(channel)))

channel 0 interval: [0.02; 1.00]
channel 1 interval: [0.01; 1.00]
channel 2 interval: [0.01; 1.00]


# VAE

In [12]:
# VAE
image_shape = np.array([3, 68, 68])
latent_features = 256
vae = CytoVariationalAutoencoder(image_shape, latent_features).to(device) # @TODO: load trained parameters - see below

# load trained parameters
vae, validation_data, training_data, VAE_settings = LoadVAEmodel("pretrained", '2022-11-25 - 20-36-12')

# Finding Targe Cells for each Well and Compound/Concentration

In [13]:
images.shape[0]
batch_size=4

batch_offset = np.arange(start=0, stop=images.shape[0]+1, step=batch_size)
print(batch_offset)

[   0    4    8   12   16   20   24   28   32   36   40   44   48   52
   56   60   64   68   72   76   80   84   88   92   96  100  104  108
  112  116  120  124  128  132  136  140  144  148  152  156  160  164
  168  172  176  180  184  188  192  196  200  204  208  212  216  220
  224  228  232  236  240  244  248  252  256  260  264  268  272  276
  280  284  288  292  296  300  304  308  312  316  320  324  328  332
  336  340  344  348  352  356  360  364  368  372  376  380  384  388
  392  396  400  404  408  412  416  420  424  428  432  436  440  444
  448  452  456  460  464  468  472  476  480  484  488  492  496  500
  504  508  512  516  520  524  528  532  536  540  544  548  552  556
  560  564  568  572  576  580  584  588  592  596  600  604  608  612
  616  620  624  628  632  636  640  644  648  652  656  660  664  668
  672  676  680  684  688  692  696  700  704  708  712  716  720  724
  728  732  736  740  744  748  752  756  760  764  768  772  776  780
  784 

In [14]:
# extracting latent variables for each image/cell

def z_extraction(metadata, images, batch_size, vae):
  images.shape[0]
  batch_size=batch_size
  batch_offset = np.arange(start=0, stop=images.shape[0]+1, step=batch_size)

  df = pd.DataFrame()
  new_metadata = pd.DataFrame()

  for j, item in enumerate(batch_offset[:-1]):
      start = batch_offset[j]
      end = batch_offset[j+1]

      outputs = vae(images[start:end,:,:,:])
      z = outputs["z"]
      z_df = pd.DataFrame(z.detach().numpy())
      z_df.index = list(range(start,end))
      df = pd.concat([metadata.iloc[start:end], z_df], axis=1)
      new_metadata = pd.concat([new_metadata, df], axis=0)

  return new_metadata

In [15]:
nm = z_extraction(metadata, images, batch_size, vae)
nm

Unnamed: 0.1,Unnamed: 0,Multi_Cell_Image_Id,Multi_Cell_Image_Name,Single_Cell_Image_Id,Single_Cell_Image_Name,TableNumber,ImageNumber,Image_FileName_DAPI,Image_PathName_DAPI,Image_FileName_Tubulin,...,246,247,248,249,250,251,252,253,254,255
0,252112,375,G05_s1_w1EEDF0712-3112-4798-9254-788374B3EBD9,5,G05_s1_w1EEDF0712-3112-4798-9254-788374B3EBD9_...,4,3093,G05_s1_w1EEDF0712-3112-4798-9254-788374B3EBD9.tif,Week4_27801,G05_s1_w2A57BFC48-7BA3-43D1-9454-4221E458BD66.tif,...,1.710681,-3.405804,-3.985713,3.628738,0.086376,0.269028,-0.414974,2.483156,2.900129,2.757358
1,74895,721,Week1_150607_B05_s1_w12F4AB0D9-2FD7-4563-BF1D-...,4,Week1_150607_B05_s1_w12F4AB0D9-2FD7-4563-BF1D-...,1,2893,Week1_150607_B05_s1_w12F4AB0D9-2FD7-4563-BF1D-...,Week1_22361,Week1_150607_B05_s1_w2529E219F-9E07-4735-9206-...,...,-2.987455,1.409883,-0.083221,-0.817735,2.238451,-4.588113,2.337346,-0.314419,-4.823447,1.460214
2,141330,1613,Week3_290607_B05_s1_w1B12F5390-FE04-4013-BEE2-...,0,Week3_290607_B05_s1_w1B12F5390-FE04-4013-BEE2-...,3,13,Week3_290607_B05_s1_w1B12F5390-FE04-4013-BEE2-...,Week3_25421,Week3_290607_B05_s1_w2B57FE6C1-75D7-4B0F-901A-...,...,2.414869,-0.171636,-3.995393,3.888929,2.272825,2.030403,3.942008,-5.922760,3.553857,-2.744163
3,3441,471,Week10_200907_C05_s1_w1CCD0C22C-717E-4F62-8FF2...,6,Week10_200907_C05_s1_w1CCD0C22C-717E-4F62-8FF2...,0,53,Week10_200907_C05_s1_w1CCD0C22C-717E-4F62-8FF2...,Week10_40111,Week10_200907_C05_s1_w2DE115BC3-BC1E-4303-8219...,...,-0.641278,3.754029,2.287447,-3.682684,-2.764186,-0.506040,-4.550258,2.500869,-2.233270,-3.043130
4,13277,627,Week10_200907_F02_s1_w1DF5FDA94-5FCF-41B3-9656...,2,Week10_200907_F02_s1_w1DF5FDA94-5FCF-41B3-9656...,0,161,Week10_200907_F02_s1_w1DF5FDA94-5FCF-41B3-9656...,Week10_40111,Week10_200907_F02_s1_w26D7A740E-66D3-47CC-BEB1...,...,3.919448,2.927941,4.166117,-4.497015,3.456796,0.767842,4.958160,0.618127,-2.188950,-1.793081
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1143,4805,505,Week10_200907_C08_s1_w14962D4AB-2A48-4B3B-BBE1...,6,Week10_200907_C08_s1_w14962D4AB-2A48-4B3B-BBE1...,0,65,Week10_200907_C08_s1_w14962D4AB-2A48-4B3B-BBE1...,Week10_40111,Week10_200907_C08_s1_w2F9EB49AC-DE29-415C-A7A0...,...,-2.885714,-0.737390,2.667868,0.656025,-5.110986,0.824040,-2.869857,-0.329856,-4.926680,2.237576
1144,405644,3078,Week7_7__F04_s1_w1D8760743-4E7A-405F-8F24-C134...,7,Week7_7__F04_s1_w1D8760743-4E7A-405F-8F24-C134...,7,3049,Week7_7__F04_s1_w1D8760743-4E7A-405F-8F24-C134...,Week7_34661,Week7_7__F04_s1_w2500B0AA3-AAFB-4356-8C4E-34E4...,...,-0.018181,-2.274724,4.390459,-0.410926,0.215004,-2.517410,-4.201649,0.747309,-3.964204,2.449304
1145,6300,542,Week10_200907_D02_s1_w1E8853E7D-940A-46CA-A42C...,4,Week10_200907_D02_s1_w1E8853E7D-940A-46CA-A42C...,0,81,Week10_200907_D02_s1_w1E8853E7D-940A-46CA-A42C...,Week10_40111,Week10_200907_D02_s1_w2C6B52338-5CD2-4434-8398...,...,-2.106383,3.848512,-1.840647,3.920534,-0.208488,5.649556,-1.303567,3.660558,2.308335,-4.395069
1146,252107,375,G05_s1_w1EEDF0712-3112-4798-9254-788374B3EBD9,0,G05_s1_w1EEDF0712-3112-4798-9254-788374B3EBD9_...,4,3093,G05_s1_w1EEDF0712-3112-4798-9254-788374B3EBD9.tif,Week4_27801,G05_s1_w2A57BFC48-7BA3-43D1-9454-4221E458BD66.tif,...,5.228197,0.922124,-1.052521,-2.091652,-3.856126,-3.613803,2.957520,-4.654444,2.619969,0.196059


In [16]:
# Wells Profiles
def well_profiles(nm):
  wa = nm.groupby('Image_Metadata_Well_DAPI').mean().iloc[:,-256:]
  return wa

wa = well_profiles(nm)
wa

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,246,247,248,249,250,251,252,253,254,255
Image_Metadata_Well_DAPI,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
B02,2.197969,1.8477,1.075443,-0.755408,0.013666,-0.67136,-0.226263,1.732998,-1.505796,-1.051048,...,-0.879821,-0.773964,0.337882,-0.654532,-0.61545,0.176297,-2.29039,-0.964481,-1.50868,1.189335
B03,0.140139,0.389527,-0.097867,0.41505,-0.206038,-0.924571,-0.624326,0.787322,-0.103471,-0.144696,...,0.700939,-0.061221,-0.175222,0.75722,0.866617,-0.386845,-0.33361,-0.22675,0.081398,-0.23552
B04,-0.170455,0.775914,-0.457341,0.393027,0.465532,-0.075105,0.870882,-0.085963,0.725298,-0.325686,...,1.372616,-0.407576,-0.799089,-0.039321,0.516905,0.035407,0.355185,0.059719,0.565643,0.375286
B05,0.91841,0.451575,0.043338,0.814646,0.129472,0.230597,0.026372,0.758701,-0.276602,0.531423,...,0.518501,-0.689663,-0.315685,0.501627,0.872677,0.806053,0.354716,0.453001,0.301665,-0.049849
B06,-0.292256,0.880781,0.069513,0.500307,0.404478,-0.459916,0.05465,-0.364998,0.52803,0.194081,...,1.504596,-0.68179,-0.19644,0.363401,0.475419,0.408339,0.264879,-0.21325,-0.361162,-0.088381
B07,-1.112124,0.462457,-0.713039,0.034344,-0.905105,-0.368587,-0.355849,1.892814,0.595217,1.424073,...,3.018235,-1.031749,0.921873,1.03882,1.587109,1.350309,-0.749514,0.68702,0.917918,-0.754964
B08,-0.137934,-0.270942,-2.022458,-0.257244,1.290825,1.282376,-0.713167,1.468318,0.354764,-1.928588,...,1.341642,0.084517,-3.166451,-1.222314,0.243468,1.887459,0.332394,-1.228796,0.455811,-0.428544
B09,0.349717,0.041529,-0.864532,-1.02086,1.120184,-0.374426,0.275788,-1.38487,1.606933,-0.351569,...,2.17581,0.603814,-1.160025,-0.755953,0.40764,-0.86076,-2.212229,0.852319,-0.484672,1.480835
B10,-0.995846,-1.19156,-1.199732,-0.037252,-1.298525,-0.245511,-0.092454,0.344612,-0.747745,-0.187368,...,1.783979,0.321939,-0.590307,1.789766,-1.215955,0.560221,-0.341455,0.145511,-0.652745,0.251446
B11,-0.22109,1.374165,-0.115911,1.495991,-0.979506,-0.20727,0.748923,0.689916,-0.521191,-0.482381,...,-0.02369,-0.362566,1.725099,-1.951243,0.553803,-0.719344,-0.167915,0.135322,0.404172,0.556688


In [17]:
# function to get the cell closest to each Well profile

def well_center_cells(df,well_profiles,p=2):
  wcc = []
  for w in well_profiles.index:
    diffs = (abs(df[df['Image_Metadata_Well_DAPI'] == w].iloc[:,11:] - well_profiles.loc[w])**p)
    diffs_sum = diffs.sum(axis=1)**(1/p)
    diffs_min = diffs_sum.min()
    wcc.append(diffs[diffs_sum == diffs_min].index[0])
  
  return wcc

In [18]:
# Compount/Concentration Profiles
def CC_Profile(nm):
  cc =  nm.groupby(['Image_Metadata_Compound','Image_Metadata_Concentration']).median().iloc[:,-256:]
  return cc

ccp = CC_Profile(nm)
ccp

Unnamed: 0_level_0,Unnamed: 1_level_0,0,1,2,3,4,5,6,7,8,9,...,246,247,248,249,250,251,252,253,254,255
Image_Metadata_Compound,Image_Metadata_Concentration,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
ALLN,3.00,0.017081,-0.988462,-1.464466,1.321767,-0.504487,1.134424,0.741944,-1.176179,0.206689,0.911776,...,0.036158,-3.279941,-1.054231,0.745838,-0.934788,-0.963632,0.598804,-1.875561,1.504205,1.159413
ALLN,100.00,-2.880006,-0.447961,-1.290977,-0.876084,1.671228,1.318244,-0.254659,1.841938,1.161055,-2.340252,...,2.315969,1.482379,0.235648,2.007025,1.276002,-0.976427,1.654652,-2.068170,0.263408,-0.633963
AZ-A,0.10,0.340437,-1.289059,-0.809800,-2.584209,1.043596,0.203519,2.204541,-0.317576,-1.106602,-2.410073,...,-1.879795,0.464195,-0.678599,0.595501,1.529669,-1.620397,0.641331,0.536533,1.461800,0.072432
AZ-A,0.30,-0.235774,0.676192,-1.294827,-2.204907,-0.816374,-0.860257,-0.625899,0.063125,-1.833860,-0.194444,...,-0.950566,-0.115478,-0.324740,-0.170305,0.937828,-0.898056,0.444940,1.124071,0.162410,0.387235
AZ-A,1.00,-0.035717,-0.371279,1.161689,-0.613616,0.229554,0.271101,1.067837,-0.839245,-0.952492,-1.685497,...,-1.344659,2.134080,0.084457,-0.190118,-1.244428,0.874029,-0.944954,0.389654,-0.176657,-1.249034
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
vincristine,0.03,0.552455,-0.403518,-3.376607,-0.641183,1.911475,2.190751,-2.146626,2.080235,1.550187,-2.008544,...,0.987178,0.953620,-3.051330,-0.978127,0.363129,2.483466,0.629518,-1.549022,-0.867686,-1.178651
vincristine,0.10,-1.427478,-0.269944,-0.788872,1.335124,-3.028841,0.003757,-0.596179,2.634957,-0.253881,2.611570,...,3.873374,-2.082667,1.519566,1.776839,2.329321,1.608731,-0.132060,0.732994,0.699666,-0.518309
vincristine,0.30,-1.547221,1.914800,-0.360947,1.193521,2.287811,0.252378,1.014646,-2.512022,2.872985,-0.233791,...,3.067448,-0.172361,-0.706334,-2.980991,0.417384,1.826716,1.237555,0.549271,-2.957500,0.796399
vincristine,1.00,0.526094,-1.518952,0.654153,-0.977310,1.393262,0.797621,-0.246117,-1.625960,2.120208,0.997166,...,1.876896,-2.596971,-0.661215,1.924254,0.623905,1.806457,-0.852041,0.560599,3.205488,-0.776584


In [19]:
# function to get the cell closest to each Compound/Concentration profile

def cc_center_cells(df,cc_profiles,p=2):
  cc_center_cells = []
  for cc in ccp.index:
    diffs = (abs(df[(df['Image_Metadata_Compound'] == cc[0]) & (nm['Image_Metadata_Concentration'] == cc[1])].iloc[:,-256:] - cc_profiles.loc[cc]))**p
    diffs_sum = diffs.sum(axis=1)**(1/p)
    diffs_min = diffs_sum.min()
    cc_center_cells.append(diffs[diffs_sum == diffs_min].index[0])
  
  return cc_center_cells

cc_center_cells(nm, ccp, p=2)

[747,
 474,
 56,
 497,
 177,
 95,
 695,
 317,
 525,
 575,
 1077,
 562,
 320,
 417,
 9,
 467,
 15,
 988,
 934,
 1098,
 507,
 156,
 379,
 832,
 274,
 49,
 965,
 454,
 522,
 59,
 534,
 75,
 556,
 1138,
 392,
 332,
 1144,
 655,
 650,
 997,
 610,
 760,
 163,
 1002,
 659,
 617,
 742,
 250,
 50,
 335,
 134,
 691,
 653,
 393,
 164,
 494,
 940,
 303,
 756,
 842,
 719,
 1113,
 503,
 681,
 203,
 147,
 178,
 901,
 933,
 1043,
 994,
 252,
 900,
 941,
 601,
 202,
 69,
 820,
 297,
 79,
 578,
 1118,
 686,
 439,
 917,
 113,
 246,
 652,
 1089,
 980,
 1055,
 516,
 139,
 445,
 943,
 1057,
 72,
 790,
 231,
 518,
 237,
 955,
 363,
 387]