In [1]:
# TODO
# Add key argument in forward which is required for memories
# Add static prefix for memory based on key
# Reduce redundancy in state storage (now data is duplicated 'nodes' # of times), maybe add indexing function to memorybuffer
# If memory still used too much, implement file-storage

# Randomize positions within -1 to 1, no matter the environment size
# Make distance reward env-size and dataset-agnostic (i.e. spawn nodes in range 0-1 (or at origin), normalize dist per dataset (maybe by average dist))
# Add FCL between raw features and the features appended to vector
# Check that rewards are normalized after (?) advantage

# Fix off-center positioning in large environments (kinda solved with post-centering?)
# Revise distance reward
# Try using running average early stopping
# Save checkpoint models

# Try full MMD-MA data
# Try real data
# Add parallel envs of different sizes, with different data to help generality

In [2]:
%load_ext autoreload
%autoreload 2
%env WANDB_NOTEBOOK_NAME train.ipynb
%env WANDB_SILENT true

env: WANDB_NOTEBOOK_NAME=train.ipynb
env: WANDB_SILENT=true


In [3]:
from collections import defaultdict
import itertools
import os

import inept
import numpy as np
import pandas as pd
import torch
import wandb

# Set params
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DATA_FOLDER = os.path.join(os.path.abspath(''), '../data')
MODEL_FOLDER = os.path.join(os.path.abspath(''), 'temp/trained_models')

# Script arguments
# import sys
# arg1 = int(sys.argv[1])

In [4]:
# Original paper (pg 24)
# https://arxiv.org/pdf/1909.07528.pdf

# Original blog
# https://openai.com/research/emergent-tool-use

# Gym
# https://gymnasium.farama.org/

# Slides
# https://glouppe.github.io/info8004-advanced-machine-learning/pdf/pleroy-hide-and-seek.pdf

# PPO implementation
# https://github.com/nikhilbarhate99/PPO-PyTorch/blob/master/PPO.py#L38

# Residual SA
# https://github.com/openai/multi-agent-emergence-environments/blob/bafaf1e11e6398624116761f91ae7c93b136f395/ma_policy/layers.py#L89

In [5]:
# Reproducibility
seed = 42
torch.manual_seed(seed)
if DEVICE == 'cuda': torch.cuda.manual_seed(seed)
np.random.seed(seed)

note_kwargs = {'seed': seed}

### Create Environment

In [6]:
# Dataset loading
dataset_name = 'BrainChromatin'
if dataset_name == 'BrainChromatin':
    M1 = pd.read_csv(os.path.join(DATA_FOLDER, 'brainchromatin/multiome_rna_counts.tsv'), delimiter='\t', nrows=1_000).transpose()  # TODO: Raise number of features
    M2 = pd.read_csv(os.path.join(DATA_FOLDER, 'brainchromatin/multiome_atac_gene_activities.tsv'), delimiter='\t', nrows=1_000).transpose()  # TODO: Raise number of features
    M2 = M2.transpose()[M1.index].transpose()
    meta = pd.read_csv(os.path.join(DATA_FOLDER, 'brainchromatin/multiome_cell_metadata.txt'), delimiter='\t')
    meta_names = pd.read_csv(os.path.join(DATA_FOLDER, 'brainchromatin/multiome_cluster_names.txt'), delimiter='\t')
    meta_names = meta_names[meta_names['Assay'] == 'Multiome ATAC']
    meta = pd.merge(meta, meta_names, left_on='ATAC_cluster', right_on='Cluster.ID', how='left')
    meta.index = meta['Cell.ID']
    T1 = T2 = np.array(meta.transpose()[M1.index].transpose()['Cluster.Name'])
    F1, F2 = M1.columns, M2.columns
    M1, M2 = M1.to_numpy(), M2.to_numpy()

    del meta, meta_names

elif dataset_name == 'scGEM':
    M1 = pd.read_csv(os.path.join(DATA_FOLDER, 'UnionCom/scGEM/GeneExpression.txt'), delimiter=' ', header=None).to_numpy()
    M2 = pd.read_csv(os.path.join(DATA_FOLDER, 'UnionCom/scGEM/DNAmethylation.txt'), delimiter=' ', header=None).to_numpy()
    T1 = pd.read_csv(os.path.join(DATA_FOLDER, 'UnionCom/scGEM/type1.txt'), delimiter=' ', header=None).to_numpy()
    T2 = pd.read_csv(os.path.join(DATA_FOLDER, 'UnionCom/scGEM/type2.txt'), delimiter=' ', header=None).to_numpy()
    F1 = np.loadtxt(os.path.join(DATA_FOLDER, 'UnionCom/scGEM/gex_names.txt'), dtype='str')
    F2 = np.loadtxt(os.path.join(DATA_FOLDER, 'UnionCom/scGEM/dm_names.txt'), dtype='str')

# MMD-MA data
elif dataset_name == 'MMD-MA':
    M1 = pd.read_csv(os.path.join(DATA_FOLDER, 'UnionCom/MMD/s1_mapped1.txt'), delimiter='\t', header=None).to_numpy()
    M2 = pd.read_csv(os.path.join(DATA_FOLDER, 'UnionCom/MMD/s1_mapped2.txt'), delimiter='\t', header=None).to_numpy()
    T1 = pd.read_csv(os.path.join(DATA_FOLDER, 'UnionCom/MMD/s1_type1.txt'), delimiter='\t', header=None).to_numpy()
    T2 = pd.read_csv(os.path.join(DATA_FOLDER, 'UnionCom/MMD/s1_type2.txt'), delimiter='\t', header=None).to_numpy()

# Random data
elif dataset_name == 'Random':
    num_nodes = 100
    M1 = torch.rand((num_nodes, 8), device=DEVICE)
    M2 = torch.rand((num_nodes, 16), device=DEVICE)

else: assert False, 'No matching dataset found.'

# Parameters
num_nodes = 50

# Modify data
M1, M2 = inept.utilities.normalize(M1, M2)  # Normalize
M1, M2 = inept.utilities.pca_features(M1, M2, num_features=(16, 16))  # PCA features
M1, M2, T1, T2 = inept.utilities.subsample_nodes(M1, M2, T1, T2, num_nodes=num_nodes)  # Subsample nodes
# M1, M2 = inept.utilities.subsample_features(M1, M2, num_features=(16, 16))  # Subsample features

# Cast types
M1 = torch.tensor(M1, dtype=torch.float32, device=DEVICE)
M2 = torch.tensor(M2, dtype=torch.float32, device=DEVICE)

In [7]:
# Record data kwargs
modalities = (M1, M2)
data_kwargs = {
    'dataset': dataset_name,
    'num_nodes': num_nodes,
}

# Environment
# x, y, vx, vy
num_dims = 2
env_kwargs = {
    'dim': num_dims,
    'pos_bound': 3,
    'vel_bound': 1,
    'delta': .1,
    'reward_distance': 10,
    # 'reward_origin': 1,
    'penalty_bound': 1,
    'penalty_velocity': 1,
    'penalty_action': 1,
    'reward_distance_type': 'euclidean',
}
env = inept.environments.trajectory(*modalities, **env_kwargs, device=DEVICE)

### Train Policy

In [8]:
# Tracking parameters
# Use `watch -d -n 0.5 nvidia-smi` to watch CUDA memory usage
# Use `top` to watch system memory usage
use_wandb = True

# Policy parameters
input_dims = 2*num_dims+sum([m.shape[1] for m in modalities])
batch_split_factor = 1  # Set to high if large number of features
update_minibatch = int( 4e4 * (10 / num_nodes) / batch_split_factor )
update_max_batch = batch_split_factor * update_minibatch  # Only run one minibatch, empirically the benefit is minimal compared to time loss
policy_kwargs = {
    'num_features_per_node': input_dims,
    'output_dim': num_dims,
    'action_std_init': .6,
    'action_std_decay': .05,
    'action_std_min': .1,
    'actor_lr': 3e-4,
    'critic_lr': 1e-3,
    'lr_gamma': 1,
    'update_minibatch': update_minibatch,  # Based on no minibatches needed with 10 nodes at 4k update timesteps
    'update_max_batch': update_max_batch,  # Try making larger, e.g. 20x minibatches
    'device': DEVICE,
}
policy = inept.models.PPO(**policy_kwargs)

# Training parameters
max_ep_timesteps = 3e2  # 2e2
max_timesteps = 1e6
update_timesteps = 4e3  # 20 * max_ep_timesteps
train_kwargs = {
    'max_ep_timesteps': max_ep_timesteps,
    'max_timesteps': max_timesteps,
    'update_timesteps': update_timesteps,
}

# Early stopping parameters
es_kwargs = {
    'buffer': 3 * int(update_timesteps / max_ep_timesteps),
    'delta': .01,
}
early_stopping = inept.utilities.EarlyStopping(**es_kwargs)

# Initialize wandb
if use_wandb: wandb.init(
    project='INEPT',
    config={
        **{'note/'+k:v for k, v in note_kwargs.items()},
        **{'data/'+k:v for k, v in data_kwargs.items()},
        **{'env/'+k:v for k, v in env_kwargs.items()},
        **{'policy/'+k:v for k, v in policy_kwargs.items()},
        **{'train/'+k:v for k, v in train_kwargs.items()},
        **{'es/'+k:v for k, v in es_kwargs.items()},
    },
)

# Initialize logging vars
torch.cuda.reset_peak_memory_stats()
timer = inept.utilities.time_logger(discard_first_sample=True)
timestep = 0; episode = 1

# CLI
print('Beginning training')
print(f'Subsampling {update_max_batch} states with minibatches of size {update_minibatch} from {int(update_timesteps * num_nodes)} total.')

# Simulation loop
while timestep < max_timesteps:
    # Reset environment
    env.reset()
    # env.reward_scales['reward_origin'] = episode / 1e3
    timer.log('Reset Environment')

    # Start episode
    ep_timestep = 0; ep_reward = 0; ep_itemized_reward = defaultdict(lambda: 0)
    while ep_timestep < max_ep_timesteps:
        with torch.no_grad():
            # Get current state
            state = env.get_state(include_modalities=True)

            # Get self features for each node
            self_entity = state

            # Get node features for each state
            idx = torch.zeros((num_nodes, num_nodes), dtype=torch.bool)
            for i, j in itertools.product(*[range(x) for x in idx.shape]):
                idx[i, j] = i!=j
            node_entities = state.unsqueeze(0).expand(num_nodes, *state.shape)
            node_entities = node_entities[idx].reshape(num_nodes, num_nodes-1, input_dims)
            timer.log('Environment Setup')

            # Get actions from policy
            actions = policy.act(self_entity, node_entities).detach()
            timer.log('Calculate Actions')

            # Step environment and get reward
            rewards, finished, itemized_rewards = env.step(actions, return_rewards=True)
            timer.log('Step Environment')

            # Record rewards
            for key in range(num_nodes):
                policy.memory.rewards.append(rewards[key].item())  # Could just add lists
                policy.memory.is_terminals.append(finished)
            ep_reward = ep_reward + rewards.cpu().mean()
            for k, v in itemized_rewards.items():
                ep_itemized_reward[k] += v.cpu().mean()
            timer.log('Record Rewards')

        # Iterate
        timestep += 1
        ep_timestep += 1

        # Update model
        if timestep % update_timesteps == 0:
            print(f'Updating model with average reward {np.mean(policy.memory.rewards)} on episode {episode} and timestep {timestep}', end='')
            policy.update()
            print(f' ({torch.cuda.max_memory_allocated() / 1024**3:.2f} GB CUDA)')
            torch.cuda.reset_peak_memory_stats()
            timer.log('Update Policy')

        # Escape if finished
        if finished: break

    # Upload stats
    ep_reward = (ep_reward / ep_timestep).item()
    if use_wandb:
        wandb.log({
            **{
            'episode': episode,
            'update': int(timestep / update_timesteps),
            'end_timestep': timestep,
            'average_reward': ep_reward,
            'action_std': policy.action_std,
            },
            **{'rewards/'+k: (v / ep_timestep).item() for k, v in ep_itemized_reward.items()},
        })
    timer.log('Record Stats')

    # Decay model std
    if early_stopping(ep_reward):
        # End if already at minimum
        if policy.action_std <= policy.action_std_min:
            print(f'Ending early on episode {episode} and timestep {timestep}')
            break

        # Decay and reset early stop
        policy.decay_action_std()
        early_stopping.reset()

        # CLI
        print(f'Decaying std to {policy.action_std} on episode {episode} and timestep {timestep}')
    timer.log('Early Stopping')

    # Iterate
    episode += 1

# CLI Timer
print()
timer.aggregate('sum')

# Save model
wgt_file = os.path.join(MODEL_FOLDER, 'policy.wgt')
torch.save(policy.state_dict(), wgt_file)  # Save just weights
if use_wandb: wandb.save(wgt_file)
mdl_file = os.path.join(MODEL_FOLDER, 'policy.mdl')
torch.save(policy, mdl_file)  # Save whole model
if use_wandb: wandb.save(mdl_file)

# Finish wandb
if use_wandb: wandb.finish()

Beginning training
Subsampling 8000 states with minibatches of size 8000 from 200000 total.


Updating model with average reward -0.33742163390720614 on episode 14 and timestep 4000

 (3.79 GB CUDA)


Updating model with average reward -0.3549472073540208 on episode 27 and timestep 8000

 (3.79 GB CUDA)


Updating model with average reward -0.3217626810658292 on episode 40 and timestep 12000

 (3.79 GB CUDA)


Updating model with average reward -0.18903329477637423 on episode 54 and timestep 16000

 (3.79 GB CUDA)


Updating model with average reward -0.1753807219015801 on episode 67 and timestep 20000

 (3.79 GB CUDA)


Updating model with average reward -0.15970467612843378 on episode 80 and timestep 24000

 (3.79 GB CUDA)


Updating model with average reward -0.05058263150967599 on episode 94 and timestep 28000

 (3.79 GB CUDA)


Updating model with average reward -0.04379719832922332 on episode 107 and timestep 32000

 (3.79 GB CUDA)


Updating model with average reward -0.09646117235771497 on episode 120 and timestep 36000

 (3.79 GB CUDA)


Updating model with average reward 0.02062136474990868 on episode 134 and timestep 40000

 (3.79 GB CUDA)


Decaying std to 0.5499999999999999 on episode 144 and timestep 43200


Updating model with average reward -0.0006054904202988837 on episode 147 and timestep 44000

 (3.79 GB CUDA)


Updating model with average reward 0.009089347269616847 on episode 160 and timestep 48000

 (3.79 GB CUDA)


Updating model with average reward 0.10277069826098392 on episode 174 and timestep 52000

 (3.79 GB CUDA)


Updating model with average reward 0.09137276851738803 on episode 187 and timestep 56000

 (3.79 GB CUDA)


Updating model with average reward 0.07134784532654535 on episode 200 and timestep 60000

 (3.79 GB CUDA)


Updating model with average reward 0.11170690088717733 on episode 214 and timestep 64000

 (3.79 GB CUDA)


Updating model with average reward 0.09893521088420006 on episode 227 and timestep 68000

 (3.79 GB CUDA)


Updating model with average reward 0.10597655123192817 on episode 240 and timestep 72000

 (3.79 GB CUDA)


Updating model with average reward 0.18363375334954762 on episode 254 and timestep 76000

 (3.79 GB CUDA)


Updating model with average reward 0.16789298196339747 on episode 267 and timestep 80000

 (3.79 GB CUDA)


Updating model with average reward 0.11941682263829163 on episode 280 and timestep 84000

 (3.79 GB CUDA)


Updating model with average reward 0.1888702464319137 on episode 294 and timestep 88000

 (3.79 GB CUDA)


Decaying std to 0.49999999999999994 on episode 299 and timestep 89700


Updating model with average reward 0.17605088317867484 on episode 307 and timestep 92000

 (3.79 GB CUDA)


Updating model with average reward 0.21031232490027557 on episode 320 and timestep 96000

 (3.79 GB CUDA)


Updating model with average reward 0.26831197150129416 on episode 334 and timestep 100000

 (3.79 GB CUDA)


Updating model with average reward 0.18731416712173699 on episode 347 and timestep 104000

 (3.79 GB CUDA)


Updating model with average reward 0.23808216849937075 on episode 360 and timestep 108000

 (3.79 GB CUDA)


Updating model with average reward 0.2984386812020786 on episode 374 and timestep 112000

 (3.79 GB CUDA)


Updating model with average reward 0.22191581168130214 on episode 387 and timestep 116000

 (3.79 GB CUDA)


Updating model with average reward 0.24465242811225005 on episode 400 and timestep 120000

 (3.79 GB CUDA)


Updating model with average reward 0.2910533218130132 on episode 414 and timestep 124000

 (3.79 GB CUDA)


Decaying std to 0.44999999999999996 on episode 417 and timestep 125100


Updating model with average reward 0.24788639416606573 on episode 427 and timestep 128000

 (3.79 GB CUDA)


Updating model with average reward 0.30766983818562893 on episode 440 and timestep 132000

 (3.79 GB CUDA)


Updating model with average reward 0.3503954257119264 on episode 454 and timestep 136000

 (3.79 GB CUDA)


Decaying std to 0.39999999999999997 on episode 458 and timestep 137400


Updating model with average reward 0.3080150988750678 on episode 467 and timestep 140000

 (3.79 GB CUDA)


Updating model with average reward 0.32998147204501493 on episode 480 and timestep 144000

 (3.79 GB CUDA)


Updating model with average reward 0.4126788856174017 on episode 494 and timestep 148000

 (3.79 GB CUDA)


Updating model with average reward 0.32510941633527574 on episode 507 and timestep 152000

 (3.79 GB CUDA)


Updating model with average reward 0.3541722079621762 on episode 520 and timestep 156000

 (3.79 GB CUDA)


Decaying std to 0.35 on episode 529 and timestep 158700


Updating model with average reward 0.4144421791120357 on episode 534 and timestep 160000

 (3.79 GB CUDA)


Updating model with average reward 0.380069047601865 on episode 547 and timestep 164000

 (3.79 GB CUDA)


Updating model with average reward 0.3958445609928155 on episode 560 and timestep 168000

 (3.79 GB CUDA)


Updating model with average reward 0.46837464590896255 on episode 574 and timestep 172000

 (3.79 GB CUDA)


Updating model with average reward 0.3750162097250587 on episode 587 and timestep 176000

 (3.79 GB CUDA)


Updating model with average reward 0.4100927825526509 on episode 600 and timestep 180000

 (3.79 GB CUDA)


Updating model with average reward 0.4765679738095391 on episode 614 and timestep 184000

 (3.79 GB CUDA)


Updating model with average reward 0.40451776972861114 on episode 627 and timestep 188000

 (3.79 GB CUDA)


Decaying std to 0.3 on episode 634 and timestep 190200


Updating model with average reward 0.4631964902380924 on episode 640 and timestep 192000

 (3.79 GB CUDA)


Updating model with average reward 0.4824164319141599 on episode 654 and timestep 196000

 (3.79 GB CUDA)


Updating model with average reward 0.4589753293728032 on episode 667 and timestep 200000

 (3.79 GB CUDA)


Updating model with average reward 0.48384943080373166 on episode 680 and timestep 204000

 (3.79 GB CUDA)


Updating model with average reward 0.5198656593912734 on episode 694 and timestep 208000

 (3.79 GB CUDA)


Decaying std to 0.25 on episode 694 and timestep 208200


Updating model with average reward 0.46440515182117176 on episode 707 and timestep 212000

 (3.79 GB CUDA)


Updating model with average reward 0.5034648174642693 on episode 720 and timestep 216000

 (3.79 GB CUDA)


Updating model with average reward 0.5928505786625939 on episode 734 and timestep 220000

 (3.79 GB CUDA)


Updating model with average reward 0.5089905271143306 on episode 747 and timestep 224000

 (3.79 GB CUDA)


Decaying std to 0.2 on episode 755 and timestep 226500


Updating model with average reward 0.538795898058479 on episode 760 and timestep 228000

 (3.79 GB CUDA)


Updating model with average reward 0.553259618443727 on episode 774 and timestep 232000

 (3.79 GB CUDA)


Updating model with average reward 0.4877369771728634 on episode 787 and timestep 236000

 (3.79 GB CUDA)


Decaying std to 0.15000000000000002 on episode 798 and timestep 239400


Updating model with average reward 0.4692792590717132 on episode 800 and timestep 240000

 (3.79 GB CUDA)


Updating model with average reward 0.5963940169394474 on episode 814 and timestep 244000

 (3.79 GB CUDA)


Updating model with average reward 0.5179716011224118 on episode 827 and timestep 248000

 (3.79 GB CUDA)


Updating model with average reward 0.5111584251204919 on episode 840 and timestep 252000

 (3.79 GB CUDA)


Decaying std to 0.10000000000000002 on episode 853 and timestep 255900


Updating model with average reward 0.5629214499404432 on episode 854 and timestep 256000

 (3.79 GB CUDA)


Updating model with average reward 0.496296186981346 on episode 867 and timestep 260000

 (3.79 GB CUDA)


Updating model with average reward 0.449132731425655 on episode 880 and timestep 264000

 (3.79 GB CUDA)


Updating model with average reward 0.5811640221439699 on episode 894 and timestep 268000

 (3.79 GB CUDA)


Decaying std to 0.1 on episode 901 and timestep 270300


Updating model with average reward 0.452586037782519 on episode 907 and timestep 272000

 (3.79 GB CUDA)


Updating model with average reward 0.486601121726086 on episode 920 and timestep 276000

 (3.79 GB CUDA)


Updating model with average reward 0.5074503203861603 on episode 934 and timestep 280000

 (3.79 GB CUDA)


Updating model with average reward 0.48698903586917325 on episode 947 and timestep 284000

 (3.79 GB CUDA)


Ending early on episode 948 and timestep 284400

Reset Environment: 0.113290155989489
Environment Setup: 3156.4125311535554
Calculate Actions: 1620.861532955194
Step Environment: 202.36833615327942


Record Rewards: 175.31642860978718
Record Stats: 0.2782003049742343
Early Stopping: 0.017291964014475525
Update Policy: 2209.811145688008
Total: 7365.178756984802
