# ELBO channel ablation

In the previous sections we aimed to show, both qualitatively and quantitatively,  that the Integrated Cell enables us to model the organization of subcellular structures by leveraging the reference channels, i.e. the cell membrane and the DNA localization.
An important next question is to what extent the reference channels by themselves inform the prediction of subcellular structure organization.

Specfically, we quantify the sensitivity of our model to the coupling between each subcellular structure and a reference channel (say, cell membrane) by comparing the ELBO for that unperturbed image with the ELBO of a perturbed version of that image, where the reference channel (e.g.\ membrane or DNA) is replaced by a randomly selected membrane channel from the other cells in the population.

Here we load a best-performing target model from `1) Model Compare.ipynb`

In [None]:
import json
import integrated_cell
from integrated_cell import model_utils, utils
import os
import numpy as np
import torch

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline


from integrated_cell.utils.plots import tensor2im, imshow

gpu_ids = [7]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(ID) for ID in gpu_ids])
if len(gpu_ids) == 1:
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

torch.cuda.empty_cache()
    
parent_dir = '/allen/aics/modeling/gregj/results/integrated_cell/'


model_dir = "/allen/aics/modeling/gregj/results/integrated_cell/test_cbvae_3D_avg_inten/2019-10-22-15:24:09/"
suffix = "_93300"


networks, dp, args = utils.load_network_from_dir(model_dir, parent_dir, suffix=suffix)

dp.default_return_mesh = False
dp.default_return_patch = False
    
recon_loss = utils.load_losses(args)['crit_recon']    
    
enc = networks['enc']
dec = networks['dec']

enc.train(False)
dec.train(False)

results_dir = '{}/results/kl_demo{}/'.format(model_dir, suffix)
if not os.path.exists(results_dir):
    os.makedirs(results_dir)
    
print("Results dir: {}".format(results_dir))

dp.image_parent = '/allen/aics/modeling/gregj/results/ipp/scp_19_04_10/'


save_dir = results_dir

In [None]:


# import imp
# imp.reload(integrated_cell.metrics.embeddings)
import integrated_cell.utils.plots as plots
from integrated_cell.metrics.embeddings_target import get_latent_embeddings
from integrated_cell.utils.target import sample2im

mode = 'test'

def sample(mode, inds):
    return dp.get_sample(mode, inds)


def permute_channels(mode, inds, perm_channel):
    x, classes, ref = dp.get_sample(mode, inds)
    
    rand_inds = np.random.randint(0, dp.get_n_dat(mode), len(inds))
    _, _, ref_perm = dp.get_sample(mode, rand_inds)
    
    ref[:, perm_channel] = ref_perm[:, perm_channel]
    
    return x, classes, ref
    
perm_channel = [1]
sampler = lambda mode, inds: permute_channels(mode, inds, perm_channel)
    

    

In [None]:
from aicsimageio.writers import OmeTiffWriter

struct, label, ref = dp.get_sample('test')

im = torch.cat([ref[:,[0]], struct, ref[:,[1]]], 1)

plots.imshow(im[[0]])


from aicsimageio.writers import OmeTiffWriter

with OmeTiffWriter('test.tiff', overwrite_file=True) as writer:
    writer.save(im[0].numpy().transpose(3, 0, 1, 2))



In [None]:
im[0].numpy().transpose(3, 0, 1, 2).shape

In [None]:
import imp

from aicsimageio.writers import OmeTiffWriter

with OmeTiffWriter('test.tiff', overwrite_file = True) as writer:
    writer.save(im[0].numpy().transpose(0, 3, 1, 2))

# OmeTiffWriter.imwrite("test.tiff", im[0].numpy().transpose(0, 3, 1, 2))

In [None]:
img_index = [1032]

img, _, ref = sample(mode, img_index)
plots.imshow(sample2im(img,ref))


perm_channel = [0]
sampler = lambda mode, inds: permute_channels(mode, inds, perm_channel)
img, _, ref = sampler(mode, img_index)
plots.imshow(sample2im(img,ref))


perm_channel = [1]
sampler = lambda mode, inds: permute_channels(mode, inds, perm_channel)
img, _, ref = sampler(mode, img_index)
plots.imshow(sample2im(img,ref))



In [None]:
import tqdm

def ELBO_permutation(enc, dec, dp, recon_loss, modes, batch_size, results_dir, n_permutations):

    if 'beta' in args['kwargs_model']:
        beta = args['kwargs_model']['beta']
    else:
        beta = 1

    components = ['struct']
    # n_px = img[0].numpy().size


    

    save_dir = "{}/elbo_test".format(results_dir)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    save_path = "{}/embeddings.pth".format(save_dir)
    if not os.path.exists(save_path):
        embeddings = get_latent_embeddings(enc, dec, dp, recon_loss = recon_loss, modes=[mode], beta=beta, batch_size = batch_size)
        torch.save(embeddings, save_path)
    else:
        embeddings = torch.load(save_path)

    elbo = {}
    elbo['target'] = embeddings[mode]['target']['elbo'].numpy()

    perm_channels = [[0], [1]]
    perm_structures = ['perm_cell', 'perm_nuc']


    elbo_perm = {}

    for perm_channel, perm_structure in zip (perm_channels, perm_structures):                        
        perm_dir = "{}/{}/".format(save_dir, perm_structure)

        elbo_perm[perm_structure] = {}
        elbo_perm[perm_structure]['target'] = list()
        elbo_perm[perm_structure]['ref'] = list()

        if not os.path.exists(perm_dir):
            os.makedirs(perm_dir)

        sampler = lambda mode, inds: permute_channels(mode, inds, perm_channel)

        for i in tqdm.tqdm(range(n_permutations)):
            save_path = "{}/embeddings_{}.pth".format(perm_dir, i)
            if not os.path.exists(save_path):
                embeddings_perm = get_latent_embeddings(enc, dec, dp, recon_loss = recon_loss, modes=[mode], sampler = sampler, beta=1, batch_size = batch_size)
                torch.save(embeddings_perm, save_path)
            else:
                embeddings_perm = torch.load(save_path)

            embeddings_perm = embeddings_perm
            elbo_perm[perm_structure]['target'].append(embeddings_perm[mode]['target']['elbo'].numpy())
    #        elbo_perm[perm_structure]['ref'].append(embeddings_perm[mode]['ref']['elbo'].numpy())

        elbo_perm[perm_structure]['target'] = np.stack(elbo_perm[perm_structure]['target'], 1)
#    elbo_perm[perm_structure]['ref'] = np.stack(elbo_perm[perm_structure]['ref'], 1)

    return elbo, elbo_perm, embeddings, perm_structures

n_permutations = 100

elbo, elbo_perm, embeddings, perm_structures = ELBO_permutation(enc, dec, dp, recon_loss, modes = [mode], batch_size=32, results_dir=results_dir, n_permutations = n_permutations)

In [None]:
classes = embeddings['test']['target']['class'].numpy()

u_classes = np.unique(classes)

u_classes, class_inds = np.unique(classes, return_inverse=True)
class_names = dp.label_names[u_classes]
controls = np.array(['Control - ' in c for c in class_names])

u_classes = np.hstack([u_classes[controls], u_classes[~controls]])
class_names = class_names[u_classes]

In [None]:
n_perms = 10000

p_self={}

for perm_structure in perm_structures:

    p_vals = np.zeros(len(u_classes))

    for i, u_class in enumerate(u_classes):
        class_inds = classes == u_class


        class_elbo = elbo['target'][class_inds]
        perm_elbo = elbo_perm[perm_structure]['target'][class_inds].flatten()

        n_test = len(class_elbo)
        n_null = len(perm_elbo)

        test_stat = np.random.choice(class_elbo, [n_perms, n_test])

        null_stat = np.random.choice(perm_elbo, [n_perms, n_test])

        p_vals[i] = np.mean(np.sum(test_stat,1) <= np.sum(null_stat,1))

    p_vals = p_vals
    sort_inds = np.argsort(p_vals)

    p_save = {}

    for i in sort_inds:
        p_save[class_names[i]] = p_vals[i]

    p_self[perm_structure] = np.mean(np.tile(np.expand_dims(elbo['target'], 1), [1, n_permutations]) <= elbo_perm[perm_structure]['target'],1)




In [None]:
perm_structure = 'perm_cell'
# perm_structure = 'perm_nuc'

In [None]:
alpha = 0.05 / len(u_classes)

x_pos = np.arange(len(sort_inds))

plt.figure(figsize=[12, 6])

# plt.subplot(1, 2, 1)
plt.bar(x_pos, p_vals[sort_inds])
plt.xticks(x_pos, class_names[sort_inds], rotation='45', ha='right')
plt.ylim([0, 0.1])

xlims = plt.xlim()
plt.plot(xlims, [alpha, alpha], 'gray', linestyle='--')

plt.show()

In [None]:
data = list()

means = np.zeros(len(u_classes))

for i, u_class in enumerate(u_classes):
    class_inds = classes == u_class
    
    p_tmp = p_self[perm_structure][class_inds]
    
    means[i] = np.mean(p_tmp)
    data.append(p_tmp)
    
data = np.array(data)
sort_inds = np.argsort(means)    

plt.figure(figsize=[8,8])
plt.boxplot(data[sort_inds], labels = class_names[sort_inds])


plt.xticks(rotation='45', ha='right')

plt.title('distribution of p-values')

plt.show()

### Computing the Fractional Information:

we want fractional information of the target|reference over the information of the target; how much the reference informs where the target is  
FI = I(t|r) / I(t)

where I(x) = -log(P(x)) = -ELBO(x)

FI = -ELBO(t|r) \ -log(1/n * sum(exp(elbo(t|r_shuffle)+elbo(r_shuffle))))


In [None]:
perm_structures

In [None]:
from scipy.special import logsumexp

dpi = 120

# elbo_sensitivity_x = np.mean(elbo_perm[perm_structures[0]]['target'],1) - elbo['target']
# elbo_sensitivity_y = np.mean(elbo_perm[perm_structures[1]]['target'],1) - elbo['target']

elbo_t_shuffle = elbo_perm[perm_structures[0]]['target']
elbo_r_shuffle = elbo_perm[perm_structures[0]]['ref']
# x1 = logsumexp(elbo_t_shuffle+elbo_t_shuffle - logsumexp(elbo_t_shuffle),1) - np.log(elbo_t_shuffle.shape[1])

elbo_sensitivity_x = 1-(elbo['target']/np.mean(elbo_t_shuffle,1))
# elbo_sensitivity_x = (elbo['target'] / (logsumexp(elbo_t_shuffle,1) - np.log(elbo_t_shuffle.shape[1])))

elbo_t_shuffle = elbo_perm[perm_structures[1]]['target']
elbo_r_shuffle = elbo_perm[perm_structures[1]]['ref']
x2 = logsumexp(elbo_t_shuffle+ elbo_t_shuffle - logsumexp(elbo_t_shuffle),1) - np.log(elbo_t_shuffle.shape[1])

elbo_sensitivity_y = 1-(elbo['target']/np.mean(elbo_t_shuffle,1))
# elbo_sensitivity_y = (elbo['target'] / (logsumexp(elbo_t_shuffle,1) - np.log(elbo_t_shuffle.shape[1])))

colors = elbo['target']
crange = np.percentile(colors, [2, 98])


lims_new = [-0.1, 1.1] 

plt.style.use('default')

plt.figure(figsize=[12,12], dpi=120)
plt.set_cmap('jet')

for i, u_class in enumerate(u_classes):
    class_inds = classes == u_class
    
    x = elbo_sensitivity_x[class_inds]
    y = elbo_sensitivity_y[class_inds]
    
    c = colors[class_inds]
    
    plt.subplot(5,5,i+1)
    plt.scatter(x, y, s = 10, c=c, vmin = crange[0], vmax = crange[1])
    plt.title(dp.label_names[u_class])
    
    plt.xlim(lims_new)
    plt.ylim(lims_new)
    plt.plot(lims_new, lims_new, c='gray', linestyle="--")
    
    if i == 0:
        plt.xlabel('cell coupling')
        plt.ylabel('dna coupling')
    
plt.subplots_adjust(hspace=0.5, wspace = 0.5)

plt.savefig('{}/elbo_ablation.png'.format(results_dir), dpi=dpi, bbox_inches='tight')

plt.show()

In [None]:
n_perms = 100000

p_self={}

elbo_sens = [elbo_sensitivity_x, elbo_sensitivity_y]

p_vals_all = list()

p_vals = np.zeros([len(elbo_sens), len(u_classes), 2])

for i, elbo_sensitivity in enumerate(elbo_sens):

    for j, u_class in enumerate(u_classes):
        class_inds = classes == u_class

        mito_inds_tmp = dp.data['test']['mito_state_binary_ind'][class_inds]
        
        elbo_tmp = elbo_sensitivity[class_inds]
        mito_tmp = elbo_tmp[mito_inds_tmp == 1]
        
        n_test = len(mito_tmp)
        n_null = len(elbo_tmp)

        test_stat = np.random.choice(mito_tmp, [n_perms, n_test])
        null_stat = np.random.choice(elbo_tmp, [n_perms, n_test])

        p_vals[i, j, 0] = np.mean(np.mean(test_stat,1) < np.mean(null_stat,1))
        p_vals[i, j, 1] = np.mean(np.mean(test_stat,1) > np.mean(null_stat,1))

    p_vals_all.append(p_vals)
        
p_vals = p_vals * np.size(p_vals)
#     p_vals = p_vals
#     sort_inds = np.argsort(p_vals)

#     p_save = {}

#     for i in sort_inds:
#         p_save[class_names[i]] = p_vals[i]

#     p_self[perm_structure] = np.mean(np.tile(np.expand_dims(elbo['target'], 1), [1, n_permutations]) <= elbo_perm[perm_structure]['target'],1)


In [None]:
from matplotlib import cm
from matplotlib.lines import Line2D

plt.style.use('default')

plt.figure(figsize=[12,12], dpi=120)
plt.set_cmap('Paired_r')

colors = dp.data['test']['mito_state_binary_ind']
crange = np.percentile(colors, [0, 100])

p_pos_fraction = 0.1
p_pos = lims_new[0] + (lims_new[1] - lims_new[0]) * p_pos_fraction
p_pos_end = lims_new[1] - p_pos_fraction

p_val = 0.05

elbo_sensitivity_dict = {}

for i, u_class in enumerate(u_classes):
    class_inds = classes == u_class
    
    x = elbo_sensitivity_x[class_inds]
    y = elbo_sensitivity_y[class_inds]
    
    c = colors[class_inds]
    
    
    plt.subplot(5,5,i+1)
    plt.scatter(x[c==0], y[c==0], s = 10, c=c[c==0], vmin = crange[0], vmax = crange[1])
    plt.scatter(x[c==1], y[c==1], s = 10, c=c[c==1], vmin = crange[0], vmax = crange[1])
    
    title_name = dp.label_names[u_class]
    if '(' in title_name:
        ind = title_name.find('(')
        title_name = title_name[0:ind] + '\n' + title_name[ind:]
    
    plt.title(title_name, fontsize=8)
    
    plt.xlim(lims_new)
    plt.ylim(lims_new)
    plt.plot(lims_new, lims_new, c='gray', linestyle="--")
    
    if p_vals[0, i, 0] < p_val: #elbo went left
        plt.scatter(p_pos_end, p_pos, marker = ">", color = 'k')
    if p_vals[0, i, 1] < p_val: #elbo went right
        plt.scatter(p_pos_end, p_pos, marker = "<", color = 'k')
    if p_vals[1, i, 0] < p_val: #elbo went up
        plt.scatter(p_pos, p_pos_end, marker = "^", color = 'k')
    if p_vals[1, i, 1] < p_val: #elbo went down
        plt.scatter(p_pos, p_pos_end, marker = "v", color = 'k')
    
    if i == 0:
        plt.xlabel('cell coupling', fontsize=8)
        plt.ylabel('dna coupling', fontsize=8)

    elbo_sensitivity_dict[u_class] = {}
    elbo_sensitivity_dict[u_class]['x'] = x
    elbo_sensitivity_dict[u_class]['y'] = y
    elbo_sensitivity_dict[u_class]['mito'] = c   
    elbo_sensitivity_dict[u_class]['pval'] = p_vals[:, i, :]
    elbo_sensitivity_dict[u_class]['name'] = dp.label_names[u_class]
        
colors = cm.Paired([255, 0])    
legend_elements = [Line2D([0], [0], marker='o', color = 'w', markerfacecolor=colors[0], markersize=10, label=dp.mito_state_names[0]),
                   Line2D([0], [0], marker='o', color = 'w', markerfacecolor=colors[1], markersize=10, label=dp.mito_state_names[1])]
    
plt.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1.0), frameon=False)

plt.subplots_adjust(hspace=0.5, wspace = 0.5)

plt.savefig('{}/elbo_ablation.png'.format(results_dir), dpi=dpi, bbox_inches='tight')

plt.show()

In [None]:
box_width = 0.75
ylim = 2.5


# proj_meth = ["scalar_proj", 'relative_diff']
# proj_meth = 'diff'
# proj_meth = 'relative_diff'
proj_meth = 'scalar_proj'

def scalar_rejection(x,y):
    origin = np.array([0.5, 0.5])
    
    vector = np.array([0, 1])
    
    b = vector - origin
    a = np.vstack([y,x]).T - origin
    
    return np.dot(a,b)/(np.linalg.norm(b)**2)


if proj_meth == "scalar_proj":
    proj_func = scalar_rejection
elif proj_meth == 'relative_diff':
    proj_func = lambda x,y: (x-y)/(x+y)
elif proj_meth == 'diff':
    proj_func = lambda x,y: (x-y)

for u_class in elbo_sensitivity_dict:
    elbo_sensitivity_dict[u_class]['proj'] = proj_func(elbo_sensitivity_dict[u_class]['x'], elbo_sensitivity_dict[u_class]['y'])

sorted_inds = np.argsort([np.median(elbo_sensitivity_dict[u_class]['proj']) for u_class in u_classes])[::-1]



stats = list()

stats_median = list()
stats_median_mito = list()


for i, u_class in enumerate(u_classes[sorted_inds]):
    
    x = elbo_sensitivity_dict[u_class]['x']
    y = elbo_sensitivity_dict[u_class]['y']
    is_mito = elbo_sensitivity_dict[u_class]['mito']
    p_val = elbo_sensitivity_dict[u_class]['pval']
    

    stats.append(elbo_sensitivity_dict[u_class]['proj'])
    
    stats_median.append(np.median(elbo_sensitivity_dict[u_class]['proj'][is_mito==0]))
    stats_median_mito.append(np.median(elbo_sensitivity_dict[u_class]['proj'][is_mito==1]))    

labels = [elbo_sensitivity_dict[u_class]['name'] for u_class in u_classes[sorted_inds]]    

for i in range(len(labels)):
    if labels[i] == "Nucleolus (Dense Fibrillar Component)":
        labels[i] = "Nucleolus (DFC)"

    
    if labels[i] == "Nucleolus (Granular Component)":
        labels[i] = "Nucleolus (GC)"


plt.figure(figsize=[2,5], dpi=120)
# boxes = plt.boxplot(stats, positions = np.arange(0, len(u_classes)), showfliers=False, widths=box_width, vert=False, whis=False, labels=labels, patch_artist=True)

bars = plt.barh(labels, stats_median, color=colors[0], height=box_width)
bars = plt.barh(np.arange(0, len(u_classes)), stats_median_mito, color=colors[1], height=box_width/2)
# plt.yticks(np.arange(0, len(u_classes)), label = labels)

# for item in ['boxes', 'whiskers', 'fliers', 'caps']:
#     for element in boxes[item]:
#         plt.setp(element, color='k')
# #         patch.set_facecolor('k')

# for item in ['medians']:
#     for element in boxes[item]:
#         plt.setp(element, color=[1,1,1,1])


# plt.tick_params(top=False, bottom=False, left=True, right=True, labelleft=True, labelbottom=False)
    
plt.plot([0,0], [-0.5, len(u_classes)-0.5], linestyle="-", color = 'gray', linewidth=1)    

# 

    
plt.ylim(-0.5, len(u_classes)- 0.5)
plt.xlim(-1.1, 1.1)
# plt.ylabel("(x-y)/(x+y)")

plt.gca().set_xticks([-1,1])
plt.gca().set_xticklabels(["DNA\ncoupling", "cell\ncoupling"])

# plt.title("$1/n \sum(cell\ coupling - dna\ coupling)$", fontsize=8)

plt.show()
plt.close()
    
    


In [None]:
labels

In [None]:
origin = np.array([0.5, 0.5])

vector = np.array([0, 1])

b = vector - origin
a = np.array([0.5, 0.5]) - origin

np.dot(a,b)/(np.linalg.norm(b)**2)

In [None]:
structures_to_plot = ['Control - Blank', 'Control - DNA', 'Control - Memb', 'Mitochondria', 'Nuclear envelope', 'Tight junctions']

plt.style.use('default')

plt.figure(figsize=[7.2,4.8], dpi=120)
plt.set_cmap('Paired_r')

colors = dp.data['test']['mito_state_binary_ind']
crange = np.percentile(colors, [0, 100])

p_pos_fraction = 0.1
p_pos = lims_new[0] + (lims_new[1] - lims_new[0]) * p_pos_fraction
p_pos_end = lims_new[1] - p_pos_fraction

p_val = 0.05


for i, u_class_name in enumerate(structures_to_plot):
    
    u_class = np.where(dp.label_names == u_class_name)[0][0]
    
    class_inds = classes == u_class
    
    x = elbo_sensitivity_x[class_inds]
    y = elbo_sensitivity_y[class_inds]
        
    c = colors[class_inds]
    
    plt.subplot(2,3,i+1)
    plt.scatter(x[c==0], y[c==0], s = 10, c=c[c==0], vmin = crange[0], vmax = crange[1])
    plt.scatter(x[c==1], y[c==1], s = 10, c=c[c==1], vmin = crange[0], vmax = crange[1])
    
    title_name = dp.label_names[u_class]
    if '(' in title_name:
        ind = title_name.find('(')
        title_name = title_name[0:ind] + '\n' + title_name[ind:]
    
    plt.title(title_name)
    
    plt.xlim(lims_new)
    plt.ylim(lims_new)
    plt.plot(lims_new, lims_new, c='gray', linestyle="--")
    
    p_ind = np.where(u_classes == u_class)[0][0]
    
    if p_vals[0, p_ind, 0] < p_val: #elbo went left
        plt.scatter(p_pos_end, p_pos, marker = ">", color = 'k')
    if p_vals[0, p_ind, 1] < p_val: #elbo went right
        plt.scatter(p_pos_end, p_pos, marker = "<", color = 'k')
    if p_vals[1, p_ind, 0] < p_val: #elbo went up
        plt.scatter(p_pos, p_pos_end, marker = "^", color = 'k')
    if p_vals[1, p_ind, 1] < p_val: #elbo went down
        plt.scatter(p_pos, p_pos_end, marker = "v", color = 'k')

    if i == 3:
        plt.xlabel('cell coupling')
        plt.ylabel('DNA coupling')

    plt.xticks([0, 0.5, 1])
    plt.yticks([0, 0.5, 1])
    
colors = cm.Paired([255, 0])    
legend_elements = [Line2D([0], [0], marker='o', color = 'w', markerfacecolor=colors[0], markersize=10, label=dp.mito_state_names[0]),
                   Line2D([0], [0], marker='o', color = 'w', markerfacecolor=colors[1], markersize=10, label=dp.mito_state_names[1])]
    
plt.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1.0), frameon=False)

plt.subplots_adjust(hspace=0.35, wspace = 0.35)

plt.savefig('{}/elbo_ablation.png'.format(results_dir), dpi=dpi, bbox_inches='tight')

plt.show()

In [None]:
def imfunc(im_id, train_or_test='test'):

    x, _, ref = dp.get_sample(train_or_test, [im_id])
    
    im = sample2im(x, ref)

    im_out = plots.tensor2im(im, proj_xy=False)
    
    alpha = np.sum(im_out, 2)>0
    
    im_out = np.concatenate([im_out, np.expand_dims(alpha,2)], 2)
    
    return im_out



plt.imshow(imfunc(i))

In [None]:
from matplotlib import cm
from integrated_cell.utils.plots import scatter_im
import integrated_cell.utils.plots as plots


# plt.figure(figsize=[12,12], dpi=120)
plt.style.use('dark_background')
# lims_new = np.percentile(elbo['target'], [0.02, 99.8])

for u_class in tqdm.tqdm(u_classes):
    class_inds = np.where(classes == u_class)[0]
    
    elbo_tmp = elbo['target'][class_inds]
#     elbo_perm_tmp = np.median(elbo_perm[perm_structure]['target'][class_inds],1)
    
    x = elbo_sensitivity_x[class_inds]
    y = elbo_sensitivity_y[class_inds]
    
    c = elbo['target'][class_inds]
    
    X = np.vstack([x, y]).T
    
    def myfunc(ind):
        return imfunc(class_inds[ind])
    
    
    plt.figure(figsize=(30, 30))
    
    ax, ax_inset = scatter_im(X, myfunc, zoom = 1, inset = True, inset_colors = c, inset_width_and_height=0.17, inset_scatter_size = 50, inset_clims = crange)
    
    plt.sca(ax)
    lims = np.vstack([plt.xlim(), plt.ylim()])
#     lims_new = [np.min(lims[:,0]), np.max(lims[:,1])]
    lims_new = [-0.1, 1.1]

    plt.plot(lims_new, lims_new, linewidth=4, c='gray', linestyle="--")
    
#     ax_inset.set_xlim(lims_new[0], lims_new[1])    
#     ax_inset.set_ylim(lims_new[0], lims_new[1])
    plt.sca(ax_inset)
    plt.ylabel('DNA coupling', fontsize = 15)
    plt.xlabel('Membrane coupling', fontsize = 15)
    plt.title(dp.label_names[u_class], fontsize = 15)
    
#     lims = np.vstack([plt.xlim(), plt.ylim()])
#     lims_new = [np.min(lims[:,0]), np.max(lims[:,1])]
    plt.xlim(lims_new)
    plt.ylim(lims_new)
    plt.plot(lims_new, lims_new, c='gray', linestyle="--")
    
#     plt.show()
    
    plt.savefig('{}/elbo_ablation_{}.png'.format(results_dir, dp.label_names[u_class]), dpi=dpi, bbox_inches='tight')
    
    plt.close()

plt.style.use('default')

print('please fined images in {}'.format(results_dir))

In [None]:
from matplotlib import cm
from integrated_cell.utils.plots import scatter_im

# plt.figure(figsize=[12,12], dpi=120)
plt.style.use('dark_background')
# lims_new = np.percentile(elbo['target'], [0.02, 99.8])

for u_class in tqdm.tqdm(u_classes):
    class_inds = np.where(classes == u_class)[0]
    
    elbo_tmp = elbo['target'][class_inds]
#     elbo_perm_tmp = np.median(elbo_perm[perm_structure]['target'][class_inds],1)
    
    x = elbo_sensitivity_x[class_inds]
    y = elbo_sensitivity_y[class_inds]
    
    c = elbo['target'][class_inds]
    
    X = np.vstack([x, y]).T
    
    def myfunc(ind):
        return imfunc(class_inds[ind])
    
    
    plt.figure(figsize=(30, 30))
    
    ax = scatter_im(X, myfunc, zoom = 1, inset = False)
    
    plt.sca(ax)
#     lims = np.vstack([plt.xlim(), plt.ylim()])
#     lims_new = [np.min(lims[:,0]), np.max(lims[:,1])]

    lims_new = [-0.1, 1.1]

    plt.plot(lims_new, lims_new, linewidth=4, c='gray', linestyle="--")
    
    ax.set_xlim(lims_new[0], lims_new[1])    
    ax.set_ylim(lims_new[0], lims_new[1])
        
    plt.savefig('{}/elbo_ablation_no_subplot_{}.png'.format(results_dir, dp.label_names[u_class]), dpi=dpi, bbox_inches='tight')
    
    plt.close()
    
plt.style.use('default')

print('please find images in {}'.format(results_dir))

# Now do the drugs

In [None]:
import imp
imp.reload(integrated_cell.utils.utils)
imp.reload(utils)
# imp.reload(utils.load_drug_data_provider)

dp_drugs = utils.load_drug_data_provider(dp, args)

results_dir_drugs = "{}/drugs/".format(results_dir)

imshow(dp_drugs.get_sample('test')[0][[0]])



In [None]:
print(type(dp_drugs))

struct, label, ref = dp.get_sample('test')

for im in [struct, ref[0,0], ref[0,1]]:
    im_sum = torch.sum(im)
    im_min = torch.min(im)
    im_max = torch.max(im)
    
    print("sum: {}, min: {}, max: {}".format(im_sum, im_min, im_max))




# print(ref.shape)

In [None]:
elbo_drugs, elbo_perm_drugs, embeddings_drugs, perm_structures_drugs = ELBO_permutation(enc, dec, dp_drugs, recon_loss, modes = [mode], batch_size=32, results_dir=results_dir_drugs, n_permutations = n_permutations)

In [None]:
# elbo_sensitivity_x = np.mean(elbo_perm[perm_structures[0]]['target'],1) - elbo['target']
# elbo_sensitivity_y = np.mean(elbo_perm[perm_structures[1]]['target'],1) - elbo['target']

elbo_t_shuffle = elbo_perm_drugs[perm_structures_drugs[0]]['target']
elbo_r_shuffle = elbo_perm_drugs[perm_structures_drugs[0]]['ref']
# x1 = logsumexp(elbo_t_shuffle+elbo_t_shuffle - logsumexp(elbo_t_shuffle),1) - np.log(elbo_t_shuffle.shape[1])

elbo_sensitivity_x_drugs = 1-(elbo_drugs['target']/np.mean(elbo_t_shuffle,1))
# elbo_sensitivity_x = (elbo['target'] / (logsumexp(elbo_t_shuffle,1) - np.log(elbo_t_shuffle.shape[1])))

elbo_t_shuffle = elbo_perm_drugs[perm_structures_drugs[1]]['target']
elbo_r_shuffle = elbo_perm_drugs[perm_structures_drugs[1]]['ref']
# x2 = logsumexp(elbo_t_shuffle + elbo_t_shuffle - logsumexp(elbo_t_shuffle),1) - np.log(elbo_t_shuffle.shape[1])

elbo_sensitivity_y_drugs = 1-(elbo_drugs['target']/np.mean(elbo_t_shuffle,1))
# elbo_sensitivity_y = (elbo['target'] / (logsumexp(elbo_t_shuffle,1) - np.log(elbo_t_shuffle.shape[1])))

# colors = elbo['target']
# crange = np.percentile(colors, [2, 98])

In [None]:
classes = embeddings_drugs['test']['target']['class'].numpy()
u_structures =  np.unique(embeddings_drugs['test']['target']['class'])

drugs = dp_drugs.drug_info['test']
u_drugs = np.unique(dp_drugs.drug_info['test'])

class_inds = classes == u_classes[0]
drug_inds = drugs == u_drugs[0]

In [None]:
dp_drugs

In [None]:
classes = embeddings_drugs['test']['target']['class'].numpy()
u_structures = np.unique(embeddings_drugs['test']['target']['class'])

drugs = dp_drugs.drug_info['test']
u_drugs = np.unique(dp_drugs.drug_info['test'])

n_drug_structures = len(u_structures)
n_drugs = len(u_drugs)

lims_new = [-0.1, 1.1] 

plt.style.use('default')

plt.figure(figsize=[12,12], dpi=120)
plt.set_cmap('jet')

counter = -1
for i, u_drug in enumerate(u_drugs):
    
    drug_inds = drugs == u_drug
    
    for j, u_structure in enumerate(u_structures):
        
        counter += 1
        
        class_inds = (classes == u_structure) & drug_inds

        if np.sum(class_inds) == 0:
            continue
            
        x = elbo_sensitivity_x_drugs[class_inds]
        y = elbo_sensitivity_y_drugs[class_inds]

#         c = colors[class_inds]
        
        c = np.ones(sum(class_inds))

        plt.subplot(n_drugs,n_drug_structures,counter + 1)
        plt.scatter(x, y, s = 10, c=c, vmin = crange[0], vmax = crange[1])
        plt.title(dp.label_names[u_structure])
        
        

        plt.xlim(lims_new)
        plt.ylim(lims_new)
        plt.plot(lims_new, lims_new, c='gray', linestyle="--")

        if counter == 0:
            plt.xlabel('cell coupling')
            plt.ylabel('dna coupling')
            
        if j == 0:
            plt.ylabel(dp_drugs.drug_names[u_drug])
        
        plt.axis('equal')
        
plt.subplots_adjust(hspace=0.5, wspace = 0.5)

plt.savefig('{}/elbo_ablation.png'.format(results_dir), dpi=dpi, bbox_inches='tight')

plt.show()

plt.close()