# Dataset creation

![flowchart](./doc/dataset_creation_flowchart.drawio.png)


In [18]:
import os
import pandas as pd
import numpy as np

from utils.file_utils import open_json, write_json
from utils.dataset_creation import *
from utils.dataset_mapping import *


In [19]:
FRESH_START = True
UPDATE_MAPPING = True


In [20]:
COLUMNS = ["pdbs", "uniprot", "wild_aa", "mutated_chain", "mutation_position",
           "mutated_aa", "pH",
           "sequence", "length", "chain_start", "chain_end",
           "AlphaFoldDB", "Tm", "ddG", "dTm",
           "dataset_source", "source_id", "infos_found"]

SUBSET_DUPLICATES = ["wild_aa", "mutation_position",
                     "mutated_aa", "pH", "sequence"]

NAME = "all_dTm_source_id"
DIR = "./data/main_dataset_creation"
OUTPUT_DIR = DIR+'/outputs/'+NAME

LOCAL_UNIPROT_INFOS_PATH = DIR+"/uniprot_infos.json"
PDB_UNIPROT_MAPPING_PATH = DIR+"/mapping/pdb_uniprot_mapping.json"
LINKED_UNIPROT_MAPPING_PATH = DIR+"/mapping/linked_uniprot_mapping.json"
SEQUENCE_UNIPROT_MAPPING_PATH = DIR + \
    "/mapping/sequence_uniprot_mapping.json"
PDB_NO_UNIPROT_PATH = DIR+"/mapping/pdb_no_uniprot.json"
SEQUENCE_NO_UNIPROT_PATH = DIR+"/mapping/sequence_no_uniprot.json"

DATASET_OUTPUT_PATH_RAW = OUTPUT_DIR+f"/dataset_raw.csv"
DATASET_OUTPUT_PATH_ONLY_INFOS = OUTPUT_DIR+f"/dataset_only_infos.csv"


# Infos for dataset creation


In [21]:
local_uniprot_infos = open_json(LOCAL_UNIPROT_INFOS_PATH)
dataset_config = open_json(DIR+"/dataset_config.json")

print(f"loaded {len(local_uniprot_infos)} uniprot infos from local storage")


loaded 578 uniprot infos from local storage


In [22]:
# prepare output dir
if not os.path.exists(OUTPUT_DIR):
    print(f"creating {OUTPUT_DIR} folder")
    os.mkdir(OUTPUT_DIR)


creating ./data/main_dataset_creation/outputs/all_dTm_source_id folder


### Loop through all the required dataset


In [23]:
if not FRESH_START:
    main_df = pd.read_csv(DATASET_OUTPUT_PATH_RAW)
else:
    main_df = pd.DataFrame()
    main_df = add_missing_column(main_df, COLUMNS)

    for dataset_source in dataset_config["dataset_to_process"]:
        errors = {
            "no_sequence_in_data": 0,
            "not_in_local": 0,
            "wrong_position": 0,
            "no_uniprot": 0,
            "no_pdb": 0,
            "no_sequence": 0,
        }

        individual_config = dataset_config[dataset_source]
        # load csv
        if dataset_source == "thermomutdb":
            df = pd.read_json(individual_config["data_path"])
            df = df[df.mut_count.eq(0)]
            df[df.uniprot.eq('-')] = np.nan
        elif dataset_source == "Q3421":
            df = pd.read_csv(individual_config["data_path"],
                             delimiter="\s+").iloc[1:].reset_index(drop=True)
            df.ddG = df.ddG.astype(float)
            df.pH = df.pH.astype(float)
            df["Pos(PDB)"] = df["Pos(PDB)"].astype(float)
        else:
            df = pd.read_csv(individual_config["data_path"],
                             sep=individual_config.get("sep", ',')).drop_duplicates()
        
        if dataset_source == "fireprotdb":
            # Invalid PDB structures, we use AF2 structures in FireProtDB
            df = df.dropna(subset=['pdb_id']).reset_index(drop=True)

        # rename columns
        df.rename(columns=individual_config["renaming_dict"],
                  inplace=True)
        if dataset_source in ["S140", "S2648", "Q1744", "Q3214"]:
            df["mutated_chain"] = df.pdbs.str[-1]
            df.pdbs = df.pdbs.str[:-1]
        # add missing columns
        df = add_missing_column(df, COLUMNS)
        # ad source_id
        df["source_id"] = dataset_source+"_"+df.index.astype(str)

        # split mutation code if needed
        if individual_config["need_mutation_code_split"]:
            df = df.apply(apply_split_mutation_code, axis=1)
        # remove nan mutation_code
        df = df[~df["mutation_position"].isna()]
        # keep only COLUMNS
        df = df[COLUMNS]
        # drop duplicates
        df.drop_duplicates(inplace=True)
        # add dataset_source
        df["dataset_source"] = dataset_source
        # index start at 0
        df["mutation_position"] = df["mutation_position"].apply(lambda x: x-1)
        # max precision of pH: .1 (to avoid duplicates)
        df["pH"] = df["pH"].round(1)

        # apply target corrections
        df["ddG"] = df["ddG"].apply(
            lambda x: x*individual_config["corrections"]["ddG"])
        df["dTm"] = df["dTm"].apply(
            lambda x: x*individual_config["corrections"]["dTm"])
        # better to initialize infos_found at 0 than nan
        df["infos_found"] = 0
        # by default chain is "A"
        df["mutated_chain"].fillna("A")
        df["mutated_chain"] = df["mutated_chain"].astype("str")
        df["mutated_chain"].str.replace('0', 'A')

        # check number of rows without uniprot
        # check validity of uniprot, and add the infos for those
        df = df.apply(lambda row: apply_valid_uniprot(
            row, local_uniprot_infos, dataset_config, errors), axis=1)
        
        # if target is not "all", we keep only the target, 
        # in order to make sure data is not lost when we remove duplicates
        if dataset_config["general_config"]["target"]=="ddG":
            l = len(df)
            df = df[~(df.ddG.isna())]
            print(f"target is ddG, so we removed {l-len(df)} rows")
        elif dataset_config["general_config"]["target"] == "dTm":
            l = len(df)
            df = df[~(df.dTm.isna())]
            print(f"target is dTm, so we removed {l-len(df)} rows")


        print(f"processed {dataset_source}:")
        print(f"{errors=}")

        main_df = pd.concat([main_df, df], ignore_index=True)
        main_df.drop_duplicates(SUBSET_DUPLICATES, inplace=True)

    # save
    write_json(LOCAL_UNIPROT_INFOS_PATH, local_uniprot_infos)
    main_df.to_csv(DATASET_OUTPUT_PATH_RAW, index=False)


  df = pd.read_csv(individual_config["data_path"],


target is dTm, so we removed 12098 rows
processed fireprotdb:
errors={'no_sequence_in_data': 0, 'not_in_local': 0, 'wrong_position': 226, 'no_uniprot': 0, 'no_pdb': 0, 'no_sequence': 0}
target is dTm, so we removed 6118 rows
processed thermomutdb:
errors={'no_sequence_in_data': 0, 'not_in_local': 0, 'wrong_position': 2348, 'no_uniprot': 0, 'no_pdb': 0, 'no_sequence': 0}
target is dTm, so we removed 2557 rows
processed O2567_new:
errors={'no_sequence_in_data': 0, 'not_in_local': 0, 'wrong_position': 0, 'no_uniprot': 0, 'no_pdb': 0, 'no_sequence': 0}
target is dTm, so we removed 3639 rows
processed prothermdb:
errors={'no_sequence_in_data': 0, 'not_in_local': 0, 'wrong_position': 1112, 'no_uniprot': 0, 'no_pdb': 0, 'no_sequence': 0}
target is dTm, so we removed 630 rows
processed S630:
errors={'no_sequence_in_data': 0, 'not_in_local': 0, 'wrong_position': 0, 'no_uniprot': 0, 'no_pdb': 0, 'no_sequence': 0}
target is dTm, so we removed 3568 rows
processed S3568:
errors={'no_sequence_in_dat

### update mapping and try to add infos


In [24]:
if not UPDATE_MAPPING:
    # don't go beyond here with Run All
    assert False


In [25]:
# update pdb to uniprot mapping
update_pdb_uniprot_mapping(LOCAL_UNIPROT_INFOS_PATH,
                           PDB_UNIPROT_MAPPING_PATH,
                           LINKED_UNIPROT_MAPPING_PATH)

pdb_uniprot_mapping = open_json(PDB_UNIPROT_MAPPING_PATH)
linked_uniprot_mapping = open_json(LINKED_UNIPROT_MAPPING_PATH)
pdb_without_uniprot = open_json(PDB_NO_UNIPROT_PATH)


added 0 entries to pdb_uniprot_mapping


In [26]:
# add infos based on pdb not uniprot
df = pd.read_csv(DATASET_OUTPUT_PATH_RAW)

with_infos = df.infos_found.sum()
df = df.apply(lambda row: apply_infos_from_pdb(row, local_uniprot_infos, pdb_uniprot_mapping,
                                               linked_uniprot_mapping, dataset_config,
                                               pdb_without_uniprot, errors),
              axis=1)
print(
    f"added {df.infos_found.sum()-with_infos} new infos thanks to uniprot_from_pdb")
df.to_csv(DATASET_OUTPUT_PATH_RAW, index=False)


added 546.0 new infos thanks to uniprot_from_pdb


In [27]:
# update sequence to uniprot mapping
update_sequence_uniprot_mapping(LOCAL_UNIPROT_INFOS_PATH,
                                SEQUENCE_UNIPROT_MAPPING_PATH,
                                LINKED_UNIPROT_MAPPING_PATH)

sequence_uniprot_mapping = open_json(SEQUENCE_UNIPROT_MAPPING_PATH)
sequence_without_uniprot = open_json(SEQUENCE_NO_UNIPROT_PATH)


added 0 entries to sequence_uniprot_mapping


In [28]:
# add infos based on sequence not pdb or uniprot

df = pd.read_csv(DATASET_OUTPUT_PATH_RAW)

with_infos = df.infos_found.sum()
df = df.apply(lambda row: apply_infos_from_sequence(row, local_uniprot_infos, sequence_uniprot_mapping,
                                                    linked_uniprot_mapping, dataset_config,
                                                    sequence_without_uniprot, errors),
              axis=1)
print(
    f"added {df.infos_found.sum()-with_infos} new infos thanks to uniprot_from_sequence")

df.to_csv(DATASET_OUTPUT_PATH_RAW, index=False)


added 916.0 new infos thanks to uniprot_from_sequence


In [29]:
# make sure mapping and other data is saved
write_json(LOCAL_UNIPROT_INFOS_PATH, local_uniprot_infos)
write_json(PDB_NO_UNIPROT_PATH, pdb_without_uniprot)
write_json(SEQUENCE_NO_UNIPROT_PATH, sequence_without_uniprot)


In [30]:
d = main_df.head()

In [31]:
def correct_chain(row):
    if type(row["mutated_chain"]) != type(""):
        row["mutated_chain"] = "A"
    if len(row["mutated_chain"]) != 1:
        row["mutated_chain"] = "A"
    if row["mutated_chain"] in ['_', '-']:
        row["mutated_chain"] = "A"

    row["mutated_chain"] = row["mutated_chain"].upper()
    return row

## Final filtering

In [32]:
SEARCH_ON = ['uniprot', 'wild_aa', 'mutation_position',
             'mutated_aa', 'sequence']  # same as SUBSET but without pH

main_df = pd.read_csv(DATASET_OUTPUT_PATH_RAW)
dataset_config = open_json(DIR+"/dataset_config.json")
# remove pdbs
main_df = main_df[COLUMNS[1:]]
print(0, main_df.infos_found.eq(0.0).sum())

# remove record without uniprot infos
print(1, len(main_df))
main_df = main_df.loc[main_df.infos_found == 1]
print(2, len(main_df))
mean_pH = main_df.pH.mean()
count = 0
num_no_pH = len(main_df.loc[main_df.pH.isna()])
print(f"{num_no_pH=}")

# # try to find if rows without pH are already somewhere else in the df
# for _, row in main_df.loc[main_df.pH.isna()].iterrows():
#     if (len(main_df.loc[(main_df["wild_aa"] == row["wild_aa"]) &
#                         (main_df["mutation_position"] == row["mutation_position"]) &
#                         (main_df["mutated_aa"] == row["mutated_aa"]) &
#                         (main_df["sequence"] == row["sequence"])]) == 0):
#         # the row with no pH is not present in the rest of the df
#         row["pH"] = mean_pH
#         count += 1
# print(f"out of {num_no_pH} rows without pH we kept {count} because those had no duplicate")
# #  ==> always 0


main_df = main_df.loc[~main_df.pH.isna()]
print(3, len(main_df))

# check for errors in chain and correct them
main_df = main_df.apply(correct_chain, axis=1)
# remove duplicates
print(4, len(main_df))
main_df.drop_duplicates(subset=SUBSET_DUPLICATES, inplace=True)
print(5, len(main_df))

dataset_infos = {
    "total_len": len(main_df),
    "dataset_processed": dataset_config["dataset_to_process"],
    "general_config": dataset_config["general_config"],
    "dataset_source_repartition": main_df.dataset_source.value_counts().to_dict(),
    "unique_uniprot": len(main_df.uniprot.unique()),
    "ddG": (len(main_df)-main_df.ddG.isna().sum()),
    "dTm": (len(main_df)-main_df.dTm.isna().sum()),
    "Tm": (len(main_df)-main_df.Tm.isna().sum()),
    "nan_repartition": main_df.isna().sum().to_dict(),
    "no_pH_repartition": main_df[main_df.pH.isna()].dataset_source.value_counts().to_dict(),
}

main_df.to_csv(DATASET_OUTPUT_PATH_ONLY_INFOS, index=False)
write_json(OUTPUT_DIR+"/dataset_config.json", dataset_config)
write_json(OUTPUT_DIR+"/dataset_infos.json", dataset_infos)


0 1255
1 8102
2 6847
num_no_pH=1549
3 5298
4 5298
5 5146


In [33]:
from pprint import pprint
pprint(dataset_infos)

{'Tm': 3019,
 'dTm': 5146,
 'dataset_processed': ['fireprotdb',
                       'thermomutdb',
                       'O2567_new',
                       'prothermdb',
                       'S630',
                       'S3568',
                       'jinyuan_sun_test',
                       'jinyuan_sun_train',
                       'datasetDDG_train',
                       'datasetDDG_test',
                       'all_train_data_v17',
                       'S140',
                       'S2648',
                       'Q1744',
                       'Q3214',
                       'Q3421'],
 'dataset_source_repartition': {'fireprotdb': 2735,
                                'prothermdb': 385,
                                'thermomutdb': 2026},
 'ddG': 593,
 'general_config': {'fill_na_pH': False, 'target': 'dTm'},
 'nan_repartition': {'AlphaFoldDB': 1049,
                     'Tm': 2127,
                     'chain_end': 0,
                     'chain_start': 0,
     

In [34]:
main_df = pd.read_csv(DATASET_OUTPUT_PATH_ONLY_INFOS)
print(len(main_df))

{
    "total_len": len(main_df),
    "unique_uniprot": len(main_df.uniprot.unique()),
    "ddG": (len(main_df)-main_df.ddG.isna().sum()),
    "dTm": (len(main_df)-main_df.dTm.isna().sum()),
    "Tm": (len(main_df)-main_df.Tm.isna().sum()),
    "nan_repartition": main_df.isna().sum().to_dict(),
}


5146


{'total_len': 5146,
 'unique_uniprot': 320,
 'ddG': 593,
 'dTm': 5146,
 'Tm': 3019,
 'nan_repartition': {'uniprot': 0,
  'wild_aa': 0,
  'mutated_chain': 0,
  'mutation_position': 0,
  'mutated_aa': 0,
  'pH': 0,
  'sequence': 0,
  'length': 0,
  'chain_start': 0,
  'chain_end': 0,
  'AlphaFoldDB': 1049,
  'Tm': 2127,
  'ddG': 4553,
  'dTm': 0,
  'dataset_source': 0,
  'source_id': 0,
  'infos_found': 0}}