In [40]:
import pandas as pd

# Read split files
flip_test = pd.read_csv('/hpi/fs00/scratch/tobias.fiedler/hotprot_data/splits/test_FLIP.csv')
flip_val = pd.read_csv('/hpi/fs00/scratch/tobias.fiedler/hotprot_data/splits/val_FLIP.csv')
flip_train = pd.read_csv('/hpi/fs00/scratch/tobias.fiedler/hotprot_data/splits/train_FLIP.csv')
epa_test = pd.read_csv('/hpi/fs00/scratch/tobias.fiedler/hotprot_data/splits/val_median.csv')
epa_val = pd.read_csv('/hpi/fs00/scratch/tobias.fiedler/hotprot_data/splits/test_median.csv')
epa_train = pd.read_csv('/hpi/fs00/scratch/tobias.fiedler/hotprot_data/splits/train_median.csv')

# Extract sequences from split files
test_sequences = set(flip_test['sequence'].to_numpy())
val_sequences = set(flip_val['sequence'].to_numpy())
epa_test_sequences = set(epa_test['sequence'].to_numpy())
epa_val_sequences = set(epa_val['sequence'].to_numpy())
total_val_sequences = test_sequences.union(val_sequences)
epa_total_val_sequences = epa_test_sequences.union(epa_val_sequences)

In [41]:
# Read ESM2 pretraining validation clusters
esm_eval_clusters = dict()
esm_eval_ids = set()
with open("./uniref201803_ur100_valid_headers.txt") as txt_file:
    for line in txt_file:
        parts = line.split(" ")
        id = parts[0].split("_")[1]
        cluster = parts[1].split("_")[1].replace("\n", "")
        esm_eval_ids.add(id)
        if cluster not in esm_eval_clusters:
            esm_eval_clusters[cluster] = []
        esm_eval_clusters[cluster].append(id)

In [42]:
# Complete dataset from FLIP with all measurements
def read_fasta(filepath="full_dataset_sequences.fasta"):
    first = True
    max = 0
    dataset = []
    with open(filepath) as fasta:
        for line in fasta:
            if line[0] == ">":
                if first:
                    first = False
                else:
                    dataset.append(entry)
                entry = {}
                header_tokens = line.split(" ")
                entry["id"] = header_tokens[0].replace(">", "").split("_")[0]
                entry["header"] = line.replace("\n", "")
                entry["temp"] = float(header_tokens[1].split("=")[1].replace("\n", ""))
                entry["sequence"] = ""
            else:
                entry["sequence"] = entry["sequence"] + line.replace("\n", "")
                max = len(entry["sequence"]) if len(entry["sequence"]) > max else max

    return dataset


flip_dataset = read_fasta()

In [43]:
# Build list of validation proteins by linking the sequences to the ids
eval_proteins = set()
for entry in flip_dataset:
    if entry["id"] in esm_eval_ids:
        eval_proteins.add(entry["sequence"])

In [46]:
def data_leakage(total_val_sequences_of_split, esm_eval_proteins):
    # Calculate intersection of test/val sequences and validation proteins
    # --> Low intersection means that the test/val sequences must come from ESM2 pretraining train set 
    # --> data leakage
    intersection = total_val_sequences_of_split.intersection(esm_eval_proteins)
    return 1.0 - len(intersection) / len(total_val_sequences_of_split)

In [47]:
epa_data_leakage = data_leakage(epa_total_val_sequences, eval_proteins)
flip_data_leakage = data_leakage(total_val_sequences, eval_proteins)
print("Data leakage in EPA split: ", epa_data_leakage)
print("Data leakage in FLIP split: ", flip_data_leakage)

Data leakage in EPA split:  0.0
Data leakage in FLIP split:  0.9083484573502723
