In [None]:
"""Imports"""
import sys
import os
import re
from typing import Optional, List, Any
import math
import statistics
from itertools import combinations
from copy import deepcopy
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

SCRIPT_DIR = os.path.dirname(os.path.abspath('__file__'))
sys.path.append(os.path.dirname(SCRIPT_DIR))
print(f'Here is: {SCRIPT_DIR}')

from experiments.partition_search_opt.partition_search_config import *

from libs.random_sample_generator.random_sample_generator import random_sample_generator
from libs.partition_property_finder.cost_finder import find_cost
from libs.utils.utils import one_indexing_to_zero_indexing
from libs.partitioner.partitioner import find_cuttable_positions

In [1]:
s = "MRFPSIFTAVLFAASSALAAPVNTITEDETAQIPAEAVIGYSDLEGDFDVAVLPFSNSTNNGLLFMNTTTASIAAKEEVVSLEKREAEALLRRQEEASTASQAGGGRPIQAVDPPSPPGPLSFTGTKLVNDRDHPWRPLRNGDIRGPCPGLNTLASHGYLPRDGVATPTQLINAIQEGFNFDHTAAVSATYLGHLLNGNLVTDLLSIGGKTSKTGPPPPPPAHAGGLNVHGTFEGDAGLTRADEFFGDNHSFNQTLFDKLVDFSNRFGGGFYNLTVAGELRYSRIQDSIATNPQFSFKNVRFLTAYGETVFPINLFVDGRQTERKLSLDHAASFFRDMRFPPDFHRAAQPSSAEGVAEVIAAHPWLPGGNEDGQLNNYVVDPNSADFTDPCSLHRFVLGSLQELYPSPTGVLRRNLIKNINFWYTAAFAPAGCPELFPFGQL"
len(s)

442

In [None]:
def measure_mutation_distribution(seq: str, partition: tuple, mutations_0idx: Optional[List[Any]],
                                  linked_mutations_0idx: Optional[List[Any]]) -> tuple:
    """Measure the mutation position distribution: Find standard deviation of number of mutation positions per fragment"""

    partition_add_terminals = (0,) + partition + (len(seq),)
    linked_mutation_positions = []
    if linked_mutations_0idx:
        linked_mutation_positions = [[mut[1] for mut in mut_set] for mut_set in
                                     linked_mutations_0idx]  # example: [[13, 15],[35, 40]]

    # loop over each fragment
    all_frag_variations = []
    all_number_of_mut_positions = []
    for idx in range(0, len(partition_add_terminals) - 1):
        frag_variations = []
        number_of_mut_positions = 0
        frag_start, frag_end = partition_add_terminals[idx], partition_add_terminals[idx + 1]
        # compute number of variations for each fragment
        if mutations_0idx:
            seen_linked_muts: list = []
            for mutation in mutations_0idx:
                is_linked_flag = False
                if frag_start < mutation['position'] < frag_end:
                    number_of_mut_positions += 1
                    for positions in linked_mutation_positions:
                        if mutation['position'] in positions:
                            is_linked_flag = True
                            if mutation['position'] not in seen_linked_muts:
                                # linked mutations have variation of 2: WT and all linked mutations presenting at the
                                # same time
                                frag_variations.append(2)
                                seen_linked_muts = seen_linked_muts + positions
                            break  # assume that one mutation location can only present in one linked mutation set
                    if not is_linked_flag:
                        frag_variations.append(len(mutation['aa']) + 1)
        n_alt_fragments = math.prod(frag_variations)
        all_frag_variations.append(n_alt_fragments)
        all_number_of_mut_positions.append(number_of_mut_positions)
        mut_positions_distribution = {
            'list': all_number_of_mut_positions,
            'std': statistics.pstdev(all_number_of_mut_positions),
            'mean': sum(all_number_of_mut_positions) / len(all_number_of_mut_positions),
            'fragment_count': sum(all_frag_variations)
        }
        mut_variation_distribution = {
            'list': all_frag_variations,
            'std': statistics.pstdev(all_frag_variations),
            'mean': sum(all_frag_variations) / len(all_frag_variations),
            'fragment_count': sum(all_frag_variations)
        }

    return mut_positions_distribution, mut_variation_distribution

In [None]:
def bulk_measure_mutation_distribution_n_cost(cuttable_positions, cut_number_range, s, mutations_0idx, linked_mutations_0idx, trial):
    mut_distribution_lst = []
    for number_of_cuts in cut_number_range:
        all_combinations = combinations(cuttable_positions, number_of_cuts) if len(cuttable_positions) > 0 else [tuple()]
        for partition in all_combinations:
            mut_positions_distribution, mut_variation_distribution = \
                measure_mutation_distribution(s, partition, mutations_0idx, linked_mutations_0idx)
            cost = find_cost(s, partition, mutations_0idx, linked_mutations_0idx, 'BsaI', ENZYME_INFO)

            mut_variation_distribution_ = deepcopy(mut_variation_distribution)
            mut_variation_distribution_.update({'cost':cost,
                                                'number_of_fragments': number_of_cuts+1,
                                                'mutation_property': 'mut_variation',
                                                'trial':trial})
            mut_distribution_lst.append(mut_variation_distribution_)

            mut_positions_distribution_ = deepcopy(mut_positions_distribution)
            mut_positions_distribution_.update({'cost':cost,
                                                'number_of_fragments': number_of_cuts+1,
                                                'mutation_property': 'mut_positions',
                                                'trial':trial})
            mut_distribution_lst.append(mut_positions_distribution_)
    return mut_distribution_lst

In [None]:
def scattered_plot(x_axis, res_data, cut_number_range):

    fig, (ax1, ax2) = plt.subplots(1, 2, sharey=False, figsize=(16, 6))
    plt.subplots_adjust(wspace=0.5)
    fig.patch.set_facecolor('white')  # Set the background of the figure to white
    colors = sns.color_palette("flare", len(cut_number_range))
    color_map = {i+1:color for i,color in zip(cut_number_range,colors)}
    x_lims = {'mut_positions':[0.5,10],
              'mut_variation':[0.5,100000000000000]}
    scales = {'mut_positions': {'x':'log','y':'log'}, 'mut_variation': {'x':'log','y':'log'}}

    for ax, mutation_property in zip((ax1, ax2), ["mut_positions", "mut_variation"]):
        df = pd.DataFrame([d for d in res_data if d['mutation_property'] == mutation_property and d['number_of_fragments']-1 in cut_number_range])

        grouped = df.groupby('number_of_fragments')
        for key, group in grouped:
            sns.regplot(ax=ax,
                        x=group['std'],
                        y=group['cost'],
                        scatter=True,
                        label=key,
                        color=color_map[key],
                        scatter_kws={'alpha':0.5,'s':1},
                        line_kws={'linewidth':1.5},
                        fit_reg=False)

        ax.set_ylabel(f"Estimated Cost (€)")
        ax.set_xlabel(f"Standard Deviation of {str.title(re.sub('_', ' ', mutation_property))} per Fragment")
        ax.legend(title='Number of Fragments', loc='upper left', ncols=2)
        ax.set_xscale(scales[mutation_property]['x'])
        ax.set_yscale(scales[mutation_property]['y'])
        ax.set_facecolor('white')  # Set the background of the axes to white
        # ax.set_xlim(x_lims[mutation_property])
        ax.set_title(f'Cost in Relation to Mutation Distribution and Number of Fragments \n'
                     f'\n '
                     f'{str.title(re.sub("_", " ", mutation_property))}')
        sel_param = "cost"
        max_y = max(df[sel_param].values)
        for idx, (x, y) in enumerate(zip(df[x_axis].values,df[sel_param].values)):
            if y == max_y:
                ax.annotate(f"{round(y, 0)} €", (x, y))

    return plt

In [None]:
mutations_log = {}
mut_distribution_lst_agg = []
for trial in range(0,1):

    """Prepare mutations"""
    s, mutations_1idx, linked_mutations_1idx = random_sample_generator(
        min_aa_length=60, max_aa_length=300,
        min_number_of_positions=3, max_number_of_positions=50,
        min_variations_per_position=1, max_variations_per_position=10,
        max_positions_per_linked_mutation_set=2, max_number_of_mutation_linked_mutation_sets=1
    )
    mutations_0idx, linked_mutations_0idx = one_indexing_to_zero_indexing(
        mutations_1idx=mutations_1idx,
        linked_mutations_1idx=linked_mutations_1idx
    )
    mutations_log.update({trial:(s, mutations_0idx, linked_mutations_0idx)})

    """Cut at the middle of each mutation positions"""
    allow_cut_positions = []
    for i in range(0, len(mutations_0idx) - 1):
        allow_cut_position = (mutations_0idx[i]['position'] + mutations_0idx[i+1]['position']) // 2
        allow_cut_positions.append((allow_cut_position+1,allow_cut_position+1))

    cuttable_positions = find_cuttable_positions(s=s,
                                                 mutations_0idx=mutations_0idx,
                                                 linked_mutations_0idx=linked_mutations_0idx,
                                                 min_aa_length=MIN_FRAGMENT_LENGTH,
                                                 provider_max_dna_len=1500,
                                                 enzyme='BsaI',
                                                 allowed_cut_positions_1idx=allow_cut_positions,
                                                 enzyme_info_dic=ENZYME_INFO)

    """bulk_measure_mutation_distribution_n_cost"""
    mut_distribution_lst = bulk_measure_mutation_distribution_n_cost(cuttable_positions=cuttable_positions,
                                                                     cut_number_range=range(1,6),
                                                                     trial=trial, s=s, mutations_0idx=mutations_0idx,
                                                                     linked_mutations_0idx=linked_mutations_0idx)

    mut_distribution_lst_agg = mut_distribution_lst_agg + mut_distribution_lst


In [None]:
"""Plotting - Scatter Plots"""
plt = scattered_plot(x_axis='std', res_data=mut_distribution_lst_agg, cut_number_range=range(1,6))
plt.show()