In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import argparse
import os
import random
import shlex

import ray

import celltrip

# Detect Cython
CYTHON_ACTIVE = os.path.splitext(celltrip.utility.general.__file__)[1] in ('.c', '.so')
print(f'Cython is{" not" if not CYTHON_ACTIVE else ""} active')


Cython is not active


# Arguments

In [3]:
# Arguments
# NOTE: It is not recommended to use s3 with credentials unless the creds are permanent, the bucket is public, or this is run on AWS
parser = argparse.ArgumentParser(description='Train CellTRIP model', formatter_class=argparse.ArgumentDefaultsHelpFormatter)

# Reading
group = parser.add_argument_group('Input')
group.add_argument('input_files', type=str, nargs='*', help='h5ad files to be used for input')
group.add_argument('--merge_files', type=str, action='append', nargs='+', help='h5ad files to merge as input')
group.add_argument('--partition_cols', type=str, nargs='+', help='Columns for data partitioning, found in `adata.obs` DataFrame')
group.add_argument('--backed', action='store_true', help='Read data directly from disk or s3, saving memory at the cost of time')
group.add_argument('--input_modalities', type=int, nargs='+', help='Input modalities to give to CellTRIP')
group.add_argument('--target_modalities', type=int, nargs='+', help='Target modalities to emulate, dictates environment reward')
# Algorithm
group = parser.add_argument_group('Algorithm')
group.add_argument('--dim', type=int, default=16, help='Dimensions in the output latent space')
group.add_argument('--discrete', action='store_true', help='Use the discrete model rather than continuous')
group.add_argument('--train_split', type=float, default=1., help='Fraction of input data to use as training')
group.add_argument('--train_partitions', action='store_true', help='Split training/validation data across partitions rather than samples')
# Computation
group = parser.add_argument_group('Computation')
group.add_argument('--num_gpus', type=int, default=1, help='Number of GPUs to use during computation')
group.add_argument('--num_learners', type=int, default=1, help='Number of learners used in backward computation, cannot exceed GPUs')
group.add_argument('--num_runners', type=int, default=1, help='Number of workers for environment simulation')
# Training
group = parser.add_argument_group('Training')
group.add_argument('--update_timesteps', type=int, default=int(1e6), help='Number of timesteps recorded before each update')
group.add_argument('--max_timesteps', type=int, default=int(2e9), help='Maximum number of timesteps to compute before exiting')
group.add_argument('--dont_sync_across_nodes', action='store_true', help='Avoid memory sync across nodes, saving overhead time at the cost of stability')
# File saves
group = parser.add_argument_group('Logging')
group.add_argument('--logfile', type=str, default='cli', help='Location for log file, can be `cli`, `<local_file>`, or `<s3 location>`')
group.add_argument('--flush_iterations', default=25, type=int, help='Number of iterations to wait before flushing logs')
group.add_argument('--checkpoint', type=str, help='Checkpoint to use for initializing model')
group.add_argument('--checkpoint_iterations', type=int, default=100, help='Number of updates to wait before recording checkpoints')
group.add_argument('--checkpoint_dir', type=str, default='./checkpoints', help='Directory for checkpoints')
group.add_argument('--checkpoint_name', type=str, help='Run name, for checkpointing')

# Notebook defaults and script handling
if not celltrip.utility.notebook.is_notebook():
    # ray job submit -- python train.py...
    config = parser.parse_args()
else:
    experiment_name = 'VCC-250818'
    bucket_name = 'nkalafut-celltrip'
    # bucket_name = 'arn:aws:s3:us-east-2:245432013314:accesspoint/ray-nkalafut-celltrip'
    command = (
        # MERFISH
        # f's3://{bucket_name}/MERFISH/expression.h5ad s3://{bucket_name}/MERFISH/spatial.h5ad --target_modalities 1 '
        # scGLUE
        # f's3://{bucket_name}/scGLUE/Chen-2019-RNA.h5ad s3://{bucket_name}/scGLUE/Chen-2019-ATAC.h5ad '
        # f's3://{bucket_name}/scGLUE/Chen-2019-RNA.h5ad s3://{bucket_name}/scGLUE/Chen-2019-ATAC.h5ad --input_modalities 0 --target_modalities 0 '
        # f'../data/scglue/Chen-2019-RNA.h5ad ../data/scglue/Chen-2019-ATAC.h5ad --input_modalities 0 --target_modalities 0 '
        # Flysta3D
        # f' '.join([f'--merge_files ' + ' ' .join([f's3://{bucket_name}/Flysta3D/{p}_{m}.h5ad' for p in ('E14-16h_a', 'E16-18h_a', 'L1_a', 'L2_a', 'L3_b')]) for m in ('expression', 'spatial')]) + ' '
        # f'--target_modalities 1 '
        # f'--partition_cols development '
        # Particular stage Flysta
        # f' '.join([f'--merge_files ' + ' ' .join([f's3://{bucket_name}/Flysta3D/{p}_{m}.h5ad' for p in ('L3_b',)]) for m in ('expression', 'spatial')]) + ' '
        # f'--target_modalities 1 '
        # f'--partition_cols development '
        # Tahoe-100M
        # f'--merge_files ' + ' '.join([f's3://{bucket_name}/Tahoe/plate{i}_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad' for i in range(1, 15)]) + ' '
        # f'--partition_cols sample '
        # scMultiSim
        # f's3://{bucket_name}/scMultiSim/expression.h5ad s3://{bucket_name}/scMultiSim/peaks.h5ad '
        # MERFISH Bench
        # f's3://{bucket_name}/MERFISH_Bench/expression.h5ad s3://{bucket_name}/MERFISH_Bench/spatial.h5ad '
        # f'--target_modalities 1 '
        # TemporalBrain
        # f's3://{bucket_name}/TemporalBrain/expression.h5ad s3://{bucket_name}/TemporalBrain/peaks.h5ad '
        # f'--partition_cols "Donor ID" '
        # Virtual Cell Challenge
        f's3://{bucket_name}/VirtualCell/vcc_flt_data.h5ad '
        f'--partition_cols target_gene '

        f'--backed '
        # f'--dim 2 '
        # f'--dim 8 '
        f'--dim 4 '
        # f'--discrete '

        # Sample split
        f'--train_split .8 '
        # Partition split
        # f'--train_split .6 '
        # f'--train_partitions '
        # Single slice
        # f'--train_split .0001 '
        # f'--train_partitions '

        f'--num_gpus 2 --num_learners 2 --num_runners 2 '
        f'--update_timesteps 1_000_000 '
        f'--max_timesteps 800_000_000 '
        # f'--update_timesteps 100_000 '
        # f'--max_timesteps 100_000_000 '
        f'--dont_sync_across_nodes '
        f'--logfile s3://{bucket_name}/logs/{experiment_name}.log '
        f'--flush_iterations 1 '
        # f'--checkpoint s3://nkalafut-celltrip/checkpoints/MERFISH_Bench-250805-2-0800.weights '
        f'--checkpoint_iterations 100 '
        f'--checkpoint_dir s3://{bucket_name}/checkpoints '
        f'--checkpoint_name {experiment_name}')
    config = parser.parse_args(shlex.split(command))
    print(f'python train.py {command}')
    
# Defaults
if config.checkpoint_name is None:
    config.checkpoint_name = f'RUN_{random.randint(0, 2**32):0>10}'
    print(f'Run Name: {config.checkpoint_name}')
# print(config)  # CLI


python train.py s3://nkalafut-celltrip/VirtualCell/vcc_flt_data.h5ad --partition_cols target_gene --backed --dim 4 --train_split .8 --num_gpus 2 --num_learners 2 --num_runners 2 --update_timesteps 1_000_000 --max_timesteps 800_000_000 --dont_sync_across_nodes --logfile s3://nkalafut-celltrip/logs/VCC-250818.log --flush_iterations 1 --checkpoint_iterations 100 --checkpoint_dir s3://nkalafut-celltrip/checkpoints --checkpoint_name VCC-250818


# Deploy Remotely

In [None]:
# Start Ray
ray.shutdown()
a = ray.init(
    # address='ray://100.85.187.118:10001',
    address='ray://localhost:10001',
    runtime_env={
        'py_modules': [celltrip],
        'pip': '../requirements.txt',
        'env_vars': {
            # **access_keys,
            'RAY_DEDUP_LOGS': '0'}},
        # 'NCCL_SOCKET_IFNAME': 'tailscale',  # lo,en,wls,docker,tailscale
    _system_config={'enable_worker_prestart': True})  # Doesn't really work for scripts


In [None]:
@ray.remote(num_cpus=1e-4)
def train(config):
    import celltrip

    # Initialization
    dataloader_kwargs = {
        'num_nodes': [2**9, 2**11], 'mask': config.train_split,
        'mask_partitions': config.train_partitions}  # {'num_nodes': 20, 'pca_dim': 128}
    environment_kwargs = {
        'input_modalities': config.input_modalities,
        'target_modalities': config.target_modalities, 'dim': config.dim,
        'discrete': config.discrete}  # , 'spherical': config.discrete
    policy_kwargs = {'discrete': config.discrete}
    memory_kwargs = {'device': 'cuda:0'}
    initializers = celltrip.train.get_initializers(
        input_files=config.input_files, merge_files=config.merge_files,
        backed=config.backed, partition_cols=config.partition_cols,
        dataloader_kwargs=dataloader_kwargs,
        environment_kwargs=environment_kwargs,
        policy_kwargs=policy_kwargs,
        memory_kwargs=memory_kwargs)  # Skips casting, cutting time significantly for relatively small batch sizes

    # Stages
    stage_functions = [
        # lambda w: w.env.set_delta(.1),
        # lambda w: w.env.set_delta(.05),
        # lambda w: w.env.set_delta(.01),
        # lambda w: w.env.set_delta(.005),
    ]

    # Run function
    celltrip.train.train_celltrip(
        initializers=initializers,
        num_gpus=config.num_gpus, num_learners=config.num_learners,
        num_runners=config.num_runners, max_timesteps=config.max_timesteps,
        update_timesteps=config.update_timesteps, sync_across_nodes=not config.dont_sync_across_nodes,
        flush_iterations=config.flush_iterations,
        checkpoint_iterations=config.checkpoint_iterations, checkpoint_dir=config.checkpoint_dir,
        checkpoint=config.checkpoint, checkpoint_name=config.checkpoint_name,
        stage_functions=stage_functions, logfile=config.logfile)

ray.get(train.remote(config))


# Run Locally

In [None]:
# import numpy as np
# import torch
# torch.random.manual_seed(42)
# np.random.seed(42)

# # Initialize locally
# os.environ['AWS_PROFILE'] = 'waisman-admin'
# config.update_timesteps = 100_000
# config.max_timesteps = 20_000_000

# dataloader_kwargs = {'num_nodes': [2**9, 2**11], 'mask': config.train_split}  # {'num_nodes': [2**9, 2**11], 'pca_dim': 128}
# environment_kwargs = {
#     'input_modalities': config.input_modalities,
#     'target_modalities': config.target_modalities, 'dim': config.dim}
# env_init, policy_init, memory_init = celltrip.train.get_initializers(
#     input_files=config.input_files, merge_files=config.merge_files,
#     partition_cols=config.partition_cols,
#     backed=config.backed, dataloader_kwargs=dataloader_kwargs,
#     policy_kwargs={'minibatch_size': 10_000},
#     # memory_kwargs={'device': 'cuda:0'},  # Skips casting, cutting time significantly for relatively small batch sizes
#     environment_kwargs=environment_kwargs)

# # Environment
# # os.environ['CUDA_LAUNCH_BLOCKING']='1'
# try: env
# except: env = env_init().to('cuda')

# # Policy
# policy = policy_init(env).to('cuda')

# # Memory
# memory = memory_init(policy)


In [None]:
# # Forward
# import line_profiler
# memory.mark_sampled()
# memory.cleanup()
# prof = line_profiler.LineProfiler(
#     celltrip.train.simulate_until_completion,
#     celltrip.policy.PPO.forward,
#     celltrip.policy.EntitySelfAttentionLite.forward,
#     celltrip.policy.ResidualAttention.forward,
#     celltrip.environment.EnvironmentBase.step)
# ret = prof.runcall(celltrip.train.simulate_until_completion, env, policy, memory, max_memories=config.update_timesteps, reset_on_finish=True)
# print('ROLLOUT: ' + f'total: {ret[2]:.3f}, ' + ', '.join([f'{k}: {v:.3f}' for k, v in ret[3].items()]))
# # memory.feed_new(policy.reward_standardization)
# memory.compute_advantages()  # moving_standardization=policy.reward_standardization
# prof.print_stats(output_unit=1)


tensor([23.3764, 34.5881, 67.3236,  ..., 28.3253, 24.7067, 43.3550],
       device='cuda:0')
tensor([23.4403, 34.5967, 67.3538,  ..., 28.3907, 24.7294, 43.4552],
       device='cuda:0')
tensor([23.3492, 34.6029, 67.5473,  ..., 28.3453, 24.7068, 43.4087],
       device='cuda:0')
tensor([23.5119, 34.5351, 67.5567,  ..., 28.3330, 24.6583, 43.2891],
       device='cuda:0')
tensor([23.3925, 34.5835, 67.2903,  ..., 28.4343, 24.7068, 43.4626],
       device='cuda:0')
tensor([23.4300, 34.5289, 67.3157,  ..., 28.3361, 24.7067, 43.4820],
       device='cuda:0')
tensor([23.3340, 34.6085, 67.2487,  ..., 28.3315, 24.6761, 43.4621],
       device='cuda:0')
tensor([23.4502, 34.5984, 67.2822,  ..., 28.4285, 24.7070, 43.3872],
       device='cuda:0')
tensor([23.4115, 34.6331, 67.4580,  ..., 28.3361, 24.7655, 43.4114],
       device='cuda:0')
tensor([23.3178, 34.5520, 67.2211,  ..., 28.3614, 24.7184, 43.6287],
       device='cuda:0')
tensor([23.3612, 34.5404, 67.1771,  ..., 28.3285, 24.6666, 43.5539],
 

tensor([23.4561, 34.5837, 67.3328,  ..., 28.3716, 24.7132, 43.3648],
       device='cuda:0')
tensor([23.3126, 34.5459, 67.5058,  ..., 28.3377, 24.6027, 43.6120],
       device='cuda:0')
tensor([23.3556, 34.5437, 67.4349,  ..., 28.3498, 24.6632, 43.3484],
       device='cuda:0')
tensor([23.3647, 34.5711, 67.4977,  ..., 28.3324, 24.7125, 43.5646],
       device='cuda:0')
tensor([23.3789, 34.5903, 67.4968,  ..., 28.3450, 24.5565, 43.3904],
       device='cuda:0')
tensor([23.3890, 34.6002, 67.3328,  ..., 28.3406, 24.6291, 43.4824],
       device='cuda:0')
tensor([23.3859, 34.5525, 67.4654,  ..., 28.3473, 24.6278, 43.4824],
       device='cuda:0')
tensor([23.4802, 34.5198, 67.4215,  ..., 28.3514, 24.6294, 43.4554],
       device='cuda:0')
tensor([23.4312, 34.5192, 67.7230,  ..., 28.3354, 24.5486, 43.7063],
       device='cuda:0')
tensor([23.4369, 34.6252, 67.7474,  ..., 28.3750, 24.6333, 43.5159],
       device='cuda:0')
tensor([23.5438, 34.5236, 67.5900,  ..., 28.3349, 24.5526, 43.4814],
 

In [10]:
# # Memory pull
# import line_profiler
# prof = line_profiler.LineProfiler(
#     celltrip.memory.AdvancedMemoryBuffer.__getitem__)
# ret = prof.runcall(memory.__getitem__, np.random.choice(len(memory), 10_000, replace=False))
# memory.compute_advantages()
# prof.print_stats(output_unit=1)


In [None]:
# # Updates
# import line_profiler
# prof = line_profiler.LineProfiler(
#     # memory.fast_sample, policy.actor_critic.forward,
#     celltrip.policy.ResidualAttentionBlock.forward,
#     policy.calculate_losses, policy.update,
#     celltrip.memory.AdvancedMemoryBuffer.__getitem__)
# ret = prof.runcall(policy.update, memory, verbose=True)
# print('UPDATE: ' + ', '.join([f'{k}: {v:.3f}' for ret_dict in ret[1:] for k, v in ret_dict.items()]))
# prof.print_stats(output_unit=1)


Iteration 01 - Total (1.90447) + PPO (0.44890) + critic (1.46693) + entropy (-11.36146) + KL (3.53687) :: Moving Return Mean (-0.00725), Moving Return STD (1.01281), Return Mean (-0.53739), Return STD (1.62391), Moving Input Mean (0.00435), Moving Input STD (1.04395), Input Mean (0.32519), Input STD (6.75260), Explained Variance (-0.44245)
Iteration 05 - Total (1.42196) + PPO (0.48111) + critic (0.95228) + entropy (-11.43396) + KL (3.54111) :: Moving Return Mean (-0.06725), Moving Return STD (1.11182), Return Mean (-0.53801), Return STD (1.62348), Moving Input Mean (0.04051), Moving Input STD (1.20664), Input Mean (0.32304), Input STD (6.75808), Explained Variance (0.07293)
UPDATE: Total: 1.562, PPO: 0.447, critic: 1.126, entropy: -11.399, KL: 3.444, Moving Return Mean: -0.038, Moving Return STD: 1.064, Return Mean: -0.538, Return STD: 1.623, Moving Input Mean: 0.023, Moving Input STD: 1.134, Input Mean: 0.324, Input STD: 6.756, Explained Variance: -0.092
Timer unit: 1 s

Total time: 1

In [None]:
# for _ in range(int(config.max_timesteps / config.update_timesteps)):
#     # Forward
#     memory.mark_sampled()
#     memory.cleanup()
#     ret = celltrip.train.simulate_until_completion(
#         env, policy, memory,
#         max_memories=config.update_timesteps,
#         # max_timesteps=100,
#         reset_on_finish=True)
#     print('ROLLOUT: ' + f'iterations: {ret[0]: 5.0f}, ' + f'total: {ret[2]: 5.3f}, ' + ', '.join([f'{k}: {v: 5.3f}' for k, v in ret[3].items()]))
#     memory.compute_advantages()

#     # Update
#     # NOTE: Training often only improves when PopArt and actual distribution match
#     ret = policy.update(memory, verbose=False)
#     print('UPDATE: ' + ', '.join([f'{k}: {v: 5.3f}' for ret_dict in ret[1:] for k, v in ret_dict.items()]))


tensor([[-0.6789,  0.6088,  0.5714,  ..., -1.1509, -1.6819,  2.0672],
        [ 2.3578, -1.3243, -2.1103,  ...,  0.4661, -0.2684, -0.1553],
        [ 0.8144, -1.5807,  0.1408,  ...,  0.9140, -1.8973, -0.8437],
        ...,
        [-0.8506,  0.5550, -1.2938,  ..., -0.5367,  0.6879,  1.3929],
        [-1.1079,  0.1883,  0.9195,  ...,  0.2999, -2.9068, -1.6411],
        [ 0.7353, -1.9406, -0.6259,  ...,  0.6976,  1.1831,  1.1897]],
       device='cuda:0')
tensor([-1.7893, -1.9561, -1.1821,  ..., -1.1454, -1.7943, -1.1448],
       device='cuda:0')
tensor([[-0.9200, -0.3403, -1.2042,  ...,  0.5757,  3.0719, -1.0956],
        [ 0.5124, -1.1635,  0.1780,  ..., -1.9831, -0.6756, -2.2262],
        [ 1.2168, -0.7434, -0.7732,  ..., -0.6749,  0.1009,  0.3238],
        ...,
        [ 0.2259,  0.1847,  0.4070,  ...,  2.5144, -4.0476,  0.5622],
        [-1.3349,  1.0668,  0.1303,  ..., -0.9193, -1.2485,  0.6354],
        [-1.0493,  1.3721,  1.2776,  ...,  0.3216, -0.9740,  0.9439]],
       device='

KeyboardInterrupt: 