# Training Potts Models with Contrastive Divergence for Protein Design

## GREMLIN

https://github.com/whbpt/GREMLIN_PYTORCH/blob/master/GREMLIN_pytorch.ipynb

#### Import

In [1]:
# IMPORTANT, only tested using PYTHON 3!
import numpy as np
import tensorflow as tf
import matplotlib.pylab as plt
import pandas as pd
import torch
import torch.nn.functional as F

from scipy import stats
from scipy.spatial.distance import pdist,squareform
from torch import optim

#### Params

In [2]:
################
# note: if you are modifying the alphabet
# make sure last character is "-" (gap)
################
alphabet = "ARNDCQEGHILKMFPSTWYV-"
states = len(alphabet)
a2n = {}
for a,n in zip(alphabet,range(states)):
  a2n[a] = n
################

def aa2num(aa):
  '''convert aa into num'''
  if aa in a2n: return a2n[aa]
  else: return a2n['-']

In [3]:
## Convert FASTA to MSA np.array()

def parse_fasta(filename):
  '''function to parse fasta file'''
  header = []
  sequence = []
  lines = open(filename, "r")
  for line in lines:
    line = line.rstrip()
    if line[0] == ">":
      header.append(line[1:])
      sequence.append([])
    else:
      sequence[-1].append(line)
  lines.close()
  sequence = [''.join(seq) for seq in sequence]
  return np.array(header), np.array(sequence)

def one_hot(msa,states):
  one = np.eye(states)
  return one[msa]

def mk_msa(seqs):
  '''one hot encode msa'''
  
  ################
  alphabet = "ARNDCQEGHILKMFPSTWYV-"
  states = len(alphabet)
  a2n = {}
  for a,n in zip(alphabet,range(states)):
    a2n[a] = n

  def aa2num(aa):
    '''convert aa into num'''
    if aa in a2n: return a2n[aa]
    else: return a2n['-']
  ################
  
  msa = []
  for seq in seqs:
    msa.append([aa2num(aa) for aa in seq])
  msa_ori = np.array(msa)
  return msa_ori, one_hot(msa_ori,states)

In [4]:
names,seqs = parse_fasta("../pfamncamseed.fas.txt")
msa_ori, msa = mk_msa(seqs)

print(msa_ori.shape)
print(msa.shape)

(48, 113)
(48, 113, 21)


In [5]:
# collecting some information about input msa
N = msa.shape[0] # number of sequences
L = msa.shape[1] # length of sequence
A = msa.shape[2] # number of states (or categories)

In [30]:
class GREMLIN(torch.nn.Module):
  def __init__(self,L,A):
    super(GREMLIN, self).__init__()
    self.W0 = torch.nn.Parameter(torch.zeros(L*A,L*A), requires_grad=True) # this is J in the manuscript
    self.b0 = torch.nn.Parameter(torch.zeros(L*A), requires_grad=True) # this is H 
    self.MASK = (1.0 -torch.eye(L*A))
    
  def forward(self,X):
    X = X.reshape(-1,L*A)
    W = (self.W0+self.W0)/2.0 * self.MASK
    MSA_pred = (X.mm(W)+self.b0).reshape(-1,L,A)
    loss = torch.sum(- MSA_Input * F.log_softmax(MSA_pred, -1))
    L2_w = (W**2).sum() * 0.01 * 0.5 *L*A
    L2_b = (self.b0**2).sum() * 0.01
    loss = loss + L2_w + L2_b
    return loss

In [7]:
class Model(torch.nn.Module):
  def __init__(self,L,A):
    super(Model, self).__init__()
    self.GREMLIN_ = GREMLIN(L,A)
    
  def forward(self,X):
    loss = self.GREMLIN_(X)
    return loss

In [8]:
#enviroment setting
device = torch.device("cuda:0") # Uncomment this to run on GPU
MSA_Input = torch.from_numpy(msa.astype(np.float32))

model = Model(L,A)
learning_rate = 0.1*np.log(N)/L
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


for t in range(100):

    loss = model(MSA_Input)      
    optimizer.zero_grad()    
    loss.backward()
    optimizer.step()
    
    if (t) % (int(100/10)) == 0: 
      print(t, loss.item())

0 16513.498046875
10 7362.86572265625
20 6565.06982421875
30 6130.17431640625
40 5972.72900390625
50 5892.46337890625
60 5841.9091796875
70 5809.32177734375
80 5783.7626953125
90 5762.48681640625


In [9]:
w = model.GREMLIN_.W0.detach().numpy()
w = (w+w.T).reshape(L,A,L,A)

In [10]:
model(MSA_Input)  

tensor(5743.4326, grad_fn=<AddBackward0>)

In [31]:
# Use the equation for probability of Boltzmann distribution 
#(without the 1/Z term) to calculate likelihood.
boltzprob = torch.exp(model.GREMLIN_.b0 + model.GREMLIN_.W0)

### bmDCA

**Important Notes:**

*  All amino acids must be upper case

https://github.com/ranganathanlab/bmDCA

In [None]:
!git clone https://github.com/ranganathanlab/bmDCA.git

In [None]:
!sudo apt-get update
!sudo apt-get install git gcc g++ automake autoconf pkg-config \
  libarmadillo-dev libopenblas-dev libarpack++2-dev

In [None]:
%cd bmDCA
!bash autogen.sh --prefix=/usr/local && \
%cd ..

In [None]:
%%shell
cd bmDCA
make -j4 && \
make install
cd ..

In [None]:
!mkdir results

In [None]:
!cp pfam_hits.txt lcc.fasta

#### Training

100-245 of LCC?

In [None]:
import numpy as np

def read_fasta(fname):
    seqs = []
    s = ""
    with open(fname) as f:
        line = f.readline()
        while line:
            if line.startswith(">"):
                if s != "":
                    seqs.append(list(s))
                s = ""
            elif len(line) > 0:
                s += line.strip()
            line = f.readline()
        seqs.append(list(s))
    return np.array(seqs)

In [None]:
seqs = read_fasta("pfam_hits.txt")

In [None]:
mask = np.zeros(len(seqs[0]), dtype=np.bool)
for i in range(len(seqs[0])):
    gaps = 0
    for s in seqs:
        if s[i] == '-':
            gaps += 1
    if gaps/len(seqs) < 0.67:   # keep positions where less that 2/3rd are gaps
        mask[i] = True
seqs = seqs[:,mask]

In [None]:
towrite = ""
for i in range(len(seqs)):
    towrite += ">{}\n".format(i)
    towrite += "".join(seqs[i][100:]) + "\n"   # take positions 100-226
with open("lcc_short.fasta",'w') as f:
    f.write(towrite)

In [None]:
%%shell
rm results/*
bmdca -i lcc_short.fasta -r -d /content/results

In [None]:
!tar -czf boltzmann.tar.gz results/*

#### Sampling

Change temperature in a config file

In [None]:
%%shell
bmdca_sample -p parameters.txt -d /content/results -o samples.txt -c config.conf

In [None]:
!perl convert.pl lcc_pfam.txt lcc_pfam.fa

### Contrastive Divergence

In [None]:
import jax.numpy as jnp
from jax import random
from jax import grad
from jax.scipy.stats.norm import pdf
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats

key = random.PRNGKey(0)

### Model evaluation

In [None]:
!git clone https://github.com/igemto-drylab/CSBERG-ML.git
%cd CSBERG-ML
from util import *