In [22]:
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
import torch

import celltrip


In [24]:
# Read data
fnames = ['../data/MERFISH/expression.h5ad', '../data/MERFISH/spatial.h5ad']
partition_cols = None
adatas = celltrip.utility.processing.read_adatas(*fnames, on_disk=False)
celltrip.utility.processing.test_adatas(*adatas, partition_cols=partition_cols)

# Construct dataloader
dataloader = celltrip.utility.processing.PreprocessFromAnnData(
    *adatas, partition_cols=partition_cols, num_nodes=200, pca_dim=128, seed=42)
modalities, adata_obs, adata_vars = dataloader.sample()



In [4]:
policy = celltrip.policy.PPO(6, [m.shape[1] for m in modalities], 3, device='cuda')
env = celltrip.environment.EnvironmentBase(*[torch.from_numpy(m).to('cuda') for m in modalities], dim=3, device='cuda')

# from collections import defaultdict
# memory = defaultdict(lambda: [])
# for _ in range(1000):
#     state = env.get_state(include_modalities=True).to('cuda')
#     state_split, action, action_log, state_val = policy.act_macro(state)
#     rewards, finished = env.step(action.to('cpu'))
#     memory['states'].append([s.detach() for s in state_split])
#     memory['actions'].append(action.detach())
#     memory['action_logs'].append(action_log.detach())
#     memory['state_vals'].append(state_val.detach())
#     memory['rewards'].append(rewards.detach())

# Append
# for k, v in memory.items():
#     if k == 'states': memory[k] = [torch.concat([s[i] for s in v], dim=0) for i in range(2)]
#     else: memory[k] = torch.concat(v, dim=0)

In [None]:
for _ in range(80):
    loss = policy.backward(**memory)[0]
    loss.backward()


In [5]:
for _ in range(1000):
    state = env.get_state(include_modalities=True)
    state_split, action, action_log, state_val = policy.act_macro(state)
    rewards, finished = env.step(action)

In [6]:
policy = celltrip.policy.PPO(6, [m.shape[1] for m in modalities], 3, device='cuda')
env = celltrip.environment.EnvironmentBase(*[torch.from_numpy(m).to('cpu') for m in modalities], dim=3, device='cpu')


In [7]:
for _ in range(1000):
    state = env.get_state(include_modalities=True).to('cuda')
    state_split, action, action_log, state_val = policy.act_macro(state)
    rewards, finished = env.step(action.to('cpu'))

In [118]:
import ray
import ray.util.collective as col

@ray.remote(num_gpus=1)
class Actor:
    def __init__(self, world_size, rank):
        # Environment
        self.env = celltrip.environment.EnvironmentBase(*[torch.from_numpy(m) for m in modalities], dim=3)

        # Rewards
        self.rewards_buffer = []
        
        # Group
        col.init_collective_group(world_size, rank, 'nccl')
        self.actions = torch.empty([modalities[0].shape[0], 3], device='cuda')

        # Main loop
        self.loop()

    def observe(self):
        obs = self.env.get_state(include_modalities=True)
        col.send(obs.to('cuda'), 0)

    def act(self):
        col.recv(self.actions, 0)
        rewards, finished = self.env.step(self.actions.to('cpu'))
        # col.send(rewards, 0)
        # col.send(finished, 0)

    def loop(self):
        while True:
            self.observe()
            self.act()


@ray.remote(num_gpus=1)
class Learner:
    def __init__(self, world_size):
        self.world_size = world_size
        self.policy = celltrip.policy.PPO(6, [m.shape[1] for m in modalities], 3, device='cuda')

        # Group
        col.init_collective_group(world_size, 0, 'nccl')
        self.obs = torch.empty([modalities[0].shape[0], 6+sum([m.shape[1] for m in modalities])], device='cuda')

    def act(self, rank):
        col.recv(self.obs, rank)
        actions = self.policy.act_macro(self.obs)[0]
        col.send(actions, rank)

    def loop(self):
        for _ in range(1000):
            for i in range(self.world_size-1):
                self.act(i+1)


ray.shutdown()
ray.init(
    address='ray://100.64.246.20:10001',
    runtime_env={
        'env_vars': {
            # NOTE: Important, NCCL will timeout if network device is non-standard
            'NCCL_SOCKET_IFNAME': 'tailscale',
            # 'NCCL_DEBUG': 'WARN',
            'RAY_DEDUP_LOGS': '0',
        }
    }
)

2025-03-23 23:06:34,339	INFO client_builder.py:244 -- Passing the following kwargs to ray.init() on the server: log_to_driver
SIGTERM handler is not set because current thread is not the main thread.


0,1
Python version:,3.10.16
Ray version:,2.43.0
Dashboard:,http://100.64.246.20:8265


In [28]:
learner = Learner.remote(2)
actors = [Actor.remote(2, i+1) for i in range(1)]
ray.get(learner.loop.remote())

In [115]:
ray.available_resources()

{'node:100.64.246.20': 1.0,
 'accelerator_type:G': 1.0,
 'node:__internal_head__': 1.0,
 'node:100.85.187.118': 1.0,
 'CPU': 46.0,
 'memory': 155214666958.0,
 'accelerator_type:RTX': 1.0,
 'VRAM': 57291112448.0,
 'GPU': 2.0,
 'object_store_memory': 67043780472.0}

In [73]:
torch.ones([2000, 500]).nbytes / 2**20

3.814697265625

In [108]:
import numpy as np
import time

In [117]:
@ray.remote(num_gpus=1)
def large(world_size, rank):
    pol = celltrip.policy.PPO(6, [256, 3], 3).to('cuda')
    t1 = time.perf_counter()
    col.init_collective_group(world_size, rank, 'nccl')
    col.barrier()
    t2 = time.perf_counter()
    for k, w in pol.state_dict().items():
        col.allreduce(w)
        w /= world_size
    t3 = time.perf_counter()
    col.destroy_collective_group()
    t4 = time.perf_counter()
    print(f'{t2-t1}, {t3-t2}, {t4-t3}')

ray.get([large.remote(2, 0), large.remote(2, 1)])

[None, None]

In [None]:
@ray.remote(
    # memory=32*2**30,
    resources={'node:100.64.246.20': 1.0},
    max_calls=1,
)
def large():
    pol = celltrip.policy.PPO(6, [256 ,3], 3)
    return pol.state_dict()

ray.get(large.remote())

OrderedDict([('actor.action_var', tensor([0.3600, 0.3600, 0.3600])),
             ('actor.scale_tril',
              tensor([[[0.6000, 0.0000, 0.0000],
                       [0.0000, 0.6000, 0.0000],
                       [0.0000, -0.0000, 0.6000]]])),
             ('actor.layer_norm.self embedding.weight',
              tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])),
             ('actor.layer_norm.self embedding.bias',
              tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                      0., 0., 0., 0., 0., 0., 0., 0.

In [8]:
assert False

AssertionError: 

In [None]:
memory['states'][0].shape

torch.Size([2000, 136])