In [100]:
import anndata as ad
import scanpy as sc
from sklearn.model_selection import train_test_split
import pylab as pl
import seaborn as sns
from limix_core.util.preprocess import gaussianize, regressOut
import scipy.stats as st
from sklearn.impute import SimpleImputer
import scipy.linalg as la
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
from os.path import join
from mtgwas import VCTEST
from mtgwas.utils import df_match

import sys

from torchvision.utils import make_grid
import json
from PIL import Image
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
import torch
from models.progressive_gan import ProgressiveGAN as PGAN
from mpl_toolkits.axes_grid1 import ImageGrid


In [101]:
#Settings 
path = '/Users/andrewdenny/Desktop/NIHPostBac/repos/HistoGWAS_PGAN/PGAN/Notebooks'
os.chdir(path)
outliers = 0.01
extreme = 0.05
np.random.seed(42)
days = [15, 20, 25, 30]
interpolations = 5
analysis_group = 'TCF7L2_ko'

In [102]:
class Generator():
    
    def __init__(self, config, checkpoint, useGPU=True):
        with open(config, 'rb') as file:
            config = json.load(file)
        self.pgan = PGAN(useGPU=useGPU, storeAVG=True, **config)
        self.pgan.load(checkpoint)
        self.netG = self.pgan.netG
        self.device = self.pgan.device
        
    def forward(self, x, eps=None):
        if eps is None:
            eps = torch.randn(x.shape[0], 512)
        if type(x)==np.ndarray:
            x = torch.Tensor(x)
        if type(eps)==np.ndarray:
            eps = torch.Tensor(eps)
        x = x.to(self.device)
        eps = eps.to(self.device)
        with torch.no_grad():
            out = self.netG(eps, x).data.cpu()
            out = 0.5 * (out + 1)
            out = torch.clip(out, 0, 1)
        return out
    
def load_image_torch(path, size):
    if type(path) in [list, np.ndarray]:
        return torch.cat([load_image_torch(_, size) for _ in path])
    return pil_to_tensor(Image.open(path).resize((size, size)))[None] / 255.


def torch_imshow(x):
    pl.imshow(x.permute(1, 2, 0))

In [103]:
checkpoint = '/Users/andrewdenny/Desktop/NIHPostBac/repos/Organoid_s6_i736000.pt'
config = '/Users/andrewdenny/Desktop/NIHPostBac/repos/HistoGWAS_PGAN/PGAN/config/config_OrganoidLocal.json'
#checkpoint = '/Users/dennyal/Desktop/repos/Organoid_s6_i320000.pt'
generator = Generator(config, checkpoint, useGPU=True)

here False
Average network found !


In [104]:
tissue = 'Organoid'
outdir = f'visualization/{tissue}'
os.makedirs(outdir, exist_ok=True)

In [105]:
dfX = pd.read_csv("./org_features_metadata.csv.gz")
dfX['SampleBarcode'] = dfX['SampleBarcode'].astype('category')
print(dfX.columns)

Index(['Day', 'Well', 'Well literal', 'SampleBarcode', 'Plate', 'Cell_density',
       'Run_ID', 'MinDiameter_shape', 'MaxDiameter_shape',
       'MeanDiameter_shape',
       ...
       'Eccentricity_halo', 'Orientation_halo', 'Compactness_halo',
       'analysis_group', 'edit_id_-/-', 'edit_id_CC', 'edit_id_CT',
       'edit_id_TT', 'edit_id_WT/-', 'edit_id_WT/WT'],
      dtype='object', length=185)


In [106]:
#trait = input("Trait: ")
trait = 'MeanDiameter_shape'

In [107]:
def filter_df(df, day, analysis_group):
    #Get Day
    day_df = df[(df['Day'] == day)].copy()
    day_df = day_df.drop(columns = ['Day', 'Plate', 'Well', 'Cell_density', 'Well literal', "Run_ID"])

    #Mean by SampleBarcode and analysis group
    mean_df = day_df.groupby(["SampleBarcode", "analysis_group"]).mean().reset_index()

    #filter for analysis_group

    day_ag_df = mean_df[mean_df['analysis_group'] == analysis_group]

    # Drop tha NAs Values
    final_df = day_ag_df.dropna(subset=["edit_id_-/-", trait])
    return final_df



In [108]:
def vc_preproccess(df, trait, analysis_group):
    if analysis_group == 'TCF7L2_ko':
        filter_array = ['edit_id_-/-', 'edit_id_WT/-', 'edit_id_WT/WT']
    else:
        filter_array = ['edit_id_CC', 'edit_id_CT', 'edit_id_TT']

    X = df.loc[:, filter_array].copy()
    trait_df = pd.DataFrame(np.asarray(df[trait]))
    y = trait_df.values
    F = np.zeros((df.shape[0], 1))
    return X, y, F


In [109]:
def generate(emb1, emb2, interpolations):
    inter = np.linspace(0, 1, interpolations)[:, None]

    embs = emb1 * (1 - inter) + emb2 * inter
    eps = np.random.randn(1, 512) * np.ones([embs.shape[0], 1])
    xinter = generator.forward(embs, eps)
    print(xinter.shape)
    return xinter



In [112]:
def create_visualization(images, trait, analysis_group):
    n_rows = len(days)
    n_cols = interpolations

    fig = plt.figure(figsize=(38, 28), dpi=150, constrained_layout=True)
    grid = ImageGrid(fig, 111,nrows_ncols=(n_rows, n_cols),
        axes_pad=(0.25, 0.25))

    np_images = images.numpy()
    count = 0
    for ax, img in zip(grid, np_images):
        if count == 0:
            ax.set_title('WT/WT', fontsize=50)
        if count == 2:
            ax.set_title('WT/-', fontsize=50)
        if count == 4:
            ax.set_title('-/-', fontsize=50)
        print(ax)

        if count % (interpolations) == 0:
            print(count)
            ax.set_ylabel(f'Day {days[count // interpolations]}', fontsize=50)

        ax.imshow(img.squeeze(), cmap='gray')
        ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
        count += 1
    plt.suptitle(f"{trait} {analysis_group}", fontsize=60, y=0.995)  # adjust y to move it up/down
    return fig
   

In [None]:
image_acc = torch.zeros((0, 1, 256,256))
for day in days:  
       df = filter_df(dfX, day = day, analysis_group = analysis_group)
       X,y, F = vc_preproccess(df, trait, analysis_group= analysis_group)
       yr = regressOut(y, F)
       Xr = regressOut(X, F)
       vc = VCTEST()
       res = vc.fit(Xr, yr, compute_pvals=True, normalize_X=False)
       ystar = vc.predict_loo()
       features =  X.to_numpy()
       df['ystar'] = vc.predict(features).ravel()

       df_embeddings = df.drop(columns = ['SampleBarcode', 'analysis_group','ystar',  'edit_id_-/-', 'edit_id_CC', 'edit_id_CT', 'edit_id_TT', 'edit_id_WT/-',
              'edit_id_WT/WT']).copy()


       q1, q2, Q1, Q2 = np.quantile(df['ystar'].values, [outliers, extreme, 1 - extreme, 1 - outliers])
       Ih1 = np.logical_and(df['ystar'].values>=Q1, df['ystar'].values<=Q2)
       Il1 = np.logical_and(df['ystar'].values>=q1, df['ystar'].values<=q2)

       # interpolates
       emb1 = df_embeddings[Il1].mean().values

       emb2 = df_embeddings[Ih1].mean().values


       xinter_grid = generate(emb1, emb2, interpolations)
       image_acc = torch.concat((image_acc, xinter_grid))

fig = create_visualization(image_acc, trait, analysis_group)
plt.show(fig)
fig.savefig(f'{outdir}/{trait}_interpolation_{analysis_group}.png', dpi=150)



#plt.title(f'{trait} interpolation Day {day} TCF7L2_ko')

# emb1 = dfX_30_embeddings.iloc[np.where(dfX_30_embeddings[Il1][trait] == min(dfX_30_embeddings[Il1][trait]))].values
# emb2 = dfX_30_embeddings.iloc[np.where(dfX_30_embeddings[Ih1][trait] == max(dfX_30_embeddings[Ih1][trait]))].values




  mean_df = day_df.groupby(["SampleBarcode", "analysis_group"]).mean().reset_index()


100%|██████████| 100/100 [00:00<00:00, 6074.30it/s]
100%|██████████| 100/100 [00:00<00:00, 8631.85it/s]
100%|██████████| 100/100 [00:00<00:00, 8913.43it/s]
100%|██████████| 100/100 [00:00<00:00, 9860.83it/s]
100%|██████████| 100/100 [00:00<00:00, 9897.36it/s]
100%|██████████| 100/100 [00:00<00:00, 10291.26it/s]
100%|██████████| 100/100 [00:00<00:00, 9363.54it/s]
100%|██████████| 100/100 [00:00<00:00, 9599.71it/s]
100%|██████████| 100/100 [00:00<00:00, 9464.11it/s]
100%|██████████| 100/100 [00:00<00:00, 9233.88it/s]
100%|██████████| 100/100 [00:00<00:00, 10051.05it/s]
  mean_df = day_df.groupby(["SampleBarcode", "analysis_group"]).mean().reset_index()


torch.Size([5, 1, 256, 256])


100%|██████████| 100/100 [00:00<00:00, 8433.64it/s]
100%|██████████| 100/100 [00:00<00:00, 6861.51it/s]
100%|██████████| 100/100 [00:00<00:00, 8786.64it/s]
100%|██████████| 100/100 [00:00<00:00, 8126.61it/s]
100%|██████████| 100/100 [00:00<00:00, 8451.66it/s]
100%|██████████| 100/100 [00:00<00:00, 8171.41it/s]
100%|██████████| 100/100 [00:00<00:00, 6858.26it/s]
100%|██████████| 100/100 [00:00<00:00, 7180.18it/s]
100%|██████████| 100/100 [00:00<00:00, 6725.41it/s]
100%|██████████| 100/100 [00:00<00:00, 8107.13it/s]
100%|██████████| 100/100 [00:00<00:00, 8116.07it/s]
  mean_df = day_df.groupby(["SampleBarcode", "analysis_group"]).mean().reset_index()


torch.Size([5, 1, 256, 256])


100%|██████████| 100/100 [00:00<00:00, 8856.03it/s]
100%|██████████| 100/100 [00:00<00:00, 8958.17it/s]
100%|██████████| 100/100 [00:00<00:00, 10539.51it/s]
100%|██████████| 100/100 [00:00<00:00, 8818.97it/s]
100%|██████████| 100/100 [00:00<00:00, 8819.16it/s]
100%|██████████| 100/100 [00:00<00:00, 8704.04it/s]
100%|██████████| 100/100 [00:00<00:00, 10091.68it/s]
100%|██████████| 100/100 [00:00<00:00, 8764.24it/s]
100%|██████████| 100/100 [00:00<00:00, 9045.10it/s]
100%|██████████| 100/100 [00:00<00:00, 8010.97it/s]
100%|██████████| 100/100 [00:00<00:00, 8799.55it/s]
  mean_df = day_df.groupby(["SampleBarcode", "analysis_group"]).mean().reset_index()


torch.Size([5, 1, 256, 256])


100%|██████████| 100/100 [00:00<00:00, 10352.22it/s]
100%|██████████| 100/100 [00:00<00:00, 9197.83it/s]
100%|██████████| 100/100 [00:00<00:00, 9338.11it/s]
100%|██████████| 100/100 [00:00<00:00, 9820.43it/s]
100%|██████████| 100/100 [00:00<00:00, 9310.12it/s]
100%|██████████| 100/100 [00:00<00:00, 9564.90it/s]
100%|██████████| 100/100 [00:00<00:00, 7843.78it/s]
100%|██████████| 100/100 [00:00<00:00, 7877.07it/s]
100%|██████████| 100/100 [00:00<00:00, 8709.46it/s]
100%|██████████| 100/100 [00:00<00:00, 8585.91it/s]
100%|██████████| 100/100 [00:00<00:00, 8850.42it/s]


torch.Size([5, 1, 256, 256])


TypeError: create_visualization() missing 2 required positional arguments: 'trait' and 'analysis_group'