# Metrics on uniform vs skewed mutagenesis sequence distributions

Previously, we discussed how using sequences from a skewed mutagenesis distibution (e.g., Mason and Brij datasets) cause problems when calculating performance metrics (accuracy, precision/recall, etc), because positions important for binding are overrepresented in both binders and non-binders. 


Below I use a minimal simulation to show that when calculating performance metrics, weighting individual sequences from a dataset with skewed character distributions can be used to estimate what the performance would have been on a dataset with a uniform character distribution. Sequences are weigthed according to P(sequence|uniform distribution)/P(sequence|skewed distribution).

In [7]:
import random
import pandas as pd
import plotly.express as px

from sklearn.metrics import accuracy_score, balanced_accuracy_score, precision_score, recall_score, roc_auc_score, confusion_matrix

In [39]:
def get_skewed_sequence(alphabet="ABCD", mutagenesis_weights=((25,25,25,25), (25,25,25,25), (25,25,25,25), (25,25,25,25), (25,25,25,25))):
    '''
    len(mutagenesis_weights) is the length of the sequence that will be created
    each tuple inside mutagenesis_weights represents the relative percentage of each alphabet character
    
    example:
    with alphabet="ABCD" and mutagenesis_weights=((25,25,25,25), (85,5,5,5))
    - a sequence will be created of length 2
    - position 1 will be A, B, C or D with uniform probabilities
    - position 2 will have 85% chance of being A, and 5% chance of being B, C or D
    '''
    return "".join([item for sublist in
                    [random.choices(alphabet, weights=weights, k=1) for weights in mutagenesis_weights]
                    for item in sublist])


def is_binding(seq, signal_fn, prob_binding_if_signal=0.2, prob_binding_if_not_signal=0.02):
    random_cutoff = random.uniform(0, 1)

    if signal_fn(seq):
        if random_cutoff <= prob_binding_if_signal:
            return True
        else:
            return False
    else:
        if random_cutoff <= prob_binding_if_not_signal:
            return True
        else:
            return False


def get_sequence_weight(seq, alphabet, mutagenesis_weights):
    '''
    Returns the sequence weight according to how much more or less often the given sequence
    would be observed in a uniform distribution compared to the skewed mutagenesis distribution. 
    
    For alphabet ABCD, uniform percentage would be 25%
    If A in position 1 is skewed to 85%:
        if sequence has A in position 1: sequence_weight would be 25/85 = 0.29
        if sequence has B/C/D in position 1: sequence_weight would be 25/5 = 5

    multiply sequence_weights for each position to get the total weigth for the sequence.
    '''
    sequence_weight = 1

    uniform_percentage = 100 / len(alphabet)

    for position in range(len(seq)):
        mutagenesis_percentage = mutagenesis_weights[position][alphabet.index(seq[position])]
        position_weight = uniform_percentage / mutagenesis_percentage
        sequence_weight = sequence_weight*position_weight

    return sequence_weight


# Utility functions for making mutagenesis_weights tuples with a 'signal' consisting of one or more 'A'

def get_skewed_a_weights(a_percentage, alphabet="ABCD"):
    assert "A" in alphabet, "A not in alphabet"

    number_of_not_a = len(alphabet) - 1

    non_a_percentages = (100 - a_percentage) / number_of_not_a

    weights = [a_percentage] + [non_a_percentages for _ in range(number_of_not_a)]

    return tuple(weights)

def get_multi_position_skewed_a_weights(a_percentages, alphabet="ABCD"):
    return tuple([get_skewed_a_weights(a_percentage, alphabet) for a_percentage in a_percentages])


The `experiment` function below creates a sequence dataset, where the sequences are sampled according to a given positional amino acid frequency distribution (i.e., skewed mutagenesis), and calculates several performance metrics (accuracy, balanced accuracy, precision, recall/sensitivity, specificity). The performance metrics are calculated in 2 ways: with and without weighting. When weighting is used, the weight of each sequence is according to P(sequence|uniform distribution)/P(sequence|skewed distribution).

Explanation of parameters:
- n_sequences: the total number of sequences 
- alphabet: the sequence alphabet. Set to a limited alphabet for simplicity. To use true amino acid sequences set to "ACDEFGHIKLMNPQRSTVWY"
- mutagenesis_weights: for each position in the sequence, a tuple with percentages representing the relative frequency distribution of each character in the alphabet. len(mutagenesis_weights) == len(sequence), and len(mutagenesis_weights[i]) == len(alphabet)
- signal_fn: a function which takes in a sequence and returns a boolean value representing whether the sequence contains the 'signal'
- prob_binding_if_signal: P(binding|sequence contains signal)
- prob_binding_if_not_signal: P(binding|sequence does not contain signal)
- repeat: how often to repeat the complete experiment with a different random starting state. Each repeat adds one row to the resulting dataframe. 
- random_seed: random seed to apply before all repeated experiments start
- perform_checks: whether to sanity check the parameters (may fail if mutagenesis_weights contains fractions and can in that case be turned off)

In [54]:
def check_params(n_sequences, mutagenesis_weights, prob_binding_if_signal, prob_binding_if_not_signal, alphabet, signal_fn):
    for pos_weights in mutagenesis_weights:
        assert len(alphabet) == len(pos_weights), f"Alphabet length {len(alphabet)} does not match {pos_weights}"
        assert sum(pos_weights) == 100, f"{pos_weights} does not sum to 100"

    assert 1 >= prob_binding_if_signal >= 0, f"prob_binding_if_signal should be in range [0,1] found {prob_binding_if_signal}"
    assert 1 >= prob_binding_if_not_signal >= 0, f"prob_binding_if_not_signal should be in range [0,1] found {prob_binding_if_not_signal}"

def experiment(n_sequences = 10000,
               alphabet="ABCD",
               mutagenesis_weights = ((25, 25, 25, 25), (25, 25, 25, 25), (25, 25, 25, 25), (25, 25, 25, 25), (25, 25, 25, 25)),
               signal_fn=lambda seq: seq[1:-1] == "AAA",
               prob_binding_if_signal = 0.2,
               prob_binding_if_not_signal = 0.02,
               repeat=10,
               random_seed=2022,
               perform_checks=True):
    random.seed(random_seed)

    if perform_checks:
        check_params(n_sequences, mutagenesis_weights, prob_binding_if_signal, prob_binding_if_not_signal, alphabet, signal_fn)

    dfs = []

    for i in range(repeat):
        seqs = [get_skewed_sequence(alphabet=alphabet, mutagenesis_weights=mutagenesis_weights) for i in range(n_sequences)]
        binding = [is_binding(seq, signal_fn=signal_fn,
                              prob_binding_if_signal=prob_binding_if_signal,
                              prob_binding_if_not_signal=prob_binding_if_not_signal)
                   for seq in seqs]
        prediction = [signal_fn(seq) for seq in seqs]

        sample_weights = [get_sequence_weight(seq, alphabet=alphabet, mutagenesis_weights=mutagenesis_weights) for seq in seqs]

        metrics = {"percentage_A": ",".join([str(pos_weights[0]) for pos_weights in mutagenesis_weights]),  # could add: percentage_B, percentage_C, percentage_D as well
                   "prob_binding_if_signal": prob_binding_if_signal,
                   "prob_binding_if_not_signal": prob_binding_if_not_signal,
                   "fraction_binders": [sum(binding) / n_sequences],
                   "fraction_predicted_binders": [sum(prediction) / n_sequences],
                   "accuracy": [accuracy_score(y_true=binding, y_pred=prediction)],
                   "w_accuracy": [accuracy_score(y_true=binding, y_pred=prediction, sample_weight=sample_weights)],
                   "balanced_accuracy": [balanced_accuracy_score(y_true=binding, y_pred=prediction)],
                   "w_balanced_accuracy": [balanced_accuracy_score(y_true=binding, y_pred=prediction, sample_weight=sample_weights)],
                   "precision": [precision_score(y_true=binding, y_pred=prediction)],
                   "w_precision": [precision_score(y_true=binding, y_pred=prediction, sample_weight=sample_weights)],
                   "recall/sensitivity": [recall_score(y_true=binding, y_pred=prediction)],
                   "w_recall/sensitivity": [recall_score(y_true=binding, y_pred=prediction, sample_weight=sample_weights)],
                   "specificity": [recall_score(y_true=binding, y_pred=prediction, pos_label=False)],
                   "w_specificity": [recall_score(y_true=binding, y_pred=prediction, pos_label=False, sample_weight=sample_weights)]}

        dfs.append(pd.DataFrame(metrics))

    return pd.concat(dfs)

In [55]:
# Utility code to reshape/combine multiple dataframes created by experiment()

param_cols = ["percentage_A", "prob_binding_if_signal", "prob_binding_if_not_signal", "fraction_binders", "fraction_predicted_binders"]
metric_cols = ["accuracy", "balanced_accuracy", "precision", "recall/sensitivity", "specificity"]
weigted_metric_cols = [f"w_{metric}" for metric in metric_cols]


def collapse_weight_cols(df):
    df_unweighted = df[param_cols + metric_cols].copy()
    df_weighted = df[param_cols + weigted_metric_cols].copy()
    df_weighted.rename(columns=lambda x: x.replace("w_", ""), inplace=True)

    df_unweighted["type"] = "unweigthed"
    df_weighted["type"] = "weighted"

    return pd.concat([df_unweighted, df_weighted])


def merge_uni_weighted_dfs(df_uniform, df_mutagenesis):
    '''
    input: dfs created by 'experiment'; one with uniform sequences and one mutagenesis-based
    '''
    df_uni = df_uniform[param_cols + metric_cols].copy()
    df_uni["type"] = "uniform"
    df_muta = collapse_weight_cols(df_mutagenesis)
    df = pd.concat([df_uni, df_muta])
    return df


def metrics_to_long_format(df):
    '''
    input: df as created by merge_uni_weighted_dfs()

    '''
    return pd.melt(df, id_vars=param_cols + ["type"], value_vars=metric_cols, var_name="metric", value_name="value")


I here run the experiment function under 6 different conditions:
- Testing 3 different signals, namely A in positon 1, AA in positon 1-2 and AAA in position 1-3
- Using a uniform sequence distribution (each character occurs 25% at each position) and a skewed mutagenesis distribution. In the skewed mutagenesis case, amino acid A is oversampled (85% instead of 25%) in the positions according to the 3 different signals defined above. 

In [56]:
df_uniform_1pos = experiment(signal_fn=lambda seq: seq[0]=="A",  mutagenesis_weights=get_multi_position_skewed_a_weights([25, 25, 25, 25, 25]))
df_mutagenesis_1pos = experiment(signal_fn=lambda seq: seq[0]=="A", mutagenesis_weights=get_multi_position_skewed_a_weights([85, 25, 25, 25, 25]))

df_uniform_2pos = experiment(signal_fn=lambda seq: seq[0:2]=="AA",  mutagenesis_weights=get_multi_position_skewed_a_weights([25, 25, 25, 25, 25]))
df_mutagenesis_2pos = experiment(signal_fn=lambda seq: seq[0:2]=="AA", mutagenesis_weights=get_multi_position_skewed_a_weights([85, 85, 25, 25, 25]))

df_uniform_3pos = experiment(signal_fn=lambda seq: seq[0:3]=="AAA",  mutagenesis_weights=get_multi_position_skewed_a_weights([25, 25, 25, 25, 25]))
df_mutagenesis_3pos = experiment(signal_fn=lambda seq: seq[0:3]=="AAA", mutagenesis_weights=get_multi_position_skewed_a_weights([85, 85, 85, 25, 25]))

As an example, the tables for signal 'A-1' are printed to show what is inside. Each row in the table is one repetition of the experiment with a different random starting state. Thus, only the values of the 'metrics' are different. Weighted versions of the metrics have the prefix 'w_'.

For the uniform table, metric == w_metric, because we used a uniform distribution for the simulatd mutagenesis
i.e., P(sequence|uniform) == P(sequence|mutagenesis) and therefore sequences are weighted equally. 

The percentage_A column shows how amino acid 'A' was skewed in each of the 5 positions of the sequence (I could have shown percentage_B, C, D as well, but I am only using character A for the signal here). See how in the mutagenesis table, A in position 1 was set to 85%.

In [57]:
print("Uniform table:")
print(df_uniform_1pos)


print("\n\nMutagenesis table:")
print(df_mutagenesis_1pos)

Uniform table:
     percentage_A  prob_binding_if_signal  prob_binding_if_not_signal  \
0  25,25,25,25,25                     0.2                        0.02   
0  25,25,25,25,25                     0.2                        0.02   
0  25,25,25,25,25                     0.2                        0.02   
0  25,25,25,25,25                     0.2                        0.02   
0  25,25,25,25,25                     0.2                        0.02   

   fraction_binders  fraction_predicted_binders  accuracy  w_accuracy  \
0            0.0622                      0.2486    0.7850      0.7850   
0            0.0639                      0.2502    0.7863      0.7863   
0            0.0660                      0.2603    0.7727      0.7727   
0            0.0646                      0.2531    0.7837      0.7837   
0            0.0679                      0.2556    0.7805      0.7805   

   balanced_accuracy  w_balanced_accuracy  precision  w_precision  \
0           0.778042             0.778

Congratulations, you have arrived at the real point of this document!

The figure below compares 3 values in each panel: 
- Uniform: the performance metric on a dataset with a uniform character distribution
- Unweigthed: the performance metric on a dataset with a skewed character distribution (oversampled the signal of interest)
- Weighted: a weighted version of the performance metric on the skewed dataset (exact same dataset as 'unweighted')

The row and column facets respectively show different metrics and signals. 

In [47]:
df_1pos = metrics_to_long_format(merge_uni_weighted_dfs(df_uniform_1pos, df_mutagenesis_1pos))
df_2pos = metrics_to_long_format(merge_uni_weighted_dfs(df_uniform_2pos, df_mutagenesis_2pos))
df_3pos = metrics_to_long_format(merge_uni_weighted_dfs(df_uniform_3pos, df_mutagenesis_3pos))

df_1pos["signal"] = "A-1"
df_2pos["signal"] = "AA-1"
df_3pos["signal"] = "AAA-1"

df = pd.concat([df_1pos, df_2pos, df_3pos])


fig = px.strip(df, y="value", x="type", facet_row="metric", facet_col="signal", stripmode="overlay",
               labels={
                   "value": "Metric value",
                   "type": "Experiment type",
               })

fig.update_layout(
    autosize=False,
    width=1000,
    height=1000)

fig.show()