In [None]:
from glob import glob
import math
import random
import itertools
import numpy as np
import pandas as pd
from tqdm import tqdm
import pickle
import sys
import os

In [2]:
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
sys.path.append(os.path.abspath(".."))

In [3]:
df = pd.read_pickle("../saved/sampled_proteins.pkl")
df.head()

Unnamed: 0,filename,rmsd,tm,ptm,len,recovery,seq
0,3p5pa02_seq30,2.260551,0.879183,0.79,202,0.30198,SNDDRKTAEIAQLKAKDADGDDHPYPPAKLQKVLNDQLANGAWKVA...
1,3p5pa02_seq40,1.364573,0.945448,0.72,202,0.336634,LLSPSQVAHVARLQATNARGEQEPRFKAALEVCRGSQLDNGSWGEA...
2,3p5pa02_seq26,1.723687,0.921527,0.84,202,0.351485,ARDAFGTAQVAQLKEPNPLGRLIPLYEGYLNYVSNHQLEDQSWGIE...
3,3p5pa02_seq62,1.414773,0.946349,0.84,202,0.316832,ENDHHEISHIALLQQTDIFGLKSPMFPSALTIVTEHREHNELWGVE...
4,3p5pa02_seq64,1.576217,0.936109,0.87,202,0.30198,KFDHYAVAAVATLYCRTCHERLAPAFPDALHFVQNHQLPNGRWGVK...


In [4]:
# only keep the ones with tm-score>0.8 and ptm>0.8
df = df[(df['tm'] > 0.8) & (df['ptm'] > 0.8)]
df

Unnamed: 0,filename,rmsd,tm,ptm,len,recovery,seq
2,3p5pa02_seq26,1.723687,0.921527,0.84,202,0.351485,ARDAFGTAQVAQLKEPNPLGRLIPLYEGYLNYVSNHQLEDQSWGIE...
3,3p5pa02_seq62,1.414773,0.946349,0.84,202,0.316832,ENDHHEISHIALLQQTDIFGLKSPMFPSALTIVTEHREHNELWGVE...
4,3p5pa02_seq64,1.576217,0.936109,0.87,202,0.301980,KFDHYAVAAVATLYCRTCHERLAPAFPDALHFVQNHQLPNGRWGVK...
5,3p5pa02_seq12,1.284573,0.951054,0.86,202,0.321782,GDNYFTMAYVAQLTRTDANGNEAPDFPAALKYVKNNQLSDGSYGTA...
6,3p5pa02_seq23,1.536744,0.935105,0.81,202,0.400990,EHNFNLRADVARLQETEPQGGWRPRWPVNGSYVQNHQHADGSLGIE...
...,...,...,...,...,...,...,...
28308,3kt4a02_seq150,1.682975,0.961751,0.82,316,0.303797,EMGLHFGADHYPQDRVSPLGANVLSACRSMLEXXXXXXXXXXXXXX...
28309,3kt4a02_seq152,1.397122,0.964856,0.81,316,0.303797,NDSHYEFKYMYPKDVMDSLDMSTKLQNYQLLLXXXXXXXXXXXXXX...
28314,3kt4a02_seq157,1.486581,0.966366,0.81,316,0.281646,MRLIFLWKAVYPFCETDPLPERIINKMRTRLRXXXXXXXXXXXXXX...
28317,3kt4a02_seq158,1.550105,0.952713,0.82,316,0.278481,ALNRLFKYFVYPPNEVILLDDDTKKRTLERLNXXXXXXXXXXXXXX...


In [5]:
domain_count = {}
for filename in df['filename']:
    domain = filename.split("_")[0]
    if domain not in domain_count:
        domain_count[domain] = 1
    else:
        domain_count[domain] += 1
        
len({k: v for k, v in domain_count.items() if v >= 16})

90

# build dataset

In [6]:
# reindex the dataframe
df = df.reset_index(drop=True)
df.head()

Unnamed: 0,filename,rmsd,tm,ptm,len,recovery,seq
0,3p5pa02_seq26,1.723687,0.921527,0.84,202,0.351485,ARDAFGTAQVAQLKEPNPLGRLIPLYEGYLNYVSNHQLEDQSWGIE...
1,3p5pa02_seq62,1.414773,0.946349,0.84,202,0.316832,ENDHHEISHIALLQQTDIFGLKSPMFPSALTIVTEHREHNELWGVE...
2,3p5pa02_seq64,1.576217,0.936109,0.87,202,0.30198,KFDHYAVAAVATLYCRTCHERLAPAFPDALHFVQNHQLPNGRWGVK...
3,3p5pa02_seq12,1.284573,0.951054,0.86,202,0.321782,GDNYFTMAYVAQLTRTDANGNEAPDFPAALKYVKNNQLSDGSYGTA...
4,3p5pa02_seq23,1.536744,0.935105,0.81,202,0.40099,EHNFNLRADVARLQETEPQGGWRPRWPVNGSYVQNHQHADGSLGIE...


In [7]:
domain_to_seqs = {}
for i in range(len(df)):
    filename = df['filename'][i]
    seq = df['seq'][i]
    name = filename.split("_")[0]
    if name not in domain_to_seqs:
        domain_to_seqs[name] = [seq]
    else:
        domain_to_seqs[name].append(seq)

In [None]:
from metrics.diversity import compute_diversity

In [None]:
# new dict to store the results
domain_to_highest_diversity_seqs = {}

# loop through each domain and its sampled sequences
for domain, sampled_seqs in tqdm(domain_to_seqs.items(), desc="Processing Files", leave=True):
    if len(sampled_seqs) >= 16:
        highest_diversity = -1.0   # track the highest diversity found
        best_combination = None    # track the best combination

        n_starts = 160
        rng = random.Random(42)

        # pick n_starts distinct random starting sequences
        k_starts = min(n_starts, len(sampled_seqs))
        start_candidates = rng.sample(sampled_seqs, k=k_starts)

        for start_seq in start_candidates:
            selected = [start_seq]
            # use value inequality instead of identity check
            remaining = [s for s in sampled_seqs if s != start_seq]

            # greedily grow the set to 16 sequences
            while len(selected) < 16 and remaining:
                best_cand = None
                best_div_after_add = -1.0
                for seq in remaining:
                    div = compute_diversity(selected + [seq])
                    if div > best_div_after_add:
                        best_div_after_add = div
                        best_cand = seq
                        
                if best_cand is None:
                    break

                selected.append(best_cand)
                remaining.remove(best_cand)

            # only consider complete sets of size 16
            if len(selected) == 16:
                total_div = compute_diversity(selected)
                if total_div > highest_diversity:
                    highest_diversity = total_div
                    best_combination = selected

        # store the best combination in the results dictionary
        domain_to_highest_diversity_seqs[domain] = best_combination


Processing Files: 100%|██████████| 100/100 [4:32:14<00:00, 163.34s/it] 


In [10]:
dict(list(domain_to_highest_diversity_seqs.items())[:3])

{'3p5pa02': ['YRHWGATASVALLERLWNTGNRMPQYPSALSWVAATMKKDGSWGIWTKPYLWSRLRDTLSAAIALAKWQCAATETQKAIHWVNSEIQRMHSVSKRQPWFKVYLPKLFTEAVGLGLELPYDDAYIQNIIHERAEGLKERQENPEGLTQDQQPGIDGYTDVIDWKNVLRPRKPDGSFDPGLGPTATIYASTEDPRAGEYIKRHL',
  'KPDMYQQSEVAKLVRDDEPGSAAPLYESALAYVAKNQLDDGTWGDKENPQLSKQLKATSASILALKNFNTNEINLELGEKYFNKELHALNEDTPLIPCMDLLWNFLCDQLLKASIKFPYDLPFVKRLQEQMSTTLRAVASCESGMPEQMLSARGCMYYVLNLDKIQCFRESNGSYKGDSGATAHLLKMTSDEKAMEWLASKL',
  'SMDPFSKAYIALLVKQNSKGLLEPLFPGEVEWVKDHVQDNGYWGKSAQPNLDCEYQATTAGAIYLKEWGTGKEAIARNTRFLNEQMQKLTEQHVQVQMFRQDFPELLIKYQALDLELDYEIPYIAVMLHDRESFLQSIAQNLIGLPPRLLESFYGFTPCVDMNRVREYKEDNGSLCGMLANTAAYLAATNEWDALTYIREML',
  'CYSPVGLRFIAEMTKTTANGEDTPAFPSALLAVMNSTLVDGSYGKARNRKNWQRLLFSLRSIVALKKHHAGDKEVCRATKFINGTLDKLDRASSTIPGFEVLFPAVVISAQAKGITLPWDHKIMQTLIELRDQVLEGVANNADPLPEALVPQLAGLLGVADTQKIKKYQDADGSYDGSYDQTSAVLALTGDRKCGKYIKGEI',
  'ARDAFGTAQVAQLKEPNPLGRLIPLYEGYLNYVSNHQLEDQSWGIEDKYELDNRSASTLQCVIALVTWDSDTDNVSSGVTFINNNHDKEHDERVHRPDFNIYYPASWMVAADWGIQIPKDRPVIQAHIAEREAELQMVRNTPSGLPLNLLASIY

In [11]:
# save to pickle
with open("../saved/RemoteFoldSet_domain_to_seqs.pkl", "wb") as f:
    pickle.dump(domain_to_highest_diversity_seqs, f)