In [1]:
from __future__ import division
from __future__ import print_function

import time
import argparse
import pickle
import os
import datetime

import torch.optim as optim
from torch.optim import lr_scheduler

from utils import *
from modules import *
from sklearn.metrics import precision_score
from sklearn import metrics

## Modifty the arguments and run all the cells in order to start training.

Here is an example of running MHSC-GM dataset while using Non-specific ChIP-Seq network.

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='Disables CUDA training.')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=180,
                    help='Number of epochs to train.')
parser.add_argument('--batch-size', type=int, default=11,
                    help='Number of samples per batch.')
parser.add_argument('--lr', type=float, default=0.0005,
                    help='Initial learning rate.')
parser.add_argument('--encoder-hidden', type=int, default=64,
                    help='Number of hidden units.')
parser.add_argument('--decoder-hidden', type=int, default=64,
                    help='Number of hidden units.')
parser.add_argument('--temp', type=float, default=0.5,
                    help='Temperature for Gumbel softmax.')
parser.add_argument('--num-atoms', type=int, default=301,
                    help='Number of genes.')
parser.add_argument('--num-tfs', type=int, default=82,
                    help='Number of transcription factors.')
parser.add_argument('--density', type=int, default=0.03,
                    help='Density of edges in the network.')
parser.add_argument('--encoder', type=str, default='mlp',
                    help='Type of path encoder model (mlp).')
parser.add_argument('--decoder', type=str, default='mlp',
                    help='Type of decoder model (mlp or sim).')
parser.add_argument('--no-factor', action='store_true', default=False,
                    help='Disables factor graph model.')
parser.add_argument('--suffix', type=str, default='_rna',
                    help='Suffix for training data (e.g. "_charged".')
parser.add_argument('--encoder-dropout', type=float, default=0.0,
                    help='Dropout rate (1 - keep probability).')
parser.add_argument('--decoder-dropout', type=float, default=0.0,
                    help='Dropout rate (1 - keep probability).')
parser.add_argument('--save-folder', type=str, default=os.path.dirname(os.path.abspath('logs')),
                    help='Where to save the trained model, leave empty to not save anything.')
parser.add_argument('--load-folder', type=str, default='',
                    help='Where to load the trained model if finetunning. ' +
                         'Leave empty to train from scratch')
parser.add_argument('--edge-types', type=int, default=2,
                    help='The number of edge types to infer.')
parser.add_argument('--dims', type=int, default=1,
                    help='The number of input dimensions (position + velocity).')
parser.add_argument('--timesteps', type=int, default=1,
                    help='The number of time steps per sample.')
parser.add_argument('--prediction-steps', type=int, default=1, metavar='N',
                    help='Num steps to predict before re-using teacher forcing.')
parser.add_argument('--lr-decay', type=int, default=80,
                    help='After how epochs to decay LR by a factor of gamma.')
parser.add_argument('--gamma', type=float, default=0.4,
                    help='LR decay factor.')
parser.add_argument('--skip-first', action='store_true', default=False,
                    help='Skip first edge type in decoder, i.e. it represents no-edge.')
parser.add_argument('--var', type=float, default=5e-5,
                    help='Output variance.')
parser.add_argument('--hard', action='store_true', default=True,
                    help='Uses discrete samples in training forward pass.')
parser.add_argument('--prior', action='store_true', default=True,
                    help='Whether to use sparsity prior.')
parser.add_argument('--dynamic-graph', action='store_true', default=False,
                    help='Whether test with dynamically re-computed graph.')
parser.add_argument('--compute-acc', action='store_true', default=True,
                    help='whether to compute accuracy for each batch..')




_StoreTrueAction(option_strings=['--compute-acc'], dest='compute_acc', nargs=0, const=True, default=True, type=None, choices=None, help='whether to compute accuracy for each batch..', metavar=None)

In [3]:
args, unknown = parser.parse_known_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
args.factor = not args.no_factor
print(args)

Namespace(no_cuda=False, seed=42, epochs=180, batch_size=11, lr=0.0005, encoder_hidden=64, decoder_hidden=64, temp=0.5, num_atoms=301, num_tfs=82, density=0.03, encoder='mlp', decoder='mlp', no_factor=False, suffix='_rna', encoder_dropout=0.0, decoder_dropout=0.0, save_folder='C:\\Users\\CSE-Admin\\Documents\\Gene Regulatory Network Project\\NRI-master\\NRI-master', load_folder='', edge_types=2, dims=1, timesteps=1, prediction_steps=1, lr_decay=80, gamma=0.4, skip_first=False, var=5e-05, hard=True, prior=True, dynamic_graph=False, compute_acc=True, cuda=True, factor=True)


In [4]:
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

if args.dynamic_graph:
    print("Testing with dynamically re-computed graph.")

In [5]:
# Save model and meta-data. Always saves in a new sub-folder.
if args.save_folder:
    exp_counter = 0
    now = datetime.datetime.now()
    timestamp = now.isoformat()
    save_folder = 'result'.format(args.save_folder, timestamp)
    os.makedirs(save_folder)
    meta_file = os.path.join(save_folder, 'metadata.pkl')
    encoder_file = os.path.join(save_folder, 'encoder.pt')
    decoder_file = os.path.join(save_folder, 'decoder.pt')

    log_file = os.path.join(save_folder, 'log.txt')
    log = open(log_file, 'w')

    pickle.dump({'args': args}, open(meta_file, "wb"))
else:
    print("WARNING: No save_folder provided!" +
          "Testing (within this script) will throw an error.")

In [6]:
#preparing mask 
ex_tfs = np.zeros([args.num_atoms, args.num_atoms])
for i in range(args.num_tfs,args.num_atoms):
    for j in range(args.num_atoms):
        ex_tfs[i,j] = 1
        
off_diag = np.ones([args.num_atoms, args.num_atoms]) - np.eye(args.num_atoms) 
off_diag = off_diag - ex_tfs

off_diag[off_diag == -1] = 0

In [7]:
rel_rec = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.int16)
rel_send = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.int16)
rel_rec = torch.FloatTensor(rel_rec)
rel_send = torch.FloatTensor(rel_send)

In [8]:
sh = rel_rec.shape[0]

In [9]:
encoder = MLPEncoder(args.timesteps * args.dims, args.encoder_hidden,
                         args.edge_types,
                         args.encoder_dropout)

Using MLP encoder.


In [10]:
if args.decoder == 'mlp':
    decoder = MLPDecoder(n_in_node=args.dims,
                         edge_types=args.edge_types,
                         msg_hid=args.decoder_hidden,
                         msg_out=args.decoder_hidden,
                         n_hid=args.decoder_hidden,
                         do_prob=args.decoder_dropout,
                         skip_first=args.skip_first)

elif args.decoder == 'sim':
    decoder = SimulationDecoder(loc_max, loc_min, args.suffix)

Using learned interaction net decoder.


In [11]:
if args.load_folder:
    encoder_file = os.path.join(args.load_folder, 'encoder.pt')
    encoder.load_state_dict(torch.load(encoder_file))
    decoder_file = os.path.join(args.load_folder, 'decoder.pt')
    decoder.load_state_dict(torch.load(decoder_file))

    args.save_folder = False


In [12]:
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()),
                       lr=args.lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.lr_decay,
                                gamma=args.gamma, verbose=True)

Adjusting learning rate of group 0 to 5.0000e-04.


In [13]:
if args.prior:
    prior = np.array([1-args.density, args.density])  # TODO: hard coded for now
    print("Using prior")
    log_prior = torch.FloatTensor(np.log(prior))
    log_prior = torch.unsqueeze(log_prior, 0)
    log_prior = torch.unsqueeze(log_prior, 0)
    log_prior = Variable(log_prior)

    if args.cuda:
        log_prior = log_prior.cuda()

Using prior


In [14]:
if args.cuda:
    encoder.cuda()
    decoder.cuda()
    rel_rec = rel_rec.cuda()
    rel_send = rel_send.cuda()


In [15]:
train_loader, loc_max, loc_min = load_data(args.num_tfs, args.batch_size, args.suffix, args.compute_acc)

In [16]:
true_edges = np.load('data/true_edges.npy', allow_pickle=True)
rel_send_idx = np.array((np.where(off_diag)[0]), dtype=np.int16) 
rel_rec_idx = np.array((np.where(off_diag)[1]), dtype=np.int16)

In [17]:
def train(epoch, best_eprec):
    t = time.time()
    nll_train = []
    kl_train = []
    mse_train = []
    if(args.compute_acc):
        acc_train = []
    all_edges_preds = np.zeros(sh)

    encoder.train()
    decoder.train()
    
    
    for batch_idx, (data, relations) in enumerate(train_loader):
        if args.cuda:
            data, relations = data.cuda(), relations.cuda()
        data, relations = Variable(data), Variable(relations)

        optimizer.zero_grad()

        logits = encoder(data, rel_rec, rel_send, args.num_tfs)
    
        edges = gumbel_softmax(logits, tau=args.temp, hard=True)
        prob = my_softmax(logits, -1)

        output = torch.squeeze(decoder(data, edges, rel_rec, rel_send, args.num_tfs, 0))

        target = torch.squeeze(data)

        loss_nll = nll_gaussian(output, target, args.var)

        if args.prior:
            loss_kl = kl_categorical(prob, log_prior, args.num_atoms)
        else:
            loss_kl = kl_categorical_uniform(prob, args.num_atoms,
                                             args.edge_types)
        
        if(args.compute_acc):
            acc = edge_accuracy(logits, relations)
            acc_train.append(acc)
        mse_train.append(F.mse_loss(output, target).item())
        nll_train.append(loss_nll.item())
        kl_train.append(loss_kl.item())
        all_edges_preds = edge_precision_util(logits, all_edges_preds)
        _, preds = logits.max(-1)
        
        sparse_loss = np.mean(np.absolute(list(encoder.parameters())[0].cpu().data.numpy())) + np.mean(np.absolute(list(decoder.parameters())[0].cpu().data.numpy()))
        

        loss = loss_nll + loss_kl #+ 100 * sparse_loss
        loss.backward()
        optimizer.step()
        
    scheduler.step()
    eprecR_train_whole = calc_eprec_net(all_edges_preds)
    
    
    if(args.compute_acc):
        acc_eval = []
    all_edges_preds = np.zeros(sh)

    encoder.eval()
    decoder.eval()
    for batch_idx, (data, relations) in enumerate(train_loader):
        if args.cuda:
            data, relations = data.cuda(), relations.cuda()
        data, relations = Variable(data), Variable(
            relations)

        logits = encoder(data, rel_rec, rel_send, args.num_tfs)
        edges = gumbel_softmax(logits, tau=args.temp, hard=True)
        prob = my_softmax(logits, -1)

        output = decoder(data, edges, rel_rec, rel_send, args.num_tfs, 0)

        target = data
        
        if(args.compute_acc):
            acc = edge_accuracy(logits, relations)
            acc_eval.append(acc)
        all_edges_preds = edge_precision_util(logits, all_edges_preds)                     

    eprecR_eval_whole = calc_eprec_net(all_edges_preds)
    
    if(args.compute_acc):
        print('Epoch: {:04d}'.format(epoch),
            'nll_train: {:.10f}'.format(np.mean(nll_train)),
              'kl_train: {:.10f}'.format(np.mean(kl_train)),
              'mse_train: {:.10f}'.format(np.mean(mse_train)),
              'acc_train: {:.10f}'.format(np.mean(acc_train)),
              'eprecR_train: {:.10f}'.format((eprecR_train_whole)),
              'acc_eval: {:.10f}'.format(np.mean(acc_eval)),
              'eprecR_eval: {:.10f}'.format((eprecR_eval_whole)),
              'time: {:.4f}s'.format(time.time() - t),
              "-----------------------------------------------",
              "-----------------------------------------------")
    else: 
        print('Epoch: {:04d}'.format(epoch),
            'nll_train: {:.10f}'.format(np.mean(nll_train)),
              'kl_train: {:.10f}'.format(np.mean(kl_train)),
              'mse_train: {:.10f}'.format(np.mean(mse_train)),
              'eprecR_train: {:.10f}'.format((eprecR_train_whole)),
              'eprecR_eval: {:.10f}'.format((eprecR_eval_whole)),
              'time: {:.4f}s'.format(time.time() - t),
              "-----------------------------------------------",
              "-----------------------------------------------")
    
    if eprecR_eval_whole > eprecR_train_whole: 
        max_eprec = eprecR_eval_whole
    else:
        max_eprec = eprecR_train_whole

    if args.save_folder and max_eprec > best_eprec:
        torch.save(encoder.state_dict(), encoder_file)
        torch.save(decoder.state_dict(), decoder_file)
        print('Best model so far, saving...')
        print('Epoch: {:04d}'.format(epoch))
        log.flush()
        
    return max_eprec


In [18]:
def edge_precision_util(preds,  all_edges_preds):
    
    _, preds = preds.max(-1)
    for i in range(preds.shape[0]):
        all_edges_preds = np.add(all_edges_preds, preds[i].cpu())
        all_edges_preds[all_edges_preds > 1] = 1
                
    return all_edges_preds
    
    

In [19]:
def calc_eprec_net(preds):

    tp = 0
    fp = 0
    for i in range(preds.shape[0]):
        if(preds[i] == 1):
            if(true_edges[rel_send_idx[i], rel_rec_idx[i]] == 1):
                tp += 1
            else:
                fp += 1

    
    if(tp + fp == 0):
        prec = 0
    else:
        prec = tp / (tp + fp)
    return 1. * prec/args.density

In [20]:
# Train model
t_total = time.time()
best_eprec = 0
best_epoch = 0
for epoch in range(args.epochs):
    torch.cuda.empty_cache()
    eprec = train(epoch, best_eprec) 
    if eprec > best_eprec:
        best_eprec = eprec
        best_epoch = epoch
print("Optimization Finished!")
print("Best Epoch: {:04d}".format(best_epoch))
print(best_eprec)
if args.save_folder:
    print("Best Epoch: {:04d}".format(best_epoch), file=log)
    log.flush()
 


Adjusting learning rate of group 0 to 5.0000e-04.
Epoch: 0000 nll_train: 77.3127313985 kl_train: -5.0855905221 mse_train: 0.0077312733 acc_train: 0.4725694666 eprecR_train: 1.0067750678 acc_eval: 0.5048753012 eprecR_eval: 1.0067750678 time: 8.5625s ----------------------------------------------- -----------------------------------------------
Best model so far, saving...
Epoch: 0000
Adjusting learning rate of group 0 to 5.0000e-04.
Epoch: 0001 nll_train: 1.4214158559 kl_train: -9.0757377766 mse_train: 0.0001421416 acc_train: 0.4660859777 eprecR_train: 1.0067750678 acc_eval: 0.4459094711 eprecR_eval: 1.0067750678 time: 7.4092s ----------------------------------------------- -----------------------------------------------
Adjusting learning rate of group 0 to 5.0000e-04.
Epoch: 0002 nll_train: 0.8705251290 kl_train: -9.2223109963 mse_train: 0.0000870525 acc_train: 0.4925435019 eprecR_train: 1.0067750678 acc_eval: 0.4789900460 eprecR_eval: 1.0067750678 time: 7.4122s ----------------------

Adjusting learning rate of group 0 to 5.0000e-04.
Epoch: 0024 nll_train: 0.2742622421 kl_train: -9.3233741772 mse_train: 0.0000274262 acc_train: 0.8508886973 eprecR_train: 1.0067750678 acc_eval: 0.8152313662 eprecR_eval: 1.0067750678 time: 7.5119s ----------------------------------------------- -----------------------------------------------
Adjusting learning rate of group 0 to 5.0000e-04.
Epoch: 0025 nll_train: 0.3545241870 kl_train: -9.3234846504 mse_train: 0.0000354524 acc_train: 0.8270611363 eprecR_train: 1.0067750678 acc_eval: 0.8227374112 eprecR_eval: 1.0067750678 time: 7.4690s ----------------------------------------------- -----------------------------------------------
Adjusting learning rate of group 0 to 5.0000e-04.
Epoch: 0026 nll_train: 0.1795066325 kl_train: -9.3245392670 mse_train: 0.0000179507 acc_train: 0.8604381312 eprecR_train: 1.0067750678 acc_eval: 0.8734891067 eprecR_eval: 1.0067750678 time: 7.5478s ----------------------------------------------- ----------------

Adjusting learning rate of group 0 to 5.0000e-04.
Epoch: 0047 nll_train: 0.1919743867 kl_train: -9.3272366641 mse_train: 0.0000191974 acc_train: 0.9750620782 eprecR_train: 1.0154851231 acc_eval: 0.9907230287 eprecR_eval: 1.4724877293 time: 7.4596s ----------------------------------------------- -----------------------------------------------
Adjusting learning rate of group 0 to 5.0000e-04.
Epoch: 0048 nll_train: 0.0631177235 kl_train: -9.3294410764 mse_train: 0.0000063118 acc_train: 0.9713662572 eprecR_train: 1.0066886456 acc_eval: 0.9898284562 eprecR_eval: 1.2554633842 time: 7.5149s ----------------------------------------------- -----------------------------------------------
Adjusting learning rate of group 0 to 5.0000e-04.
Epoch: 0049 nll_train: 0.0639754635 kl_train: -9.3294655129 mse_train: 0.0000063975 acc_train: 0.9693828377 eprecR_train: 1.0077582465 acc_eval: 0.9901424615 eprecR_eval: 1.3048635824 time: 7.5987s ----------------------------------------------- ----------------

Adjusting learning rate of group 0 to 5.0000e-04.
Epoch: 0071 nll_train: 0.2273196441 kl_train: -9.3305157791 mse_train: 0.0000227320 acc_train: 0.9914756265 eprecR_train: 1.3475888184 acc_eval: 0.9917244461 eprecR_eval: 1.0276172126 time: 14.1556s ----------------------------------------------- -----------------------------------------------
Adjusting learning rate of group 0 to 5.0000e-04.
Epoch: 0072 nll_train: 0.0595993390 kl_train: -9.3311777645 mse_train: 0.0000059599 acc_train: 0.9915495870 eprecR_train: 1.0346756152 acc_eval: 0.9917925060 eprecR_eval: 0.8130081301 time: 14.2519s ----------------------------------------------- -----------------------------------------------
Adjusting learning rate of group 0 to 5.0000e-04.
Epoch: 0073 nll_train: 0.0270011268 kl_train: -9.3314184436 mse_train: 0.0000027001 acc_train: 0.9915714355 eprecR_train: 1.0664229129 acc_eval: 0.9917723810 eprecR_eval: 0.6289308176 time: 14.2998s ----------------------------------------------- -------------

Adjusting learning rate of group 0 to 2.0000e-04.
Epoch: 0095 nll_train: 0.0062056763 kl_train: -9.3321027815 mse_train: 0.0000006206 acc_train: 0.9916600868 eprecR_train: 1.1334721258 acc_eval: 0.9918004039 eprecR_eval: 0.9592326139 time: 14.1920s ----------------------------------------------- -----------------------------------------------
Adjusting learning rate of group 0 to 2.0000e-04.
Epoch: 0096 nll_train: 0.0066339137 kl_train: -9.3321255166 mse_train: 0.0000006634 acc_train: 0.9916602236 eprecR_train: 1.1797362942 acc_eval: 0.9917996739 eprecR_eval: 1.0489510490 time: 14.2955s ----------------------------------------------- -----------------------------------------------
Adjusting learning rate of group 0 to 2.0000e-04.
Epoch: 0097 nll_train: 0.0057716651 kl_train: -9.3321205481 mse_train: 0.0000005772 acc_train: 0.9916654754 eprecR_train: 0.9644789461 acc_eval: 0.9918060359 eprecR_eval: 1.1320754717 time: 14.2858s ----------------------------------------------- -------------

Adjusting learning rate of group 0 to 2.0000e-04.
Epoch: 0119 nll_train: 0.0157989095 kl_train: -9.3322758321 mse_train: 0.0000015799 acc_train: 0.9917794628 eprecR_train: 1.5277777778 acc_eval: 0.9918291466 eprecR_eval: 1.8372703412 time: 14.2209s ----------------------------------------------- -----------------------------------------------
Adjusting learning rate of group 0 to 2.0000e-04.
Epoch: 0120 nll_train: 0.0800492359 kl_train: -9.3320509063 mse_train: 0.0000080049 acc_train: 0.9917720718 eprecR_train: 1.5784586815 acc_eval: 0.9918304241 eprecR_eval: 1.6025641026 time: 14.2229s ----------------------------------------------- -----------------------------------------------
Adjusting learning rate of group 0 to 2.0000e-04.
Epoch: 0121 nll_train: 0.0329340921 kl_train: -9.3321843206 mse_train: 0.0000032934 acc_train: 0.9917951116 eprecR_train: 1.1019283747 acc_eval: 0.9918313822 eprecR_eval: 1.5686274510 time: 14.1414s ----------------------------------------------- -------------

Adjusting learning rate of group 0 to 2.0000e-04.
Epoch: 0142 nll_train: 0.0043536581 kl_train: -9.3325259009 mse_train: 0.0000004354 acc_train: 0.9918288729 eprecR_train: 0.0000000000 acc_eval: 0.9918346214 eprecR_eval: 3.9215686275 time: 14.1607s ----------------------------------------------- -----------------------------------------------
Adjusting learning rate of group 0 to 2.0000e-04.
Epoch: 0143 nll_train: 0.0359930505 kl_train: -9.3324675442 mse_train: 0.0000035993 acc_train: 0.9918308803 eprecR_train: 0.0000000000 acc_eval: 0.9918340283 eprecR_eval: 4.5977011494 time: 14.2985s ----------------------------------------------- -----------------------------------------------
Adjusting learning rate of group 0 to 2.0000e-04.
Epoch: 0144 nll_train: 0.2048493482 kl_train: -9.3311225102 mse_train: 0.0000204849 acc_train: 0.9918219381 eprecR_train: 1.5594541910 acc_eval: 0.9918344389 eprecR_eval: 4.7619047619 time: 14.1919s ----------------------------------------------- -------------

Adjusting learning rate of group 0 to 8.0000e-05.
Epoch: 0166 nll_train: 0.0021644690 kl_train: -9.3328060397 mse_train: 0.0000002164 acc_train: 0.9918350320 eprecR_train: 0.0000000000 acc_eval: 0.9918353514 eprecR_eval: 0.0000000000 time: 14.2567s ----------------------------------------------- -----------------------------------------------
Adjusting learning rate of group 0 to 8.0000e-05.
Epoch: 0167 nll_train: 0.0021750174 kl_train: -9.3328107963 mse_train: 0.0000002175 acc_train: 0.9918348495 eprecR_train: 0.0000000000 acc_eval: 0.9918353514 eprecR_eval: 0.0000000000 time: 14.1834s ----------------------------------------------- -----------------------------------------------
Adjusting learning rate of group 0 to 8.0000e-05.
Epoch: 0168 nll_train: 0.0022362783 kl_train: -9.3328095895 mse_train: 0.0000002236 acc_train: 0.9918348495 eprecR_train: 0.0000000000 acc_eval: 0.9918353058 eprecR_eval: 16.6666666667 time: 14.1261s ----------------------------------------------- ------------