In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import itertools
import os

import matplotlib as mpl
import matplotlib.pyplot as plt
import mpl_toolkits
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn
import sklearn.neighbors
import sklearn.neural_network
import torch
import tqdm

import celltrip

os.environ['AWS_PROFILE'] = 'waisman-admin'
mpl.rcParams['pdf.fonttype'] = mpl.rcParams['ps.fonttype'] = 42
sns.set_theme(context='paper', style='darkgrid', palette='colorblind')


# Load Data and Policy

In [3]:
# Read data files
adatas = [
    celltrip.utility.processing.merge_adatas(
        *celltrip.utility.processing.read_adatas(*[
            f's3://nkalafut-celltrip/Flysta3D/{p}_{m}.h5ad'
            for p in ('E14-16h_a', 'E16-18h_a', 'L1_a', 'L2_a', 'L3_b')
            # for p in ('L2_a',)
        ], backed=True), backed=True)
    for m in ('expression', 'spatial')]
# Model location and name (should be prefix for .weights, .pre, and .mask file)
# prefix, training_step = 's3://nkalafut-celltrip/checkpoints/flysta-250909-5', 800  # Double-standard
prefix, training_step = 's3://nkalafut-celltrip/checkpoints/flysta-250909-4', 800  # Regular
# Generate or load preprocessing
preprocessing = celltrip.utility.processing.Preprocessing().load(f'{prefix}.pre')
with celltrip.utility.general.open_s3_or_local(f'{prefix}.mask', 'rb') as f:
    mask = np.loadtxt(f).astype(bool)
adatas[0].obs['Training'] = mask  # For meta export, note that obs is stored in memory
# Create sample env (kind of a dumb workaround, TODO)
m1, m2 = [preprocessing.transform(ad[:2].X, subset_modality=i)[0] for i, ad in enumerate(adatas)]
env = celltrip.environment.EnvironmentBase(
    torch.tensor(m1), torch.tensor(m2), target_modalities=[1], compute_rewards=False, dim=8).eval().to('cuda')
# Load policy
policy = celltrip.policy.create_agent_from_env(
    env, forward_batch_size=1_000, vision_size=1_000, pinning_spatial=[1]).eval().to('cuda')
policy.load_checkpoint(f'{prefix}-{training_step:04}.weights');


# Generate Steady States

In [None]:
# for dev in (pbar := tqdm.tqdm(adatas[0].obs['development'].unique(), desc='')):
#     # Subset and preprocess the data
#     pbar.set_description(f'{dev} (Preprocessing)')
#     samples = adatas[0].obs.index[adatas[0].obs['development']==dev]
#     # if len(samples) > 10_000: samples = np.random.choice(samples, 10_000, replace=False)  # For runtime, TESTING
#     m1, m2 = [
#         celltrip.utility.processing.chunk_X(
#             ad[samples], chunk_size=2_000,
#             func=lambda x: preprocessing.transform(x, subset_modality=i)[0])
#             for i, ad in enumerate(adatas)]
#     # Initialize environment
#     pbar.set_description(f'{dev} (Initializing)')
#     env = celltrip.environment.EnvironmentBase(
#         torch.tensor(m1), torch.tensor(m2), target_modalities=[1], compute_rewards=False, dim=8).eval(time_scale=1).to('cuda')  # 32/env.max_time
#     # Simulate to steady state
#     pbar.set_description(f'{dev} (Running)')
#     # env.train().eval(time_scale=1)
#     env.reset()
#     ret = celltrip.train.simulate_until_completion(env, policy, skip_states=100, store_states='cpu')  # progress_bar=True
#     steady_state = ret[-1][-1, :, :env.dim]
#     target_state = env.modalities[env.target_modalities[0]].cpu()
#     with torch.no_grad():
#         imputed_steady_state = policy.pinning[0](steady_state.to('cuda'), Y=target_state.to('cuda')).detach().cpu().numpy()
#     imputed_steady_state, = preprocessing.inverse_transform(imputed_steady_state, subset_modality=1)
#     # Save
#     pbar.set_description(f'{dev} (Saving)')
#     np.save(f'../plots/flysta/CellTRIP_{dev}.npy', imputed_steady_state)
#     np.save(f'../plots/flysta/spatial_{dev}.npy', adatas[1][samples].X)
#     adatas[0].obs.loc[samples].to_csv(f'../plots/flysta/meta_{dev}.csv', index=False);


## Run Comparison Methods

In [None]:
# # Load full data
# X, Y = [
#     celltrip.utility.processing.chunk_X(
#         ad, chunk_size=2_000,
#         func=lambda x: preprocessing.transform(x, subset_modality=i)[0])
#         for i, ad in enumerate(adatas)]


In [None]:
# # Train MLP and export predictions
# model = sklearn.neural_network.MLPRegressor(max_iter=100, verbose=True).fit(X[mask], Y[mask])
# for dev in (pbar := tqdm.tqdm(adatas[0].obs['development'].unique())):
#     # Subset and preprocess the data
#     pbar.set_description(f'{dev} (Preprocessing)')
#     samples = adatas[0].obs.index[adatas[0].obs['development']==dev]
#     X_dev, Y_dev = [
#         celltrip.utility.processing.chunk_X(
#             ad[samples], chunk_size=2_000,
#             func=lambda x: preprocessing.transform(x, subset_modality=i)[0])
#             for i, ad in enumerate(adatas)]
#     # Run model
#     pbar.set_description(f'{dev} (Running)')
#     Y_pred = model.predict(X_dev)
#     imputed_steady_state, = preprocessing.inverse_transform(Y_pred, subset_modality=1)
#     # Save
#     pbar.set_description(f'{dev} (Saving)')
#     np.save(f'../plots/flysta/MLP_{dev}.npy', imputed_steady_state)


In [None]:
# # Export KNN predictions
# model = sklearn.neighbors.KNeighborsRegressor(n_neighbors=10).fit(X[mask], Y[mask])
# for dev in (pbar := tqdm.tqdm(adatas[0].obs['development'].unique())):
#     # Subset and preprocess the data
#     pbar.set_description(f'{dev} (Preprocessing)')
#     samples = adatas[0].obs.index[adatas[0].obs['development']==dev]
#     X_dev, Y_dev = [
#         celltrip.utility.processing.chunk_X(
#             ad[samples], chunk_size=2_000,
#             func=lambda x: preprocessing.transform(x, subset_modality=i)[0])
#             for i, ad in enumerate(adatas)]
#     # Run model
#     pbar.set_description(f'{dev} (Running)')
#     Y_pred = model.predict(X_dev)
#     imputed_steady_state, = preprocessing.inverse_transform(Y_pred, subset_modality=1)
#     # Save
#     pbar.set_description(f'{dev} (Saving)')
#     np.save(f'../plots/flysta/KNN_{dev}.npy', imputed_steady_state)


# Recover Validation State

In [None]:
# # Separate training and validation stages
# development = np.array(['E14-16h_a', 'E16-18h_a', 'L1_a', 'L2_a', 'L3_b'])  # Ordered
# development_training = adatas[0].obs.loc[mask, 'development'].unique()
# development_validation = adatas[0].obs.loc[~mask, 'development'].unique()
# assert len(np.intersect1d(development_training, development_validation)) == 0  # Properly partitioned
# # Get possible interpolation stages
# possible_interpolated_stages = []
# for i in np.argwhere(np.isin(development, development_validation)).flatten():
#     if i == 0 or i == len(development)-1: continue
#     possible_interpolated_stages.append(development[i-1:i+2])
# # Set interpolation
# start_stage, interp_stage, end_stage = possible_interpolated_stages[-1]

In [None]:
# # Grab data
# start_idx = np.argwhere(adatas[0].obs['development'] == start_stage).flatten()
# end_idx = np.argwhere(adatas[0].obs['development'] == end_stage).flatten()
# start_exp = celltrip.utility.processing.chunk_X(
#     adatas[0][start_idx], chunk_size=2_000,
#     func=lambda x: preprocessing.transform(x, subset_modality=0)[0])
# end_exp = celltrip.utility.processing.chunk_X(
#     adatas[0][end_idx], chunk_size=2_000,
#     func=lambda x: preprocessing.transform(x, subset_modality=0)[0])
# start_obs = celltrip.utility.general.transform_and_center(celltrip.utility.processing.chunk_X(adatas[1][start_idx], chunk_size=2_000))
# end_obs = celltrip.utility.general.transform_and_center(celltrip.utility.processing.chunk_X(adatas[1][end_idx], chunk_size=2_000))

# # Use K-Means to create start and end pseudocells
# start_n_pcells = end_n_pcells = 5_000
# start_pcell_ids = sklearn.cluster.KMeans(n_clusters=start_n_pcells, random_state=42).fit_predict(start_obs)
# end_pcell_ids = sklearn.cluster.KMeans(n_clusters=end_n_pcells, random_state=42).fit_predict(start_obs)

# # Get expression and spatial for pseudocells
# start_processed_exp = np.stack([start_exp[np.argwhere(start_pcell_ids==i).flatten()].mean(axis=0) for i in range(start_n_pcells)], axis=0)
# start_processed_obs = np.stack([start_obs[np.argwhere(start_pcell_ids==i).flatten()].mean(axis=0) for i in range(start_n_pcells)], axis=0)
# end_processed_exp = np.stack([end_exp[np.argwhere(end_pcell_ids==i).flatten()].mean(axis=0) for i in range(end_n_pcells)], axis=0)
# end_processed_obs = np.stack([end_obs[np.argwhere(end_pcell_ids==i).flatten()].mean(axis=0) for i in range(end_n_pcells)], axis=0)

# # Calculate OT matrix
# import ot
# a, b = ot.utils.unif(start_processed_obs.shape[0]), ot.utils.unif(end_processed_obs.shape[0])
# M = ot.dist(start_processed_obs, end_processed_obs)
# M /= M.max()
# OT_mat = ot.emd(a, b, M, numItermax=1_000_000)
# # OT_mat = ot.solve(M, a, b)
# # OT_mat = ot.sinkhorn(a, b, M, 1e-1)

# # Calculate pseudocells
# pcells = [([i], np.argwhere(OT_mat[i] > 0).flatten()) for i in range(OT_mat.shape[0]) if OT_mat[i].sum() > 0]
# start_pcells_exp, end_pcells_exp = [], []
# for pcell_start, pcell_end in pcells:
#     start_pcells_exp.append(start_processed_exp[pcell_start].mean(axis=0))
#     end_pcells_exp.append(end_processed_exp[pcell_end].mean(axis=0))
# start_pcells_exp = np.stack(start_pcells_exp, axis=0)
# end_pcells_exp = np.stack(end_pcells_exp, axis=0)

# # Create env
# m1_start, m1_end = start_pcells_exp, end_pcells_exp  # preprocessing.transform
# env = celltrip.environment.EnvironmentBase(
#     torch.tensor(m1_start), target_modalities=None, compute_rewards=False, dim=8).eval(time_scale=1).to('cuda')

# # Get transition states
# env.reset()
# celltrip.train.simulate_until_completion(env, policy, store_states=False)  # Set env at steady state
# env.time = 0  # Reset timing
# env.set_modalities([torch.tensor(m1_end)]).to('cuda')  # Set to ending expression
# transition_states = celltrip.train.simulate_until_completion(env, policy, skip_states=50, store_states='cpu', progress_bar=True)[-1][..., :env.dim].cpu()

# # Impute transition states
# with torch.no_grad():
#     imputed_transition_states = policy.pinning[0](transition_states.to('cuda')).detach().cpu().numpy()
# imputed_transition_states, = preprocessing.inverse_transform(imputed_transition_states, subset_modality=1)
# np.save(f'../plots/flysta/Interpolated_{interp_stage}.npy', imputed_transition_states)


1281it [00:17, 71.60it/s]


# Perform Knockdown

In [None]:
results = []
for dev in adatas[0].obs['development'].unique():
    # Subset and preprocess the data
    samples = adatas[0].obs.index[adatas[0].obs['development']==dev]
    raw_m1 = celltrip.utility.processing.chunk_X(adatas[0][samples], chunk_size=2_000)
    m1, = preprocessing.transform(raw_m1, subset_modality=0)
    m2 = celltrip.utility.processing.chunk_X(
        adatas[1][samples], chunk_size=2_000,
        func=lambda x: preprocessing.transform(x, subset_modality=1)[0])
    # Initialize environment
    env = celltrip.environment.EnvironmentBase(
        torch.tensor(m1), torch.tensor(m2), target_modalities=[1], compute_rewards=False, dim=8).eval(time_scale=1).to('cuda')
    # Simulate to steady state
    env.reset()
    celltrip.train.simulate_until_completion(env, policy)
    steady_pos, steady_vel = (env.pos, env.vel)
    # Perturb
    for i, gene in enumerate(tqdm.tqdm(adatas[0].var.index)):
        # Reset environment
        env.set_max_time(5*env.delta).reset()  # TODO: Run longer
        env.set_positions(steady_pos)
        env.set_velocities(steady_vel)  # Or maybe 0?
        # Get knockdowns
        iso_modality, = preprocessing.transform(raw_m1, subset_features=[i], subset_modality=0)
        iso_modality = torch.tensor(iso_modality).to(env.device)
        env.modalities[0] = env.modalities[0] - (iso_modality - 0*iso_modality)
        # Simulate
        states = celltrip.train.simulate_until_completion(env, policy, store_states='cpu')[-1]
        # Impute
        with torch.no_grad():
            imputed_states = policy.pinning[0](states[..., :env.dim].to('cuda')).detach().cpu().numpy()
        imputed_states, = preprocessing.inverse_transform(imputed_states, subset_modality=1)
        # Record
        def add_record(states, ct):
            results.append({
                'Gene': gene,
                'Development': dev,
                'Cell Type': ct,
                'Effect Size': np.square(states[-1] - states[0]).mean(axis=-1).mean(),
                'Trajectory Length': np.square(states[1:] - states[:-1]).mean(axis=(0, -1)).mean()})
        add_record(imputed_states, 'All')
        for ct in adatas[0][samples].obs['annotation'].unique():
            add_record(imputed_states[:, adatas[0][samples].obs['annotation']==ct], ct)
# Convert and save
pd.DataFrame(results).to_csv('../plots/flysta/knockdown.csv', index=None)

  0%|          | 53/13668 [02:27<10:31:28,  2.78s/it]


KeyboardInterrupt: 