# Latent space vizualization


## imports

In [None]:
import json
import os
from pathlib import Path
from tqdm import tqdm

import pandas as pd
import numpy as np
import torch

import integrated_cell
from integrated_cell import model_utils, utils
from integrated_cell.utils.plots import tensor2im, imshow

%matplotlib inline
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt

import PIL
from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw 

## set up a place to save our results

In [None]:
parent_dir = "/allen/aics/modeling/rorydm/projects/pytorch_integrated_cell/examples/training_scripts"
ref_model_dir = f"{parent_dir}/bvae3D_actk_ref_seg_nomito_beta_1_2021-02-02"
ref_suffix = "_64880"

results_dir = Path(parent_dir) / "results/latent_space_vizualization_actk"
results_dir.mkdir(parents=True, exist_ok=True)

## setup up cuda env

In [None]:
gpu_ids = [0]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(ID) for ID in gpu_ids])
if len(gpu_ids) == 1:
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

torch.cuda.empty_cache()

## load the reference model

In [None]:
networks_ref, _, args_ref = utils.load_network_from_dir(
    ref_model_dir,
    parent_dir,
    suffix=ref_suffix,
    load_dataprovider=False,
)

ref_enc = networks_ref['enc']
ref_dec = networks_ref['dec']

recon_loss = utils.load_losses(args_ref)['crit_recon']

## load the data provider

In [None]:
data_provider_kwargs = args_ref["kwargs_dp"].copy()

data_provider_kwargs["image_parent"] = args_ref["imdir"]
data_provider_kwargs["batch_size"] = args_ref["batch_size"]

for k in ['save_path', 'im_dir']:
    if k in data_provider_kwargs:
        del data_provider_kwargs[k]

In [None]:
import importlib
dp_name = args_ref["dataProvider"]
dp_module = importlib.import_module(f"integrated_cell.data_providers.{dp_name}")

In [None]:
dp_ref = dp_module.DataProvider(**data_provider_kwargs)

## Get metadata for all splits

In [None]:
dfs_split = {
    split: pd.concat(
        [
            dp_ref.csv_data.loc[
                dp_ref.data[split]["inds"],
                ["CellId", "PlateId", "CellIndex", "StructureDisplayName"]
            ].reset_index(
            ).rename(
                columns={"index":"UnsplitCsvIndex"}
            ),
            pd.DataFrame({"split":[split]*len(dp_ref.data[split]["inds"])})
        ],
        axis=1
    ) for split in dp_ref.data.keys()
}

## Get mito annotations

In [None]:
df_mito = pd.read_csv(
    "/allen/aics/modeling/rorydm/projects/IntegratedCellWorkingGroup/data/draft_plus_human_mito_annotations.csv",
    index_col=0
)

In [None]:
dfs_mito = {
    split: dp_ref.csv_data.iloc[d["inds"]].reset_index(drop=True).merge(df_mito, how="left")
    for split, d in dp_ref.data.items()
}

## Get features for all splits together (unsplit df)

In [None]:
feats_path = Path(args_ref["imdir"]).parent / "singlecellfeatures/cell_features"

In [None]:
json_files = list(feats_path.glob('*.json'))

df_feats = pd.DataFrame()
for json_file in tqdm(json_files):
    with open(json_file) as f:
        data = json.load(f)
        df_feats = df_feats.append(data,ignore_index=True)

In [None]:
df_feats["cell_height"] = df_feats_all["cell_position_highest_z"] - df_feats_all["cell_position_lowest_z"]
df_feats["dna_height"] = df_feats_all["dna_position_highest_z"] - df_feats_all["dna_position_lowest_z"]

## Find the embeddings for each split

In [None]:
from integrated_cell.metrics.embeddings_reference import get_latent_embeddings

embeds_test = get_latent_embeddings(
    ref_enc,
    ref_dec,
    dp_ref,
    recon_loss,
    modes=['test'],
    batch_size=32,
)

In [None]:
dfs_embeds = {
    split: pd.DataFrame(
        embeds_test[split]["ref"]["mu"].numpy(),
        columns=[f"mu_{i}" for i in range(embeds_test["test"]["ref"]["mu"].numpy().shape[1])]
    ) for split in ['test']
}

In [None]:
for split, df_embed in dfs_embeds.items():
    assert len(dfs_embeds[split]) == len(dfs_split[split])
    dfs_embeds[split]["UnsplitCsvIndex"] = dfs_split[split]["UnsplitCsvIndex"]

In [None]:
dfs_embeds["test"].to_csv(Path(results_dir) / "test_embeddings_nomito.csv")

In [None]:
# this is just for test set until i can run all splits

df_embeds_test = pd.DataFrame(
    embeds_test["test"]["ref"]["mu"].numpy(),
    columns=[f"mu_{i}" for i in range(embeds_test["test"]["ref"]["mu"].numpy().shape[1])]
)

assert len(df_embeds_test) == len(dfs_split["test"])
df_embeds_test["UnsplitCsvIndex"] = dfs_split["test"]["UnsplitCsvIndex"]

## Merge embeddings in to metadata

In [None]:
dfs_split = {
    split: df_split.merge(df_feats) for split, df_split in dfs_split.items()
}

In [None]:
df_embeddings_plus_meta_test = dfs_split["test"].merge(
    dfs_mito["test"]
).merge(
    df_feats
).merge(
    df_embeds_test
)

df_embeddings_plus_meta_test_all_feats = dfs_split["test"].merge(
    dfs_mito["test"]
).merge(
    df_feats_all
).merge(
    df_embeds_test
)

# df_embeddings_plus_meta_test.to_csv("embeddings_plus_meta_test.csv")

## find top latent dims and drop others for grid plots

In [None]:
mu_cols = [c for c in df_embeds_test.columns if "mu" in c]

top_mu_dims = list(
    df_embeds_test[
        mu_cols
    ].abs().mean().sort_values().index
)
top_mu_dims.reverse()

In [None]:
plt.figure(figsize=(8,4))

g = sns.lineplot(
    x=df_embeds_test[mu_cols].abs().mean().rank(ascending=False),
    y=df_embeds_test[mu_cols].abs().mean()
);

g.set(
    xlabel="Reference latent space dimension rank",
    ylabel="Mean absolute value",
);
sns.despine()

In [None]:
g = sns.pairplot(
    df_embeds_test[
        top_mu_dims[:10]
    ].sample(
        frac=0.01,
        replace=False
    ),
);

g.set(
    xlim=(-5, 5),
    ylim=(-5, 5)
);

In [None]:
g = sns.pairplot(
    df_embeds_test[
        mu_cols
    ].sample(
        n=10,
        axis='columns'
    ).sample(
        frac=0.01,
        replace=False
    ),
);

g.set(
    xlim=(-5, 5),
    ylim=(-5, 5)
);

In [None]:
for dim in top_mu_dims[:10]:
    print(dim)
    int_dim = int(dim.split("_")[-1])
    
    latent_walk_path = np.float32(df_embeds_test[dim].std()*np.linspace(-2,2,9))
    latent_walk_path = torch.from_numpy(latent_walk_path)

    latent_input = torch.zeros([len(latent_walk_path), 512], dtype=torch.float32)
    latent_input[:,int_dim] = latent_walk_path
    latent_input = latent_input.cuda()
    
    walk_output = ref_dec(latent_input)
    walk_output = walk_output.cpu().data
    walk_im = tensor2im(walk_output)
    
    pil_im = Image.fromarray(np.uint8(walk_im*255))
    
    display(pil_im)
    print("")

In [None]:
feature_corrs = df_embeddings_plus_meta_test[
    [c for c in df_embeddings_plus_meta_test.columns if "mu" in c or c in df_feats.drop(columns=["CellId"]).columns]
].corr(method="spearman")

In [None]:
top_mu_dims = list(
    df_embeddings_plus_meta_test[
        [c for c in df_embeddings_plus_meta_test.columns if "mu" in c]
    ].abs().mean().sort_values().index
)
top_mu_dims.reverse()
# top_mu_dims

In [None]:
cols_to_drop = [
    c for c in df_embeddings_plus_meta_test.columns if "mu" in c and c not in top_mu_dims
]
df_embeddings_plus_meta_test_top = df_embeddings_plus_meta_test.drop(columns=cols_to_drop)

## plots

In [None]:
# drop missing structure labels
df_embeddings_plus_meta_test = df_embeddings_plus_meta_test.dropna()

# drop controls
df_embeddings_plus_meta_test = df_embeddings_plus_meta_test[
    ~df_embeddings_plus_meta_test.StructureDisplayName.str.startswith("Control")
]

## plot mito annotations

In [None]:
df_mito_search = df_embeddings_plus_meta_test_all_feats.dropna().copy()
df_mito_search = df_mito_search[df_mito_search["mito_state_resolved"] != 'u']

df_mito_search["mito_state_int"] = df_mito_search["mito_state_resolved"].map(
    {'M0':0, 'M1/M2':1, 'M3':2, 'M4/M5':3, 'M6/M7':4}
)
# df_mito_search = df_mito_search[df_mito_search["mito_state_int"] != 0]

mu_cols = [c for c in df_embeds_test.columns if c.startswith("mu_") and c != "CellId"]
mito_corrs = df_mito_search[mu_cols+["mito_state_int"]].corr(method="spearman")
mito_corrs = mito_corrs.drop(mu_cols, axis="columns").drop(["mito_state_int"], axis="rows")

In [None]:
plt.figure(figsize=(4,4))

g = sns.lineplot(
    x=mito_corrs.abs().rank(ascending=False)["mito_state_int"],
    y=mito_corrs.abs()["mito_state_int"]
);

g.set(
#     xlim=(-5, 3),
#     ylim=(-4, 5),
    xlabel="Dimension rank",
    ylabel="Spearman corr. with mitotic state",
);

sns.despine()

plt.savefig('mito_dims_ranked.png', dpi=300, bbox_inches = "tight")

In [None]:
top_mito_dims = list(mito_corrs.abs().squeeze().sort_values().index)
top_mito_dims.reverse()

In [None]:
mito_dims_compared_to_overall_mu_ranks = [
    (dim, i, top_mu_dims.index(dim))
    for i, dim in enumerate(top_mito_dims[:10])]

mito_dims_compared_to_overall_mu_ranks

In [None]:
pd.DataFrame(mito_dims_compared_to_overall_mu_ranks, columns=["dim", "mito rank", "overall rank"])

## plots

In [None]:
df_embeddings_plus_meta_test_pretty_names = df_embeddings_plus_meta_test.rename(
    columns={
        "mito_state_resolved": "Mitotic state",
        'dna_volume': 'DNA volume',
        'cell_volume': 'Cell volume',
        'dna_height': 'DNA height',
        'cell_height': 'Cell height'
    }
)

pixel_length_in_micrometers = 0.29
df_embeddings_plus_meta_test_pretty_names['DNA volume (μm^3)'] = df_embeddings_plus_meta_test_pretty_names['DNA volume']*(pixel_length_in_micrometers**3)
df_embeddings_plus_meta_test_pretty_names['DNA height (μm)'] = df_embeddings_plus_meta_test_pretty_names['DNA height']*pixel_length_in_micrometers
df_embeddings_plus_meta_test_pretty_names['Cell volume (μm^3)'] = df_embeddings_plus_meta_test_pretty_names['Cell volume']*(pixel_length_in_micrometers**3)
df_embeddings_plus_meta_test_pretty_names['Cell height (μm)'] = df_embeddings_plus_meta_test_pretty_names['Cell height']*pixel_length_in_micrometers

### mito plots

#### by top 2 dims that separate mito states

In [None]:
plt.figure(figsize=(3,3))

plt_mito = sns.scatterplot(
    data=df_embeddings_plus_meta_test_pretty_names[
        df_embeddings_plus_meta_test_pretty_names["Mitotic state"] != 'u'
    ].sort_values(
        by="Mitotic state"
    ),
    x="mu_452",
    y="mu_419",
    hue="Mitotic state",
    hue_order=['M0', 'M1/M2', 'M3', 'M4/M5', 'M6/M7'],
    palette=[("lightgrey")] + sns.color_palette("tab10")[:4],
    linewidth=0,
    alpha = 0.5,
    s=20,
)

plt_mito.set(
    xlabel="$\mu_{452}$",
    ylabel="$\mu_{419}$",
);

plt.legend(bbox_to_anchor=(1.0, 0.7), frameon=False)
sns.despine()

plt.savefig('latent_plot_mito_top_mito_dims.png', dpi=300, bbox_inches = "tight")

#### by top two overall dims

In [None]:
plt.figure(figsize=(3,3))

plt_mito = sns.scatterplot(
    data=df_embeddings_plus_meta_test_pretty_names[
        df_embeddings_plus_meta_test_pretty_names["Mitotic state"] != 'u'
    ].sort_values(
        by="Mitotic state"
    ),
    x="mu_71",
    y="mu_419",
    hue="Mitotic state",
    hue_order=['M0', 'M1/M2', 'M3', 'M4/M5', 'M6/M7'],
    palette=[("lightgrey")] + sns.color_palette("tab10")[:4],
    linewidth=0,
    alpha = 0.5,
    s=20,
)

plt_mito.set(
    xlabel="$\mu_{71}$",
    ylabel="$\mu_{419}$",
);

plt.legend(bbox_to_anchor=(1.0, 0.7), frameon=False)
sns.despine()

plt.savefig('latent_plot_mito_top_mu_overall_dims.png', dpi=300, bbox_inches = "tight")

### grid plots

In [None]:
pretty_hue_names = {
    'DNA volume (μm^3)': 'DNA volume (μm$^3$)',
    'DNA height (μm)': 'DNA height (μm)',
    'Cell volume (μm^3)': 'Cell volume (μm$^3$)',
    'Cell height (μm)': 'Cell height (μm)',
}

In [None]:
pretty_mu_names = {
    'mu_71': "$\mu_{71}$",
    'mu_419': "$\mu_{419}$",
    'mu_188': "$\mu_{188}$",
    'mu_465': "$\mu_{465}$",
}
pretty_mu_names

In [None]:
for hue, hue_pretty_name in pretty_hue_names.items():

    hue_norm_tail = 0.01

    df_plot = df_embeddings_plus_meta_test_pretty_names[
        df_embeddings_plus_meta_test_pretty_names["Mitotic state"] != 'u'
    ]

    g = sns.PairGrid(
        df_plot,
        height=1.5,
        vars=top_mu_dims[:4]
    )
    g.map_diag(sns.kdeplot, color=".2", shade=False)
    g.map_offdiag(
        sns.scatterplot,
        hue=df_plot[hue],
        palette="viridis",
        s=20,
        linewidth=0,
        alpha=0.5,
        hue_norm=(np.quantile(df_plot[hue], hue_norm_tail), np.quantile(df_plot[hue], 1-hue_norm_tail))
    );

    plt.legend(title=hue_pretty_name, bbox_to_anchor=(1.01, 1), frameon=False)
#     g._legend.set_title(new_title)
    
    for i in range(4):
        for j in range(4):
            xlabel = g.axes[i][j].get_xlabel()
            ylabel = g.axes[i][j].get_ylabel()
            if xlabel in pretty_mu_names.keys():
                g.axes[i][j].set_xlabel(pretty_mu_names[xlabel])
            if ylabel in pretty_mu_names.keys():
                g.axes[i][j].set_ylabel(pretty_mu_names[ylabel])
    
    
    plt.subplots_adjust(top=0.95)
#     g.fig.suptitle(f"color = {hue_pretty_name}");
    g.fig.savefig(f"latent_top_mu_dims_gridplot_color_{hue.lower().replace(' ', '_')}.png", dpi=300, bbox_inches = "tight")

### cell volume and height solo for main figure

In [None]:
plt.figure(figsize=(3,3))

hue_norm_tail = 0.01

plt_mito = sns.scatterplot(
    data=df_embeddings_plus_meta_test_pretty_names[
        df_embeddings_plus_meta_test_pretty_names["Mitotic state"] != 'u'
    ].sort_values(
        by="Cell height (μm)"
    ),
    x="mu_71",
    y="mu_419",
    hue="Cell height (μm)",
    palette="viridis",
    linewidth=0,
    alpha = 0.5,
    s=20,
    hue_norm=(
        np.quantile(df_embeddings_plus_meta_test_pretty_names["Cell height (μm)"], hue_norm_tail),
        np.quantile(df_embeddings_plus_meta_test_pretty_names["Cell height (μm)"], 1-hue_norm_tail)
    )
)

plt_mito.set(
    xlabel="$\mu_{71}$",
    ylabel="$\mu_{419}$",
);

plt.legend(bbox_to_anchor=(1.0, 0.7), frameon=False)
sns.despine()

plt.savefig('latent_plot_mito_cell_height_alone.png', dpi=300, bbox_inches = "tight")

In [None]:
plt.figure(figsize=(3,3))

plt_mito = sns.scatterplot(
    data=df_embeddings_plus_meta_test_pretty_names[
        df_embeddings_plus_meta_test_pretty_names["Mitotic state"] != 'u'
    ].sort_values(
        by="Cell volume (μm^3)"
    ),
    x="mu_71",
    y="mu_419",
    hue="Cell volume (μm^3)",
    palette="viridis",
    linewidth=0,
    alpha = 0.5,
    s=20,
    hue_norm=(
        np.quantile(df_embeddings_plus_meta_test_pretty_names["Cell volume (μm^3)"], hue_norm_tail),
        np.quantile(df_embeddings_plus_meta_test_pretty_names["Cell volume (μm^3)"], 1-hue_norm_tail)
    )
)

plt_mito.set(
    xlabel="$\mu_{71}$",
    ylabel="$\mu_{419}$",
);

plt.legend(bbox_to_anchor=(1.0, 0.7), frameon=False)
sns.despine()

plt.savefig('latent_plot_mito_cell_volume_alone.png', dpi=300, bbox_inches = "tight")

## for the top N latent dims, find the top features correlated with them

In [None]:
mu_cols = [c for c in df_embeds_test.columns if c.startswith("mu_") and c != "CellId"]
feat_cols = [c for c in df_feats_all.columns if (c.startswith("dna_") or c.startswith("cell_"))]
corrs_all = df_embeddings_plus_meta_test_all_feats[mu_cols+feat_cols].corr(method="spearman")

In [None]:
corrs_all_cross = corrs_all.drop(mu_cols, axis="rows").drop(feat_cols, axis="columns")

In [None]:
top_mu_dims = list(
    df_embeddings_plus_meta_test[mu_cols].abs().mean().sort_values().index
)
top_mu_dims.reverse()

In [None]:
plt.figure(figsize=(4,2))

g = sns.lineplot(
    x=df_embeddings_plus_meta_test[mu_cols].abs().mean().rank(ascending=False),
    y=df_embeddings_plus_meta_test[mu_cols].abs().mean()
);

g.set(
#     xlim=(-5, 3),
#     ylim=(-4, 5),
    xlabel="Reference latent space dimension rank",
    ylabel="Mean absolute value",
);
sns.despine()

plt.savefig('mu_overall_dims_ranked.png', dpi=300, bbox_inches = "tight")

In [None]:
g = sns.clustermap(
    corrs_all_cross[top_mu_dims[:32]].abs(),
    col_cluster=False,
    figsize=(10,20),
    cmap="Blues",
    vmin=0,
#     center=0,
    vmax=1,
    xticklabels=True,
    yticklabels=True,
);

g.savefig(f"heatmap_mu_vs_feats_vertical.png", dpi=300, bbox_inches = "tight")

In [None]:
g = sns.clustermap(
    corrs_all_cross[top_mu_dims[:32]].T.abs(),
    row_cluster=False,
    figsize=(20,10),
    cmap="Blues",
    vmin=0,
#     center=0,
    vmax=1,
    xticklabels=True,
    yticklabels=True,
);
g.savefig(f"heatmap_mu_vs_feats_horizontal.png", dpi=300, bbox_inches = "tight")

In [None]:
N_latent_dims = 8
K_feats_per_dim = 3

mu_feat_match_dict = {}
for c in corrs_all_cross[top_mu_dims[:N_latent_dims]].columns:
    top_dims_for_c = list(corrs_all_cross[top_mu_dims][c].abs().sort_values()[-K_feats_per_dim:].index)
    top_dims_for_c.reverse()
    mu_feat_match_dict[c] = top_dims_for_c
    
df_mu_feat_match = pd.DataFrame(mu_feat_match_dict)

In [None]:
df_mu_feat_match

## draw samples from latent space along top N interesting dimensions

In [None]:
for dim in df_mu_feat_match.columns:
    int_dim = int(dim.split("_")[-1])
    top_feat_for_dim = df_mu_feat_match.loc[0, dim]
    
    latent_walk_path = np.float32(df_embeds_test[dim].std()*np.linspace(-2,2,9))
    latent_walk_path = torch.from_numpy(latent_walk_path)

    latent_input = torch.zeros([len(latent_walk_path), 512], dtype=torch.float32)
    latent_input[:,int_dim] = latent_walk_path
    latent_input = latent_input.cuda()
    
    walk_output = ref_dec(latent_input)
    walk_output = walk_output.cpu().data
    walk_im = tensor2im(walk_output)
    
    pil_im = Image.fromarray(np.uint8(walk_im*255))
    pil_im.save(f"latent_walk_{dim}_top_feat_{top_feat_for_dim}.png","PNG")
    
    print(f"{dim}, top feature = {top_feat_for_dim}")
    display(pil_im)
    print("")

In [None]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import scale

In [None]:
pca = PCA()
pca.fit(scale(df_feats_all[feat_cols].dropna()))

In [None]:
plt.figure(figsize=(4,2))

g = sns.lineplot(
    x=pd.DataFrame(
        pca.explained_variance_ratio_, columns=["Feature PCA explained variance"]
    ).index,
    y=pd.DataFrame(
        pca.explained_variance_ratio_, columns=["Feature PCA explained variance"]
    )["Feature PCA explained variance"]
);

g.set(
#     xlim=(0, 12),
#     ylim=(-4, 5),
    xlabel="Feature PC rank",
    ylabel="Explained variance",
);
sns.despine()

plt.savefig('pca_feats_explained_variance.png', dpi=300, bbox_inches = "tight")

In [None]:
pd.DataFrame(
    np.cumsum(pca.explained_variance_ratio_), columns=[
    "Feature PCA explained variance"
    ]
)["Feature PCA explained variance"]

In [None]:
mu_corrs = df_embeddings_plus_meta_test_all_feats[top_mu_dims[:32]].corr()

plt.figure(figsize=(8,8))
g = sns.heatmap(
    mu_corrs.abs(),
    cmap="Blues",
    vmin=0,
    vmax=1,
    xticklabels=True,
    yticklabels=True,
    square=True,
    cbar_kws={"shrink": .82}
)

plt.savefig('latent_dim_corrs.png', dpi=300, bbox_inches = "tight")