In [1]:
import os
import pandas as pd
import numpy as np
os.chdir('/home/ubuntu/SaprotHub/colab')

In [2]:
df = pd.read_csv('eSol_output_SA_split.csv')

In [3]:
def extract_sequence(sequence):
    return sequence[::2]

In [4]:
df['sequence_only'] = df['sequence'].apply(extract_sequence)

In [5]:
df['index'] = [f"sample_{i+1}" for i in range(len(df))]

In [6]:
# df.to_csv('expression_data.csv', index=False)

In [7]:
# with open('exression_data.fasta', 'w') as fasta_file:
#     for index, row in df.iterrows():
#         fasta_file.write(f">{row['index']}\n")
#         fasta_file.write(f"{row['sequence_only']}\n") 

In [8]:
file = 'mmseq2_clusters/expression_si02/clusterRes_cluster.tsv'

In [9]:
def process_clusters(input_filename):
    representative_counts = {}
    # Open the TSV file and count the members for each representative
    with open(input_filename, 'r') as infile:
        # Skip the header line
        next(infile)
        for line in infile:
            cluster_representative = line.split("\t")[0].strip()
            # Increment the count for this representative or set to 1 if it’s the first occurrence
            representative_counts[cluster_representative] = representative_counts.get(cluster_representative, 0) + 1
    # Filter out representatives with only 1 member
    filtered_representatives = {rep: count for rep, count in representative_counts.items() if count > 0}
    # Print the results
    # Uncomment if you want to print individual representative counts
    # for rep, count in filtered_representatives.items():
    #     print(f”{rep}: {count} members”)
    # Print the total number of remaining representatives
    print(f"\nTotal number of representatives with more than one member for file '{input_filename}': {len(filtered_representatives)}")
    return filtered_representatives

In [10]:
reps_count_si02= process_clusters(file)


Total number of representatives with more than one member for file 'mmseq2_clusters/expression_si02/clusterRes_cluster.tsv': 2442


In [11]:
def get_pdb_clusters(input_filename):
    representative_counts = {}
    representative_files = {}  # Store the PDB files associated with each representative
    pdb_pattern = re.compile('sample')  # Regular expression pattern to match any 5-letter name before .pdb
    # Create a dictionary to store matches for each representative
    matches_dict = {}
    cluster_number = 1
    # Open the TSV file and count the members for each representative
    with open(input_filename, 'r') as infile:
        # Skip the header line
        next(infile)
        for line in infile:
            line_parts = line.split("\t")
            cluster_representative, pdb_file = line_parts[0].strip(), line_parts[1].strip()
            # Increment the count for this representative or set to 1 if it's the first occurrence
            representative_counts[cluster_representative] = representative_counts.get(cluster_representative, 0) + 1
            # Append the PDB file associated with this representative
            if cluster_representative in representative_files:
                representative_files[cluster_representative].append(pdb_file)
            else:
                representative_files[cluster_representative] = [pdb_file]
    # Code moved outside of the loop to process pdb_file and update matches_dict
    for cluster_representative, pdb_files in representative_files.items():
        matches = []
        for pdb_file in pdb_files:
            pdb_matches = pdb_pattern.findall(pdb_file)
            if len(pdb_matches) > 0:
                matches.append(pdb_matches[0])
        if len(matches) > 0:
            matches_dict[cluster_representative] = matches
        else:
            matches_dict[cluster_representative] = None
    # Filter out representatives with only 1 member
    filtered_representatives = {rep: count for rep, count in representative_counts.items() if count > 0}
    # Create a dictionary to store the results
    results_dict = {}
    # Process the results and store them in the results_dict
    for rep, count in filtered_representatives.items():
        pdb_files = representative_files[rep]
        matches = matches_dict[rep]
        # if matches is not None:
        #     print(matches)
        #     # Replace matches in each PDB file once
        #     for i in range(len(pdb_files)):
        #         pdb_files[i] = pdb_pattern.sub(lambda x: x.group() + '.pdb (rcsb structure)', pdb_files[i])
        pdb_files_str = pdb_files if pdb_files else "None"
        # Store the results in the results_dict
        results_dict[rep] = {
            "Count": count,
            "PDB Files": pdb_files_str,
            "Matches": matches,
            "Cluster": cluster_number
        }
        
        cluster_number+=1
    return results_dict

def print_results_with_pdb(results_dict):
    match_dict ={}
    for rep, result in results_dict.items():
        matches = result["Matches"]
        pdb_files_str = result["PDB Files"]
        if matches != "None":
            if matches is not None:
                matches_str = ', '.join(matches)  # Convert matches list to a comma-separated string
                # print(f"Cluster Representative: {rep}, Count: {result['Count']}, len_matches: {len(matches)}, Cluster:{result['Cluster']}")
                match_dict[rep] = {
                    "Count": result['Count'],
                    "PDB Files": pdb_files_str,
                    "rcsb_matches": matches_str,
                    "len_matches": len(matches),
                    "Cluster": result["Cluster"]  
                }
    return match_dict



In [14]:
import re
expression_si02_dict = get_pdb_clusters(file)
match_dict_si02 = print_results_with_pdb(expression_si02_dict)

In [18]:
exp_pdb_train =  [match_dict_si02[clust]["PDB Files"] for clust in match_dict_si02]

In [20]:
sum([len(a) for a in exp_pdb_train])

3150

In [24]:
train_size = 2800
val_size = 200
test_size = 150

cluster_index= 0
len_train =0
for i, cluster in enumerate(exp_pdb_train):
    if len_train < train_size:
        len_train+=len(cluster)
        cluster_index = i

train_cluster_index =  cluster_index
len_val =0 
for i, cluster in enumerate(exp_pdb_train[train_cluster_index+1:]):
    if len_val < val_size:
        len_val+=len(cluster)
        cluster_index = i+train_cluster_index

val_cluster_index = cluster_index

In [34]:
train_dataset= [item for sublist in exp_pdb_train[:train_cluster_index] for item in sublist]
val_dataset = [item for sublist in exp_pdb_train[train_cluster_index+1:val_cluster_index] for item in sublist]
test_dataset =[item for sublist in exp_pdb_train[val_cluster_index+1:] for item in sublist]

In [40]:
df_splits = pd.read_csv('expression_data.csv')

In [43]:
df_splits.loc[df['index'].isin(train_dataset), 'stage']='train'
df_splits.loc[df['index'].isin(val_dataset), 'stage']='val'
df_splits.loc[df['index'].isin(test_dataset), 'stage']='test'

In [47]:
df_splits.to_csv('eSol_output_SA_split

Unnamed: 0,sequence,label,stage,sequence_only,index
0,MdVwKkVkYwAfPaAwSkSfAfNpMlSlVfGdFhDlVfLkGiAfAg...,0.32,train,MVKVYAPASSANMSVGFDVLGAAVTPVDGALLGDVVTVEAAETFSL...,sample_1
1,MdKwLwYaNfLlKvDhHrNvEdQiVdSaFlAlQvAlVlTlQlGlLa...,0.18,train,MKLYNLKDHNEQVSFAQAVTQGLGKNQGLFFPHDLPEFSLTEIDEM...,sample_2
2,M#K#K#M#Q#S#I#V#L#A#L#S#L#V#L#V#A#P#M#A#A#Q#A#...,0.78,train,MKKMQSIVLALSLVLVAPMAAQAAEITLVPSVKLQIGDRDNRGYYW...,sample_3
3,MdLaIeLeIaSePfAaKqTaLwDdYlQpSdPdLdTlTdTpRdYaTd...,0.07,train,MLILISPAKTLDYQSPLTTTRYTLPELLDNSQQLIHEARKLTPPQI...,sample_4
4,M#TdDfKqLlTnSqLlRvQvYqTaTqVeVeAaDeTaGlDdIpAvAl...,0.85,train,MTDKLTSLRQYTTVVADTGDIAAMKLYQPQDATTNPSLILNAAQIP...,sample_5
...,...,...,...,...,...
3146,MdKeIeFaQaRaYdNdPlLlQvVvAlKvYvVlKlIvLdFqRwGdRk...,0.78,train,MKIFQRYNPLQVAKYVKILFRGRLYIKDVGAFEFDKGKILIPKVKD...,sample_3147
3147,MdRkIaFwVfYqGkS#L#R#H#K#Q#G#N#S#H#W#M#T#N#AkQa...,0.98,train,MRIFVYGSLRHKQGNSHWMTNAQLLGDFSIDNYQLYSLGHYPGAVP...,sample_3148
3148,M#V#K#K#S#EpFdEaRaGqDfIwVkLwVfGaFpDpPpAdSdGdHp...,0.94,train,MVKKSEFERGDIVLVGFDPASGHEQQGAGRPALVLSVQAFNQLGMT...,sample_3149
3149,MdNaYwHaQdYwYdPlVfDdIpVqNqGdPhGdTtRaCiTeLtFaVa...,0.69,train,MNYHQYYPVDIVNGPGTRCTLFVSGCVHECPGCYNKSTWRVNSGQP...,sample_3150


In [15]:
DEV_SIZE= 0.1
TEST_SIZE = 0.1

sum_total = sum(item['Count'] for item in expression_si02_dict.values())
# Calculate target lengths once and store them
dev_target_length = round(sum_total * DEV_SIZE * 0.9)
test_target_length = round(sum_total * DEV_SIZE * 1.1)
# Convert match_dict_si02 into a set for faster membership checking
match_dict_si02_set = set(match_dict_si02)
merged_keys = [key for key in expression_si02_dict.keys() if key not in match_dict_si02_set]
exp_pdb_train =  [match_dict_si02[clust]["PDB Files"] for clust in match_dict_si02]
exp_pdb_train = list(exp_pdb_train)
while True:
    indices = np.random.permutation(len(merged_keys))
    train_size = 1 - DEV_SIZE - TEST_SIZE
    train_idx = indices[: int(train_size * len(merged_keys))]
    dev_idx = indices[int(train_size * len(merged_keys)) : int((train_size + DEV_SIZE) * len(merged_keys))]
    test_idx = indices[int((train_size + DEV_SIZE) * len(merged_keys)) :]
    train_items = [merged_keys[i] for i in train_idx]
    dev_items = [merged_keys[i] for i in dev_idx]
    test_items = [merged_keys[i] for i in test_idx]
    train_proteins = [repres_counts_si01_PDB[clust]['PDB Files'] for clust in train_items]
    dev_proteins = [repres_counts_si01_PDB[clust]['PDB Files'] for clust in dev_items]
    test_proteins = [repres_counts_si01_PDB[clust]['PDB Files'] for clust in test_items]
    train_proteins = list(train_proteins)
    dev_proteins = list(dev_proteins)
    test_proteins = list(test_proteins)
    
    train_proteins+=exp_pdb_train
    
    # Check if the lengths of dev_proteins and test_proteins are within the target range
    if dev_target_length <= len(dev_proteins) <= test_target_length and \
       dev_target_length <= len(test_proteins) <= test_target_length:
        break  # Exit the loop if the condition is met
print('Per chain sizes:')
print("Train size: ", len(train_proteins))
print("Dev size: ", len(dev_proteins))
print("Test size: ", len(test_proteins))
print("Train size percentage: ",  len(train_proteins)*100/sum_total)
print("Dev size percentage: ", len(dev_proteins)*100/sum_total)
print("Test size percentage: ", len(test_proteins)*100/sum_total)
print("Train clusters: ", len(train_items))
print("Dev clusters: ", len(dev_items))
print("Test clusters: ", len(test_items))

KeyboardInterrupt: 