In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import os
import time

# Set env vars
os.environ['RAY_DEDUP_LOGS'] = '0'

import numpy as np
import ray
import torch

# Enable text output in notebooks
import tqdm.auto
import tqdm.notebook
tqdm.notebook.tqdm = tqdm.auto.tqdm

import celltrip
import data

# Detect Cython
CYTHON_ACTIVE = os.path.splitext(celltrip.utility.general.__file__)[1] in ('.c', '.so')
print(f'Cython is{" not" if not CYTHON_ACTIVE else ""} active')

# Set params
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
BASE_FOLDER = os.path.abspath('')
DATA_FOLDER = os.path.join(BASE_FOLDER, '../data/')
MODEL_FOLDER = os.path.join(BASE_FOLDER, 'models/')


Cython is active


- High priority
  - Optimize cancels to only cancel non-running
  - Implement stages
  - Add partitioning
  - Make ray init command for bash and add to README
- Medium Priority
  - Add seeding
  - Add state manager to env and then parallelize in analysis, maybe make `analyze` function
  - Add parallelism on max_batch and update. With update, encase whole epoch as ray function so splitting occurs within ray function, using ray.remote inline API to allow for non-ray usage. Then, adjustable policy weight sync (i.e. 1 epoch, 10 epochs)
- Low Priority
  - Allow memory to pre-process keys and persistent storage
  - Add hook for wandb, ex.
  - Move preprocessing to manager
  - Figure out why sometimes just throws CUDA not available errors
  - Better split_state reproducibility

In [3]:
modalities, types, features = data.load_data('MMD-MA', DATA_FOLDER)
ppc = celltrip.utility.processing.Preprocessing(pca_dim=128, device=DEVICE)
processed_modalities, features = ppc.fit_transform(modalities, features)
# modalities = ppc.cast(processed_modalities)
modalities = [m.astype(np.float32) for m in processed_modalities]
# modalities = [np.concatenate([m for _ in range(10000)], axis=0) for m in modalities]
# modalities = [m[:100] for m in modalities]


In [None]:
# Behavioral functions
dim = 3
policy_init = lambda modalities: celltrip.policy.PPO(
    positional_dim=2*dim,
    modal_dims=[m.shape[1] for m in modalities],
    output_dim=dim,
    # BACKWARDS
    # epochs=5,
    # memory_prune=0,
    update_load_level='batch',
    update_cast_level='minibatch',
    update_batch=1e4,
    update_minibatch=3e3,
    # SAMPLING
    # max_batch=100,
    max_nodes=100,
    # DEVICE
    device='cpu')
# policy = policy_init(modalities)
# policy_init = lambda _: policy
env_init = lambda policy, modalities: celltrip.environment.EnvironmentBase(
    *modalities,
    dim=dim,
    # max_timesteps=1e2,
    penalty_bound=1,
    device=policy.device)
memory_init = lambda policy: celltrip.memory.AdvancedMemoryBuffer(
    sum(policy.modal_dims),
    split_args=policy.split_args)

# Initialize ray and distributed
ray.shutdown()
ray.init(
    resources={'VRAM': torch.cuda.get_device_properties(0).total_memory},
    dashboard_host='0.0.0.0')
dm = celltrip.train.DistributedManager(
    modalities,
    policy_init=policy_init,
    env_init=env_init,
    memory_init=memory_init)


2025-02-26 20:33:08,988	INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://192.168.0.119:8265 [39m[22m


[36m(wrapper pid=2361463)[0m Timestep 100 - Reward -0.040
[36m(wrapper pid=2361463)[0m Timestep 200 - Reward -0.233
[36m(wrapper pid=2361463)[0m Timestep 300 - Reward -0.173
[36m(wrapper pid=2361463)[0m Timestep 400 - Reward -0.220
[36m(wrapper pid=2361463)[0m Timestep 500 - Reward -0.180
[36m(wrapper pid=2361463)[0m Timestep 600 - Reward -0.200
[36m(wrapper pid=2361463)[0m Timestep 700 - Reward -0.260
[36m(wrapper pid=2361463)[0m Timestep 800 - Reward -0.193
[36m(wrapper pid=2361463)[0m Timestep 900 - Reward -0.180
[36m(wrapper pid=2361463)[0m Timestep 1000 - Reward -0.253
[36m(wrapper pid=2361463)[0m Simulation finished in 1000 steps with mean reward -0.200
[36m(wrapper pid=2361465)[0m Timestep 100 - Reward -0.047
[36m(wrapper pid=2361464)[0m Timestep 100 - Reward -0.053
[36m(wrapper pid=2361470)[0m Timestep 100 - Reward -0.053
[36m(wrapper pid=2361466)[0m Timestep 100 - Reward -0.033
[36m(wrapper pid=2361465)[0m Timestep 200 - Reward -0.307
[36m(wrap

In [5]:
from datetime import datetime
print(datetime.now())


2025-02-26 20:33:09.830911


In [None]:
# Train loop iter
max_rollout_futures = 20
num_updates = 0; calibrated = False
while True:
    # Retrieve active futures
    futures = dm.get_futures()
    num_futures = len(dm.get_all_futures())

    # CLI
    # print('; '.join([f'{k} ({len(v)})' for k, v in futures.items()]))
    # print(ray.available_resources())

    ## Check for futures to add
    # Check memory and apply update if needed 
    if len(futures['update']) == 0 and dm.get_memory_len() >= int(1e6):
        # assert False
        print(f'Queueing policy update {num_updates+1}')
        dm.cancel()  # Cancel all non-running (TODO)
        dm.update()

    # Add rollouts if no update future and below max queued futures
    elif len(futures['update']) == 0 and num_futures < max_rollout_futures:
        num_new_rollouts = max_rollout_futures - num_futures
        print(f'Queueing {num_new_rollouts} rollouts')
        dm.rollout(num_new_rollouts, dummy=False)

    ## Check for completed futures
    # Completed rollouts
    if len(ray.wait(futures['rollout'], timeout=0)[0]) > 0:
        # Calibrate if needed
        all_variants_run = True  # TODO: Set to true if all partitions have been run
        if dm.resources['rollout']['core']['memory'] == 0 and all_variants_run:
            dm.calibrate()
            print(
                f'Calibrated rollout'
                f' memory ({dm.resources["rollout"]["core"]["memory"] / 2**30:.2f} GiB)'
                f' and VRAM ({dm.resources["rollout"]["custom"]["VRAM"] / 2**30:.2f} GiB)')
            dm.cancel(); time.sleep(1)  # Cancel all non-running (TODO)
            dm.policy_manager.release_locks.remote()
        # Clean if calibrated
        if dm.resources['rollout']['core']['memory'] != 0: dm.clean('rollout')

    # Completed updates
    if len(ray.wait(futures['update'], timeout=0)[0]) > 0:
        num_updates += 1
        # Calibrate if needed
        if dm.resources['update']['core']['memory'] == 0:
            dm.calibrate()
            print(
                f'Calibrated update'
                f' memory ({dm.resources["update"]["core"]["memory"] / 2**30:.3f} GiB)'
                f' and VRAM ({dm.resources["update"]["custom"]["VRAM"] / 2**30:.3f} GiB)')
        dm.clean('update')

    # Wait for a new completion
    num_futures = len(dm.get_all_futures())
    if num_futures > 0:
        num_completed_futures = len(dm.wait(num_returns=num_futures, timeout=0))
        if num_completed_futures != num_futures: dm.wait(num_returns=num_completed_futures+1)

    # Escape
    if num_updates >= 50: break


Queueing 20 rollouts
Calibrated rollout memory (0.39 GiB) and VRAM (0.47 GiB)
Queueing 20 rollouts


In [None]:
from datetime import datetime
print(datetime.now())


In [None]:
# # Cancel
# # dm.cancel()
# # dm.clean()
# # dm.rollout(dummy=True)
# # dm.wait()

# # Clear locks
# dm.policy_manager.release_locks.remote()

# # Get policy
# device = DEVICE
# policy = policy_init(modalities).to(device)
# celltrip.training.set_policy_state(policy, ray.get(dm.policy_manager.get_policy_state.remote()))

# # Get memory
# memory = memory_init(policy)
# memory.append_memory(
#     *ray.get(dm.policy_manager.get_memory_storage.remote()))


In [None]:
# # Get state of job from ObjectRef
# import ray.util.state
# object_id = dm.futures['simulation'][0].hex()
# object_state = ray.util.state.get_objects(object_id)[0]
# object_state.task_status
