In [214]:
import numpy as np
import sympy as sym
import scipy as sp
import matplotlib.pyplot as plt

In [215]:
def construct_frequency_matrix(path_to_pairwise_counts_unbound : str, path_to_pairwise_counts_bound : str):
    """Constructs the frequency matrix from the pairwise count files output by the dmMIME simulator.
    First the bound and unbound counts are added together to get the total counts of the initial pool.
    Then a square n x n matrix is constructed where n is the number of possible mutations, so sequence length * 3.
    For mutations at the same position the entry is 0, which are 3x3 blocks along the diagonal.
    For the mutation of interest the entry is 1, which is 1 along the diagonal.
    For mutations at different positions the entry is the pairwise frequency of both mutations - the pairwise frequency of the mutation and the wildtype.


    Args:
        path_to_pairwise_counts_unbound (str): path to the pairwise count file for unbound sequences output by the dmMIME simulator (8.txt)
        path_to_pairwise_counts_bound (str): path to the pairwise count file for bound sequences output by the dmMIME simulator (7.txt)
    """

    unbound_counts = np.loadtxt(path_to_pairwise_counts_unbound, skiprows=1, delimiter='\t')
    bound_counts = np.loadtxt(path_to_pairwise_counts_bound, skiprows=1, delimiter='\t')

    # first 2 columns for positions stay the same, the rest are the counts
    counts = np.hstack((unbound_counts[:, :2], unbound_counts[:, 2:] + bound_counts[:, 2:]))

    # construct the frequency matrix
    n_pos = np.where(counts[:, 0] == 1)[0].shape[0] + 1
    n_mut =  n_pos * 3
    freq_matrix = np.zeros((n_mut, n_mut))

    for pos1 in range(n_pos):
        for pos2 in range(n_pos):
            if pos1 == pos2:
                for mut1 in range(3):
                    for mut2 in range(3):
                    # if mutations are at the same position
                        if mut1 == mut2:
                            freq_matrix[pos1*3 + mut1, pos2*3 + mut2] = 1
                        else:
                            freq_matrix[pos1*3 + mut1, pos2*3 + mut2] = 0

            else:
                for mut1 in range(3):
                    for mut2 in range(3):
                        if pos1 < pos2:
                            pairwise_counts = counts[np.where((counts[:, 0] == pos1 + 1) & (counts[:, 1] == pos2 + 1))[0], 2:]
                            # get pairwise frequency of the mutation 2 at position 2 and the wildtype at position 1
                            freq_mut2_wt1 = pairwise_counts[0, mut2+1]/(pairwise_counts[0, 0] + pairwise_counts[0, 1] + pairwise_counts[0, 2] + pairwise_counts[0, 3])
                            # get pairwise frequency of the mutation 2 at position 2 and the mutation 1 at position 1
                            freq_mut2_mut1 = pairwise_counts[0, (mut1+1)*4 + mut2+1]/(pairwise_counts[0, (mut1+1)*4] + pairwise_counts[0, (mut1+1)*4 + 1] + pairwise_counts[0, (mut1+1)*4 + 2] + pairwise_counts[0, (mut1+1)*4 + 3])

                            freq_matrix[pos1*3 + mut1, pos2*3 + mut2] = freq_mut2_mut1 - freq_mut2_wt1

                        if pos1 > pos2:
                            pairwise_counts = counts[np.where((counts[:, 0] == pos2 + 1) & (counts[:, 1] == pos1 + 1))[0], 2:]
                            # get pairwise frequency of the mutation 2 at position 2 and the wildtype at position 1
                            freq_mut2_wt1 = pairwise_counts[0, (mut2+1)*4]/(pairwise_counts[0, 0] + pairwise_counts[0, 4] + pairwise_counts[0, 8] + pairwise_counts[0, 12])
                            # get pairwise frequency of the mutation 2 at position 2 and the mutation 1 at position 1
                            freq_mut2_mut1 = pairwise_counts[0, (mut2+1)*4 + mut1+1]/(pairwise_counts[0, (mut2+1)*4] + pairwise_counts[0, (mut2+1)*4 + 1] + pairwise_counts[0, (mut2+1)*4 + 2] + pairwise_counts[0, (mut2+1)*4 + 3])

                            freq_matrix[pos1*3 + mut1, pos2*3 + mut2] = freq_mut2_mut1 - freq_mut2_wt1

    print(freq_matrix.shape)
    print(np.round(freq_matrix[:12, :12], 2))

    return freq_matrix


In [216]:
path_to_data = '/home/user/data_directory/MIME_sim_data/no_epistasis/secondFromProt1/prot1/2d/'
construct_frequency_matrix(path_to_data + '8.txt', path_to_data + '7.txt')

(300, 300)
[[ 1.    0.    0.    0.03  0.03  0.03  0.    0.    0.    0.   -0.    0.  ]
 [ 0.    1.    0.    0.16  0.04  0.06  0.    0.01  0.01  0.   -0.    0.  ]
 [ 0.    0.    1.    0.29  0.19  0.08  0.    0.01  0.01  0.   -0.    0.01]
 [ 0.03  0.17  0.3   1.    0.    0.    0.01  0.04  0.    0.   -0.    0.  ]
 [ 0.02  0.04  0.2   0.    1.    0.    0.01  0.02  0.05  0.01 -0.    0.01]
 [ 0.03  0.06  0.08  0.    0.    1.    0.45  0.01  0.03  0.    0.    0.01]
 [-0.    0.01  0.01  0.01  0.01  0.46  1.    0.    0.    0.03  0.    0.04]
 [-0.    0.    0.01  0.04  0.02  0.02  0.    1.    0.    0.05  0.01  0.04]
 [-0.    0.    0.01  0.    0.05  0.03  0.    0.    1.    0.49  0.01  0.04]
 [-0.    0.01  0.01  0.    0.01  0.01  0.03  0.06  0.5   1.    0.    0.  ]
 [-0.01 -0.   -0.   -0.01 -0.    0.   -0.    0.01  0.02  0.    1.    0.  ]
 [-0.    0.    0.01 -0.    0.    0.01  0.03  0.05  0.05  0.    0.    1.  ]]


In [217]:
def matrix_equation(x : np.ndarray):
    """this function is the input for fsolve. It returns the matrix equation that needs to be solved to estimate the true Kds. 
    The equation has the form 0 = x + Ax - y, where x is the array of log(true Kds), A is a constructed frequency matrix, and y is the array of log(estimated Kds).
    

    Args:
        x (np.ndarray): array of log(true Kds) that is going to be estimated by solving the matrix equation
    """    