<a href="https://colab.research.google.com/github/ayshaw/complexCorrection/blob/master/notebooks/GREMLIN_TF_v2_weights_edit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GREMLIN_TF v2.1
GREMLIN implemented in tensorflow

### Change log:
*   02Apr2019
 - fixing a few hard-coded values, to allow GREMLIN to work with any alphabet (binary, protein, rna etc)
*   22Jan2019
 - moving [GREMLIN_TF_simple](https://colab.research.google.com/github/sokrypton/GREMLIN_CPP/blob/master/GREMLIN_TF_simple.ipynb) to a seperate notebook
*   19Jan2019
 - in the past we found that optimizing V first, required less iterations for convergence. Since V can be computed exactly (assuming no W), we replace this first optimization step with a simple V initialization.
 - a few variables were renamed to be consistent with the c++ version
*   16Jan2019
 - updating how indices are handled (for easier/cleaner parsing)
 - minor speed up in how we symmetrize and zero the diagional of W
*   15Jan2019
 - LBFGS optimizer replaced with a modified version of the ADAM optimizer
 - Added option for stochastic gradient descent (via batch_size)
  
### Method:
GREMLIN takes a multiple sequence alignment (MSA) and returns a Markov Random Field (MRF). The MRF consists of a one-body term (V) that encodes conservation, and a two-body term (W) that encodes co-evolution.

For more details about the method see:
[Google slides](https://docs.google.com/presentation/d/1aooxoksosSv7CWs9-ktqhUjyXR3wrgbG5a6PCr92od4/) and accompanying [Google colab](https://colab.research.google.com/drive/17RJcExuyifnd7ShTcsZGh6mBpWq0-s60)

See [GREMLIN_TF_simple](https://colab.research.google.com/github/sokrypton/GREMLIN_CPP/blob/master/GREMLIN_TF_simple.ipynb) for a stripped down version of this code (with no funky gap removal, sequence weight, etc). This is intented for educational purpose,  and could also be very useful for anyone trying to modify or improve the algorithm!


In [0]:
# ------------------------------------------------------------
# "THE BEERWARE LICENSE" (Revision 42):
# <so@g.harvard.edu> and <pkk382@g.harvard.edu> wrote this code.
# As long as you retain this notice, you can do whatever you want
# with this stuff. If we meet someday, and you think this stuff
# is worth it, you can buy us a beer in return.
# --Sergey Ovchinnikov and Peter Koo
# ------------------------------------------------------------

In [0]:
# !wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
# !unzip ngrok-stable-linux-amd64.zip

In [0]:
# LOG_DIR = './log'
# get_ipython().system_raw(
#     'tensorboard --logdir {} --host 0.0.0.0 --port 6006 &'
#     .format(LOG_DIR)
# )

In [0]:
# get_ipython().system_raw('./ngrok http 6006 &')


In [0]:
# ! curl -s http://localhost:4040/api/tunnels | python3 -c \
#     "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"

## libraries

In [1]:
# IMPORTANT, only tested using PYTHON 3!
import numpy as np
import tensorflow as tf
import matplotlib.pylab as plt
from scipy import stats
from scipy.spatial.distance import pdist,squareform
import pandas as pd
import os
import time
import pickle as pkl
#from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
import multiprocessing as mp
from keras.callbacks import TensorBoard
tbCallBack = TensorBoard(log_dir='./log', histogram_freq=1,
                         write_graph=True,
                         write_grads=True,
                         batch_size=None,
                         write_images=True)
from google.colab import drive
drive.mount('/content/drive/',force_remount=True)
os.chdir('/content/drive/My Drive/markslab/multimerCorrection')  
from scipy.spatial import distance
from multiprocessing import Pool,Process
import psutil
!ls



Using TensorFlow backend.


Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive/
align_1
align_2
allpdb0148
batch_scripts
benchmark
concatenate.a2m
datasets
high_precision_complexes.csv
low_precision_complexes.csv
merged_top10_statistics_449.csv
notebooks
parallelperformance.xlsx
pd_mtx_allpdb0148.csv
pd_mtx_allpdb0148_weights.csv
pd_mtx_allpdb0777_confirmation.csv
pd_mtx_allpdb0777_reduced.csv
pd_mtx_allpdb0777_reduced_weight_prior.csv
pd_mtx_allpdb0777_reduced_weights.csv
pd_mtx_allpdb0777_reduced_weights_relu.csv
pd_mtx_allpdb0777_redu

In [2]:
!ls datasets

4FAZA.fas	allpdb0777_evcouplings.a2m  wb_ini
allpdb0148.a2m	allpdb0777_reduced.a2m


## Params

In [0]:
################
# note: if you are modifying the alphabet
# make sure last character is "-" (gap)
################
alphabet = "ARNDCQEGHILKMFPSTWYV-"
states = len(alphabet)
a2n = {}
for a,n in zip(alphabet,range(states)):
    a2n[a] = n
################

def aa2num(aa):
    '''convert aa into num'''
    if aa in a2n: return a2n[aa]
    else: return a2n['-']

## Functions for prepping the MSA (Multiple sequence alignment)

In [0]:
from functools import partial
# from fasta
def parse_fasta(filename,limit=-1):
    '''function to parse fasta'''
    header = []
    sequence = []
    lines = open(filename, "r")
    for line in lines:
        line = line.rstrip()
        if line[0] == ">":
            if len(header) == limit:
                break
            header.append(line[1:])
            sequence.append([])
        else:
            sequence[-1].append(line)
    lines.close()
    sequence = [''.join(seq) for seq in sequence]
    return np.array(header), np.array(sequence)

def filt_gaps(msa,gap_cutoff=0.5):
    '''filters alignment to remove gappy positions'''
    tmp = (msa == states-1).astype(np.float)
    non_gaps = np.where(np.sum(tmp.T,-1).T/msa.shape[0] < gap_cutoff)[0]
    return msa[:,non_gaps],non_gaps

def get_eff(msa,eff_cutoff=0.8):
    '''compute effective weight for each sequence'''
    ncol = msa.shape[1]
    start = time.time()
    print('starting pdist!')
    # pairwise identity
    pdist_res = pdist(msa,"hamming") #need to sparsify this process
    #print('shape of pdist result before squareform: {}, sum of pdist: {}'.format(pdist(msa,'hamming').shape, np.sum(pdist(msa,'hamming'))))
    msa_sm = 1.0 - squareform(pdist(msa,"hamming")) #need to sparsify this process
    #print('finished hamming: {} seconds \t shape of msa_sm after squareform: {}'.format(time.time()-start,msa_sm.shape))
    # weight for each sequence
    msa_w = (msa_sm >= eff_cutoff).astype(np.float64)
    print(msa_w)
    #print('shape of msa_w after cutoff: {}, \t sum of msa_w:{} \t shape of sum: {}'.format(msa_w.shape,np.sum(msa_w,-1),np.sum(msa_w,-1).shape))
    msa_w = 1/np.sum(msa_w,-1)
    #print('shape of weights after sum normalization: {}'.format(msa_w.shape))
    return msa_w
  
def f(msa,i): 
  rncol=(1/(msa.shape[1]))
  return 1/(np.sum(rncol*np.sum(msa==msa[i],axis=1,dtype=np.uint16)>=0.8,dtype=np.uint32))


def get_eff_lowmem(msa,eff_cutoff=0.8):
  pool = mp.Pool(processes=1)
  return np.fromiter(pool.map(partial(f,msa),np.arange(msa.shape[0])),dtype=float)

def mk_msa(seqs):
    '''converts list of sequences to msa'''

    msa_ori = []
    for seq in seqs:
        msa_ori.append(list(map(aa2num,seq)))
    msa_ori = np.array(msa_ori,dtype=np.int)
    start=time.time()
    # remove positions with more than > 50% gaps
    msa, v_idx = filt_gaps(msa_ori,0.5)
    # compute effective weight for each sequence
    
    start = time.time()
    msa_weights = get_eff_lowmem(msa)
    ncol = msa.shape[1] # length of sequence
    w_idx = v_idx[np.stack(np.triu_indices(ncol,1),-1)]
    return {"msa_ori": msa_ori,
          "msa":msa,
          "weights":msa_weights,
          "neff":np.sum(msa_weights),
          "nrow":msa.shape[0],
          "ncol":ncol,
         "w_idx":w_idx,
         "v_idx":v_idx}


In [10]:

# process input sequences
names, seqs = parse_fasta("datasets/allpdb0609.a2m")
msa = mk_msa(seqs)

FileNotFoundError: ignored

## GREMLIN

In [0]:
# external functions

def sym_w(w):
    '''symmetrize input matrix of shape (x,y,x,y)'''
    x = w.shape[0]
    w = w * np.reshape(1-np.eye(x),(x,1,x,1))
    w = w + tf.transpose(w,[2,3,0,1])
    return w

def opt_adam(loss, name, var_list=None, lr=1.0, b1=0.9, b2=0.999, b_fix=False):
    # adam optimizer
    # Note: this is a modified version of adam optimizer. More specifically, we replace "vt"
    # with sum(g*g) instead of (g*g). Furthmore, we find that disabling the bias correction
    # (b_fix=False) speeds up convergence for our case.

    if var_list is None: var_list = tf.trainable_variables() 
    gradients = tf.gradients(loss,var_list)
    if b_fix: t = tf.Variable(0.0,"t")
    opt = []
    for n,(x,g) in enumerate(zip(var_list,gradients)):
        if g is not None:
            ini = dict(initializer=tf.zeros_initializer,trainable=False)
            mt = tf.get_variable(name+"_mt_"+str(n),shape=list(x.shape), **ini)
            vt = tf.get_variable(name+"_vt_"+str(n),shape=[], **ini)

            mt_tmp = b1*mt+(1-b1)*g
            vt_tmp = b2*vt+(1-b2)*tf.reduce_sum(tf.square(g))
            lr_tmp = lr/(tf.sqrt(vt_tmp) + 1e-8)

            if b_fix: lr_tmp = lr_tmp * tf.sqrt(1-tf.pow(b2,t))/(1-tf.pow(b1,t))

            opt.append(x.assign_add(-lr_tmp * mt_tmp))
            opt.append(vt.assign(vt_tmp))
            opt.append(mt.assign(mt_tmp))

    if b_fix: opt.append(t.assign_add(1.0))
    return(tf.group(opt))

In [0]:
def GREMLIN_weights(msa, l2_wb=0.01, wb_input=None,opt_type="adam", opt_iter=100, opt_rate=1.0, batch_size=512):
  
    ##############################################################
    # SETUP COMPUTE GRAPH
    ##############################################################
    # kill any existing tensorflow graph
    tf.reset_default_graph()

    ncol = msa["ncol"] # length of sequence
    nrow = msa["nrow"] # number of sequences
    print("ncol: {},n nrow: {}".format(ncol,nrow))
    if wb_input==None:
      wb_input=np.ones([nrow])
    # msa (multiple sequence alignment) 
    MSA = tf.placeholder(tf.int32,shape=(None,ncol),name="msa")

    # one-hot encode msa
    OH_MSA = tf.one_hot(MSA,states)

    # msa weights
    MSA_weights = tf.placeholder(tf.float32, shape=(None,), name="msa_weights")
    idx = tf.placeholder(tf.int64,shape=[batch_size], name = 'idx')

    # 1-body-term of the MRF
    V = tf.get_variable(name="V", 
                      shape=[ncol,states],
                      initializer=tf.zeros_initializer)

    # 2-body-term of the MRF
    W = tf.get_variable(name="W",
                      shape=[ncol,states,ncol,states],
                      initializer=tf.zeros_initializer)

    # weights for concatenation
    wb = tf.get_variable(name="wb",
                      shape=[nrow],
                      initializer=tf.ones_initializer,
                      constraint=lambda x: tf.clip_by_value(x, 0, np.infty)
                      )
    wb=tf.math.multiply(wb,wb_input)
    # symmetrize W
    W = sym_w(W)

    def L2(x): return tf.reduce_sum(tf.square(x))
    def L1(x): return tf

    ########################################
    # V + W
    ########################################
    VW = V + tf.tensordot(OH_MSA,W,2)

    # hamiltonian
    H = tf.reduce_sum(tf.multiply(OH_MSA,VW),axis=(1,2))

    # local Z (parition function)
    Z = tf.reduce_sum(tf.reduce_logsumexp(VW,axis=2),axis=1)

    # Psuedo-Log-Likelihood
    PLL = H - Z
    wb = tf.nn.relu(wb)
    # Regularization
    L2_V = 0.01 * L2(V)
    L2_W = 0.01 * L2(W) * 0.5 * (ncol-1) * (states-1)
    L2_wb = l2_wb * L2(tf.gather(wb,idx))

    # loss function to minimize
    #loss = -tf.reduce_sum(PLL*MSA_weights*tf.gather(wb,idx))/tf.reduce_sum(MSA_weights*tf.gather(wb,idx))
    loss = -tf.reduce_sum(PLL*MSA_weights*tf.gather(wb,idx))/tf.reduce_sum(MSA_weights*tf.gather(wb,idx))-tf.minimum(tf.reduce_min(wb),0)
    loss = loss + (L2_V + L2_W + L2_wb)/msa["neff"]
    #wb = tf.nn.softmax(wb)
    ##############################################################
    # MINIMIZE LOSS FUNCTION
    ##############################################################
    if opt_type == "adam":  
        opt = opt_adam(loss,"adam",lr=opt_rate)

    # generate input/feed
    def feed(feed_all=False):
        if batch_size is None or feed_all:
            return {MSA:msa["msa"], MSA_weights:msa["weights"],idx:np.arange(len(msa['weights']))}
        else:
            idx_val = np.random.randint(0,msa["nrow"],size=batch_size)
            return {MSA:msa["msa"][idx_val], MSA_weights:msa["weights"][idx_val],idx:idx_val}

    # optimize!
    with tf.Session() as sess:
        # initialize variables V and W
        sess.run(tf.global_variables_initializer())
        feed_dict = feed()
        # initialize V
        msa_cat = tf.keras.utils.to_categorical(msa["msa"],states)
        pseudo_count = 0.01 * np.log(msa["neff"])
        V_ini = np.log(np.sum(msa_cat.T * msa["weights"],-1).T + pseudo_count)
        V_ini = V_ini - np.mean(V_ini,-1,keepdims=True)
        wb_ini = sess.run(wb)
        sess.run(V.assign(V_ini))

        

        # compute loss across all data
        get_loss = lambda: round(sess.run(loss,feed()) * msa["neff"],2)
        print("starting",get_loss())

#         if opt_type == "lbfgs":
#             lbfgs = tf.contrib.opt.ScipyOptimizerInterface
#             opt = lbfgs(loss,method="L-BFGS-B",options={'maxiter': opt_iter})
#             opt.minimize(sess,feed(feed_all=True))

        if opt_type == "adam":
            for i in range(opt_iter):
                sess.run(opt,feed())  
                if (i+1) % int(opt_iter/10) == 0:
                    print("iter",(i+1),get_loss())

        # save the V and W parameters of the MRF
        V_ = sess.run(V)
        W_ = sess.run(W)
        wb_ =sess.run(wb)

    # only return upper-right triangle of matrix (since it's symmetric)
    tri = np.triu_indices(ncol,1)
    W_ = W_[tri[0],:,tri[1],:]

    mrf = {"v": V_,
         "w": W_,
         "wb": wb_,
         "wb_ini":wb_ini,
          'w_idx':msa['w_idx'],
          'v_idx':msa['v_idx']}

    return mrf

In [0]:
###################
def normalize(x):
  x = stats.boxcox(x - np.amin(x) + 1.0)[0]
  x_mean = np.mean(x)
  x_std = np.std(x)
  return((x-x_mean)/x_std)

def get_mtx(mrf):
  '''get mtx given mrf'''
  
  # l2norm of 20x20 matrices (note: we ignore gaps)
  raw = np.sqrt(np.sum(np.square(mrf["w"][:,:-1,:-1]),(1,2)))
  raw_sq = squareform(raw)

  # apc (average product correction)
  ap_sq = np.sum(raw_sq,0,keepdims=True)*np.sum(raw_sq,1,keepdims=True)/np.sum(raw_sq)
  apc = squareform(raw_sq - ap_sq, checks=False)

  mtx = {"i": mrf["w_idx"][:,0],
         "j": mrf["w_idx"][:,1],
         "raw": raw,
         "apc": apc,
         "zscore": normalize(apc)}
  return mtx
def output_couplingScores_csv(mrf_weights,l2_wb=0.01):
  mtx_weights = get_mtx(mrf_weights)  
  mtx_weights["i_aa"] = np.array([alphabet[msa['msa_ori'][0][i]]+"_"+str(i+1) for i in mtx_weights["i"]])
  mtx_weights["j_aa"] = np.array([alphabet[msa['msa_ori'][0][j]]+"_"+str(j+1) for j in mtx_weights["j"]])
  pd_mtx_weights = pd.DataFrame(mtx_weights,columns=["i","j","apc","zscore","i_aa","j_aa"])
  pd_mtx_weights.to_csv('.csv')
  try:
    os.mkdir('../{}'.format(complex_name))
  except:
    pass
  mtx_weights = get_mtx(mrf_weights)
  pd_mtx_weights.to_csv('../{0}/{0}_l2{1}_couplings_score.csv'.format(complex_name,l2_wb))

In [0]:
import time

In [9]:
%%time
# ===============================================================================
# RUN GREMLIN
# ===============================================================================
# Note: the original GREMLIN uses the "lbfgs" optimizer which is EXTREMELY slow 
# in tensorflow. The modified adam optimizer is much faster, but may 
# require adjusting number of iterations (opt_iter) to converge to the same 
# solution. To switch back to the original, set opt_type="lbfgs".
# ===============================================================================
mrf_weights = GREMLIN_weights(msa,wb_input=wb_input_ones,opt_iter=150,batch_size=1024)

NameError: ignored