#### Exclude the sequences outside of the range of 300–600 aa

In [2]:
from Bio import SeqIO

def filter_fasta_by_length(input_fasta, output_fasta, min_length=300, max_length=600):
    # Parse the input FASTA file and filter sequences based on length
    filtered_records = [record for record in SeqIO.parse(input_fasta, "fasta") 
                        if min_length <= len(record.seq) <= max_length]

    # Write the filtered sequences to the output FASTA file
    SeqIO.write(filtered_records, output_fasta, "fasta")
    
# Training dataset
input_fasta_file = "../data/PF01494_20201216.fasta"
output_fasta_file = "../data/PF01494_20201216_300_600aa.fasta"
filter_fasta_by_length(input_fasta_file, output_fasta_file)

# Testing dataset
input_fasta_file = "../data/PF01494_testing_dataset_20230916_markdup.fasta"
output_fasta_file = "../data/PF01494_testing_dataset_20230916_markdup_300_600aa.fasta"
filter_fasta_by_length(input_fasta_file, output_fasta_file)


#### Process the train dataset

In [4]:
import pickle
import sys
import numpy as np
from sys import exit
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

# Function to split sequences into training and testing datasets
def split_fasta(input_fasta, output_fasta1, output_fasta2, split_index=39948):
    # Load all sequences from the input FASTA into a list
    all_sequences = list(SeqIO.parse(input_fasta, "fasta"))
    
    # Split the list into two parts
    train_dataset = all_sequences[:split_index]
    test_dataset = all_sequences[split_index:]

    # Write sequences to the respective output files
    SeqIO.write(train_dataset, output_fasta1, "fasta")
    SeqIO.write(test_dataset, output_fasta2, "fasta")
    
    # print(f"First part written to {output_fasta1}")
    # print(f"Second part written to {output_fasta2}")

    return train_dataset, test_dataset

# 33948 sequences in the train dataset
# 1550 sequences in the test dataset
# Specify your file names
input_fasta_file = "../data/PF01494_20201216_300_600aa_train_test_MSA.fasta"
output_fasta_file1 = "../data/PF01494_20201216_300_600aa_train_MSA.fasta"
output_fasta_file2 = "../data/PF01494_20201216_300_600aa_test_MSA.fasta"

train_dataset, test_dataset = split_fasta(input_fasta_file, output_fasta_file1, output_fasta_file2)

MSA_file = "../data/PF01494_20201216_300_600aa_train_MSA.fasta"
query_seq_id = "B8M9J8"

# Define a list of amino acid characters
AA = ['R', 'H', 'K',
      'D', 'E',
      'S', 'T', 'N', 'Q',
      'C', 'G', 'P',
      'A', 'V', 'I', 'L', 'M', 'F', 'Y', 'W']

aa_to_index = {}
aa_to_index['-'] = 0
aa_to_index['.'] = 0
index_to_aa = {}
index_to_aa[0] = '-'
for idx, aa in enumerate(AA, start=1):
    aa_to_index[aa] = idx
    index_to_aa[idx] = aa

# Save aa_to_index
with open("../data/encoding/aa_to_index.pkl", 'wb') as file_handle:
    pickle.dump(aa_to_index, file_handle)

# Save index_to_aa
with open("../data/encoding/index_to_aa.pkl", 'wb') as file_handle:
    pickle.dump(index_to_aa, file_handle)

# Read all the sequences into a dictionary
def MSA_to_dict(MSA_file):
    seq_dict = {record.id: str(record.seq).upper() for record in SeqIO.parse(MSA_file, "fasta")}
    return seq_dict

# Remove a column if the query has a gap at that position
def remove_query_gap(seq_dict, query_seq_id):

    query_seq = seq_dict[query_seq_id] ## potentially with gaps
    idx = [ s == "-" or s == "." for s in query_seq]
    for k in seq_dict.keys():
        seq_dict[k] = [seq_dict[k][i] for i in range(len(seq_dict[k])) if idx[i] == False]
    query_seq = seq_dict[query_seq_id] ## without gaps

    return seq_dict, query_seq

# Remove gappy sequences
def remove_gappy_sequneces(seq_dict, query_seq_id):

    len_query_seq = len(query_seq)
    deleted_id_list = []
    deletion_reason = []
    for k in list(seq_dict.keys()):
        if seq_dict[k].count('X') > 0 or seq_dict[k].count('Z') > 0 or seq_dict[k].count('O') > 0:
            seq_dict.pop(k)
            deleted_id_list.append(k)
            deletion_reason.append('XZO')

        elif seq_dict[k].count("-") + seq_dict[k].count(".") > 0.2*len_query_seq:
            seq_dict.pop(k)
            deleted_id_list.append(k)
            deletion_reason.append('gappy')

    return seq_dict, deleted_id_list, deletion_reason

seq_dict = MSA_to_dict(MSA_file)
MSA_query = seq_dict[query_seq_id]

# Save MSA query seq
with open("../data/encoding/MSA_query.pkl", 'wb') as file_handle:
    pickle.dump(MSA_query, file_handle)

seq_dict, query_seq = remove_query_gap(seq_dict, query_seq_id)

# Save nogap query seq
with open("../data/encoding/nogap_query.pkl", 'wb') as file_handle:
    pickle.dump(query_seq, file_handle)

seq_dict, deleted_id_list, deletion_reason = remove_gappy_sequneces(seq_dict, query_seq_id)

# Save deleted_id_list
with open("../data/encoding/deleted_id_list.pkl", 'wb') as file_handle:
    pickle.dump(deleted_id_list, file_handle)

id_list = []
seq_msa = []
for k in seq_dict.keys():
    id_list.append(k)
    seq_msa.append([aa_to_index[s] for s in seq_dict[k]])
seq_msa = np.array(seq_msa)

# Remove positions where too many sequences have gaps
pos_idx = []
for i in range(seq_msa.shape[1]):
    if np.sum(seq_msa[:,i] == 0) <= seq_msa.shape[0]*0.2:
        pos_idx.append(i)

with open("../data/encoding/seq_pos_idx.pkl", 'wb') as file_handle:
    pickle.dump(pos_idx, file_handle)

seq_msa = seq_msa[:, np.array(pos_idx)]

# Reweight sequences
seq_weight = np.zeros(seq_msa.shape)
for j in range(seq_msa.shape[1]):
    aa_type, aa_counts = np.unique(seq_msa[:,j], return_counts = True)
    num_type = len(aa_type)
    aa_dict = {}
    for a in aa_type:
        aa_dict[a] = aa_counts[list(aa_type).index(a)]
    for i in range(seq_msa.shape[0]):
        seq_weight[i,j] = (1.0/num_type) * (1.0/aa_dict[seq_msa[i,j]])
tot_weight = np.sum(seq_weight)
seq_weight = seq_weight.sum(1) / tot_weight

# Save seq_weight
with open("../data/encoding/seq_weight.pkl", 'wb') as file_handle:
    pickle.dump(seq_weight, file_handle)

# Save seq_msa (one-hot encoded version)
with open("../data/encoding/seq_msa.pkl", 'wb') as file_handle:
    pickle.dump(seq_msa, file_handle)

# Decode idx back to aa
seq_msa_aa = []
for k in range(seq_msa.shape[0]):
    decoded_seq = "".join(index_to_aa[idx] for idx in seq_msa[k])
    seq_msa_aa.append(decoded_seq)

seq_dict_truncated = {id: seq for id, seq in zip(id_list, seq_msa_aa)}

# Save seq_dict_truncated
with open("../data/encoding/seq_dict_truncated.pkl", 'wb') as file_handle:
     pickle.dump(seq_dict_truncated, file_handle)

# Save the position number of the residues that are kept
template_pos_idx = []
for i, res in enumerate(MSA_query):
    # If the sequence is empty, break the loop
    if not seq_dict_truncated[query_seq_id]:
        break
    if res == seq_dict_truncated[query_seq_id][0]:
        template_pos_idx.append(i+1)
        seq_dict_truncated[query_seq_id] = seq_dict_truncated[query_seq_id][1:]

# Save template_pos_idx
with open("../data/encoding/template_pos_idx.pkl", 'wb') as file_handle:
     pickle.dump(template_pos_idx, file_handle)

# Save keys_list
with open("../data/encoding/keys_list.pkl", 'wb') as file_handle:
    pickle.dump(id_list, file_handle)

# Save seq_dict_truncated
with open("../data/encoding/seq_dict_truncated.pkl", 'wb') as file_handle:
    pickle.dump(seq_dict_truncated, file_handle)

# Change aa numbering into binary
K = 21 ## num of classes of aa
D = np.identity(K)
num_seq = seq_msa.shape[0]
len_seq_msa = seq_msa.shape[1]
seq_msa_binary = np.zeros((num_seq, len_seq_msa, K))
for i in range(num_seq):
    seq_msa_binary[i,:,:] = D[seq_msa[i]]

# Save seq_msa_binary
with open("../data/encoding/seq_msa_binary.pkl", 'wb') as file_handle:
    pickle.dump(seq_msa_binary, file_handle)

# Print train_seq_msa_binary shape
print('Train binary shape: ', seq_msa_binary.shape)

Train binary shape:  (33937, 353, 21)


#### Process the test dataset

In [5]:
test_MSA_file = "../data/PF01494_20201216_300_600aa_test_MSA.fasta"
test_seq_dict = MSA_to_dict(test_MSA_file)
query_seq_id = 'B8M9J8'

with open('../data/encoding/aa_to_index.pkl', 'rb') as file:
    aa_to_index = pickle.load(file)

with open('../data/encoding/index_to_aa.pkl', 'rb') as file:
    index_to_aa = pickle.load(file)

with open('../data/encoding/MSA_query.pkl', 'rb') as file:
    MSA_query = pickle.load(file)

with open('../data/encoding/nogap_query.pkl', 'rb') as file:
    query_seq = pickle.load(file)

with open('../data/encoding/seq_pos_idx.pkl', 'rb') as file:
    pos_idx = pickle.load(file)

# Remove a column if the query has a gap at that position
def remove_test_query_gap(seq_dict, MSA_query):

    idx = [ s == "-" or s == "." for s in MSA_query]
    for k in seq_dict.keys():
        seq_dict[k] = [seq_dict[k][i] for i in range(len(seq_dict[k])) if idx[i] == False]

    return seq_dict

# Remove gappy sequences
def remove_test_gappy_sequneces(seq_dict, query_seq):

    len_query_seq = len(query_seq)
    deleted_id_list = []
    deletion_reason = []
    for k in list(seq_dict.keys()):
        if seq_dict[k].count('X') > 0 or seq_dict[k].count('Z') > 0 or seq_dict[k].count('O') > 0:
            seq_dict.pop(k)
            deleted_id_list.append(k)
            deletion_reason.append('XZO')

        elif seq_dict[k].count("-") + seq_dict[k].count(".") > 0.2*len_query_seq:
            seq_dict.pop(k)
            deleted_id_list.append(k)
            deletion_reason.append('gappy')

    return seq_dict, deleted_id_list, deletion_reason

test_seq_dict = remove_test_query_gap(test_seq_dict, MSA_query)
test_seq_dict, test_deleted_id_list, test_deletion_reason = remove_gappy_sequneces(test_seq_dict, query_seq_id)

# Save deleted_id_list
with open("../data/encoding/test_deleted_id_list.pkl", 'wb') as file_handle:
    pickle.dump(test_deleted_id_list, file_handle)

test_id_list = []
test_seq_msa = []
for k in test_seq_dict.keys():
    test_id_list.append(k)
    test_seq_msa.append([aa_to_index[s] for s in test_seq_dict[k]])
test_seq_msa = np.array(test_seq_msa)

# Remove positions where too many sequences have gaps
test_seq_msa = test_seq_msa[:, np.array(pos_idx)]

# Assuming seq_msa_test represents your test sequences
num_test_samples = test_seq_msa.shape[0]
# Create an array of ones with a length equal to the number of test sequences
test_seq_weight = np.ones(num_test_samples)

# Save seq_weight
with open("../data/encoding/test_seq_weight.pkl", 'wb') as file_handle:
    pickle.dump(test_seq_weight, file_handle)

# Save seq_msa (one-hot encoded version)
with open("../data/encoding/test_seq_msa.pkl", 'wb') as file_handle:
    pickle.dump(test_seq_msa, file_handle)

# Decode idx back to aa
test_seq_msa_aa = []
for k in range(test_seq_msa.shape[0]):
    decoded_seq = "".join(index_to_aa[idx] for idx in test_seq_msa[k])
    test_seq_msa_aa.append(decoded_seq)

test_seq_dict_truncated = {id: seq for id, seq in zip(id_list, test_seq_msa_aa)}

# Save seq_dict_truncated
with open("../data/encoding/test_seq_dict_truncated.pkl", 'wb') as file_handle:
     pickle.dump(test_seq_dict_truncated, file_handle)

# Save keys_list
with open("../data/encoding/test_keys_list.pkl", 'wb') as file_handle:
    pickle.dump(test_id_list, file_handle)

# Change aa numbering into binary
K = 21 ## num of classes of aa
D = np.identity(K)
num_seq = test_seq_msa.shape[0]
len_seq_msa = test_seq_msa.shape[1]
test_seq_msa_binary = np.zeros((num_seq, len_seq_msa, K))
for i in range(num_seq):
    test_seq_msa_binary[i,:,:] = D[test_seq_msa[i]]

# Save seq_msa_binary
with open("../data/encoding/test_seq_msa_binary.pkl", 'wb') as file_handle:
    pickle.dump(test_seq_msa_binary, file_handle)

# Print test_seq_msa_binary shape
print('Test binary shape: ', test_seq_msa_binary.shape)

Test binary shape:  (1389, 353, 21)
