#### Encode the data

In [2]:
import pickle
import sys
import numpy as np
from sys import exit
from Bio.Seq import Seq
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord

file_name = "../data/PF01494_MSA.fasta"
query_seq_id = "B8M9J8" #TropB

# Read all the sequences into a dictionary
id_list = []
seq_list = []
seq_dict = {}
for record in SeqIO.parse(file_name, "fasta"):
    ID = record.id
    id_list.append(ID)
    seq_list.append(str(record.seq.upper()))
    seq_dict[ID] = str(record.seq.upper())

FMO = []
for idx, (seq, ID) in enumerate(zip(seq_list, id_list)):
    FMO.append(SeqRecord(Seq(seq), id=ID, description="")) 
with open("../data/PF01494_MSA_nodes.fasta", "w") as handle:
    SeqIO.write(FMO, handle, "fasta")
    
# Remove gaps in the query sequences
query_seq = seq_dict[query_seq_id] ## with gaps
idx = [ s == "-" or s == "." for s in query_seq]
for k in seq_dict.keys():
    seq_dict[k] = [seq_dict[k][i] for i in range(len(seq_dict[k])) if idx[i] == False]
query_seq = seq_dict[query_seq_id] ## without gaps

# Remove sequences with too many gaps
len_query_seq = len(query_seq)
seq_id = list(seq_dict.keys())
num_gaps = []
for k in seq_id:
    num_gaps.append(seq_dict[k].count("-") + seq_dict[k].count("."))
    if seq_dict[k].count("-") + seq_dict[k].count(".") > 0.2*len_query_seq:
        seq_dict.pop(k)

with open("../data/encoding/seq_dict.pkl", 'wb') as file_handle:
     pickle.dump(seq_dict, file_handle)

# Convert aa type into num 0-20
aa = ['R', 'H', 'K',
      'D', 'E',
      'S', 'T', 'N', 'Q',
      'C', 'G', 'P',
      'A', 'V', 'I', 'L', 'M', 'F', 'Y', 'W']
aa_index = {}
aa_index['-'] = 0
aa_index['.'] = 0
i = 1
for a in aa:
    aa_index[a] = i
    i += 1
with open("../data/encoding/aa_index.pkl", 'wb') as file_handle:
    pickle.dump(aa_index, file_handle)
    
seq_msa = []
keys_list = []
for k in seq_dict.keys():
    if seq_dict[k].count('X') > 0 or seq_dict[k].count('Z') or seq_dict[k].count('O')> 0:
        continue    
    seq_msa.append([aa_index[s] for s in seq_dict[k]])
    keys_list.append(k)    
seq_msa = np.array(seq_msa)

training_keys_list = keys_list[0:33972]
testing_keys_list = keys_list[33972:]

# Training keys
with open("../data/encoding/keys_list.pkl", 'wb') as file_handle:
    pickle.dump(training_keys_list, file_handle)

# Testing keys
with open("../data/encoding/t_keys_list.pkl", 'wb') as file_handle:
    pickle.dump(testing_keys_list, file_handle)

# Split the training and testing datasets
training_seq_msa = seq_msa[0:33972,:]
testing_seq_msa = seq_msa[33972:,:]

# Reweight sequences
# note: only reweighted sequences are used for training.
# seq_msa.shape[0]: number of sequences
# seq_msa.shape[1]: sequence length
seq_weight = np.zeros(training_seq_msa.shape)
for j in range(training_seq_msa.shape[1]):
    aa_type, aa_counts = np.unique(training_seq_msa[:,j], return_counts = True)
    num_type = len(aa_type)
    aa_dict = {}
    for a in aa_type:
        aa_dict[a] = aa_counts[list(aa_type).index(a)]
    for i in range(training_seq_msa.shape[0]):
        seq_weight[i,j] = (1.0/num_type) * (1.0/aa_dict[seq_msa[i,j]])
tot_weight = np.sum(seq_weight)
seq_weight = seq_weight.sum(1) / tot_weight 
with open("../data/encoding/seq_weight.pkl", 'wb') as file_handle:
    pickle.dump(seq_weight, file_handle)

# Testing sequence weight
t_seq_weight = np.ones(testing_seq_msa.shape[0]) / testing_seq_msa.shape[0]
t_seq_weight = t_seq_weight.astype(np.float32)
with open("../data/encoding/t_seq_weight.pkl", 'wb') as file_handle:
    pickle.dump(t_seq_weight, file_handle)

# Remove positions where too many sequences have gaps
pos_idx = []
for i in range(training_seq_msa.shape[1]):
    if np.sum(seq_msa[:,i] == 0) <= seq_msa.shape[0]*0.2:
        pos_idx.append(i)
with open("../data/encoding/seq_pos_idx.pkl", 'wb') as file_handle:
    pickle.dump(pos_idx, file_handle)
    
training_seq_msa = training_seq_msa[:, np.array(pos_idx)]
testing_seq_msa = testing_seq_msa[:, np.array(pos_idx)]
with open("../data/encoding/seq_msa.pkl", 'wb') as file_handle:
    pickle.dump(training_seq_msa, file_handle)
with open("../data/encoding/t_seq_msa.pkl", 'wb') as file_handle:
    pickle.dump(testing_seq_msa, file_handle)

# Change aa numbering into binary
K = 21 # num of classes of aa
D = np.identity(K)
training_num_seq = training_seq_msa.shape[0]
testing_num_seq = testing_seq_msa.shape[0]
len_seq_msa = training_seq_msa.shape[1]
training_seq_msa_binary = np.zeros((training_num_seq, len_seq_msa, K))
testing_seq_msa_binary = np.zeros((testing_num_seq, len_seq_msa, K))
for i in range(training_num_seq):
    training_seq_msa_binary[i,:,:] = D[training_seq_msa[i]]
for j in range(testing_num_seq):
    testing_seq_msa_binary[j,:,:] = D[testing_seq_msa[j]]

# Training binary
with open("../data/encoding/seq_msa_binary.pkl", 'wb') as file_handle:
    pickle.dump(training_seq_msa_binary, file_handle)

# Print train_seq_msa_binary shape
print('Training binary shape: ', training_seq_msa_binary.shape)

# Testing binary
with open("../data/encoding/t_seq_msa_binary.pkl", 'wb') as file_handle:
    pickle.dump(testing_seq_msa_binary, file_handle) 

# Print test_seq_msa_binary shape
print('Testing binary shape: ', testing_seq_msa_binary.shape)

Training binary shape:  (33972, 352, 21)
Testing binary shape:  (1091, 352, 21)
