In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
import itertools
import json
import os

import anndata as ad
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn
import torch
import tqdm
import umap

import celltrip


# Load Data and Policy

In [None]:
# Read data files
# adata_prefix = 's3://nkalafut-celltrip/Dyngen'
adata_prefix = '../data/Dyngen'
adatas = celltrip.utility.processing.read_adatas(
    's3://nkalafut-celltrip/dyngen/logcounts.h5ad',
    's3://nkalafut-celltrip/dyngen/counts_protein.h5ad',
    backed=True)
# Model location and name (should be prefix for .weights, .pre, and .mask file)
# prefix, training_step = 's3://nkalafut-celltrip/checkpoints/Dyngen-250920', 800  # 8 dim
# prefix, training_step = 's3://nkalafut-celltrip/checkpoints/Dyngen-251015', 800  # 32 dim
prefix, training_step = 's3://nkalafut-celltrip/checkpoints/Dyngen-251025', 800  # 32 dim, extra feature processing hidden layer
# Generate or load preprocessing
preprocessing = celltrip.utility.processing.Preprocessing().load(f'{prefix}.pre')
with celltrip.utility.general.open_s3_or_local(f'{prefix}.mask', 'rb') as f:
    mask = np.loadtxt(f).astype(bool)
adatas[0].obs['Training'] = mask
# Create sample env (kind of a dumb workaround, TODO)
m1, m2 = [preprocessing.transform(ad[:2].X, subset_modality=i)[0] for i, ad in enumerate(adatas)]
env = celltrip.environment.EnvironmentBase(
    torch.tensor(m1), torch.tensor(m2), compute_rewards=False, dim=32).eval().to('cuda')
# Load policy
policy = celltrip.policy.create_agent_from_env(
    env, forward_batch_size=1_000, vision_size=1_000).eval().to('cuda')
policy.load_checkpoint(f'{prefix}-{training_step:04}.weights');


# Plot Steady State

In [None]:
# m1, m2 = [preprocessing.transform(ad.X[:], subset_modality=i)[0] for i, ad in enumerate(adatas)]
# env = celltrip.environment.EnvironmentBase(
#     torch.tensor(m1), torch.tensor(m2), compute_rewards=False, dim=32).eval(time_scale=1).to('cuda')
# ret = celltrip.train.simulate_until_completion(env, policy, skip_states=100, store_states='cpu', progress_bar=True)
# steady_state = ret[-1][-1, :, :env.dim]

# # GEX
# with torch.no_grad():
#     imputed_steady_state_0 = policy.pinning[0](steady_state.to('cuda')).detach().cpu().numpy()
# imputed_steady_state_0, = preprocessing.inverse_transform(imputed_steady_state_0, subset_modality=0)

# # Protein
# with torch.no_grad():
#     imputed_steady_state_1 = policy.pinning[1](steady_state.to('cuda')).detach().cpu().numpy()
# imputed_steady_state_1, = preprocessing.inverse_transform(imputed_steady_state_1, subset_modality=1)

## Preview

In [None]:
# # Generate data GEX
# X0 = adatas[0].X[:]
# X0_pred = imputed_steady_state_0
# red = umap.UMAP()  # random_state=42
# Y0 = red.fit_transform(X0)
# Y0_pred = red.transform(X0_pred)

# # Generate data protein
# X1 = adatas[1].X[:]
# X1_pred = imputed_steady_state_1
# red = umap.UMAP()  # random_state=42
# Y1 = red.fit_transform(X1)
# Y1_pred = red.transform(X1_pred)

# # Plot figure
# fig, axs = plt.subplots(2, 2, figsize=(10, 10))
# for i, j in itertools.product(*[np.arange(2) for _ in range(2)]):
#     # Get axis
#     ax = axs[i, j]

#     # Get data
#     df = pd.DataFrame(index=adatas[0].obs_names)
#     df[['x', 'y']] = [Y0, Y0_pred, Y1, Y1_pred][2*i+j]
#     df['Trajectory'] = adatas[0].obs['traj_sim']
#     df['Validation'] = ~adatas[0].obs['Training']

#     # Plot
#     legend = (i==0)*(j==1)
#     sns.scatterplot(df, x='x', y='y', hue='Trajectory', style='Validation', edgecolor='black', legend=legend, ax=ax)
#     if legend: sns.move_legend(ax, 'upper left', bbox_to_anchor=(1.05, 1))

#     # Format
#     ax.set(xlabel=None, ylabel=None)
#     sns.despine(ax=ax)
#     if j == 1:
#         # Set xlim
#         ax_alt = axs[i, 0]
#         ax_alt_xlim = ax_alt.get_xlim()
#         ax_xlim = ax.get_xlim()
#         xlim = np.stack([ax_xlim, ax_alt_xlim], axis=0).max(axis=0)
#         ax.set_xlim(xlim)
        
#         # Set ylim
#         ax_alt = axs[i, 0]
#         ax_alt_ylim = ax_alt.get_ylim()
#         ax_ylim = ax.get_ylim()
#         ylim = np.stack([ax_ylim, ax_alt_ylim], axis=0).max(axis=0)
#         ax.set_ylim(ylim)

#     # Title
#     if i == 0:
#         if j == 0: ax.set_title('Observed')
#         if j == 1: ax.set_title('Reconstructed')
#     if j == 0:
#         if i == 0: ax.set_ylabel('Gene Expression', fontsize='large')
#         if i == 1: ax.set_ylabel('Protein Counts', fontsize='large')


# Perform Significance Estimation

In [None]:
# Params
np.random.seed(42)
genes_to_survey = adatas[0].var_names
sim_time = 1.

# Mute warnings (array wrap and indexing)
import warnings
warnings.simplefilter('ignore')

# Create anndata
ad_pert = ad.AnnData(obs=adatas[0].obs, var=pd.DataFrame(index=[f'Feature {i}' for i in range(env.dim)]))
# ad_pert0 = ad.AnnData(obs=adatas[0].obs, var=adatas[0].var)
# ad_pert1 = ad.AnnData(obs=adatas[1].obs, var=adatas[1].var)
def add_layers(states, gene):
    ad_pert.layers[gene] = states
    # ad_pert0.layers[gene] = states_0
    # ad_pert1.layers[gene] = states_1

# Add results
results = []
def add_record(states, states_0, states_1, gene, ct):
    results.append({
        'Gene': gene, 'Cell Type': ct,
        'Effect Size (Latent)': np.square(states[-1] - states[0]).mean(),
        'Trajectory Length (Latent)': np.square(states[1:] - states[:-1]).mean(axis=(-2, -1)).sum(),
        'Effect Size (Modality 0)': np.square(states_0[-1] - states_0[0]).mean(),
        'Trajectory Length (Modality 0)': np.square(states_0[1:] - states_0[:-1]).mean(axis=(-2, -1)).sum(),
        'Effect Size (Modality 1)': np.square(states_1[-1] - states_1[0]).mean(),
        'Trajectory Length (Modality 1)': np.square(states_1[1:] - states_1[:-1]).mean(axis=(-2, -1)).sum()})
    
# Reset function
def reset_env(env, steady_pos, steady_vel, modal_dict={}):
    env.set_max_time(sim_time).reset()  # TODO: Maybe longer?, early stopping?
    env.set_positions(steady_pos)
    env.set_velocities(steady_vel)  # Maybe 0 manually?
    for k, v in modal_dict.items():
        env.modalities[k] = v

# Running function
def run_and_record(samples, env, policy, preprocessing, gene, gene_idx):
    # Run and impute
    states = celltrip.train.simulate_until_completion(
        env, policy,
        env_hooks=[
            celltrip.utility.hooks.clamp_inverted_features_hook(
                gene_idx, preprocessing, feature_targets=0., modality_idx=0),
        ],
        action_hooks=[
            celltrip.utility.hooks.move_toward_targets_hook(
                gene_idx, feature_targets=0., pinning=policy.pinning[0],
                preprocessing=preprocessing, modality_idx=0,
                factor=1, device=env.device),
        ],
        store_states='cpu')[-1]
    states_pos = states[..., :env.dim]
    with torch.no_grad():
        imputed_states_0 = policy.pinning[0](states_pos.to('cuda')).detach().cpu().numpy()
        imputed_states_1 = policy.pinning[1](states_pos.to('cuda')).detach().cpu().numpy()
    imputed_states_0, = preprocessing.inverse_transform(imputed_states_0, subset_modality=0)
    imputed_states_1, = preprocessing.inverse_transform(imputed_states_1, subset_modality=1)
    # Record
    add_layers(states_pos.numpy()[-1], gene)  # , imputed_states_0[-1], imputed_states_1[-1]
    add_record(states_pos.numpy(), imputed_states_0, imputed_states_1, gene, 'All')
    for ct in adatas[0][samples].obs['traj_sim'].unique():
        add_record(
            states_pos[:, adatas[0][samples].obs['traj_sim']==ct].numpy(),
            imputed_states_0[:, adatas[0][samples].obs['traj_sim']==ct],
            imputed_states_1[:, adatas[0][samples].obs['traj_sim']==ct],
            gene, ct)

# Subset and preprocess the data
samples = adatas[0].obs.index
raw_m1 = celltrip.utility.processing.chunk_X(adatas[0][samples], chunk_size=2_000)
m1, m2 = [
    celltrip.utility.processing.chunk_X(
        ad[samples], chunk_size=2_000,
        func=lambda x: preprocessing.transform(x, subset_modality=i)[0])
        for i, ad in enumerate(adatas)]

# Initialize environment
env = celltrip.environment.EnvironmentBase(
    torch.tensor(m1), torch.tensor(m2), compute_rewards=False, dim=env.dim).eval(time_scale=1).to('cuda')

# Simulate to steady state
env.reset()
celltrip.train.simulate_until_completion(env, policy)
steady_pos, steady_vel = (env.pos, env.vel)

# Run control
reset_env(env, steady_pos, steady_vel)
add_layers(steady_pos.cpu().numpy(), 'Steady')
run_and_record(samples, env, policy, preprocessing, 'Control', [])

# Perturb
for gene in tqdm.tqdm(genes_to_survey, miniters=10, maxinterval=30):
    # Get gene idx and run
    gene_idx = np.argwhere(adatas[0].var_names==gene).flatten()
    reset_env(env, steady_pos, steady_vel)  # {0: torch.tensor(m1).cuda()}
    run_and_record(samples, env, policy, preprocessing, gene, gene_idx)

# Convert and save
pd.DataFrame(results).to_csv('../plots/dyngen/knockdown.csv', index=None)
ad_pert.write_h5ad('../plots/dyngen/knockdown_results.h5ad')
# ad_pert0.write_h5ad('../plots/dyngen/knockdown_results_modality_0.csv')
# ad_pert1.write_h5ad('../plots/dyngen/knockdown_results_modality_1.csv')


# Perform Module Knockdown

In [None]:
# # Load modules
# with celltrip.utility.general.open_s3_or_local('../plots/dyngen/dyngen_tfs.json', 'rb') as f:
#     tf_modules = json.load(f)
#     tf_modules.pop('NaN')

# # Params
# np.random.seed(42)
# sim_time = 128.

# # Mute warnings (array wrap and indexing)
# import warnings
# warnings.simplefilter('ignore')

# # Create anndata
# ad_pert = ad.AnnData(obs=adatas[0].obs, var=pd.DataFrame(index=[f'Feature {i}' for i in range(env.dim)]))
# # ad_pert0 = ad.AnnData(obs=adatas[0].obs, var=adatas[0].var)
# # ad_pert1 = ad.AnnData(obs=adatas[1].obs, var=adatas[1].var)
# def add_layers(states, gene):
#     ad_pert.layers[gene] = states
#     # ad_pert0.layers[gene] = states_0
#     # ad_pert1.layers[gene] = states_1

# # Add results
# results = []
# def add_record(states, states_0, states_1, gene, ct):
#     results.append({
#         'Gene': gene, 'Cell Type': ct,
#         'Effect Size (Latent)': np.square(states[-1] - states[0]).mean(),
#         'Trajectory Length (Latent)': np.square(states[1:] - states[:-1]).mean(axis=(-2, -1)).sum(),
#         'Effect Size (Modality 0)': np.square(states_0[-1] - states_0[0]).mean(),
#         'Trajectory Length (Modality 0)': np.square(states_0[1:] - states_0[:-1]).mean(axis=(-2, -1)).sum(),
#         'Effect Size (Modality 1)': np.square(states_1[-1] - states_1[0]).mean(),
#         'Trajectory Length (Modality 1)': np.square(states_1[1:] - states_1[:-1]).mean(axis=(-2, -1)).sum()})
    
# # Reset function
# def reset_env(env, steady_pos, steady_vel, modal_dict={}):
#     env.set_max_time(sim_time).reset()  # TODO: Maybe longer?, early stopping?
#     env.set_positions(steady_pos)
#     env.set_velocities(steady_vel)  # Maybe 0 manually?
#     for k, v in modal_dict.items():
#         env.modalities[k] = v

# # Running function
# def run_and_record(samples, env, policy, preprocessing, gene, gene_idx):
#     # Run and impute
#     states = celltrip.train.simulate_until_completion(
#         env, policy,
#         env_hooks=[
#             celltrip.utility.hooks.clamp_inverted_features_hook(
#                 gene_idx, preprocessing, feature_targets=0., modality_idx=0),
#         ],
#         action_hooks=[
#             celltrip.utility.hooks.move_toward_targets_hook(
#                 gene_idx, feature_targets=0., pinning=policy.pinning[0],
#                 preprocessing=preprocessing, modality_idx=0,
#                 factor=1, device=env.device),
#         ],
#         skip_states=100, store_states='cpu')[-1]
#     states_pos = states[..., :env.dim]
#     with torch.no_grad():
#         imputed_states_0 = policy.pinning[0](states_pos.to('cuda')).detach().cpu().numpy()
#         imputed_states_1 = policy.pinning[1](states_pos.to('cuda')).detach().cpu().numpy()
#     imputed_states_0, = preprocessing.inverse_transform(imputed_states_0, subset_modality=0)
#     imputed_states_1, = preprocessing.inverse_transform(imputed_states_1, subset_modality=1)
#     # Record
#     add_layers(states_pos.numpy()[-1], gene)  # , imputed_states_0[-1], imputed_states_1[-1]
#     add_record(states_pos.numpy(), imputed_states_0, imputed_states_1, gene, 'All')
#     for ct in adatas[0][samples].obs['traj_sim'].unique():
#         add_record(
#             states_pos[:, adatas[0][samples].obs['traj_sim']==ct].numpy(),
#             imputed_states_0[:, adatas[0][samples].obs['traj_sim']==ct],
#             imputed_states_1[:, adatas[0][samples].obs['traj_sim']==ct],
#             gene, ct)

# # Subset and preprocess the data
# samples = adatas[0].obs.index
# raw_m1 = celltrip.utility.processing.chunk_X(adatas[0][samples], chunk_size=2_000)
# m1, m2 = [
#     celltrip.utility.processing.chunk_X(
#         ad[samples], chunk_size=2_000,
#         func=lambda x: preprocessing.transform(x, subset_modality=i)[0])
#         for i, ad in enumerate(adatas)]

# # Initialize environment
# env = celltrip.environment.EnvironmentBase(
#     torch.tensor(m1), torch.tensor(m2), compute_rewards=False, dim=env.dim).eval(time_scale=5).to('cuda')

# # Simulate to steady state
# env.reset()
# celltrip.train.simulate_until_completion(env, policy)
# steady_pos, steady_vel = (env.pos, env.vel)

# # Run control
# reset_env(env, steady_pos, steady_vel)
# add_layers(steady_pos.cpu().numpy(), 'Steady')
# run_and_record(samples, env, policy, preprocessing, 'Control', [])

# # Perturb
# for module, genes in (pbar := tqdm.tqdm(tf_modules.items())):
#     # Get gene idx and run
#     pbar.set_description(module)
#     gene_idx = np.argwhere(np.isin(adatas[0].var_names, genes)).flatten()
#     reset_env(env, steady_pos, steady_vel)  # {0: torch.tensor(m1).cuda()}
#     run_and_record(samples, env, policy, preprocessing, module, gene_idx)

# # Convert and save
# pd.DataFrame(results).to_csv('../plots/dyngen/knockdown_full.csv', index=None)
# ad_pert.write_h5ad('../plots/dyngen/knockdown_full_results.h5ad')


## Preview

In [None]:
# # Load modules
# with celltrip.utility.general.open_s3_or_local('../plots/dyngen/dyngen_tfs.json', 'rb') as f:
#     tf_modules = json.load(f)
#     tf_modules.pop('NaN')

# # Load perturbations
# ad_pert = ad.read_h5ad('../plots/dyngen/knockdown_full_results.h5ad')
# control_gex, = preprocessing.inverse_transform(policy.pinning[0](torch.tensor(ad_pert.layers['Control']).cuda()).detach().cpu().numpy(), subset_modality=0)

# # Load module
# module_pivot = []
# for module in tf_modules.keys():
#     pert_gex, = preprocessing.inverse_transform(policy.pinning[0](torch.tensor(ad_pert.layers[module]).cuda()).detach().cpu().numpy(), subset_modality=0)
#     effect_sizes = pd.DataFrame(index=adatas[0].obs.index)
#     effect_sizes['Trajectory'] = adatas[0].obs['traj_sim']
#     effect_sizes['Effect Size'] = np.sqrt(np.square(pert_gex - control_gex).mean(axis=-1))
#     # tf_mask, tg_mask = adatas[0].var['is_tf'], ~adatas[0].var['is_tf']
#     tf_mask, tg_mask, hk_mask = adatas[0].var['is_tf']*~adatas[0].var['is_hk'], ~adatas[0].var['is_tf']*~adatas[0].var['is_hk'], adatas[0].var['is_hk']
#     effect_sizes['TF Effect Size'] = np.sqrt(np.square(pert_gex[..., tf_mask] - control_gex[..., tf_mask]).mean(axis=-1))
#     effect_sizes['TG Effect Size'] = np.sqrt(np.square(pert_gex[..., tg_mask] - control_gex[..., tg_mask]).mean(axis=-1))
#     effect_sizes['HK Effect Size'] = np.sqrt(np.square(pert_gex[..., hk_mask] - control_gex[..., hk_mask]).mean(axis=-1))
#     effect_sizes_pivot = effect_sizes[['Trajectory', 'TF Effect Size', 'TG Effect Size', 'HK Effect Size']].melt(id_vars='Trajectory', var_name='Type', value_name='Effect Size')
#     effect_sizes_pivot['Type'] = effect_sizes_pivot['Type'].str.split(' ').apply(lambda a: a[0])
#     effect_sizes_pivot['Module'] = module
#     module_pivot.append(effect_sizes_pivot)
# module_pivot = pd.concat(module_pivot, axis=0)
# module_pivot = module_pivot.reset_index(drop=True)

In [None]:
# # Get means and stds
# group_cols = ['Module', 'Type']
# grouped_means = module_pivot.groupby(group_cols, observed=True)[['Effect Size']].mean()
# grouped_stds = module_pivot.groupby(group_cols, observed=True)[['Effect Size']].std()

# # Normalize mean
# module_pivot[['Effect Size']] = (
#     module_pivot[['Effect Size']]
#     - grouped_means.loc[
#         list(
#             module_pivot[group_cols]
#                 .itertuples(name=None, index=False))].reset_index(drop=True))

# # Normalize std
# module_pivot[['Effect Size']] = (
#     module_pivot[['Effect Size']]
#     / grouped_stds.loc[
#         list(
#             module_pivot[group_cols]
#                 .itertuples(name=None, index=False))].reset_index(drop=True))

In [None]:
# fig, ax = plt.subplots(1, 1, figsize=(10, 5))
# df = module_pivot.loc[module_pivot['Module'] == 'A1']
# # df = df.groupby(['Trajectory', 'Type'], observed=True)[['Effect Size']].mean().reset_index()
# sns.violinplot(df, x='Trajectory', y='Effect Size', hue='Type', split=False, inner='quart', ax=ax)
# # sns.boxplot(df, x='Trajectory', y='Effect Size', hue='Type', ax=ax)
# # sns.barplot(df, x='Trajectory', y='Effect Size', hue='Type', ax=ax)
# ax.set_xticks(ax.get_xticks())
# ax.set_xticklabels(ax.get_xticklabels(), rotation=30)
# # ax.set_ylim(bottom=0, top=2)
# # ax.set_ylim(-.5, .5)
# sns.despine(ax=ax)