In [None]:
import math
import os
import random

from Bio import SeqIO

In [None]:
# Set file paths
# Set allseq.fasta file paths
allseq_path = "../remove_outliers/out/allseq.fasta"

# Set alldisorder.fasta file path
alldisorder_path = "../remove_outliers/out/alldisorder.fasta"

# Set cluster file path
cluster_path = "../cluster_seqs/out/allseq.clstr"

In [None]:
# Read in raw cluster output and extract representative protein codes
rep_proteins = []
with open(cluster_path) as file:
    for line in file:
        if "*" in line:
            protein_acc = line.split("|")[0].split(">")[1]
            rep_proteins.append(protein_acc)

In [None]:
# Load in fasta files with biopython
seq_fasta = list(SeqIO.parse(allseq_path, 'fasta'))
disorder_fasta = list(SeqIO.parse(alldisorder_path, 'fasta'))

In [None]:
# Create dict containing all protein info
# (key, value) = ("Q9UPN6", ["Q9UPN6|sequence|SR-rel...","MEAVKTFNSELYSLND...", "100000000000..."])

# First create dict with protein description and seq
all_protein_dict = {}
for entry in seq_fasta:
    protein_acc = entry.description.split("|")[0]
    protein_description = entry.description
    protein_seq = str(entry.seq)
    protein_info = [protein_description, protein_seq]
    all_protein_dict[protein_acc] = protein_info 

# Add disorder to dict
for entry in disorder_fasta:
    protein_acc = entry.description.split("|")[0]
    protein_disorder = str(entry.seq)
    all_protein_dict[protein_acc].append(protein_disorder)

In [None]:
# Create dict containing only info for representative proteins
rep_protein_dict = {}
for key in all_protein_dict:
    if key in rep_proteins:
        rep_protein_dict[key] = all_protein_dict[key]

In [None]:
# Data are now in a dictionary format {'key=ID':'all information, sequence, labels'}
# Convert data into list for shuffling and splitting
rep_protein_list = list(rep_protein_dict.values())
rep_protein_list[1]

In [None]:
# Data shuffling and splitting
# Set random seed for repeatability and shuffle the data
random.seed(1)
random.shuffle(rep_protein_list)

# Extract by index
train_length = math.ceil(0.8*len(rep_protein_list))
test_length = math.ceil(0.1*len(rep_protein_list))

train = rep_protein_list[:train_length]
test = rep_protein_list[train_length:train_length+test_length]
validation = rep_protein_list[train_length+test_length:]  # Validation gets remainder if split is not even

In [None]:
validation[1][0]

In [None]:
if not os.path.exists('out/'):
    os.mkdir('out/')

# Compile all AA sequences into a fasta file
with open("out/val_as_fasta.fasta", "w") as validation_fastas:
    for record in validation:
        validation_fastas.write(">" + record[0] + "\n" + record[1] + "\n")

# Repeat for train 
with open("out/train_as_fasta.fasta", "w") as train_fastas:
    for record in train:
        train_fastas.write(">" + record[0] + "\n" + record[1] + "\n")

# Repeat for test
with open("out/test_as_fasta.fasta", "w") as test_fastas:
    for record in test:
        test_fastas.write(">" + record[0] + "\n" + record[1] + "\n")

In [None]:
# Compile all binary disorder sequences into a fasta file
validation_labels = open("out/val_labels_as_fasta.fasta", "w")
with open("out/val_labels_as_fasta.fasta", "w") as validation_labels:
    for record in validation:
        validation_labels.write(">" + record[0] + "\n" + record[2] + "\n")

# Repeat for train 
with open("out/train_labels_as_fasta.fasta", "w") as train_labels:
    for record in train:
        train_labels.write(">" + record[0] + "\n" + record[2] + "\n")

# Repeat for test
with open("out/test_labels_as_fasta.fasta", "w") as test_labels:
    for record in test:
        test_labels.write(">" + record[0] + "\n" + record[2] + "\n")

In [None]:
# Compute statistics for each subset of the split
for subset_name, subset in zip(['train', 'test', 'validation'], [train, test, validation]):
    residue_num = 0
    order_num = 0
    disorder_num = 0
    for _, _, labels in subset:
        residue_num += len(labels)
        order_num += labels.count('0')
        disorder_num += labels.count('1')
    
    print(subset_name.upper())
    print('Number of proteins:', len(subset))
    print('Number of residues:', residue_num)
    print('Number of ordered residues:', order_num)
    print('Fraction of ordered residues:', order_num / residue_num)
    print('Number of disordered residues:', disorder_num)
    print('Fraction of disordered residues:', disorder_num / residue_num)
    print()
print('Subsets sum to total protein number:', sum([len(subset) for subset in [train, test, validation]]) == len(rep_protein_dict))