In [42]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Sep  5 11:31:34 2018

@author: Lewis Iain Moffat


This script takes in a supplied sequence (as yet does not do multiple sequences
) that is in a fasta file or a text document. This then takes this sequence and 
runs it through a forward pass of the autoencoder, encoding it and decoding it. 
This produces the same sequence with variation added. This presumes the protein
sequence is not a metal binder however if it is it shouldn't affect the 
variation. This is because the metal binding variational model is used instead
of the grammar model. This is for simplicities sake i.e. you don't need to have
the grammar of a protein before being able to run this script. 

The only two arguments that need to be passed are the text_file and number of 
sequences out wanted. These are written to standard out

This assumes you are running things on a cpu by default. 

"""


# =============================================================================
# Imports
# =============================================================================

import torch
import torch.nn.functional as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F

import numpy as np
import argparse
import utils

from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split as tts

# =============================================================================
# Sort out Command Line arguments
# =============================================================================

parser = argparse.ArgumentParser()
parser.add_argument("-infile", type=str,
        help="file with sequence", default="examples/seq2seq_example.txt")# its either struc or nostruc
parser.add_argument("-numout", type=int,
        help="number of sequences generated", default=10)
args = parser.parse_args("-infile examples/seq2seq_example.txt -numout 10".split()) #add the string normally written in the command line as the parsed arguments
args_dict = vars(args)        


# =============================================================================
# Pytorch Module
# =============================================================================
class VAE(torch.nn.Module):
    def __init__(self, input_size, hidden_sizes, batch_size):
        super().__init__()
        
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.batch_size = batch_size
           

        self.fc = torch.nn.Linear(input_size, hidden_sizes[0])  # 2 for bidirection 
        self.BN = torch.nn.BatchNorm1d(hidden_sizes[0])
        self.fc1 = torch.nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.BN1 = torch.nn.BatchNorm1d(hidden_sizes[1])
        self.fc2 = torch.nn.Linear(hidden_sizes[1], hidden_sizes[2])
        self.BN2 = torch.nn.BatchNorm1d(hidden_sizes[2])
        self.fc3_mu = torch.nn.Linear(hidden_sizes[2], hidden_sizes[3])
        self.fc3_sig = torch.nn.Linear(hidden_sizes[2], hidden_sizes[3])
        
        self.fc4 = torch.nn.Linear(hidden_sizes[3]+8, hidden_sizes[2])
        self.BN4 = torch.nn.BatchNorm1d(hidden_sizes[2])
        #print(self.BN4) batchnorm1d = batch normalization
        self.fc5 = torch.nn.Linear(hidden_sizes[2], hidden_sizes[1])
        self.BN5 = torch.nn.BatchNorm1d(hidden_sizes[1])
        self.fc6 = torch.nn.Linear(hidden_sizes[1], hidden_sizes[0])
        self.BN6 = torch.nn.BatchNorm1d(hidden_sizes[0])
        self.fc7 = torch.nn.Linear(hidden_sizes[0], input_size-8)
        #this is how the size goes back down from 5120 to 3080 from out6 to out7
        #self.fc4 = torch.nn.Linear(hidden_sizes[3+8], input_size-8)
            #torch models have sizes saved
        
    #testing this with out4 and out5; size mismatch for fc4.weight: copying a param with shape torch.Size([128, 24]) from checkpoint, the shape in current model is torch.Size([3080, 512]).
	#size mismatch for fc4.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([3080]).

    def sample_z(self,x_size, mu, log_var):
        # Using reparameterization trick to sample from a gaussian
        eps = torch.randn(x_size, self.hidden_sizes[-1])	
        return mu + torch.exp(log_var / 2) * eps
    
    def forward(self, x, code, struc=None):
        
        ###########
        # Encoder #
        ###########
        
        # get the code from the tensor
        # add the conditioned code
        x = torch.cat((x,code),1)    
        # Layer 0
        out1 = self.fc(x)        
        out1 = nn.relu(self.BN(out1))
        # Layer 1
        out2 = self.fc1(out1)
        out2 = nn.relu(self.BN1(out2))
        # Layer 2
        out3 = self.fc2(out2)
        out3 = nn.relu(self.BN2(out3))
        # Layer 3 - mu
        mu   = self.fc3_mu(out3)
        # layer 3 - sig
        sig  = nn.softplus(self.fc3_sig(out3))        


        ###########
        # Decoder #
        ###########
        
        # sample from the distro
       
        sample= self.sample_z(x.size(0),mu, sig)
        # add the conditioned code
        print(x.size(0))
            #batch size
        f = open("sample_decoder.txt", "w+")
        wsample = sample
        wsample = wsample.tolist()
        wsample = ''.join(str(e) for e in wsample)
        f.write(wsample)
        
        sample = torch.cat((sample, code),1)
        # Layer 4
        out4 = self.fc4(sample)
        test = out4 # to return without nn.relu
        #print(code)
        # print(sample.names)
        out4 = nn.relu(self.BN4(out4))
        print(out4.shape) 
            #making sure the shape of out4 matches the values/indices size that is returned
        values, indices = out4.max(0) 
            #find highest values of tensor and record highest value as well as index
            #print(values) #show the highest values on the command line
        print(len(values))
            #print(indices) #show the matching index on the command line
        f = open("out4_nodes.txt","w+") 
            #open text file to output nodes from first layer of decoder
            #   wout = out4.tolist() #write tensor object of first output nodes to list
            #   wout = ''.join(str(e) for e in wout) #change list into string
            #   f.write(wout) #write string of output nodes to test
        wvalues = str(values) 
            #change the values from tensor to string
        f.write(wvalues) 
            #write the values to out4_nodes
        wout = str(indices) 
            #indices of higest values in tensor to string
        f.write(wout) 
            #writes the indices to out4_nodes
        # Layer 5
        out5 = self.fc5(out4)
        out5 = nn.relu(self.BN5(out5))
        print(out5.shape)
        values, indices = out5.max(0) 
            #find highest values of tensor and record highest value as well as index
        #print(values) 
            #show the highest values
        print(len(values))
        #print(indices) 
            #show the matching index
        f = open("out5_nodes.txt","w+") 
            #open text file to output nodes from first layer of decoder
        #   wout = out4.tolist() 
            #write tensor object of first output nodes to list
        #   wout = ''.join(str(e) for e in wout) 
            #change list into string
        #   f.write(wout) #write string of output nodes to test
        wvalues = str(values)
        f.write(wvalues) 
            #write the values to out5_nodes
        wout = str(indices) 
            #indices of higest values in tensor to string
        f.write(wout) 
            #writes the indices to out5_nodes
        # Layer 6
        out6 = self.fc6(out5)
        out6 = nn.relu(self.BN6(out6))
        # Layer 7
        out7 = nn.sigmoid(self.fc7(out6))
        #test = nn.sigmoid(self.fc4(test))
        # test = nn.sigmoid(out4) 
            #out4 to 1280 size versus out7 to 3080 size
        # test = nn.sigmoid(test) 
            #out4 with nn.relu to 1280 size versus out7 to 3080 size
       # test = nn.sigmoid(out5) 
        #out5 to 2560 size versus out7 to 3080 size - what about without nn.relu?
       
       # test = nn.sigmoid(out6) 
            #out6 to 5120 size
        return out7, mu, sig
        #return test, mu, sig
        #return out4, mu, sig



# =============================================================================
# Create and Load model into memory
# =============================================================================

X_dim=3088
hidden_size=[512,256,128,16]
#hidden_size=[16]
batch_size=args_dict["numout"]
vae = VAE(X_dim, hidden_size, batch_size)
# load model
vae.load_state_dict(torch.load("models/metal16_nostruc", map_location=lambda storage, loc: storage))

# =============================================================================
#  Define function to produce sequences. 
# =============================================================================

    
def newMetalBinder(model,data):
    """
    Generates a new sequence based on a metal code; the first 3080 dims are the
    sequence, the final 8 are the metal binding flags. Fold is optional
    """
    scores=[]
    model.eval()
    
    code = np.tile(np.zeros(8),(model.batch_size,1))
    #print(code)
    x = np.tile(data[:3080],(model.batch_size,1))
    f = open("code_decoder.txt", "w+")
    wcode = x
    wcode = wcode.tolist()
    wcode = ''.join(str(e) for e in wcode)
    f.write(wcode)
    X = torch.from_numpy(x).type(torch.FloatTensor)
    C = torch.from_numpy(code).type(torch.FloatTensor)

    x_sample, z_mu, z_var = model(X, C)
    
    
    len_aa=140*22 #where do these numbers come from? 140*22 = 3080 which is the dimension of X
    y_label=np.argmax(x[:,:len_aa].reshape(batch_size,-1,22), axis=2)
    print(x_sample[:,:len_aa].cpu().data.numpy().shape)
    y_pred=np.argmax(x_sample[:,:len_aa].cpu().data.numpy().reshape(batch_size,-1,22), axis=2)
   # print(x_sample[:,:len_aa].cpu().data.numpy().shape) dimensions of out7/normal output from decoder
    #y_pred=np.argmax(x_sample[:,:len_aa].cpu().data.numpy().reshape(batch_size,-1,22), axis=2)
    for idx, row in enumerate(y_label):
        scores.append(accuracy_score(row[:np.argmax(row)],y_pred[idx][:np.argmax(row)]))
    print("Average Sequence Identity to Input: {0:.1f}%".format(np.mean(scores)*100))
    
    out_seqs=x_sample[:,:len_aa].cpu().data.numpy() #make sure you understand what each part of this command does
    for seq in out_seqs:
        print(seq) #see the difference between utils.vec_to_seq and out_seqs
    for seq in out_seqs:
        print(utils.vec_to_seq(seq))
        
    return



# =============================================================================
# Produce new sequence
# =============================================================================

# first we read in the sequence from the 
with open(args_dict["infile"],'r') as in_file:
    seq=in_file.readlines()

# format the sequence so if it is a FASTA file then we turf the line with >
for idx, line in enumerate(seq):
    seq[idx]=line.replace("\n","")

seq_in=""
for line in seq:
    if ">" in line:
        continue
    else:
        seq_in=seq_in+line

# now have a string which is the sequence        
seq_in_vec=utils.seq_to_vec(seq_in)

newMetalBinder(vae,seq_in_vec)
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    


10
torch.Size([10, 128])
128
torch.Size([10, 256])
256
(10, 3080)
Average Sequence Identity to Input: 60.1%
[1.1633015e-01 6.1108196e-01 1.1680838e-02 ... 2.4804574e-16 4.8697825e-16
 1.0000000e+00]
[1.8130545e-01 6.9561148e-01 5.7744263e-03 ... 6.8005178e-17 1.5393639e-16
 1.0000000e+00]
[2.8535354e-01 3.2712936e-01 9.2944168e-03 ... 1.4238643e-15 1.8644940e-15
 1.0000000e+00]
[2.1861203e-03 7.5355351e-01 2.1442639e-02 ... 2.2769651e-19 4.2823748e-19
 1.0000000e+00]
[2.1523491e-01 5.3341520e-01 2.1010539e-02 ... 1.3831701e-17 3.2444681e-17
 1.0000000e+00]
[2.52371550e-01 6.57651424e-01 5.14393784e-02 ... 6.72852037e-18
 1.34941365e-17 1.00000000e+00]
[1.4901452e-01 5.7754391e-01 5.6306254e-03 ... 3.8031517e-16 8.9857162e-16
 1.0000000e+00]
[2.9872788e-02 9.3515557e-01 2.9262018e-03 ... 4.6611092e-22 1.6563628e-21
 1.0000000e+00]
[2.8725848e-01 8.0325373e-02 2.1072498e-03 ... 2.9573607e-15 5.4466654e-15
 1.0000000e+00]
[1.3291746e-01 7.0576584e-01 5.6505664e-03 ... 4.4419784e-16 1.0054

In [19]:
#python seq_to_seq.py -infile examples/seq2seq_example.txt -numout 10

SyntaxError: invalid syntax (<ipython-input-19-4885625ec1fb>, line 1)