### Check the original model

In [113]:
import cellbox
import os
import numpy as np
import pandas as pd
import tensorflow.compat.v1 as tf
import shutil
import argparse
import json
import glob
import time
from tensorflow.compat.v1.errors import OutOfRangeError
from cellbox.utils import TimeLogger
tf.disable_v2_behavior()

In [120]:
def set_seed(in_seed):
    int_seed = int(in_seed)
    tf.compat.v1.set_random_seed(int_seed)
    np.random.seed(int_seed)


def prepare_workdir(in_cfg):
    # Read Data
    in_cfg.root_dir = os.getcwd()
    in_cfg.node_index = pd.read_csv(in_cfg.node_index_file, header=None, names=None) \
        if hasattr(in_cfg, 'node_index_file') else pd.DataFrame(np.arange(in_cfg.n_x))

    # Create Output Folder
    experiment_path = 'results/{}_{}'.format(in_cfg.experiment_id, md5)
    try:
        os.makedirs(experiment_path)
    except Exception:
        pass
    out_cfg = vars(in_cfg)
    out_cfg = {key: out_cfg[key] for key in out_cfg if type(out_cfg[key]) is not pd.DataFrame}
    os.chdir(experiment_path)
    json.dump(out_cfg, open('config.json', 'w'), indent=4)

    if "leave one out" in in_cfg.experiment_type:
        try:
            in_cfg.model_prefix = '{}_{}'.format(in_cfg.model_prefix, in_cfg.drug_index)
        except Exception('Drug index not specified') as e:
            raise e

    in_cfg.working_index = in_cfg.model_prefix + "_" + str(working_index).zfill(3)

    try:
        shutil.rmtree(in_cfg.working_index)
    except Exception:
        pass
    os.makedirs(in_cfg.working_index)
    os.chdir(in_cfg.working_index)

    with open("record_eval.csv", 'w') as f:
        f.write("epoch,iter,train_loss,valid_loss,train_mse,valid_mse,test_mse,time_elapsed\n")

    print('Working directory is ready at {}.'.format(experiment_path))
    return 0

experiment_config_path = "/users/ngun7t/Documents/cellbox-jun-6/configs_dev/Example.random_partition.json"
working_index = 0
stage = {
    "nT": 100,
    "sub_stages":[
        {"lr_val": 0.1,"l1lambda": 0.01, "n_iter_patience":1000},
        {"lr_val": 0.01,"l1lambda": 0.01},
        {"lr_val": 0.01,"l1lambda": 0.0001},
        {"lr_val": 0.001,"l1lambda": 0.00001}
    ]}

cfg = cellbox.config.Config(experiment_config_path)
cfg.ckpt_path_full = os.path.join('./', cfg.ckpt_name)
md5 = cellbox.utils.md5(cfg)
cfg.drug_index = 5         # Change this for testing purposes
cfg.seed = working_index + cfg.seed if hasattr(cfg, "seed") else working_index + 1000
set_seed(cfg.seed)
print(vars(cfg))

prepare_workdir(cfg)
logger = cellbox.utils.TimeLogger(time_logger_step=1, hierachy=3)
args = cfg
for i, stage in enumerate(cfg.stages):
    set_seed(cfg.seed)
    cfg = cellbox.dataset.factory(cfg)
    args.sub_stages = stage['sub_stages']
    args.n_T = stage['nT']
    model = cellbox.model.factory(args)
    if i == 0: break

{'experiment_id': 'Example_RP', 'model_prefix': 'seed', 'ckpt_name': 'model11.ckpt', 'export_verbose': 3, 'experiment_type': 'random partition', 'sparse_data': False, 'batchsize': 4, 'trainset_ratio': 0.7, 'validset_ratio': 0.8, 'n_batches_eval': None, 'add_noise_level': 0, 'dT': 0.1, 'ode_solver': 'heun', 'envelope_form': 'tanh', 'envelope': 0, 'pert_form': 'by u', 'ode_degree': 1, 'ode_last_steps': 2, 'n_iter_buffer': 50, 'n_iter_patience': 100, 'weight_loss': 'None', 'l1lambda': 0.0001, 'l2lambda': 0.0001, 'model': 'LinReg', 'pert_file': '/users/ngun7t/Documents/cellbox-jun-6/data/pert.csv', 'expr_file': '/users/ngun7t/Documents/cellbox-jun-6/data/expr.csv', 'node_index_file': '/users/ngun7t/Documents/cellbox-jun-6/data/node_Index.csv', 'n_protein_nodes': 82, 'n_activity_nodes': 87, 'n_x': 99, 'envelop_form': 'tanh', 'envelop': 0, 'n_epoch': 100, 'n_iter': 100, 'stages': [{'nT': 200, 'sub_stages': [{'lr_val': 0.001, 'l1lambda': 0.0001}]}], 'ckpt_path_full': './model11.ckpt', 'drug_i

In [121]:
class Screenshot(dict):
    """summarize the model"""
    def __init__(self, args, n_iter_buffer):
        # initialize loss_min
        super().__init__()
        self.loss_min = 1000
        # initialize tuning_metric
        self.saved_losses = [self.loss_min]
        self.n_iter_buffer = n_iter_buffer
        # initialize verbose
        self.summary = {}
        self.summary = {}
        self.substage_i = []
        self.export_verbose = args.export_verbose

    def avg_n_iters_loss(self, new_loss):
        """average the last few losses"""
        self.saved_losses = self.saved_losses + [new_loss]
        self.saved_losses = self.saved_losses[-self.n_iter_buffer:]
        return sum(self.saved_losses) / len(self.saved_losses)

    def screenshot(self, sess, model, substage_i, node_index, loss_min, args):
        """evaluate models"""
        self.substage_i = substage_i
        self.loss_min = loss_min
        # Save the variables to disk.
        if self.export_verbose > 0:
            params = sess.run(model.params)
            for item in params:
                try:
                    params[item] = pd.DataFrame(params[item], index=node_index[0])
                except Exception:
                    params[item] = pd.DataFrame(params[item])
            self.update(params)

        if self.export_verbose > 1 or self.export_verbose == -1:  # no params but y_hat
            sess.run(model.iter_eval.initializer, feed_dict=model.args.feed_dicts['test_set'])
            y_hat = eval_model(sess, model.iter_eval, model.eval_yhat, args.feed_dicts['test_set'], return_avg=False)
            y_hat = pd.DataFrame(y_hat, columns=node_index[0])
            self.update({'y_hat': y_hat})

        if self.export_verbose > 2:
            try:
                # TODO: not yet support data iterators
                summary_train = sess.run(model.convergence_metric,
                                         feed_dict={model.in_pert: args.dataset['pert_train']})
                summary_test = sess.run(model.convergence_metric, feed_dict={model.in_pert: args.dataset['pert_test']})
                summary_valid = sess.run(model.convergence_metric,
                                         feed_dict={model.in_pert: args.dataset['pert_valid']})
                summary_train = pd.DataFrame(summary_train, columns=[node_index.values + '_mean', node_index.values +
                                                                     '_sd', node_index.values + '_dxdt'])
                summary_test = pd.DataFrame(summary_test, columns=[node_index.values + '_mean', node_index.values +
                                                                   '_sd', node_index.values + '_dxdt'])
                summary_valid = pd.DataFrame(summary_valid, columns=[node_index.values + '_mean', node_index.values +
                                                                     '_sd', node_index.values + '_dxdt'])
                self.update(
                    {'summary_train': summary_train, 'summary_test': summary_test, 'summary_valid': summary_valid}
                )
            except Exception:
                pass

    def save(self):
        """save model parameters"""
        for file in glob.glob(str(self.substage_i) + "_best.*.csv"):
            os.remove(file)
        for key in self:
            self[key].to_csv("{}_best.{}.loss.{}.csv".format(self.substage_i, key, self.loss_min))

In [122]:
def save_model(saver, sess, path):
    """save model"""
    # Save the variables to disk.
    tmp = saver.save(sess, path)
    print("Model saved in path: %s" % tmp)

def append_record(filename, contents):
    """define function for appending training record"""
    with open(filename, 'a') as f:
        for content in contents:
            f.write('{},'.format(content))
        f.write('\n')


def eval_model(sess, eval_iter, obj_fn, eval_dict, return_avg=True, n_batches_eval=None):
    """simulate the model for prediction"""
    sess.run(eval_iter.initializer, feed_dict=eval_dict)
    counter = 0
    eval_results = []
    while True:
        try:
            eval_results.append(sess.run(obj_fn, feed_dict=eval_dict))
        except OutOfRangeError:
            break
        counter += 1
        if n_batches_eval is not None and counter > n_batches_eval:
            break

    print(f"eval_model eval_results: {eval_results[0].shape} with len {len(eval_results)}")
    if return_avg:
        return np.mean(np.array(eval_results), axis=0)
    return np.vstack(eval_results)

In [123]:
def train_substage(model, sess, lr_val, l1_lambda, l2_lambda, n_epoch, n_iter, n_iter_buffer, n_iter_patience, args):
    """
    Training function that does one stage of training. The stage training can be repeated and modified to give better
    training result.

    Args:
        model (CellBox): an CellBox instance
        sess (tf.Session): current session, need reinitialization for every nT
        lr_val (float): learning rate (read in from config file)
        l1_lambda (float): l1 regularization weight
        l2_lambda (float): l2 regularization weight
        n_epoch (int): maximum number of epochs
        n_iter (int): maximum number of iterations
        n_iter_buffer (int): training loss moving average window
        n_iter_patience (int): training loss tolerance
        args: Args or configs
    """

    stages = glob.glob("*best*.csv")
    try:
        substage_i = 1 + max([int(stage[0]) for stage in stages])
    except Exception:
        substage_i = 1

    best_params = Screenshot(args, n_iter_buffer)

    n_unchanged = 0
    idx_iter = 0
    for key in args.feed_dicts:
        args.feed_dicts[key].update({
            model.lr: lr_val,
            model.l1_lambda: l1_lambda,
            model.l2_lambda: l2_lambda
        })
    args.logger.log("--------- lr: {}\tl1: {}\tl2: {}\t".format(lr_val, l1_lambda, l2_lambda))
    sess.run(model.iter_monitor.initializer, feed_dict=args.feed_dicts['valid_set'])
    loss_train_across_epochs, loss_train_mse_across_epochs = [], []
    loss_val_across_epochs, loss_val_mse_across_epochs = [], []
    for idx_epoch in range(n_epoch):

        loss_train_l, loss_train_mse_l = [], []
        loss_val_l, loss_val_mse_l = [], []
        if idx_iter > n_iter or n_unchanged > n_iter_patience:
            break

        sess.run(model.iter_train.initializer, feed_dict=args.feed_dicts['train_set'])
        while True:
            if idx_iter > n_iter or n_unchanged > n_iter_patience:
                break
            t0 = time.perf_counter()
            try:
                _, loss_train_i, loss_train_mse_i = sess.run(
                    (model.op_optimize, model.train_loss, model.train_mse_loss), feed_dict=args.feed_dicts['train_set'])
                print(f"Loss train i: {loss_train_i}")
                loss_train_l.append(loss_train_i)
                loss_train_mse_l.append(loss_train_mse_i)

            except OutOfRangeError:  # for iter_train
                break

            # record training
            loss_valid_i, loss_valid_mse_i = sess.run(
                (model.monitor_loss, model.monitor_mse_loss), feed_dict=args.feed_dicts['valid_set'])
            loss_val_l.append(loss_valid_i)
            loss_val_mse_l.append(loss_valid_mse_i)
            #new_loss = best_params.avg_n_iters_loss(loss_valid_i)
            #if args.export_verbose > 0:
            #    print(("Substage:{}\tEpoch:{}/{}\tIteration: {}/{}" + "\tloss (train):{:1.6f}" + "\tbest:{:1.6f}\tTolerance: {}/{}").format(
            #        substage_i, idx_epoch, n_epoch, idx_iter,
            #        n_iter, loss_train_i,
            #        best_params.loss_min, n_unchanged,
            #        n_iter_patience
            #        ))
            new_loss = best_params.avg_n_iters_loss(loss_valid_i)
            if args.export_verbose > 0:
                print(("Substage:{}\tEpoch:{}/{}\tIteration: {}/{}" + "\tloss (train):{:1.6f}\tloss (buffer on valid):"
                       "{:1.6f}" + "\tbest:{:1.6f}\tTolerance: {}/{}").format(substage_i, idx_epoch, n_epoch, idx_iter,
                                                                              n_iter, loss_train_i, new_loss,
                                                                              best_params.loss_min, n_unchanged,
                                                                              n_iter_patience))
            append_record("record_eval.csv",
                          [idx_epoch, idx_iter, loss_train_i, loss_valid_i, loss_train_mse_i,
                           loss_valid_mse_i, None, time.perf_counter() - t0])
            # early stopping
            idx_iter += 1
            if new_loss < best_params.loss_min:
                n_unchanged = 0
                best_params.screenshot(sess, model, substage_i, args=args,
                                       node_index=args.dataset['node_index'], loss_min=new_loss)
            else:
                n_unchanged += 1

        loss_train_across_epochs.append(loss_train_l)
        loss_train_mse_across_epochs.append(loss_train_mse_l)
        loss_val_across_epochs.append(loss_val_l)
        loss_val_mse_across_epochs.append(loss_val_mse_l)


    return best_params, {
        "train": loss_train_across_epochs,
        "train_mse": loss_train_mse_across_epochs,
        "val": loss_val_across_epochs,
        "val_mse": loss_val_mse_across_epochs
    }


    # Evaluation on valid set
    #t0 = time.perf_counter()
    #sess.run(model.iter_eval.initializer, feed_dict=args.feed_dicts['valid_set'])
    #loss_valid_i, loss_valid_mse_i = eval_model(sess, model.iter_eval, (model.eval_loss, model.eval_mse_loss),
    #                                            args.feed_dicts['valid_set'], n_batches_eval=args.n_batches_eval)
    #append_record("record_eval.csv", [-1, None, None, loss_valid_i, None, loss_valid_mse_i, None, time.perf_counter() - t0])
#
    ## Evaluation on test set
    #t0 = time.perf_counter()
    #sess.run(model.iter_eval.initializer, feed_dict=args.feed_dicts['test_set'])
    #loss_test_mse = eval_model(sess, model.iter_eval, model.eval_mse_loss,
    #                           args.feed_dicts['test_set'], n_batches_eval=args.n_batches_eval)
    #append_record("record_eval.csv", [-1, None, None, None, None, None, loss_test_mse, time.perf_counter() - t0])
#
    #best_params.save()
    #args.logger.log("------------------ Substage {} finished!-------------------".format(substage_i))
    #save_model(args.saver, sess, './' + args.ckpt_name)
#
#
def append_record(filename, contents):
    """define function for appending training record"""
    with open(filename, 'a') as f:
        for content in contents:
            f.write('{},'.format(content))
        f.write('\n')

In [124]:
def train_model(model, args):
    """Train the model"""
    args.logger = TimeLogger(time_logger_step=1, hierachy=2)

    # Check if all variables in scope
    # TODO: put variables under appropriate scopes
    for i in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope='initialization'):
        print(i)

    # Initialization
    args.saver = tf.compat.v1.train.Saver()
    from tensorflow.core.protobuf import rewriter_config_pb2
    config = tf.compat.v1.ConfigProto()
    off = rewriter_config_pb2.RewriterConfig.OFF
    config.graph_options.rewrite_options.memory_optimization = off

    # Launching session
    sess = tf.compat.v1.Session(config=config)
    sess.run(tf.compat.v1.global_variables_initializer())
    try:
        args.saver.restore(sess, './' + args.ckpt_name)
        print('Load existing model at {}...'.format(args.ckpt_name))
    except Exception:
        print('Create new model at {}...'.format(args.ckpt_name))

    # Training
    for substage in args.sub_stages:
        n_iter_buffer = substage['n_iter_buffer'] if 'n_iter_buffer' in substage else args.n_iter_buffer
        n_iter = substage['n_iter'] if 'n_iter' in substage else args.n_iter
        n_iter_patience = substage['n_iter_patience'] if 'n_iter_patience' in substage else args.n_iter_patience
        n_epoch = substage['n_epoch'] if 'n_epoch' in substage else args.n_epoch
        l1 = substage['l1lambda'] if 'l1lambda' in substage else args.l1lambda if hasattr(args, 'l1lambda') else 0
        l2 = substage['l2lambda'] if 'l2lambda' in substage else args.l2lambda if hasattr(args, 'l2lambda') else 0
        screenshot, d = train_substage(model, sess, substage['lr_val'], l1_lambda=l1, l2_lambda=l2, n_epoch=n_epoch,
                       n_iter=n_iter, n_iter_buffer=n_iter_buffer, n_iter_patience=n_iter_patience, args=args)

    # Terminate session
    sess.close()
    tf.compat.v1.reset_default_graph()

    return screenshot, d

### Train the original model

In [125]:
screenshot, d = train_model(model, args)

<tf.Variable 'initialization/W:0' shape=(99, 99) dtype=float32_ref>
<tf.Variable 'initialization/b:0' shape=(99, 1) dtype=float32_ref>
Create new model at model11.ckpt...
########   --------- lr: 0.001	l1: 0.0001	l2: 0.0001	   --time elapsed: 0.05
Loss train i: 3.614074230194092
Substage:1	Epoch:0/100	Iteration: 0/100	loss (train):3.614074	loss (buffer on valid):504.059919	best:1000.000000	Tolerance: 0/100
eval_model eval_results: (4, 99) with len 12
Loss train i: 3.607429265975952
Substage:1	Epoch:1/100	Iteration: 1/100	loss (train):3.607429	loss (buffer on valid):338.206011	best:504.059919	Tolerance: 0/100
eval_model eval_results: (4, 99) with len 12
Loss train i: 3.6007916927337646
Substage:1	Epoch:2/100	Iteration: 2/100	loss (train):3.600792	loss (buffer on valid):255.587957	best:338.206011	Tolerance: 0/100
eval_model eval_results: (4, 99) with len 12
Loss train i: 3.5941641330718994
Substage:1	Epoch:3/100	Iteration: 3/100	loss (train):3.594164	loss (buffer on valid):205.465756	bes

In [8]:
print(d)

{'train': [[3.6140742, 7.227595, 6.70115, 3.6506257, 6.622177, 3.5594428, 5.5601854, 4.5101857, 3.3610966, 9.822557, 6.1879897, 7.4154177, 3.217002], [3.5472631, 7.108092, 6.5936656, 3.5849123, 6.519177, 3.498232, 5.4680862, 4.434257, 3.303298, 9.674823, 6.095251, 7.3001394, 3.158329], [3.4862378, 6.9968557, 6.4919667, 3.5228007, 6.420512, 3.4389234, 5.378127, 4.3604865, 3.2469144, 9.530659, 6.004282, 7.186862, 3.1008427], [3.4265602, 6.888028, 6.392318, 3.462064, 6.3238683, 3.38082, 5.289957, 4.2882223, 3.191661, 9.389341, 5.9149823, 7.0757785, 3.0445414], [3.3680334, 6.781307, 6.2945023, 3.402525, 6.2290397, 3.323862, 5.2035356, 4.2173944, 3.1375017, 9.250699, 5.8273, 6.96684, 2.9893854], [3.310612, 6.676588, 6.198433, 3.3441286, 6.135934, 3.2680173, 5.1187944, 4.1479607, 3.0844135, 9.114625, 5.7411876, 6.8599553, 2.9353518], [3.2542937, 6.57382, 6.104082, 3.286867, 6.0445, 3.2132654, 5.035672, 4.079875, 3.0323608, 8.981022, 5.6565876, 6.7550373, 2.882391], [3.199048, 6.4729385, 6.01

In [126]:
screenshot["y_hat"]

Unnamed: 0,4EBP1pS65,RbpS807,MAPKpT202,MEKpS217,S6,PAI-1,AKTpS473,AMPKpT172,b-Catenin,BIM,...,aHDAC,aMDM2,aJAK,aBRAFm,aPKC,aSTAT3,amTOR,aPI3K,aCDK4,aSRC
0,-1.872056,-0.411044,-2.620821,3.104337,-0.743947,-1.272077,0.682799,0.207566,-1.073564,2.039607,...,-1.178446,-1.828852,-0.035953,3.698557,-0.128579,0.450156,0.881209,0.210576,3.003808,2.545527
1,-0.896746,-0.129591,-0.350443,0.459292,1.69051,1.403689,0.037429,-2.392194,-0.784808,0.67666,...,1.171187,-1.419314,1.225619,0.101867,0.759815,-0.589532,1.587433,0.790972,0.231998,1.461179
2,-2.950902,0.822978,0.196699,3.285514,-0.843176,-4.585529,3.83551,-1.913752,-2.905127,1.268464,...,-1.072808,-0.842429,-0.849605,-1.310692,1.285581,-0.155561,-1.602297,2.279411,-1.784619,2.301683
3,-1.948739,-0.83135,-2.651083,2.771353,1.006357,0.159703,-0.475347,-0.478002,-1.284924,2.604939,...,0.483027,-3.53301,-0.044614,4.611651,0.107005,-0.612658,1.764507,-0.814713,2.538998,1.85665
4,-2.641239,-1.453549,0.429106,1.725165,-1.37468,-1.001256,-0.345604,-1.005148,-0.929291,0.137005,...,-0.659138,-0.251867,0.349778,-1.672482,-1.083149,-1.345617,0.525038,-0.711037,1.817143,1.450305
5,-3.46064,0.617424,0.981222,-0.390109,-0.29222,-0.812624,-4.706495,-0.358261,3.002356,0.557624,...,-1.00801,-0.174146,0.204835,-0.671163,0.820317,-0.889003,-0.988392,1.924198,3.164851,2.74942
6,-3.928865,1.21409,-1.411724,3.795213,-0.799834,-3.893729,3.077866,-0.360755,-2.813392,0.655698,...,-2.62474,-0.776332,0.349412,-1.758432,-0.435169,-0.173215,-0.776527,2.436399,-2.21404,1.160645
7,0.32738,-0.891904,0.024044,0.930231,1.151862,0.310351,1.353396,-0.843355,0.009305,1.002714,...,-0.875759,0.19233,1.146967,-0.379891,0.878628,0.699762,0.325796,0.687188,0.829806,1.634269
8,-1.991185,-1.150827,-0.5513,1.7739,-0.389288,0.400987,-0.899439,-1.549552,-1.891864,0.429134,...,0.835332,-1.591986,-0.160672,-1.266939,-0.167861,-0.842814,1.888119,-0.172633,0.198788,0.826621
9,-2.259609,-0.637128,0.680766,-0.036886,-0.292708,-0.032379,-0.299428,-2.044994,-0.864947,1.392435,...,-0.050174,0.042111,0.271335,-0.672674,0.016119,-0.503781,-0.477123,0.370982,0.999096,0.983332


### Tensorflow draft code

In [9]:
model.eval_yhat

<tf.Tensor 'add_2:0' shape=(?, 99) dtype=float32>

### Check out Pytorch model

In [1]:
import torch
import torch.nn as nn
import cellbox
import os
import numpy as np
import pandas as pd
import shutil
import argparse
import json
import glob
import time
from cellbox.utils import TimeLogger

import tensorflow.compat.v1 as tf
from tensorflow.compat.v1.errors import OutOfRangeError
tf.disable_v2_behavior()

2023-06-28 00:53:01.144094: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-06-28 00:53:01.346694: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-06-28 00:53:05.200251: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /cm/shared/apps/lsf10/10.1/linux3.10-glibc2.17-x86_64/lib:/data/weirauchlab/opt/lib:/da

Instructions for updating:
non-resource variables are not supported in the long term
   _____     _ _ ____              
  / ____|   | | |  _ \             
 | |     ___| | | |_) | _____  __  
 | |    / _ \ | |  _ < / _ \ \/ /  
 | |___|  __/ | | |_) | (_) >  <   
  \_____\___|_|_|____/ \___/_/\_\  
Running CellBox scripts developed in Sander lab
Maintained by Bo Yuan, Judy Shen, and Augustin Luna; contributions by Daniel Ritter

        version 0.3.2
        -- Feb 10, 2023 --
        * Modify CellBox to support TF2     
        
Tutorials and documentations are available at https://github.com/sanderlab/CellBox
If you want to discuss the usage or to report a bug, please use the 'Issues' function at GitHub.
If you find CellBox useful for your research, please consider citing the corresponding publication.
For more information, please email us at boyuan@g.harvard.edu and c_shen@g.harvard.edu, augustin_luna@hms.harvard.edu
 ----------------------------------------------------------------

In [2]:
def set_seed(in_seed):
    int_seed = int(in_seed)
    tf.compat.v1.set_random_seed(int_seed)
    np.random.seed(int_seed)


def prepare_workdir(in_cfg):
    # Read Data
    in_cfg.root_dir = os.getcwd()
    in_cfg.node_index = pd.read_csv(in_cfg.node_index_file, header=None, names=None) \
        if hasattr(in_cfg, 'node_index_file') else pd.DataFrame(np.arange(in_cfg.n_x))

    # Create Output Folder
    experiment_path = 'results/{}_{}'.format(in_cfg.experiment_id, md5)
    try:
        os.makedirs(experiment_path)
    except Exception:
        pass
    out_cfg = vars(in_cfg)
    out_cfg = {key: out_cfg[key] for key in out_cfg if type(out_cfg[key]) is not pd.DataFrame}
    os.chdir(experiment_path)
    json.dump(out_cfg, open('config.json', 'w'), indent=4)

    if "leave one out" in in_cfg.experiment_type:
        try:
            in_cfg.model_prefix = '{}_{}'.format(in_cfg.model_prefix, in_cfg.drug_index)
        except Exception('Drug index not specified') as e:
            raise e

    in_cfg.working_index = in_cfg.model_prefix + "_" + str(working_index).zfill(3)

    try:
        shutil.rmtree(in_cfg.working_index)
    except Exception:
        pass
    os.makedirs(in_cfg.working_index)
    os.chdir(in_cfg.working_index)

    with open("record_eval.csv", 'w') as f:
        f.write("epoch,iter,train_loss,valid_loss,train_mse,valid_mse,test_mse,time_elapsed\n")

    print('Working directory is ready at {}.'.format(experiment_path))
    return 0

experiment_config_path = "/users/ngun7t/Documents/cellbox-jun-6/configs_dev/Example.random_partition.CellBox.json"
working_index = 0
stage = {
    "nT": 100,
    "sub_stages":[
        {"lr_val": 0.1,"l1lambda": 0.01, "n_iter_patience":1000},
        {"lr_val": 0.01,"l1lambda": 0.01},
        {"lr_val": 0.01,"l1lambda": 0.0001},
        {"lr_val": 0.001,"l1lambda": 0.00001}
    ]}

cfg = cellbox.config.Config(experiment_config_path)
cfg.ckpt_path_full = os.path.join('./', cfg.ckpt_name)
md5 = cellbox.utils.md5(cfg)
cfg.drug_index = 5         # Change this for testing purposes
cfg.seed = working_index + cfg.seed if hasattr(cfg, "seed") else working_index + 1000
set_seed(cfg.seed)
print(vars(cfg))

prepare_workdir(cfg)
logger = cellbox.utils.TimeLogger(time_logger_step=1, hierachy=3)
args = cfg
for i, stage in enumerate(cfg.stages):
    set_seed(cfg.seed)
    cfg = cellbox.dataset_torch.factory(cfg)
    args.sub_stages = stage['sub_stages']
    args.n_T = stage['nT']
    model = cellbox.model_torch.factory(args)
    if i == 0: break

{'experiment_id': 'Example_RP', 'model_prefix': 'seed', 'ckpt_name': 'model11.ckpt', 'export_verbose': 3, 'experiment_type': 'random partition', 'sparse_data': False, 'batchsize': 4, 'trainset_ratio': 0.7, 'validset_ratio': 0.8, 'n_batches_eval': None, 'add_noise_level': 0, 'dT': 0.1, 'ode_solver': 'heun', 'envelope_form': 'tanh', 'envelope': 0, 'pert_form': 'by u', 'ode_degree': 1, 'ode_last_steps': 2, 'n_iter_buffer': 50, 'n_iter_patience': 100, 'weight_loss': 'None', 'l1lambda': 0.0001, 'l2lambda': 0.0001, 'model': 'CellBox', 'pert_file': '/users/ngun7t/Documents/cellbox-jun-6/data/pert.csv', 'expr_file': '/users/ngun7t/Documents/cellbox-jun-6/data/expr.csv', 'node_index_file': '/users/ngun7t/Documents/cellbox-jun-6/data/node_Index.csv', 'n_protein_nodes': 82, 'n_activity_nodes': 87, 'n_x': 99, 'envelop_form': 'tanh', 'envelop': 0, 'n_epoch': 100, 'n_iter': 100, 'stages': [{'nT': 200, 'sub_stages': [{'lr_val': 0.001, 'l1lambda': 0.0001}]}], 'ckpt_path_full': './model11.ckpt', 'drug_

In [142]:
def train_substage(model, lr_val, l1_lambda, l2_lambda, n_epoch, n_iter, n_iter_buffer, n_iter_patience, args):

    # Let's just assume that args contains also the loss function, dataloaders, and the optimizer
    stages = glob.glob("*best*.csv")
    try:
        substage_i = 1 + max([int(stage[0]) for stage in stages])
    except Exception:
        substage_i = 1

    best_params = Screenshot(args, n_iter_buffer)

    n_unchanged = 0
    idx_iter = 0
    #for key in args.feed_dicts:
    #    args.feed_dicts[key].update({
    #        model.lr: lr_val,
    #        model.l1_lambda: l1_lambda,
    #        model.l2_lambda: l2_lambda
    #    })
    args.logger.log("--------- lr: {}\tl1: {}\tl2: {}\t".format(lr_val, l1_lambda, l2_lambda))

    #sess.run(model.iter_monitor.initializer, feed_dict=args.feed_dicts['valid_set'])
    for idx_epoch in range(n_epoch):

        if idx_iter > n_iter or n_unchanged > n_iter_patience:
            break

        for i, train_minibatch in enumerate(args.iter_train):
            # Each train_minibatch has shape of (batch_size, num_features)
            x_train, y_train = train_minibatch

            if idx_iter > n_iter or n_unchanged > n_iter_patience:
                break

            # Do one forward pass
            t0 = time.perf_counter()
            args.optimizer.zero_grad()
            prediction = model(None, x_train.to(args.device))
            loss_train_i, loss_train_mse_i = args.loss_fn(y_train.to(args.device), prediction, model.W.weight)
            loss_train_i.backward()
            args.optimizer.step()

            # Record validation results
            with torch.no_grad():
                # Very questionable for validation
                loss_valid, loss_valid_mse = 0, 0
                valid_minibatch = iter(args.iter_monitor)
                x_valid, y_valid = next(valid_minibatch)
                loss_valid_i, loss_valid_mse_i = args.loss_fn(y_valid.to(args.device), model(None, x_valid.to(args.device)), model.W.weight)

            # Record results to screenshot
            new_loss = best_params.avg_n_iters_loss(loss_valid_i)
            if args.export_verbose > 0:
                print(("Substage:{}\tEpoch:{}/{}\tIteration: {}/{}" + "\tloss (train):{:1.6f}\tloss (buffer on valid):"
                       "{:1.6f}" + "\tbest:{:1.6f}\tTolerance: {}/{}").format(substage_i, idx_epoch, n_epoch, idx_iter,
                                                                              n_iter, loss_train_i, new_loss,
                                                                              best_params.loss_min, n_unchanged,
                                                                              n_iter_patience))
            
            append_record("record_eval.csv",
                          [idx_epoch, idx_iter, loss_train_i.item(), loss_valid_i.item(), loss_train_mse_i.item(),
                           loss_valid_mse_i.item(), None, time.perf_counter() - t0])

            # Early stopping
            idx_iter += 1
            if new_loss < best_params.loss_min:
                n_unchanged = 0
                best_params.screenshot(model, substage_i, args=args,
                                       node_index=args.dataset['node_index'], loss_min=new_loss)
            else:
                n_unchanged += 1


    #best_params.save()
    args.logger.log("------------------ Substage {} finished!-------------------".format(substage_i))
    #save_model(args.saver, sess, './' + args.ckpt_name)

    return best_params


def append_record(filename, contents):
    """define function for appending training record"""
    with open(filename, 'a') as f:
        for content in contents:
            f.write('{},'.format(content))
        f.write('\n')


def eval_model(device, eval_iter, model, return_avg=True, n_batches_eval=None):
    """ Simulate the model for prediction """

    with torch.no_grad():
        counter = 0
        eval_results = []
        for item in eval_iter:
            pert, expr = item
            pred = model(None, pert.to(device))
            eval_results.append(pred.detach().cpu().numpy())
            counter += 1
            if n_batches_eval is not None and counter > n_batches_eval:
                break

        if return_avg:
            return np.mean(np.array(eval_results), axis=0)
        return np.vstack(eval_results)


def train_model(model, args):
    """Train the model"""
    args.logger = TimeLogger(time_logger_step=1, hierachy=2)
    model = model[0].to(args.device)

    # Check if all variables in scope
    # TODO: put variables under appropriate scopes
    #try:
    #    args.saver.restore(sess, './' + args.ckpt_name)
    #    print('Load existing model at {}...'.format(args.ckpt_name))
    #except Exception:
    #    print('Create new model at {}...'.format(args.ckpt_name))

    # Training
    for substage in args.sub_stages:
        n_iter_buffer = substage['n_iter_buffer'] if 'n_iter_buffer' in substage else args.n_iter_buffer
        n_iter = substage['n_iter'] if 'n_iter' in substage else args.n_iter
        n_iter_patience = substage['n_iter_patience'] if 'n_iter_patience' in substage else args.n_iter_patience
        n_epoch = substage['n_epoch'] if 'n_epoch' in substage else args.n_epoch
        l1 = substage['l1lambda'] if 'l1lambda' in substage else args.l1lambda if hasattr(args, 'l1lambda') else 0
        l2 = substage['l2lambda'] if 'l2lambda' in substage else args.l2lambda if hasattr(args, 'l2lambda') else 0
        screenshot = train_substage(model, substage['lr_val'], l1_lambda=l1, l2_lambda=l2, n_epoch=n_epoch,
                       n_iter=n_iter, n_iter_buffer=n_iter_buffer, n_iter_patience=n_iter_patience, args=args)

    return screenshot
        

class Screenshot(dict):
    """summarize the model"""
    def __init__(self, args, n_iter_buffer):
        # initialize loss_min
        super().__init__()
        self.loss_min = 1000
        # initialize tuning_metric
        self.saved_losses = [self.loss_min]
        self.n_iter_buffer = n_iter_buffer
        # initialize verbose
        self.summary = {}
        self.summary = {}
        self.substage_i = []
        self.export_verbose = args.export_verbose

    def avg_n_iters_loss(self, new_loss):
        """average the last few losses"""
        self.saved_losses = self.saved_losses + [new_loss]
        self.saved_losses = self.saved_losses[-self.n_iter_buffer:]
        return sum(self.saved_losses) / len(self.saved_losses)

    def screenshot(self, model, substage_i, node_index, loss_min, args):
        """evaluate models"""
        self.substage_i = substage_i
        self.loss_min = loss_min

        # Save the variable weights associated with each of the conditions in a csv file
        if self.export_verbose > 0:
            layer = model.W
            params = layer.state_dict()
            new_params = {}
            for item in params:
                try:
                    new_params[item] = pd.DataFrame(params[item].numpy(), index=node_index[0])
                except Exception:
                    new_params[item] = pd.DataFrame(params[item].numpy())
            self.update(params)

        if self.export_verbose > 1 or self.export_verbose == -1:  # no params but y_hat
            y_hat = eval_model(args.device, args.iter_eval, model, return_avg=False)
            y_hat = pd.DataFrame(y_hat, columns=node_index[0])
            self.update({'y_hat': y_hat})

        if self.export_verbose > 2:
            try:
                # TODO: not yet support data iterators
                summary_train = sess.run(model.convergence_metric,
                                         feed_dict={model.in_pert: args.dataset['pert_train']})
                summary_test = sess.run(model.convergence_metric, feed_dict={model.in_pert: args.dataset['pert_test']})
                summary_valid = sess.run(model.convergence_metric,
                                         feed_dict={model.in_pert: args.dataset['pert_valid']})
                summary_train = pd.DataFrame(summary_train, columns=[node_index.values + '_mean', node_index.values +
                                                                     '_sd', node_index.values + '_dxdt'])
                summary_test = pd.DataFrame(summary_test, columns=[node_index.values + '_mean', node_index.values +
                                                                   '_sd', node_index.values + '_dxdt'])
                summary_valid = pd.DataFrame(summary_valid, columns=[node_index.values + '_mean', node_index.values +
                                                                     '_sd', node_index.values + '_dxdt'])
                self.update(
                    {'summary_train': summary_train, 'summary_test': summary_test, 'summary_valid': summary_valid}
                )
            except Exception:
                pass

    def save(self):
        """save model parameters"""
        for file in glob.glob(str(self.substage_i) + "_best.*.csv"):
            os.remove(file)
        for key in self:
            self[key].to_csv("{}_best.{}.loss.{}.csv".format(self.substage_i, key, self.loss_min))
        


In [143]:
args.device = torch.device("cpu")

In [145]:
screenshot = train_model(model, args)

########   --------- lr: 0.001	l1: 0.0001	l2: 0.0001	   --time elapsed: 0.00
Substage:1	Epoch:0/100	Iteration: 0/100	loss (train):0.135471	loss (buffer on valid):500.177246	best:1000.000000	Tolerance: 0/100
Substage:1	Epoch:0/100	Iteration: 1/100	loss (train):0.096996	loss (buffer on valid):333.517914	best:500.177246	Tolerance: 0/100
Substage:1	Epoch:0/100	Iteration: 2/100	loss (train):0.224137	loss (buffer on valid):250.201248	best:333.517914	Tolerance: 0/100
Substage:1	Epoch:0/100	Iteration: 3/100	loss (train):0.383463	loss (buffer on valid):200.196426	best:250.201248	Tolerance: 0/100
Substage:1	Epoch:0/100	Iteration: 4/100	loss (train):0.228764	loss (buffer on valid):166.866867	best:200.196426	Tolerance: 0/100
Substage:1	Epoch:0/100	Iteration: 5/100	loss (train):0.269860	loss (buffer on valid):143.060318	best:166.866867	Tolerance: 0/100
Substage:1	Epoch:0/100	Iteration: 6/100	loss (train):0.286096	loss (buffer on valid):125.197006	best:143.060318	Tolerance: 0/100
Substage:1	Epoch:0/

In [15]:
model[0]

LinReg(
  (W): Linear(in_features=99, out_features=99, bias=True)
)

In [104]:
a = model[0].W.state_dict()
for item in a:
    print(a[item].numpy().shape)

(99, 99)
(99,)


In [146]:
screenshot.keys()

dict_keys(['weight', 'bias', 'y_hat'])

In [150]:
screenshot["y_hat"]

Unnamed: 0,4EBP1pS65,RbpS807,MAPKpT202,MEKpS217,S6,PAI-1,AKTpS473,AMPKpT172,b-Catenin,BIM,...,aHDAC,aMDM2,aJAK,aBRAFm,aPKC,aSTAT3,amTOR,aPI3K,aCDK4,aSRC
0,-0.034876,-0.183665,0.143622,-0.082877,0.001758,-0.003982,0.003231,0.034118,0.09392,0.200523,...,-0.122743,-0.030154,-0.080675,-0.072349,0.040038,-0.025492,0.180109,0.051663,0.151754,-0.064242
1,-0.294122,-0.09316,0.246818,-0.189556,-0.158506,0.135892,0.134888,0.073156,-0.123043,0.012068,...,0.085567,-0.244929,-0.160039,0.101084,-0.258373,0.125785,0.1514,-0.253566,0.086212,0.281912
2,-0.080961,0.007904,0.031583,-0.107364,0.011465,0.017704,-0.015875,-0.095701,0.049683,0.073469,...,0.104831,-0.120854,-0.177551,-0.053515,-0.108378,0.036214,0.050245,-0.080562,0.007598,0.000651
3,-0.097761,0.064553,0.112396,-0.099148,0.078064,0.081907,-0.05924,-0.025711,0.138789,0.032307,...,0.170786,-0.126457,-0.184178,-0.106063,-0.071826,-0.018416,0.020251,-0.095692,-0.02904,0.025155
4,-0.066278,-0.015814,0.045784,-0.067259,0.013513,0.027357,-0.00047,-0.074069,0.052443,0.112367,...,0.0989,-0.094037,-0.082585,-0.008271,-0.105168,-0.003399,0.054829,-0.021752,0.056861,0.055012
5,-0.088922,0.045411,-0.218302,-0.342581,-0.189004,0.107046,0.03169,-0.262903,0.13068,-0.132162,...,0.090093,-0.297775,-0.152631,-0.134203,0.102528,0.238491,-0.151672,0.062936,0.141745,-0.143992
6,-0.039349,-0.044991,0.051317,-0.078153,0.107723,0.101997,0.064801,-0.048378,0.015586,0.130743,...,0.09225,-0.085436,-0.153277,-0.012049,-0.101245,-0.06719,0.156362,-0.049721,0.134039,0.150682
7,-0.028227,-0.041266,0.079579,-0.081062,0.106948,0.074448,0.086504,-0.073258,0.043288,0.16192,...,0.094312,-0.085749,-0.142496,0.019146,-0.127461,-0.077492,0.13533,-0.030512,0.151584,0.11362
8,-0.038961,0.014991,0.065171,-0.075859,0.053893,0.046696,-0.041396,-0.082587,0.085177,0.041151,...,0.144618,-0.110327,-0.143367,-0.057226,-0.10028,0.002345,0.042485,-0.089582,0.020044,0.028817
9,-0.106964,-0.021721,0.013824,-0.099157,-0.027562,-0.001915,0.024394,-0.088323,0.019013,0.143317,...,0.060745,-0.105131,-0.118476,-0.005167,-0.113995,0.030314,0.061448,-0.014393,0.043757,0.024685
