In [None]:
# start coding here
import pandas as pd
import numpy as np

df = pd.read_csv(snakemake.input["syn_df"])
df.drop(df.index[df.complex.duplicated(keep=False)], inplace=True)  # there are 76 duplicates
df.head()


In [None]:
df.iloc[0]

In [None]:
originals = df.groupby(["pdb"])["wt_foldx_dg"].first().to_frame().reset_index(drop=False)
originals["complex"] = originals["pdb"]
originals["mut_foldx_dg"] = originals["wt_foldx_dg"]

merged = pd.concat([df, originals], ignore_index=True)

In [None]:
# Convert ΔΔG labels to Δ-log(Kd) labels
# Negative labels indicate worse binding (checked in the publication)
merged["-log(Kd)"] = -(merged["mut_foldx_dg"]/(293.15 * 0.001987)) * 0.434  # 0.434 = log(10), 293.15 = 20 celsius in kelvin,  0.001987 = R in kcal/mol/K

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(3,3))
sns.distplot(merged["-log(Kd)"], ax=ax)
plt.xlim([-10, 30])

In [None]:
len(merged)

In [None]:
# Filter df for unreasonable delta_logkd
merged.drop(index=merged.index[(merged["-log(Kd)"] < snakemake.params.filter_logkd_min) | (merged["-log(Kd)"].abs() > snakemake.params.fliter_logkd_max)], inplace=True)
len(merged)

In [None]:
# set pdb

# fill complex where it's nan (to pdb) and set index
merged.index = merged["complex"].fillna(merged["pdb"])


In [None]:
merged["mutation"] = merged["complex"].apply(lambda v: v.split("_")[1] if "_" in str(v) else "")
merged["filename"] = merged.apply(lambda row: f"{row.pdb}_{row.ab_chain}_{row.ag_chain}_{row.mutation}.pdb"
                                  if row.mutation else f"{row.pdb}.pdb", 
                                  axis=1)
merged.rename(columns={"mut_foldx_dg": "delta_g"}, inplace=True)
# df["mutation_code"] = df.apply(lambda L: L.complex.split('_')[-1], axis=1)

In [None]:
merged.dropna(axis=1, how='any').iloc[[0, -1]]

## Now split into absolute and relative part

In [None]:
merged["test"] = False

## Assign validation/test cluster numbers

In [None]:
# Read in  the clustering file
clusters = []
pdb_codes = []

with open(snakemake.input.cdr_clusters) as f:
    for line in f.readlines():
        if line.startswith('>Cluster'):
            cluster_id = int(line.split()[1])
        elif line.strip() and not line.startswith(';;'):
            pdb_code = line.split('>')[1][:4]
            clusters.append(cluster_id)
            pdb_codes.append(pdb_code)

# Create a DataFrame
cluster_df = pd.DataFrame({'pdb': pdb_codes, 'cluster_id': clusters})


In [None]:
# split relative, according to clustering
import math

def assign_clusters(df, num_splits):
    """
    Greedly consume clusters that are fully contained within the relative dataset, until the split is full
    
    Might return 1 more split (if some remain)
    """
    
    all_pdbs = set(df["pdb"].drop_duplicates().tolist())
    df["validation"] = num_splits
    min_split_size = math.floor(len(all_pdbs) / num_splits)  # just adding one, so there are no left-overs
    
    cluster_i = 0
    for split_i in range(num_splits):
        split_pdbs = []
    
        # accumulate PDBs until the size is reached
        while len(split_pdbs) < min_split_size and cluster_i <= cluster_df["cluster_id"].max():
            cluster_pdbs = cluster_df.loc[cluster_df.cluster_id == cluster_i, "pdb"]
            
            if cluster_pdbs.isin(all_pdbs).all():
                split_pdbs.extend(cluster_pdbs.tolist())
            cluster_i += 1
    
        # set the selected PDBs to the current split number
        df.loc[df.pdb.isin(set(split_pdbs)), "validation"] = split_i

In [None]:
assign_clusters(merged, num_splits=18)

In [None]:
# there is a slight disbalance now, but it shouldn't matter too much (even if we do cross-validation)
merged.validation.value_counts()


In [None]:
# The remaining data points may have overlaps with other clusters. Delete:
merged.drop(merged.index[merged.validation == 18], inplace=True)

In [None]:
merged["test"] = merged["validation"] == 0

In [None]:
merged.dropna(axis=1, how='any').drop(columns=["wt_foldx_dg", "complex"]).to_csv(snakemake.output["full"])
