In [1]:
import pandas as pd
import numpy as np
import os
from datasets import load_dataset

In [2]:
df = load_dataset(
    "proteinea/ppb_affinity",
    "raw",
    trust_remote_code=True,
    split="train",
).to_pandas()

In [3]:
# these columns are created for processing purposes, they will be removed later
df["ligand"] = df["Ligand Sequences"].str.replace(",", "")
df["receptor"] = df["Receptor Sequences"].str.replace(",", "")
df["combined"] = df["ligand"] + df["receptor"]

In [4]:
df.head()

Unnamed: 0,Source Data Set,Complex ID,PDB,Mutations,Ligand Chains,Receptor Chains,Ligand Name,Receptor Name,KD(M),Affinity Method,...,PDB PubMed ID,PDB Release Date,Affinity PubMed ID,Affinity Release Date,Subgroup,Ligand Sequences,Receptor Sequences,ligand,receptor,combined
0,SKEMPI v2.0,"1A22:A, B::PMID=7504735",1A22,,A,B,Human growth hormone,hGH binding protein,9e-10,SPR,...,9571026.0,1998-04-29,7504735,1993 Dec 5,,FPTIPLSRLFDNAMLRAHRLHQLAFDTYQEFEEAYIPKEQKYSFLQ...,PKFTKCRSPERETFSCHWTDEVHHGTKNLGPIQLFYTRRNTQEWTQ...,FPTIPLSRLFDNAMLRAHRLHQLAFDTYQEFEEAYIPKEQKYSFLQ...,PKFTKCRSPERETFSCHWTDEVHHGTKNLGPIQLFYTRRNTQEWTQ...,FPTIPLSRLFDNAMLRAHRLHQLAFDTYQEFEEAYIPKEQKYSFLQ...
1,SKEMPI v2.0,"1A4Y:A, B::PMID=9050852",1A4Y,,A,B,Ribonuclease inhibitor,Angiogenin,5e-16,Other,...,9311977.0,1998-10-14,9050852,1997 Mar 4,,SLDIQSLDIQCEELSDARWAELLPLLQQCQVVRLDDCGLTEARCKD...,QDNSRYTHFLTQHYDAKPQGRDDRYCESIMRRRGLTSPCKDINTFI...,SLDIQSLDIQCEELSDARWAELLPLLQQCQVVRLDDCGLTEARCKD...,QDNSRYTHFLTQHYDAKPQGRDDRYCESIMRRRGLTSPCKDINTFI...,SLDIQSLDIQCEELSDARWAELLPLLQQCQVVRLDDCGLTEARCKD...
2,SKEMPI v2.0,"1ACB:E, I::PMID=9048543",1ACB,,E,I,Bovine alpha-chymotrypsin,Eglin c,1.49e-12,IASP,...,1583684.0,1993-10-31,9048543,1997 Feb 18,,CGVPAIQPVLSGLSRIVNGEEAVPGSWPWQVSLQDKTGFHFCGGSL...,KSFPEVVGKTVDQAREYFTLHYPQYDVYFLPEGSPVTLDLRYNRVR...,CGVPAIQPVLSGLSRIVNGEEAVPGSWPWQVSLQDKTGFHFCGGSL...,KSFPEVVGKTVDQAREYFTLHYPQYDVYFLPEGSPVTLDLRYNRVR...,CGVPAIQPVLSGLSRIVNGEEAVPGSWPWQVSLQDKTGFHFCGGSL...
3,SKEMPI v2.0,"1AHW:A, B, C::PMID=9480775",1AHW,,"A, B",C,Immunoglobulin fab 5G9,Tissue factor,3.4e-09,IASP,...,9480775.0,1998-02-25,9480775,1998 Feb 6,,DIKMTQSPSSMYASLGERVTITCKASQDIRKYLNWYQQKPWKSPKT...,TNTVAAYNLTWKSTNFKTILEWEPKPVNQVYTVQISTKSGDWKSKC...,DIKMTQSPSSMYASLGERVTITCKASQDIRKYLNWYQQKPWKSPKT...,TNTVAAYNLTWKSTNFKTILEWEPKPVNQVYTVQISTKSGDWKSKC...,DIKMTQSPSSMYASLGERVTITCKASQDIRKYLNWYQQKPWKSPKT...
4,SKEMPI v2.0,"1AK4:A, D::PMID=9223641",1AK4,,A,D,Cyclophilin A,HIV-1 capsid protein,1.2e-05,SPR,...,8980234.0,1997-10-15,9223641,1997 Jun 27,,MVNPTVFFDIAVDGEPLGRVSFELFADKVPKTAENFRALSTGEKGF...,PIVQNLQGQMVHQAISPRTLNAWVKVVEEKAFSPEVIPMFSALSEG...,MVNPTVFFDIAVDGEPLGRVSFELFADKVPKTAENFRALSTGEKGF...,PIVQNLQGQMVHQAISPRTLNAWVKVVEEKAFSPEVIPMFSALSEG...,MVNPTVFFDIAVDGEPLGRVSFELFADKVPKTAENFRALSTGEKGF...


In [5]:
df["KD(M)"] = pd.to_numeric(df["KD(M)"])

In [6]:
# filter out rows where any ligand or receptor chain is shorter than 40
# amino acids
minlen = 40
f = lambda seqs: any([len(seq) < minlen for seq in seqs])
min_len_mask = df["Ligand Sequences"].str.split(",").apply(f) | df[
    "Receptor Sequences"
].str.split(",").apply(f)
df = df[~min_len_mask].reset_index(drop=True)
print(
    f"Dropped {min_len_mask.sum()} rows with ligand or receptor chain(s) "
    f"shorter than {minlen}"
)
print(f"Remaining rows: {len(df)}")

Dropped 2753 rows with ligand or receptor chain(s) shorter than 40
Remaining rows: 9295


In [None]:
# replace empty strings in "Affinity Method" with "Unknown"
df["Affinity Method"] = df["Affinity Method"].apply(
    lambda x: x if x != "" else "Unknown"
)

In [8]:
# check affinity method distribution in the dataset, we will use the Affinity
# Method for handling duplicated sequences
df["Affinity Method"].value_counts(dropna=False)

Affinity Method
Unknown    2931
SPR        2418
FL         1090
Other       786
ITC         543
IASP        461
SP          336
RA          320
ELISA       172
IARA        125
BLI         104
IAGE          9
Name: count, dtype: int64

In [9]:
def filter_duplicates_by_method(
    df,
    seq_col="combined",
    kd_col="KD(M)",
    method_col="Affinity Method",
):
    """
    Group by 'seq_col' (the combined sequence) and resolve duplicates by method priority:
      1. Within each group (all rows sharing the same 'combined' sequence), find the highest-priority method present.
      2. Average the KD(M) values for all rows of that method.
      3. Keep the *first row* from that subset for all other columns.
    """

    # Priority list of methods to choose from
    method_priority = [
        "SPR",
        "ITC",
        "BLI",
        "RA",
        "FL",
        "SP",
        "ELISA",
        "IASP",
        "IARA",
        "IAGE",
        "Other",
        "Unknown",
    ]

    def choose_best_method(group):
        """
        For a subset of df with the same 'combined' sequence:
          - Iterate over the 'method_priority' list in order.
          - If the group contains that method, average KD(M) across those rows.
          - Return the *first row* of that subset for all other columns.
        """
        methods_in_group = group[method_col].unique()

        for meth in method_priority:
            if meth in methods_in_group:
                # Get only the rows of the current top-priority method
                subset = group[group[method_col] == meth]

                # Average KD(M) and pKD for this subset
                avg_kd = subset[kd_col].mean()

                # Take the *first row* from THIS subset
                # (the subset that contributes to the average)
                first_row = subset.iloc[0].copy()

                # Update the KD(M) and method to reflect the average and the chosen method
                first_row[kd_col] = avg_kd
                first_row[method_col] = meth

                return first_row
        raise ValueError(
            "None of the methods in the priority list were found in the group!"
        )

    # Apply the 'choose_best_method' to each group of duplicated sequences
    result_df = df.groupby(seq_col, as_index=False).apply(
        choose_best_method, include_groups=False
    )
    return result_df.reset_index(drop=True)

In [10]:
df = filter_duplicates_by_method(df)
print(f"Remaining rows after filtering duplicates: {len(df)}")

Remaining rows after filtering duplicates: 8207


In [13]:
# fasta_id will be a unique identifier for each row to be used in fasta files
# instead of complex id as it may contain spaces
df["fasta_id"] = df["PDB"] + "_" + df.index.astype(str)
df.head()

Unnamed: 0,combined,Source Data Set,Complex ID,PDB,Mutations,Ligand Chains,Receptor Chains,Ligand Name,Receptor Name,KD(M),...,PDB PubMed ID,PDB Release Date,Affinity PubMed ID,Affinity Release Date,Subgroup,Ligand Sequences,Receptor Sequences,ligand,receptor,fasta_id
0,AAAQYPVVNTNYGKIRGLRTPLPNEILGPVEQYLGVPYASPPTGER...,PDBbind v2020,"2WQZ:B, C::PMID=18093521",2WQZ,,B,C,"Neuroligin4, X-Linked",Neurexin-1-beta,1.32e-07,...,18093521.0,2009-09-08,18093521,2007 Dec 20,,AAAQYPVVNTNYGKIRGLRTPLPNEILGPVEQYLGVPYASPPTGER...,HAGTTYIFSKGGGQITYKWPPNDRPSTRADRLAIGFSTVQKEAVLV...,AAAQYPVVNTNYGKIRGLRTPLPNEILGPVEQYLGVPYASPPTGER...,HAGTTYIFSKGGGQITYKWPPNDRPSTRADRLAIGFSTVQKEAVLV...,2WQZ_0
1,AADWDVYCSQDESIPAKFISRLVTSKDQALEKTEINCSNGLVPITQ...,PDBbind v2020,"5IIA:A, B, C, D::PMID=28622512",5IIA,,"B, D","A, C",sperm lysinR,red abalone egg VERL repeat 3 (VR3),1.81e-09,...,28622512.0,2017-06-14,28622512,2017 Jun 15,,AADWDVYCSQDESIPAKFISRLVTSKDQALEKTEINCSNGLVPITQ...,FLNKAFEVALKVQIIAGFDRGLVKWLRVHGRTLSTVQKKALYFVNR...,AADWDVYCSQDESIPAKFISRLVTSKDQALEKTEINCSNGLVPITQ...,FLNKAFEVALKVQIIAGFDRGLVKWLRVHGRTLSTVQKKALYFVNR...,5IIA_1
2,AAEEEDEVEWVVESIAGFLRGPDWSIPILDFVEQKCEVFDDEEESK...,PDBbind v2020,"4ZI2:A, C::PMID=26455799",4ZI2,,C,A,BART-like domain of BARTL1/CCDC104 1-113,Arl3FL-GppNHp,4.2e-07,...,26455799.0,2015-11-11,26455799,2015 Nov 3,,AAEEEDEVEWVVESIAGFLRGPDWSIPILDFVEQKCEVFDDEEESK...,LLSILRKLKSAPDQEVRILLLGLDNAGKTTLLKQLASEDISHITPT...,AAEEEDEVEWVVESIAGFLRGPDWSIPILDFVEQKCEVFDDEEESK...,LLSILRKLKSAPDQEVRILLLGLDNAGKTTLLKQLASEDISHITPT...,4ZI2_2
3,AAILGDEYLWSGGVIPYTFAGVSGADQSAILSGMQELEEKTCIRFV...,PDBbind v2020,"6SAZ:A, B::PMID=31604990",6SAZ,,A,B,Crayfish Astacin,Cleaved human fetuin-b,1.4e-10,...,31604990.0,2019-10-23,31604990,2019 Oct 11,,AAILGDEYLWSGGVIPYTFAGVSGADQSAILSGMQELEEKTCIRFV...,ALNPSALLSRGCNDSDVLAVAGFALRDINKDRKDGYVLRLNRVNDA...,AAILGDEYLWSGGVIPYTFAGVSGADQSAILSGMQELEEKTCIRFV...,ALNPSALLSRGCNDSDVLAVAGFALRDINKDRKDGYVLRLNRVNDA...,6SAZ_3
4,AAKEGWLHFRPLVTDKGKRVGGSIRPWKQMYVVLRGTTPSEEEHSL...,PDBbind v2020,"2J59:D, P::PMID=17347647",2J59,,P,D,RHO-GTPASE ACTIVATING PROTEIN 10,ADP-RIBOSYLATION FACTOR 1,5.5e-08,...,17347647.0,2007-02-20,17347647,2007 Apr 4,,AAKEGWLHFRPLVTDKGKRVGGSIRPWKQMYVVLRGTTPSEEEHSL...,GSMRILMVGLDAAGKTTILYKLKLGEIVTTIPTIGFNVETVEYKNI...,AAKEGWLHFRPLVTDKGKRVGGSIRPWKQMYVVLRGTTPSEEEHSL...,GSMRILMVGLDAAGKTTILYKLKLGEIVTTIPTIGFNVETVEYKNI...,2J59_4


In [14]:
os.makedirs("data_files", exist_ok=True)
with open("data_files/all_seqs.fasta", "w") as f:
    for i, row in df.iterrows():
        fasta_id = row["fasta_id"]
        sequence = row["combined"]
        f.write(f">{fasta_id}\n{sequence}\n")

We cluster all the sequnces into as many clusters as it takes such that for each cluster the minimum sequence identity between the representative of the cluster and any other sequence in the cluster is at least 30%. We use default `easy-cluster` coverage parameters of both query and key sequences with minimum coverage of 80%. We also use the `connected component` algorithm to cluster the sequences to ensure that a single cluster can cover more remote homologs.

In [15]:
!mmseqs easy-cluster data_files/all_seqs.fasta data_files/cluster tmp --min-seq-id 0.3 --remove-tmp-files --cluster-mode 1 --single-step-clustering

easy-cluster data_files/all_seqs.fasta data_files/cluster tmp --min-seq-id 0.3 --remove-tmp-files --cluster-mode 1 --single-step-clustering 

MMseqs Version:                     	a2815df9a6c6da173589fb65b3f71639ea08336d
Substitution matrix                 	aa:blosum62.out,nucl:nucleotide.out
Seed substitution matrix            	aa:VTML80.out,nucl:nucleotide.out
Sensitivity                         	4
k-mer length                        	0
Target search mode                  	0
k-score                             	seq:2147483647,prof:2147483647
Alphabet size                       	aa:21,nucl:5
Max sequence length                 	65535
Max results per query               	20
Split database                      	0
Split mode                          	2
Split memory limit                  	0
Coverage threshold                  	0.8
Coverage mode                       	0
Compositional bias                  	1
Compositional bias                  	1
Diagonal scoring                    	true
E

Index table: fill
Index statistics
Entries:          2184101
DB size:          500 MB
Avg k-mer size:   0.034127
Top 10 k-mers
    VHFAQS	1149
    VKYPVT	1139
    WNGLGV	1114
    LYLSTV	1072
    SVCLFY	934
    STSSLT	915
    TVWSLT	854
    AAGLDY	851
    YFEVSW	847
    PALSLY	847
Time for index table init: 0h 0m 0s 424ms
Process prefiltering step 1 of 1

k-mer similarity threshold: 118
Starting prefiltering scores calculation (step 1 of 1)
Query db start 1 to 4818
Target db start 1 to 4818

204.596370 k-mers per position
62098 DB matches per sequence
0 overflows
18 sequences passed prefiltering per query sequence
20 median result list length
0 sequences with 0 size result lists
Time for merging to pref: 0h 0m 0s 6ms
Time for processing: 0h 0m 0s 961ms
align tmp/1471709256671452136/clu_tmp/16603263552163503201/input_step_redundancy tmp/1471709256671452136/clu_tmp/16603263552163503201/input_step_redundancy tmp/1471709256671452136/clu_tmp/16603263552163503201/pref tmp/1471709256671452136/

In [16]:
cluster_df = pd.read_csv(
    "data_files/cluster_cluster.tsv", sep="\t", header=None
)
cluster_df = cluster_df.rename(columns={0: "cluster_rep_id", 1: "fasta_id"})
cluster_df

Unnamed: 0,cluster_rep_id,fasta_id
0,5IIA_1,5IIA_1
1,6SAZ_3,6SAZ_3
2,2J59_6,2J59_6
3,2J59_6,2J59_4
4,2J59_6,2J59_7
...,...,...
8202,4IDJ_6476,4IDJ_6476
8203,3FKU_6477,3FKU_6477
8204,3FKU_6477,3FKU_6479
8205,3FKU_6477,3FKU_6480


In [17]:
cluster_df_count = cluster_df.groupby(["cluster_rep_id"]).count().sort_values(
    "fasta_id", ascending=False
).rename(columns={"fasta_id": "cluster_size"})
cluster_df_count

Unnamed: 0_level_0,cluster_size
cluster_rep_id,Unnamed: 1_level_1
1R0R_374,311
3SGB_3801,298
1DEE_6384,276
1A22_2752,254
1PPF_4211,247
...,...
6PLK_8036,1
6PPG_2463,1
6QB3_6603,1
6QB6_6602,1


We construct the training set by assigning the largest clusters first to the training set until the training set lengh is approx. 75% of the total number of sequences. We don't want to reach 80% right now, because later we will perform several iterations to transfer some sequences from test to train set if they are too similar to the training set. Also, we found that some entries in the data may have the same PDB but the ligand and receptor chains are swapped. Such entries are not guaranteed to be clustered together, but they must not end in two different sets. Therefore, for each cluster member when we assign it to the training set, we also assign all other members of the same PDB to the training set. This way we ensure that all entries with the same PDB are in the same set.

In [18]:
train_fasta_ids = set()
target_train_count = round(0.75 * len(df))
train_clusters = set()
for cluster_rep_id, cluster_size in cluster_df_count.iterrows():
    # add largest cluster first to the training set
    train_clusters.add(cluster_rep_id)
    cluster_members = cluster_df[
        cluster_df["cluster_rep_id"] == cluster_rep_id
    ]["fasta_id"].tolist()
    for member in cluster_members:
        if member in train_fasta_ids:
            continue
        train_fasta_ids.add(member)
        member_pdb = df[df["fasta_id"] == member]["PDB"].values[0]
        [
            train_fasta_ids.add(fasta_id)
            for fasta_id in df[df["PDB"] == member_pdb]["fasta_id"].tolist()
        ]
    if len(train_fasta_ids) >= target_train_count:
        break

df["split"] = df["fasta_id"].apply(
    lambda x: "train" if x in train_fasta_ids else "test"
)
(df["split"].value_counts() / len(df)).round(2)

split
train    0.75
test     0.25
Name: count, dtype: float64

Next we perform as many iterations as needed to transfer sequences from the test set to the training set if they are too similar to the training set. In each iteration we do the following:
1. Calculate the sequence identity between the training and test set.
2. Transfer all sequences in test set that have sequence identity of at least 30% with any sequence in the training set to the training set along with all other sequences from the same PDB (for the same reason described above).
3. Repeat until no more sequences are transferred.

Finally, we remove all sequences that are too similar to the training set from the test set.

In [19]:
from subprocess import run


def do_split_iteration(df, splits):
    split1, split2 = splits
    with open(f"data_files/{split1}.fasta", "w") as f:
        for _, row in df[df["split"] == split1].iterrows():
            fasta_id = row["fasta_id"]
            seq = row["combined"]
            f.write(f">{fasta_id}\n{seq}\n")

    with open(f"data_files/{split2}.fasta", "w") as f:
        for _, row in df[df["split"] == split2].iterrows():
            fasta_id = row["fasta_id"]
            seq = row["combined"]
            f.write(f">{fasta_id}\n{seq}\n")

    command = [
        "mmseqs",
        "easy-search",
        f"data_files/{split1}.fasta",
        f"data_files/{split2}.fasta",
        "data_files/matches.tsv",
        "tmp",
        "--min-seq-id",
        "0.3",
        "-c",
        "0.8",
    ]

    result = run(command, capture_output=True, text=True)
    if result.returncode != 0:
        print(result.stderr)
        raise ValueError("Error running mmseqs")
    try:
        matches_df = pd.read_csv(
            "data_files/matches.tsv", sep="\t", header=None
        )
    except pd.errors.EmptyDataError:
        print("No matches found")
        return df, 0
    matches_df = matches_df.rename(
        columns={0: f"{split1}_id", 1: f"{split2}_id", 2: "pident"}
    )
    fasta_ids_to_change_split = set(matches_df[f"{split2}_id"].unique())
    pdbs_to_change_split = df[
        df["fasta_id"].isin(fasta_ids_to_change_split)
    ]["PDB"].unique()
    for pdb in pdbs_to_change_split.tolist():
        fasta_ids_to_change_split.update(
            df[df["PDB"] == pdb]["fasta_id"].tolist()
        )
    print(
        f"Changing {len(fasta_ids_to_change_split)} sequences from {split2} "
        f"to {split1}"
    )
    df["split"] = df.apply(
        lambda row: (
            split1
            if row["fasta_id"] in fasta_ids_to_change_split
            else row["split"]
        ),
        axis=1,
    )
    return df, len(fasta_ids_to_change_split)

In [20]:
while True:
    df, total_changed = do_split_iteration(df, ("train", "test"))
    print(
        f"Current split ratios: {(df['split'].value_counts() / len(df)).round(2)}"
    )
    print("*" * 80)
    if (df["split"] == "test").sum() == 0 or total_changed == 0:
        break

Changing 273 sequences from test to train
Current split ratios: split
train    0.78
test     0.22
Name: count, dtype: float64
********************************************************************************
Changing 47 sequences from test to train
Current split ratios: split
train    0.79
test     0.21
Name: count, dtype: float64
********************************************************************************
Changing 8 sequences from test to train
Current split ratios: split
train    0.79
test     0.21
Name: count, dtype: float64
********************************************************************************
No matches found
Current split ratios: split
train    0.79
test     0.21
Name: count, dtype: float64
********************************************************************************


Now the train and test sets are properly separated. We now need to split the test set into two approximately equal parts: validation and test set. We cluster the test set using the same parameters as before and start assigning the largest cluster to the test set and the second largest to the validation set and so on untill all clusters are assigned. We also ensure that all entries with the same PDB are in the same set as before.

In [21]:
with open("data_files/test_seqs.fasta", "w") as f:
    for i, row in df[df["split"] == "test"].iterrows():
        fasta_id = row["fasta_id"]
        sequence = row["combined"]
        f.write(f">{fasta_id}\n{sequence}\n")

In [22]:
!mmseqs easy-cluster data_files/test_seqs.fasta data_files/cluster tmp --min-seq-id 0.3 --remove-tmp-files --cluster-mode 1 --single-step-clustering

easy-cluster data_files/test_seqs.fasta data_files/cluster tmp --min-seq-id 0.3 --remove-tmp-files --cluster-mode 1 --single-step-clustering 

MMseqs Version:                     	a2815df9a6c6da173589fb65b3f71639ea08336d
Substitution matrix                 	aa:blosum62.out,nucl:nucleotide.out
Seed substitution matrix            	aa:VTML80.out,nucl:nucleotide.out
Sensitivity                         	4
k-mer length                        	0
Target search mode                  	0
k-score                             	seq:2147483647,prof:2147483647
Alphabet size                       	aa:21,nucl:5
Max sequence length                 	65535
Max results per query               	20
Split database                      	0
Split mode                          	2
Split memory limit                  	0
Coverage threshold                  	0.8
Coverage mode                       	0
Compositional bias                  	1
Compositional bias                  	1
Diagonal scoring                    	true


In [23]:
cluster_df = pd.read_csv(
    "data_files/cluster_cluster.tsv", sep="\t", header=None
)
cluster_df = cluster_df.rename(columns={0: "cluster_rep_id", 1: "fasta_id"})
cluster_df

Unnamed: 0,cluster_rep_id,fasta_id
0,2AZE_2521,2AZE_2521
1,5JHL_2524,5JHL_2524
2,5D50_2525,5D50_2525
3,4DXA_2526,4DXA_2526
4,4NQW_2527,4NQW_2527
...,...,...
1717,6FE4_6240,6FE4_6242
1718,6FE4_6240,6FE4_6243
1719,5TOJ_6246,5TOJ_6246
1720,5TOJ_6246,5TOJ_6244


In [24]:
cluster_df_count = cluster_df.groupby(["cluster_rep_id"]).count().sort_values(
    "fasta_id", ascending=False
).rename(columns={"fasta_id": "cluster_size"})
cluster_df_count

Unnamed: 0_level_0,cluster_size
cluster_rep_id,Unnamed: 1_level_1
6EWB_6311,7
5FV2_2194,6
2C7N_5023,6
6FE4_6240,6
1TH1_5987,6
...,...
1B6C_3200,1
6J14_5441,1
6JDJ_896,1
6JHW_2954,1


In [25]:
val_fasta_ids = set()
test_fasta_ids = set(df[df["split"] == "test"]["fasta_id"].tolist())
add_to_val = False
for cluster_rep_id, cluster_size in cluster_df_count.iterrows():
    cluster_members = cluster_df[
        cluster_df["cluster_rep_id"] == cluster_rep_id
    ]["fasta_id"].tolist()
    if add_to_val:
        for member in cluster_members:
            val_fasta_ids.add(member)
            member_pdb = df[df["fasta_id"] == member]["PDB"].values[0]
            [
                val_fasta_ids.add(fasta_id)
                for fasta_id in df[df["PDB"] == member_pdb][
                    "fasta_id"
                ].tolist()
            ]
    add_to_val = not add_to_val
test_fasta_ids = test_fasta_ids - val_fasta_ids
df["split"] = df.apply(
    lambda x: "val" if x["fasta_id"] in val_fasta_ids else x["split"], axis=1
)
df["split"] = df.apply(
    lambda x: "test" if x["fasta_id"] in test_fasta_ids else x["split"], axis=1
)
(df["split"].value_counts() / len(df)).round(2)

split
train    0.79
val      0.12
test     0.09
Name: count, dtype: float64

In [31]:
df.columns

Index(['combined', 'Source Data Set', 'Complex ID', 'PDB', 'Mutations',
       'Ligand Chains', 'Receptor Chains', 'Ligand Name', 'Receptor Name',
       'KD(M)', 'Affinity Method', 'Structure Method', 'Temperature(K)',
       'Resolution(Å)', 'PDB PubMed ID', 'PDB Release Date',
       'Affinity PubMed ID', 'Affinity Release Date', 'Subgroup',
       'Ligand Sequences', 'Receptor Sequences', 'ligand', 'receptor',
       'fasta_id', 'split'],
      dtype='object')

In [41]:
df = df.drop(
    columns=[
        "ligand",
        "receptor",
        "combined",
        "fasta_id",
    ]
)
df.to_csv("PPB_Affinity_processed_filtered.csv", index=False)