In [None]:
from math import sqrt
import random
import src.generate_encodings as ge
import src.prediction_models as pm
import src.predictor_optimizer as pop
from copy import copy
import warnings
import os, sys
from tqdm import tqdm
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import numpy as np

In [None]:
debugging = False
demo_case = True
update_params_after_each_cycle = True

In [None]:
class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout


class HiddenWarnings():
    def __enter__(self):
        # Save the current filter settings before changing them
        self._previous_filters = warnings.filters[:]
        # Ignore all warnings
        warnings.filterwarnings("ignore")

    def __exit__(self, exc_type, exc_val, exc_tb):
        # Restore the original warning filter settings
        warnings.filters = self._previous_filters

In [None]:
"""defining Parameters for Input_Data"""

#sequence params
dataset = "D7PM05"
activity_threshold = 1 #Threshold for considering a mutants score as active or inactive for all datasets (data hab been preprocessed) 
repr_type = "blosum80"


"""Defining further ModelParameter"""

#mlde params
n_top = 0.95 #Percentage cutoff of top scoring datapoints, targeted to be identified during the MLDE Cycles
n_samples = 10  # realistic range: [5,10,20,100]
n_starting_points = 1000  #realistic, up to 200 points to be expected as common practice
highest_starting_fraction = 0.40  # Range from the lowest scoring datapoint to allowed percentage within sorted datapoints, allowed as starting points for MLDE Benchmark. Default: Median (0.5)
n_cycles = 100

#model params
model_type = "svr"  # xgboost, rf, lightgbm, adaboost, svr, linear, ridge, lasso
cv_folds = 5
early_stopping_fraction = 0.1

#hypertuning params
initial_trials = 500 if model_type not in ["linear", "ridge",
                                           "lasso"] else 10000  #trials per group to optimize the parameters for the mlde model - at the very beginning before the benchmark.
n_trials = 200  # trials per group to optimize the parameters for the mlde model - after each cycle.

In [None]:
#prepare Dataset and Embeddings

if repr_type not in ["esmc_600m", "esmc300m"]:
    import_data = f"../Data/Protein_Gym_Datasets/{dataset}.csv"
    ids = []
    sequences = []
    scores = []

    with open(import_data, "r") as infile:
        for i, line in enumerate(infile.readlines()[1:]):
            line = line[:-1].split(",")
            id = line[0]
            sequence = line[1]
            label = line[4]
            ids.append(id)
            sequences.append(sequence)
            scores.append(label)
            if i == 0:
                wildtype_seq = line[6]
    data = [(id, 
             sequence, 
             ge.generate_sequence_encodings(repr_type, [sequence])[0], 
             round(float(label), 3)) for id, sequence, label in zip(ids, sequences, scores)]

else: #load referring ESM Embeddings
    print("Loading ESM Embeddings not implemented yet.")

data = sorted(data, key=lambda isxy: isxy[3])
print("Applied Data-Set:", dataset)
print("Number of datapoints:", len(data))
print("Wildtype Sequence:", wildtype_seq)
print("maximum score:", max([isxy[3] for isxy in data]))
print("minimum score:", min([isxy[3] for isxy in data]))
print("Representation Type:", repr_type)
print("Activity Threshold:", activity_threshold)
print("Maximum Starting Score:", round(max([isxy[3] for isxy in data]) * highest_starting_fraction, 3))
print("Data has been loaded and encoded.")

In [None]:
#Performance evaluation

def pearson_correlation(x, y):
    if len(x) != len(y):
        raise ValueError("Lists x and y must have the same length.")
    n = len(x)
    sum_x = sum(x)
    sum_y = sum(y)
    sum_xy = sum(xi * yi for xi, yi in zip(x, y))
    sum_x2 = sum(xi ** 2 for xi in x)
    sum_y2 = sum(yi ** 2 for yi in y)

    numerator = n * sum_xy - sum_x * sum_y
    denominator = ((n * sum_x2 - sum_x ** 2) * (n * sum_y2 - sum_y ** 2)) ** 0.5

    if denominator == 0:
        raise ValueError("Denominator is zero, correlation undefined.")
    return numerator / denominator

def spearman_correlation(x, y):
    if len(x) != len(y):
        raise ValueError("Lists x and y must have the same length.")
    def rank(data):
        # Assign ranks, average in case of ties
        sorted_data = sorted((val, idx) for idx, val in enumerate(data))
        ranks = [0] * len(data)
        i = 0
        while i < len(data):
            val, idx = sorted_data[i]
            j = i
            while j < len(data) and sorted_data[j][0] == val:
                j += 1
            avg_rank = (i + j - 1) / 2 + 1  # ranks start from 1
            for k in range(i, j):
                ranks[sorted_data[k][1]] = avg_rank
            i = j
        return ranks
  
    rank_x = rank(x)
    rank_y = rank(y)

    return pearson_correlation(rank_x, rank_y)


def r2(y_trues: list, y_preds: list) -> float:
    true_mean = np.mean([float(y) for y in y_trues])
    rss = sum((float(y_true) - float(y_pred)) ** 2 for y_true, y_pred in zip(y_preds, y_trues))
    tss = sum([(float(y_pred) - float(true_mean)) ** 2 for y_pred in y_preds])
    EPSILON = float(1 ** -9)

    return 1 - (rss / (tss + EPSILON)) if tss == 0 else 1 - (rss / tss)


def rmse(y_trues: list, y_preds: list) -> float:
    n = len(y_preds)
    EPSILON = float(1 ** -9)
    mse = sum((float(y_true) - float(y_pred)) ** 2 for y_true, y_pred in zip(y_preds, y_trues)) / n
    return sqrt(mse + EPSILON) if mse == 0 else sqrt(mse)

In [None]:
#Function to display chosen datapoints

# Assuming y_mlde, remaining_data, data, and data_set_name_shortened are already defined

def display_datapoints_distribution(y_mlde, remaining_data, iteration: int = None, show_in_browser: bool = False):
    min_score = round(float(min([float(isxy[3]) for isxy in data])), 3)
    max_score = round(float(max([float(isxy[3]) for isxy in data])), 3)

    n_bins = 100
    plotwidth = 1000

    binned_mlde_scores = dict()
    binned_remaining_data = dict()

    bin_edges = np.linspace(min_score, max_score, n_bins)

    def bin_value(y, bin_edges):
        lower = float(10 ** -9)
        upper = float(10 ** 9)
        y = float(y)
        for i in bin_edges:
            if i < y:
                lower = float(i)
            else:
                upper = float(i)
                break
        return lower if abs(y - lower) < abs(y - upper) else upper

    # First loop for y_mlde
    for y in y_mlde:
        binned = bin_value(y, bin_edges)
        binned_mlde_scores[binned] = binned_mlde_scores.get(binned, 0) + 1

    # Second loop for remaining_data
    for isxy in remaining_data:
        y = isxy[3]
        binned = bin_value(y, bin_edges)
        binned_remaining_data[binned] = binned_remaining_data.get(binned, 0) + 1

    # Create ONE subplot only
    scoring_plt = make_subplots(rows=1, cols=1)

    # Add grey bars first (background)
    scoring_plt.add_trace(
        go.Bar(
            name="Remaining-Data-Points",
            x=list(binned_remaining_data.keys()),
            y=list(binned_remaining_data.values()),
            marker=dict(color="grey")
        ),
        row=1, col=1
    )

    # Add red bars on top (foreground)
    scoring_plt.add_trace(
        go.Bar(
            name=f"MLDE-Starting-Points ({sum(binned_mlde_scores.values())} Points)",
            x=list(binned_mlde_scores.keys()),
            y=list(binned_mlde_scores.values()),
            marker=dict(color="red")
        ),
        row=1, col=1
    )

    if iteration != None:
        title = f"Known vs To-Discover-Datapoints from {dataset}-Dataset used within {iteration}. MLDE Cycle"
    else:
        title = f"Known vs To-Discover-Datapoints from {dataset}-Dataset used within MLDE Cycle"

    scoring_plt.update_layout(
        title_text=title,
        title_font=dict(color="black", size=20),
        showlegend=True,
        barmode='overlay',  # Important for overlapping bars
        paper_bgcolor='rgb(233,233,233)',
        plot_bgcolor='rgb(233,233,233)',
        width=plotwidth,
        legend=dict(
            font=dict(color="black", size=12)
        )
    )

    scoring_plt.update_yaxes(
        dict(
            type="log",
            title_text="Number of sequences",
            title_font=dict(color="black"),
            color='black',
            showgrid=True,
            gridcolor='grey',
            griddash="dot",
            gridwidth=0.1
        )
    )

    scoring_plt.update_xaxes(
        dict(
            title_text="Binned Activity Score of sequence",
            title_font=dict(color="black"),
            range=[min_score, max_score],
            color='black',
            showgrid=True,
            gridcolor='grey',
            griddash="dot",
            gridwidth=0.1
        )
    )
    import plotly.io as pio
    if show_in_browser:
        pio.renderers.default = "browser"
    else:
        pio.renderers.default = "notebook_connected"
    scoring_plt.show()
    return scoring_plt

In [None]:
#Preparing Datapoints for Benchmark

min_starting_fraction = 0.2 #quote of maximum inactive samples within all MLDE samples
counter = 0

#check whether MLDE Starting Criteria have been chosen wisely

max_starting_score = round(max([isxy[3] for isxy in data]) * highest_starting_fraction, 3)

print(f"Highest possible Activity Score within initial MLDE Startingset: {max_starting_score}")
if max_starting_score < activity_threshold:
    warnings.warn("The highest starting Datapoint is still in the range of inactive classified sequences. The to be trained model will perform poorly")
else:
    print("The highest possible starting Datapoint exceeds the range of inactive classified sequences.\n"
          "The MLDE model will benefit herefrom.")
          
try:
    potential_starting_points = [isxy for isxy in data if isxy[3] <= max_starting_score]
    min_starting_points = random.sample([isxy for isxy in data if isxy[3] == min(isxy[3] for isxy in potential_starting_points)]
                                       round(n_starting_points * min_starting_fraction))
    remaining_starting_points = random.sample([isxy for isxy in potential_starting_points if isxy not in min_starting_points],n_starting_points - len(min_starting_points))                            
except:
    raise ValueError(
        "Amount of available, allowed Amount of Active or Inactive Starting Points does not meet Criteria of max_quote_of_inactives and highest_starting_fraction!")

mlde_datapoints = min_starting_points + remaining_starting_points

sequences_mlde = [isxy[1] for isxy in mlde_datapoints]
x_mlde = [isxy[2] for isxy in mlde_datapoints]
y_mlde = [isxy[3] for isxy in mlde_datapoints]
print("Declared MLDE Starting-Points and remaining dataset")
remaining_data = [isxy for isxy in data if isxy not in mlde_datapoints]
random.shuffle(remaining_data)

top_mutants = [isxy[0] for isxy in sorted(remaining_data, key=lambda isxy: isxy[3], reverse=True) if isxy[3] >= activity_threshold][:n_top]
top_x = [isxy[1] for isxy in sorted(remaining_data, key=lambda isxy: isxy[2], reverse=True)[:n_top]]
top_y = [isxy[2] for isxy in sorted(remaining_data, key=lambda isxy: isxy[2], reverse=True)[:n_top]]
# print([round(float(y),2) for y in top_y])
with HiddenWarnings():
    with HiddenPrints():
        display_datapoints_distribution(y_mlde=y_mlde, remaining_data=remaining_data, show_in_browser=True)

In [None]:
"""Create a suitable hyperparameter-set for the initial MLDE-Model"""
if not demo_case:
    mlde_optimizer = pop.Sequential_Optimizer(model_type=model_type, cv_folds=cv_folds, x_arr=x_mlde, y_arr=y_mlde,
                                              initial_params={},
                                              trials_per_group=initial_trials,
                                              early_stopping_fraction=early_stopping_fraction,
                                              n_jobs=1)

    mlde_optimizer.optimize_stepwise()
    best_trial = mlde_optimizer.get_best_trial()
    mlde_params = mlde_optimizer.get_best_params()


In [None]:
mlde_params_xgboost = {'subsample': 0.4919977609964233, 'colsample_bytree': 0.44409804279321397, 'max_depth': 82,
                       'min_child_weight': 0.2037190569366557, 'reg_alpha': 3.1708589729977215,
                       'reg_lambda': 6.826403977807785}

mlde_params_lightgbm = {'min_data_in_leaf': 5, 'num_leaves': 15, 'min_data_in_bin': 1,
                        'feature_fraction': 0.2721281361692248, 'learning_rate': 0.2999956879962157, 'max_bin': 35,
                        'n_estimators': 287, 'bagging_fraction': 0.5763358085600521}

if demo_case:
    mlde_params = mlde_params_xgboost if model_type == "xgboost" else mlde_params_lightgbm


In [None]:
finished = False
print(
    f"starting the MLDE-Benchmark with {n_cycles} cycles, starting at {n_starting_points} sequences. Discovery of {n_samples} per Cycle to identify the top {n_top} sequences from list."
)
print()

results = []  # save each iterations highest achieved sequence score and the average over all sequences and standard deviation
val_performances = []  # track each iterations R2 and RMSE Performance
test_performances = []

previous_R2 = -999
current_R2 = -999
best_R2 = float(current_R2)

for i in range(n_cycles):
    print(f"########################## Starting Cycle {i + 1}/{n_cycles} ##########################")
    print()
    counter = 1
    if i > 0:
        print("Training a more effective MLDE predictor for the next iteration")

    best_cycle_R2 = None
    best_cycle_model = None
    training_successful = True

    while not float(current_R2) > float(previous_R2):
        mlde_model = pm.ActivityPredictor(model_type=model_type,
                                          x_arr=x_mlde,
                                          y_arr=y_mlde,
                                          shuffle_data=True,
                                          early_stopping=10,
                                          params=mlde_params)
        with HiddenPrints():
            with HiddenWarnings():
                mlde_model.train(k_folds=cv_folds)
                current_R2 = mlde_model.get_performance()[0]
        counter += 1

        if best_cycle_R2 is None or current_R2 > best_cycle_R2:
            best_cycle_R2 = copy(current_R2)
            best_cycle_model = copy(mlde_model)

        if counter >= 5000:
            mlde_model = best_cycle_model
            training_successful = False
            break

    val_R2 = round(mlde_model.get_performance()[0],3)
    val_RMSE = round(mlde_model.get_performance()[0],3)

    if training_successful:
        print(
            f"Cycle No {i + 1}/{n_cycles}: Model Performance after {counter} training attempts: "
            f"{val_R2}, {val_RMSE}")

    else:
        print(
            f"Current Model Performance does not exceed the last iterations model  after {counter} iterations of Training."
            f" Therefore training has been stopped and the best attempt is now used for prediction.")

    list_predictions = []
    with tqdm(total=len(remaining_data), desc="Predicting all remaining datapoints within the dataset...") as pbar:
        with HiddenPrints():
            with HiddenWarnings():
                for sxy in remaining_data:
                    list_predictions.append(mlde_model.predict([sxy[1]])[0])
                    pbar.update(1)
            # list_predictions = mlde_model.predict([sxy[1] for sxy in remaining_data])

    r2_test = round(r2([(sxy[2]) for sxy in remaining_data], list_predictions),3)
    rmse_test = round(rmse([sxy[2] for sxy in remaining_data], list_predictions),3)

    print(f"Model Performance on all Datapoints: {r2_test}, {rmse_test}")
    test_performances.append((r2_test, rmse_test))

    top_predictions = sorted(
        [sxy for sxy in zip([sxy[0] for sxy in remaining_data], [sxy[1] for sxy in remaining_data], list_predictions)],
        key=lambda sxy: sxy[2], reverse=True)[:n_samples]

    #else:
    print(f"Top {n_samples} predicted sequences' pred scores: {[round(sxy[2], 2) for sxy in top_predictions]}")

    top_predictions_y_trues = []
    for sxy_top in top_predictions:
        for sxy_remaining in remaining_data:
            if sxy_top[0] == sxy_remaining[0]:
                #present a list of y_trues for the top predictions in comparison
                top_predictions_y_trues.append(sxy_remaining[2])

                #update mlde trainingpoints-range
                x_mlde.append(sxy_top[1])
                y_mlde.append(sxy_remaining[2])

    mlde_data = [(x, y) for x, y in zip(x_mlde, y_mlde)]
    random.shuffle(mlde_data)
    x_mlde = [xy[0] for xy in mlde_data]
    y_mlde = [xy[1] for xy in mlde_data]

    for top_prediction in top_predictions_y_trues:
        if top_prediction in top_y:
            print(f"Found one of the top {n_top} sequence during Cycle {i + 1}.")
            finished = True

    highest_score = max(top_predictions_y_trues)
    mean_score = round((sum(top_predictions_y_trues) / n_samples), 3)
    standard_dev = round(sqrt(sum([(y - mean_score) ** 2 for y in top_predictions_y_trues]) / n_samples), 3)
    results.append((i + 1, highest_score, mean_score, standard_dev))

    if finished:
        break

    print(
        f"Top {n_samples} predicted sequences' true scores: {[round(y_true, 2) for y_true in top_predictions_y_trues]}")
    print()

    #update remaining data
    remaining_data = [sxy for sxy in remaining_data if sxy[0] not in [sxy[0] for sxy in top_predictions]]
    random.shuffle(remaining_data)

    #update last iterations previous_R2
    previous_R2 = current_R2

    #optinal: udpate params
    if update_params_after_each_cycle:
        print("Updating MLDE Parameters...")
        mlde_optimizer = pop.Sequential_Optimizer(model_type=model_type, cv_folds=cv_folds, x_arr=x_mlde, y_arr=y_mlde,
                                                  initial_params=copy(mlde_params),
                                                  trials_per_group=int(n_trials),
                                                  early_stopping_fraction=early_stopping_fraction,
                                                  n_jobs=1)

        with HiddenPrints():
            mlde_optimizer.optimize_stepwise()
            best_trial = mlde_optimizer.get_best_trial()
        if mlde_params != mlde_optimizer.get_best_params():
            mlde_params = mlde_optimizer.get_best_params()  # might even overwrite with the same params
            print("Hyperparameters updated.")
        else:
            print("Current Hyperparameters maintained.")

    print()

print(
    f"MLDE-Performance-Ranking finished after {i + 1}/{n_trials} trials {"successfully" if finished else "without success"}.")

In [None]:
"""Display Results: Development of Predictions and Mutant-Selection"""

from plotly.subplots import make_subplots
import plotly.graph_objects as go

results_plot = make_subplots(
    subplot_titles=f"Development of Predictions and Mutant-Selection to identify the {n_top} active sequences testing {n_samples} new samples per iteration.",
    rows=1, cols=2)

'''First Plot for Results Tracking'''
results_plot.append_trace(
    go.Scatter(name="highest scoring mutant", x=[result[0] for result in results],
               y=[round(float(result[1]), 3) for result in results],
               marker=dict(color="darkcyan", size=3),
               mode="lines"), row=1, col=1)

results_plot.append_trace(
    go.Scatter(name="average of mutants' score", x=[result[0] for result in results],
               y=[round(float(result[2]), 3) for result in results],
               marker=dict(color="red", size=3),
               mode="lines"), row=1, col=1)

results_plot.append_trace(
    go.Scatter(name="Standard Deviation of mutant scores", x=[result[0] for result in results],
               y=[round(float(result[3]), 3) for result in results],
               marker=dict(color="grey", size=3),
               mode="lines"), row=1, col=1)


'''Second Plot for Performance Tracking'''
results_plot.append_trace(
    go.Scatter(name="R2 for MLDE Validational Datapoints", x=[i for i, scores in enumerate(val_performances)],
               y=[scores[0] for scores in test_performances],
               marker=dict(color="lightgreen", size=3),
               mode="lines"), row=1, col=2)

results_plot.append_trace(
    go.Scatter(name="RMSE for MLDE Validational Datapoints", x=[i for i, scores in enumerate(val_performances)],
               y=[scores[1] for scores in test_performances],
               marker=dict(color="darkmagenta", size=3),
               mode="lines"), row=1, col=2)

results_plot.append_trace(
    go.Scatter(name="R2 for remaining Dataset", x=[i for i, scores in enumerate(test_performances)],
               y=[scores[0] for scores in test_performances],
               marker=dict(color="darkgreen", size=3),
               mode="lines"), row=1, col=2)

results_plot.append_trace(
    go.Scatter(name="RMSE for remaining Dataset", x=[i for i, scores in enumerate(test_performances)],
               y=[scores[1] for scores in test_performances],
               marker=dict(color="magenta", size=3),
               mode="lines"), row=1, col=2)
#
results_plot.update_layout(
    title_text=f"Development of Predictions and Mutant-Selection to identify the {n_top} active sequences testing {n_samples} new samples per iteration with {model_type}-{repr_type} for {data_set_name_shortened}",
    title_font=dict(color="black", size=20),
    showlegend=True,
    paper_bgcolor='rgb(233,233,233)',
    plot_bgcolor='rgb(233,233,233)',
    # height=1000,
    width=1500,
    legend=dict(font=dict(color="black",
                          size=12)))
#
results_plot.update_yaxes(
    dict(
        title_text="Activity-Score",
        title_font=dict(color="black"),
        # range=[min_score, max_score],
        color='black',
        showgrid=True,
        gridcolor='grey',
        griddash="dot",
        gridwidth=0.2),
    row=1, col=1)

results_plot.update_xaxes(
    dict(
        title_text="cycles",
        title_font=dict(color="black"),
        # range=[min_score, max_score],
        color='black',
        showgrid=True,
        gridcolor='grey',
        griddash="dot",
        gridwidth=0.2),
    row=1, col=1)

import plotly.io as pio

pio.renderers.default = "browser"
with HiddenWarnings():
    results_plot.show()
    results_plot.write_image(f"MLDE_Development_{data_set_name_shortened}_{repr_type}_{n_samples}-samples.jpg")