In [1]:
## Checks
# Check that rewards are normalized after (?) advantage

## Improvements
# Fix off-center positioning in large environments
# Revise distance reward
# Try using running average early stopping
# Add parallel envs of different sizes, with different data to help generality

## QOL
# Save checkpoint models

## Runs
# Try full real data

In [2]:
%load_ext autoreload
%autoreload 2
%env WANDB_NOTEBOOK_NAME train.ipynb
%env WANDB_SILENT true

env: WANDB_NOTEBOOK_NAME=train.ipynb
env: WANDB_SILENT=true


In [3]:
from collections import defaultdict
import os

import inept
import numpy as np
import pandas as pd
import torch
import wandb

# Set params
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DATA_FOLDER = os.path.join(os.path.abspath(''), '../data')
MODEL_FOLDER = os.path.join(os.path.abspath(''), 'temp/trained_models')

# Script arguments
# import sys
# arg1 = int(sys.argv[1])

In [4]:
# Original paper (pg 24)
# https://arxiv.org/pdf/1909.07528.pdf

# Original blog
# https://openai.com/research/emergent-tool-use

# Gym
# https://gymnasium.farama.org/

# Slides
# https://glouppe.github.io/info8004-advanced-machine-learning/pdf/pleroy-hide-and-seek.pdf

# PPO implementation
# https://github.com/nikhilbarhate99/PPO-PyTorch/blob/master/PPO.py#L38

# Residual SA
# https://github.com/openai/multi-agent-emergence-environments/blob/bafaf1e11e6398624116761f91ae7c93b136f395/ma_policy/layers.py#L89

In [5]:
# Reproducibility
seed = 42
torch.manual_seed(seed)
if DEVICE == 'cuda': torch.cuda.manual_seed(seed)
np.random.seed(seed)

note_kwargs = {'seed': seed}

### Create Environment

In [6]:
# Dataset loading
dataset_name = 'MMD-MA'
if dataset_name == 'BrainChromatin':
    M1 = pd.read_csv(os.path.join(DATA_FOLDER, 'brainchromatin/multiome_rna_counts.tsv'), delimiter='\t', nrows=1_000).transpose()  # TODO: Raise number of features
    M2 = pd.read_csv(os.path.join(DATA_FOLDER, 'brainchromatin/multiome_atac_gene_activities.tsv'), delimiter='\t', nrows=1_000).transpose()  # TODO: Raise number of features
    M2 = M2.transpose()[M1.index].transpose()
    meta = pd.read_csv(os.path.join(DATA_FOLDER, 'brainchromatin/multiome_cell_metadata.txt'), delimiter='\t')
    meta_names = pd.read_csv(os.path.join(DATA_FOLDER, 'brainchromatin/multiome_cluster_names.txt'), delimiter='\t')
    meta_names = meta_names[meta_names['Assay'] == 'Multiome ATAC']
    meta = pd.merge(meta, meta_names, left_on='ATAC_cluster', right_on='Cluster.ID', how='left')
    meta.index = meta['Cell.ID']
    T1 = T2 = np.array(meta.transpose()[M1.index].transpose()['Cluster.Name'])
    F1, F2 = M1.columns, M2.columns
    M1, M2 = M1.to_numpy(), M2.to_numpy()

    del meta, meta_names

elif dataset_name == 'scGEM':
    M1 = pd.read_csv(os.path.join(DATA_FOLDER, 'UnionCom/scGEM/GeneExpression.txt'), delimiter=' ', header=None).to_numpy()
    M2 = pd.read_csv(os.path.join(DATA_FOLDER, 'UnionCom/scGEM/DNAmethylation.txt'), delimiter=' ', header=None).to_numpy()
    T1 = pd.read_csv(os.path.join(DATA_FOLDER, 'UnionCom/scGEM/type1.txt'), delimiter=' ', header=None).to_numpy()
    T2 = pd.read_csv(os.path.join(DATA_FOLDER, 'UnionCom/scGEM/type2.txt'), delimiter=' ', header=None).to_numpy()
    F1 = np.loadtxt(os.path.join(DATA_FOLDER, 'UnionCom/scGEM/gex_names.txt'), dtype='str')
    F2 = np.loadtxt(os.path.join(DATA_FOLDER, 'UnionCom/scGEM/dm_names.txt'), dtype='str')

# MMD-MA data
elif dataset_name == 'MMD-MA':
    M1 = pd.read_csv(os.path.join(DATA_FOLDER, 'UnionCom/MMD/s1_mapped1.txt'), delimiter='\t', header=None).to_numpy()
    M2 = pd.read_csv(os.path.join(DATA_FOLDER, 'UnionCom/MMD/s1_mapped2.txt'), delimiter='\t', header=None).to_numpy()
    T1 = pd.read_csv(os.path.join(DATA_FOLDER, 'UnionCom/MMD/s1_type1.txt'), delimiter='\t', header=None).to_numpy()
    T2 = pd.read_csv(os.path.join(DATA_FOLDER, 'UnionCom/MMD/s1_type2.txt'), delimiter='\t', header=None).to_numpy()

# Random data
elif dataset_name == 'Random':
    num_nodes = 100
    M1 = torch.rand((num_nodes, 8), device=DEVICE)
    M2 = torch.rand((num_nodes, 16), device=DEVICE)

else: assert False, 'No matching dataset found.'

# Parameters
num_nodes = 50  # M1.shape[0]

# Modify data
M1, M2 = inept.utilities.normalize(M1, M2)  # Normalize
# M1, M2 = inept.utilities.pca_features(M1, M2, num_features=(16, 16))  # PCA features
M1, M2, T1, T2 = inept.utilities.subsample_nodes(M1, M2, T1, T2, num_nodes=num_nodes)  # Subsample nodes
# M1, M2 = inept.utilities.subsample_features(M1, M2, num_features=(16, 16))  # Subsample features

# Cast types
M1 = torch.tensor(M1, dtype=torch.float32, device=DEVICE)
M2 = torch.tensor(M2, dtype=torch.float32, device=DEVICE)
modalities = (M1, M2)

### Parameters

In [7]:
# Data parameters
data_kwargs = {
    'dataset': dataset_name,
    'num_nodes': num_nodes,
}

# Environment parameters
env_kwargs = {
    'dim': 2,  # x, y, vx, vy
    'pos_bound': 5,
    'pos_rand_bound': 1,
    'vel_bound': 1,
    'delta': .1,
    'reward_distance': 10,
    'penalty_bound': 1,
    'penalty_velocity': 1,
    'penalty_action': 1,
    'reward_distance_type': 'euclidean',
}

# Training parameters
train_kwargs = {
    'max_ep_timesteps': 3e2,
    'max_timesteps': 1e6,
    'update_timesteps': 4e3,
}

# Policy parameters
policy_kwargs = {
    'modal_sizes': [M.shape[1] for M in modalities],
    'num_features_per_node': 2*env_kwargs['dim'],
    'output_dim': env_kwargs['dim'],
    'action_std_init': .6,
    'action_std_decay': .05,
    'action_std_min': .1,
    'actor_lr': 3e-4,
    'critic_lr': 1e-3,
    'lr_gamma': 1,
    'update_minibatch': int(2e3),  # If too high, the kernel will crash (can also often crash machine)
    'update_max_batch': int(2e3),  # int(train_kwargs['update_timesteps'] * data_kwargs["num_nodes"])
    'device': DEVICE,
}

# Early stopping parameters
es_kwargs = {
    'buffer': 3 * int(train_kwargs['update_timesteps'] / train_kwargs['max_ep_timesteps']),  # 3 training cycles
    'delta': .01,
}

### Train Policy

In [8]:
# Tracking parameters
# Use `watch -d -n 0.5 nvidia-smi` to watch CUDA memory usage
# Use `top` to watch system memory usage
use_wandb = True

# Initialize classes
env = inept.environments.trajectory(*modalities, **env_kwargs, device=DEVICE)
policy = inept.models.PPO(**policy_kwargs)
early_stopping = inept.utilities.EarlyStopping(**es_kwargs)

# Initialize wandb
if use_wandb: wandb.init(
    project='INEPT',
    config={
        **{'note/'+k:v for k, v in note_kwargs.items()},
        **{'data/'+k:v for k, v in data_kwargs.items()},
        **{'env/'+k:v for k, v in env_kwargs.items()},
        **{'policy/'+k:v for k, v in policy_kwargs.items()},
        **{'train/'+k:v for k, v in train_kwargs.items()},
        **{'es/'+k:v for k, v in es_kwargs.items()},
    },
)

# Initialize logging vars
torch.cuda.reset_peak_memory_stats()
timer = inept.utilities.time_logger(discard_first_sample=True)
timestep = 0; episode = 1

# CLI
print('Beginning training')
print(f'Subsampling {policy_kwargs["update_max_batch"]} states with minibatches of size {policy_kwargs["update_minibatch"]} from {int(train_kwargs["update_timesteps"] * data_kwargs["num_nodes"])} total.')

# Simulation loop
while timestep < train_kwargs['max_timesteps']:
    # Reset environment
    env.reset()
    # env.reward_scales['reward_origin'] = episode / 1e3
    timer.log('Reset Environment')

    # Start episode
    ep_timestep = 0; ep_reward = 0; ep_itemized_reward = defaultdict(lambda: 0)
    while ep_timestep < train_kwargs['max_ep_timesteps']:
        with torch.no_grad():
            # Get current state
            state = env.get_state(include_modalities=True)
            timer.log('Environment Setup')

            # Get actions from policy
            actions = policy.act_macro(state, keys=list(range(num_nodes))).detach()
            timer.log('Calculate Actions')

            # Step environment and get reward
            rewards, finished, itemized_rewards = env.step(actions, return_rewards=True)
            timer.log('Step Environment')

            # Record rewards for policy
            policy.memory.record(
                rewards=rewards.cpu().tolist(),
                is_terminals=finished,
            )

            # Record rewards for logging
            ep_reward = ep_reward + rewards.cpu().mean()
            for k, v in itemized_rewards.items():
                ep_itemized_reward[k] += v.cpu().mean()
            timer.log('Record Rewards')

        # Iterate
        timestep += 1
        ep_timestep += 1

        # Update model
        if timestep % train_kwargs['update_timesteps'] == 0:
            print(f'Updating model with average reward {np.mean(policy.memory.storage["rewards"])} on episode {episode} and timestep {timestep}', end='')
            policy.update()
            print(f' ({torch.cuda.max_memory_allocated() / 1024**3:.2f} GB CUDA)')
            torch.cuda.reset_peak_memory_stats()
            timer.log('Update Policy')

        # Escape if finished
        if finished: break

    # Upload stats
    ep_reward = (ep_reward / ep_timestep).item()
    if use_wandb:
        wandb.log({
            **{
            'episode': episode,
            'update': int(timestep / train_kwargs['update_timesteps']),
            'end_timestep': timestep,
            'average_reward': ep_reward,
            'action_std': policy.action_std,
            },
            **{'rewards/'+k: (v / ep_timestep).item() for k, v in ep_itemized_reward.items()},
        })
    timer.log('Record Stats')

    # Decay model std
    if early_stopping(ep_reward):
        # End if already at minimum
        if policy.action_std <= policy.action_std_min:
            print(f'Ending early on episode {episode} and timestep {timestep}')
            break

        # Decay and reset early stop
        policy.decay_action_std()
        early_stopping.reset()

        # CLI
        print(f'Decaying std to {policy.action_std} on episode {episode} and timestep {timestep}')
    timer.log('Early Stopping')

    # Iterate
    episode += 1

# CLI Timer
print()
timer.aggregate('sum')

# Save model
wgt_file = os.path.join(MODEL_FOLDER, 'policy.wgt')
torch.save(policy.state_dict(), wgt_file)  # Save just weights
if use_wandb: wandb.save(wgt_file)
mdl_file = os.path.join(MODEL_FOLDER, 'policy.mdl')
torch.save(policy, mdl_file)  # Save whole model
if use_wandb: wandb.save(mdl_file)

# Finish wandb
if use_wandb: wandb.finish()

Beginning training
Subsampling 2000 states with minibatches of size 2000 from 200000 total.


Updating model with average reward -1.4661225230419554 on episode 14 and timestep 4000

 (2.26 GB CUDA)


Updating model with average reward -1.3995731101571152 on episode 27 and timestep 8000

 (2.26 GB CUDA)


Updating model with average reward -1.2955645820164354 on episode 40 and timestep 12000

 (2.26 GB CUDA)


Updating model with average reward -1.2820697113142476 on episode 54 and timestep 16000

 (2.26 GB CUDA)


Updating model with average reward -1.0854975688864725 on episode 67 and timestep 20000

 (2.26 GB CUDA)


Updating model with average reward -1.0145903051113552 on episode 80 and timestep 24000

 (2.26 GB CUDA)


Updating model with average reward -0.9177235155997833 on episode 94 and timestep 28000

 (2.26 GB CUDA)


Updating model with average reward -0.8252967280298774 on episode 107 and timestep 32000

 (2.26 GB CUDA)


Updating model with average reward -0.791715454409087 on episode 120 and timestep 36000

 (2.26 GB CUDA)


Updating model with average reward -0.7602193795348701 on episode 134 and timestep 40000

 (2.26 GB CUDA)


Updating model with average reward -0.6834847722081832 on episode 147 and timestep 44000

 (2.26 GB CUDA)


Updating model with average reward -0.6926262863448326 on episode 160 and timestep 48000

 (2.26 GB CUDA)


Updating model with average reward -0.6404163023043669 on episode 174 and timestep 52000

 (2.26 GB CUDA)


Updating model with average reward -0.5653562939555896 on episode 187 and timestep 56000

 (2.26 GB CUDA)


Updating model with average reward -0.5359236595778424 on episode 200 and timestep 60000

 (2.26 GB CUDA)


Updating model with average reward -0.5369159328065534 on episode 214 and timestep 64000

 (2.26 GB CUDA)


Updating model with average reward -0.5724982478319796 on episode 227 and timestep 68000

 (2.26 GB CUDA)


Updating model with average reward -0.5509149498046501 on episode 240 and timestep 72000

 (2.26 GB CUDA)


Updating model with average reward -0.5558103483584488 on episode 254 and timestep 76000

 (2.26 GB CUDA)


Updating model with average reward -0.5107438504061871 on episode 267 and timestep 80000

 (2.26 GB CUDA)


Updating model with average reward -0.511919118231957 on episode 280 and timestep 84000

 (2.26 GB CUDA)


Updating model with average reward -0.4970103059374634 on episode 294 and timestep 88000

 (2.26 GB CUDA)


Updating model with average reward -0.5350929928694816 on episode 307 and timestep 92000

 (2.26 GB CUDA)


Decaying std to 0.5499999999999999 on episode 307 and timestep 92100


Updating model with average reward -0.431552339922658 on episode 320 and timestep 96000

 (2.26 GB CUDA)


Updating model with average reward -0.4902271403066124 on episode 334 and timestep 100000

 (2.26 GB CUDA)


Updating model with average reward -0.4076772393411209 on episode 347 and timestep 104000

 (2.26 GB CUDA)


Decaying std to 0.49999999999999994 on episode 354 and timestep 106200


Updating model with average reward -0.4656626635821513 on episode 360 and timestep 108000

 (2.26 GB CUDA)


Updating model with average reward -0.37703744952603274 on episode 374 and timestep 112000

 (2.26 GB CUDA)


Updating model with average reward -0.3326513629827125 on episode 387 and timestep 116000

 (2.26 GB CUDA)


Updating model with average reward -0.37954372694434596 on episode 400 and timestep 120000

 (2.26 GB CUDA)


Updating model with average reward -0.39124295167613743 on episode 414 and timestep 124000

 (2.26 GB CUDA)


Decaying std to 0.44999999999999996 on episode 414 and timestep 124200


Updating model with average reward -0.3414929917163284 on episode 427 and timestep 128000

 (2.26 GB CUDA)


Updating model with average reward -0.3765534237758373 on episode 440 and timestep 132000

 (2.26 GB CUDA)


Updating model with average reward -0.3864445135001905 on episode 454 and timestep 136000

 (2.26 GB CUDA)


Updating model with average reward -0.31745408327577723 on episode 467 and timestep 140000

 (2.26 GB CUDA)


Updating model with average reward -0.3317890581210388 on episode 480 and timestep 144000

 (2.26 GB CUDA)


Updating model with average reward -0.3513525380703807 on episode 494 and timestep 148000

 (2.26 GB CUDA)


Updating model with average reward -0.388704078069745 on episode 507 and timestep 152000

 (2.26 GB CUDA)


Decaying std to 0.39999999999999997 on episode 512 and timestep 153600


Updating model with average reward -0.31196305839742533 on episode 520 and timestep 156000

 (2.26 GB CUDA)


Updating model with average reward -0.31237375038002735 on episode 534 and timestep 160000

 (2.26 GB CUDA)


Updating model with average reward -0.31439124231106397 on episode 547 and timestep 164000

 (2.26 GB CUDA)


Updating model with average reward -0.3010550847993055 on episode 560 and timestep 168000

 (2.26 GB CUDA)


Updating model with average reward -0.2662360331223643 on episode 574 and timestep 172000

 (2.26 GB CUDA)


Updating model with average reward -0.24666351719386803 on episode 587 and timestep 176000

 (2.26 GB CUDA)


Updating model with average reward -0.30342723789240555 on episode 600 and timestep 180000

 (2.26 GB CUDA)


Decaying std to 0.35 on episode 603 and timestep 180900


Updating model with average reward -0.2599547951165179 on episode 614 and timestep 184000

 (2.26 GB CUDA)


Updating model with average reward -0.19686930809329925 on episode 627 and timestep 188000

 (2.26 GB CUDA)


Updating model with average reward -0.2118393329700944 on episode 640 and timestep 192000

 (2.26 GB CUDA)


Updating model with average reward -0.21424126919050177 on episode 654 and timestep 196000

 (2.26 GB CUDA)


Decaying std to 0.3 on episode 660 and timestep 198000


Updating model with average reward -0.22496214543220994 on episode 667 and timestep 200000

 (2.26 GB CUDA)


Updating model with average reward -0.18087189533821277 on episode 680 and timestep 204000

 (2.26 GB CUDA)


Updating model with average reward -0.22921007111213293 on episode 694 and timestep 208000

 (2.26 GB CUDA)


Decaying std to 0.25 on episode 706 and timestep 211800


Updating model with average reward -0.21652804490802344 on episode 707 and timestep 212000

 (2.26 GB CUDA)


Updating model with average reward -0.1970403344487145 on episode 720 and timestep 216000

 (2.26 GB CUDA)


Updating model with average reward -0.19608376258822682 on episode 734 and timestep 220000

 (2.26 GB CUDA)


Updating model with average reward -0.18420502094891752 on episode 747 and timestep 224000

 (2.26 GB CUDA)


Decaying std to 0.2 on episode 757 and timestep 227100


Updating model with average reward -0.17961907109428604 on episode 760 and timestep 228000

 (2.26 GB CUDA)


Updating model with average reward -0.1795995095731273 on episode 774 and timestep 232000

 (2.26 GB CUDA)


Updating model with average reward -0.16030648075955833 on episode 787 and timestep 236000

 (2.26 GB CUDA)


Updating model with average reward -0.14936960001314037 on episode 800 and timestep 240000

 (2.26 GB CUDA)


Decaying std to 0.15000000000000002 on episode 801 and timestep 240300


Updating model with average reward -0.3005952910213775 on episode 814 and timestep 244000

 (2.26 GB CUDA)


Updating model with average reward -0.2486820237475554 on episode 827 and timestep 248000

 (2.26 GB CUDA)


Updating model with average reward -0.23956325555971242 on episode 840 and timestep 252000

 (2.26 GB CUDA)


Updating model with average reward -0.24292002940733248 on episode 854 and timestep 256000

 (2.26 GB CUDA)


Updating model with average reward -0.2133300249191623 on episode 867 and timestep 260000

 (2.26 GB CUDA)


Updating model with average reward -0.14686671350040562 on episode 880 and timestep 264000

 (2.26 GB CUDA)


Updating model with average reward -0.2089101224562223 on episode 894 and timestep 268000

 (2.26 GB CUDA)


Updating model with average reward -0.17981014749438837 on episode 907 and timestep 272000

 (2.26 GB CUDA)


Decaying std to 0.10000000000000002 on episode 909 and timestep 272700


Updating model with average reward -0.23195439373547722 on episode 920 and timestep 276000

 (2.26 GB CUDA)


Updating model with average reward -0.2572739524953401 on episode 934 and timestep 280000

 (2.26 GB CUDA)


Updating model with average reward -0.2441360156008651 on episode 947 and timestep 284000

 (2.26 GB CUDA)


Decaying std to 0.1 on episode 954 and timestep 286200


Updating model with average reward -0.229629640105336 on episode 960 and timestep 288000

 (2.26 GB CUDA)


Updating model with average reward -0.3415543417282912 on episode 974 and timestep 292000

 (2.26 GB CUDA)


Updating model with average reward -0.3029267623892048 on episode 987 and timestep 296000

 (2.26 GB CUDA)


Ending early on episode 994 and timestep 298200

Reset Environment: 0.13452185901837765
Environment Setup: 9.941536761042698
Calculate Actions: 871.5679537820158
Step Environment: 237.8564275458748


Record Rewards: 57.445545384943216
Record Stats: 0.4735558680047234
Early Stopping: 0.12236400199390118
Update Policy: 4842.703110512002
Total: 6020.245015714896
