# GFP mutant prediction

In this notebook, we try to classify GFP mutants into a "bright" and a "dark" class, using UniRep to encode the GFP protein sequences.

There is a conda environment file available, containing all necessary dependencies to re-run the analysis, under `paper/environment-gfp.yml`.

In [None]:
import numpy as np
import pandas as pd
import janitor

from functools import partial

In [None]:
wt ="SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"

In [None]:
def aa_pos(mut):
    """
    Return mutation position from mutation string.

    Example: A111C -> 111
    """
    if not mut:
        return mut
    else:
        return int(mut[2:-1])


def aa_letter(mut):
    """
    Return mutation letter from mutation string.

    Example: A111C -> C
    """
    if not mut:
        return mut
    else:
        return mut[-1]


def mut2seq(mutation_string, wt_sequence, delimiter=":"):
    """
    Reconstruct full mutant sequence given mutation string.

    Example mutation_strings:
    - A111C
    - A111T; V130A
    - A111T; Q194R; N249I; N251Y; H255Y
    """
    if mutation_string is None or mutation_string == "":
        return wt_sequence

    mutations = mutation_string.split(delimiter)
    mutant_sequence = list(wt_sequence)  # mutant_sequence is a list
    for mut in mutations:
        position = aa_pos(mut)
        letter = aa_letter(mut)
        if position == 0:
            raise ValueError(
                f"""
            The mutation string {mut} is invalid.
            It has "0" as its position.
            """
            )
        if position > len(wt_sequence):
            raise ValueError(
                f"""
            The mutation string {mut} is invalid.
            Its position is greater than the length of the WT sequence.
            """
            )
        mutant_sequence[
            position #- 1
        ] = letter  # -1 is necessary because the list is zero-indexed
    return "".join(l for l in mutant_sequence)

In [None]:
def count_mutations(x):
    if x == '':
        return 0
    else:
        return len(x.split(":"))


In [None]:
mut2gfp = partial(mut2seq, wt_sequence=wt)

## Prepare data

In [None]:
data = (pd.read_csv("data/amino_acid_genotypes_to_brightness.tsv", sep='\t')
        .fill_empty("aaMutations", "")
        .filter_string("aaMutations", search_string="\*", complement=True)
        .transform_column("aaMutations", mut2gfp, "sequence")
        .transform_column("medianBrightness", lambda x: np.log(x), "log_bright")
        .transform_column("sequence", lambda x: len(x), "length")
        .transform_column("aaMutations", count_mutations, "mutation_count")
       )

In [None]:
data.head(3)

In [None]:
data.query("aaMutations.str.contains(r'[A-Z]2[A-Z]')").head(1)

In [None]:
data.mutation_count.unique()

In [None]:
data.length.unique()

## Get reps

We embed the sequences in chunks, to not run out of memory. If you have a lot of memory at your disposal, you can increase chunk size or rep all sequences at once (see commented out code below).

In [None]:
from jax_unirep import get_reps

In [None]:
# from: https://stackoverflow.com/questions/434287/what-is-the-most-pythonic-way-to-iterate-over-a-list-in-chunks
def chunker(seq, size):
    return (seq[pos:pos + size] for pos in range(0, len(seq), size))

In [None]:
h_final_l, c_final_l, h_avg_l = [], [], []

In [None]:
for seqs in chunker(data.sequence.values, 1000):
    hf, cf, ha = get_reps(seqs)
    h_final_l.append(hf)
    c_final_l.append(cf)
    h_avg_l.append(ha)

In [None]:
h_final = np.concatenate(h_final_l, axis=0)
c_final = np.concatenate(c_final_l, axis=0)
h_avg = np.concatenate(h_avg_l, axis=0)

In [None]:
# h_final, c_final, h_avg = get_reps(data.sequence.values)

In [None]:
h_avg.shape

In [None]:
data["h_avg"] = h_avg.tolist()
data["h_final"] = h_final.tolist()
data["c_final"] = c_final.tolist()

In [None]:
def fusion(x):
    return np.concatenate((x.h_final,x.c_final, x.h_avg))

In [None]:
data = data.join_apply(fusion, "unirep_fusion")

In [None]:
data.head(3)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
sns.set_style("white")
sns.set_context(context="paper", font_scale=1.5)

In [None]:
df = pd.DataFrame.from_dict(dict(zip(data['h_avg'].index, data['h_avg'].values))).T
df['brightness'] = data["medianBrightness"].values
df.set_index(data['aaMutations'], inplace=True)
df.head(3)

In [None]:
df_bnry = df.copy()
df_bnry["brightness"] = (df["brightness"] < 2.5).astype(int)

In [None]:
X, y = df_bnry.shuffle().get_features_targets(target_column_names=['brightness'])

In [None]:
print(X.shape, y.shape)

# Training and Testing

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score, ShuffleSplit, cross_val_predict, KFold
from sklearn import preprocessing

### Logistic Regression on binary brightness

In [None]:
X_pp = preprocessing.scale(X)

In [None]:
logreg = LogisticRegression(max_iter=300)

In [None]:
# with paLogisticRegressionlel_backend("dask"):
cv = KFold(n_splits=5)
scores = cross_val_score(logreg, X_pp, y.values.ravel(), cv=cv, scoring='accuracy')
preds = cross_val_predict(logreg, X_pp, y.values.ravel(), cv=cv)

In [None]:
scores

In [None]:
scores.mean()

In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
labels = [0, 1]
cm = confusion_matrix(y, preds, labels)
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

In [None]:
cm

In [None]:
classes = ["Dark", "Bright"]

In [None]:
fig, (ax1, ax2) = plt.subplots(figsize=(12, 6), nrows=1, ncols=2)

sns.distplot(data["medianBrightness"], ax=ax1)
ax1.axvline(x=2.5, c="r", ls="--", lw=2)
ax1.set(xlabel="median Brightness", ylabel="value")
sns.despine()

df_cm = pd.DataFrame(
    cm, index=classes, columns=classes, 
)

heatmap = sns.heatmap(df_cm, 
                      annot=True, 
                      fmt=".2f", 
                      cmap="Blues",
                      cbar=False,
                      ax=ax2
                     )

ax2.set_yticklabels(heatmap.yaxis.get_ticklabels(), rotation=90, fontsize=14)
ax2.set_xticklabels(heatmap.xaxis.get_ticklabels(), rotation=0, fontsize=14)
ax2.set_ylabel('True label')
ax2.set_xlabel('Predicted label')
plt.savefig("figures/top_model.png", bbox_inches='tight', dpi=200)