# Table of Contents
 <p>

In [None]:
import pandas as pd
import numpy as np
import scipy as scipy
import statsmodels.api as sm
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
import os

# import own libraries
import pretty_table as pretty

# plotting settings
import os
rc('text', usetex=True)
rc('text.latex', preamble=r'\usepackage{cmbright}')
rc('font', **{'family': 'sans-serif', 'sans-serif': ['Helvetica']})

%matplotlib inline

# This enables SVG graphics inline. 
%config InlineBackend.figure_formats = {'png', 'retina'}

# JB's favorite Seaborn settings for notebooks
rc = {'lines.linewidth': 2, 
      'axes.labelsize': 18, 
      'axes.titlesize': 18, 
      'axes.facecolor': 'DFDFE5'}
sns.set_context('notebook', rc=rc)
sns.set_style("dark")

# more parameters
mpl.rcParams['xtick.labelsize'] = 16 
mpl.rcParams['ytick.labelsize'] = 16 
mpl.rcParams['legend.fontsize'] = 14

In [16]:
class model:
    def __init__(self, N_100, N_010, N_001, M000, M100, M010, M001, M110, M101, M011, M111, fp, fn):
        self.N = [N_100, N_010, N_001]
        
        # index matrix, standard for all 3-way comparisons
        L = [[1,0,0],
             [0,1,0],
             [0,0,1],
             [1,1,0],
             [1,0,1],
             [0,1,1],
             [1,1,1],
             [0,0,0]]
        L = np.matrix(L)
        self.L = L
        
        # parameters:
        self.fp = fp
        self.fn = fn
        self.fps = np.zeros(shape=(2,2,2))
        self.fns = np.zeros(shape=(2,2,2))
        self.N_T = M000 + M100 + M010 + M001 + M110 + M101 + M011 + M111

        # initialize the M classes:
        self.make_M(M000, M100, M010, M001, M110, M101, M011, M111)
        
        # noise calculations
        for t in range(0, 3):
            self.false_positive(t)

            for l in L:
                if l[0,t] == 1:
                    self.false_negative(t, l)


    def make_M(self, M000, M100, M010, M001, M110, M101, M011, M111):
        """
        Inputs Mijk entries into 3x3 array, and inflates them for false negatives.
        """
        # inflation factor:
        correction = (1-self.fn)
        # total DE genes counted:
        DE = M100 + M010 + M001 + M110 + M101 + M011 + M111
    
        # make matrix and input coefficients:
        M_mat = np.zeros(shape=(2,2,2))
        
        M_mat[1, 0, 0] = M100/correction
        M_mat[0, 1, 0] = M010/correction
        M_mat[0, 0, 1] = M001/correction
        M_mat[1, 1, 0] = M110/correction
        M_mat[0, 1, 1] = M011/correction
        M_mat[1, 0, 1] = M101/correction
        M_mat[1, 1, 1] = M111/correction
        M_mat[0, 0, 0] = self.N_T - DE/correction
        
        self.M = M_mat


    def find_labels(self, t):
        """Given a genotype t, find all the labels that include that genotype"""
        return self.L[np.where(self.L[:, t] == 1)[0]]


    def find_sub_M(self, t):
        """Find the submatrix of M that only contains entries that include genotype t"""
        return self.M[np.where(self.L[:, t] == 1)[0]]


    def false_positive(self, t):
        """
        Returns the expected false positive distribution for genotype t

        t - the genotype currently being assessed
        M - the matrix containing the size of each class
        L - matrix of genotype labels
        N_t - the total number of false positives expected for t

        output:
        fps -- the false positive hit matrix for the desired subset of Labels
        """
        l = self.find_labels(t)
        for label in l:
            # find labels and classes
            label = label[0]
            M_label = self.M[label[0,0], label[0, 1], label[0, 2]]

            # calculate distribution
            if np.sum(label) > 1:
                fraction = np.sum(M_label)/np.sum(self.M)
            else:
                fraction = (np.sum(M_label) + np.sum(self.M[0,0,0]))/np.sum(self.M)

            # save
            self.fps[label[0, 0], label[0, 1], label[0, 2]] += self.fp*fraction*self.N[t]


    def false_negative(self, t, l):
        """Given a single label, l, find its adjacent labels and model false negative flow.

        params:
        t - genotype currently being assessed
        l - current label, a numpy (1x3) matrix
        M - matrix of class sizes
        L - matrix of labels

        output:
        fns - the false negative hit matrix
        """
        if l[0, t] == 0:
            raise ValueError('label must contain a 1 at the `th` position')

        # adjacent label:
        curr = np.zeros(3)
        curr[t] = 1
        l_adj = l - curr
        l_adj = l_adj.astype(int)

        m_t = self.M[l[0,0], l[0,1], l[0,2]]
        self.fns[l[0,0], l[0,1], l[0,2]] -= self.fn*m_t
        self.fns[l_adj[0,0], l_adj[0,1], l_adj[0,2]] += self.fn*m_t


    def signal_threshold(self, alpha):
        self.snr = alpha

    def test_classes(self, M_obs):
        
        accepted = np.array([0]*8)
        sn = np.array([0]*8)
        SN = M_obs/(self.fns + self.fps + 0.0001)
        for l in self.L:
            signal = SN[l[0,0], l[0,1], l[0,2]]
            if (signal > self.snr) | (signal < 0):
                # find which entry this is in
                # the column is useless
                row, col = np.where(np.all(self.L == l, axis=1))
                accepted[row] = M_obs[l[0,0], l[0,1], l[0,2]]
                sn[row] = signal
        self.accepted = accepted
        self.signal = sn

In [17]:
def find_k_min(classes):
    a = np.zeros(8)
    if classes[7] > 0:
        a[7] = classes[7]
        return a
    elif (classes[3:] > 0).any():
        a[3:] = classes[3:]
        return a
    else:
        return classes

In [156]:
def iterate(M_obs, N_A, N_B, N_C, k_min, fn=0.1, fp=0.1, alpha=5):
    """
    """
    min_model = model(N_A, N_B, N_C, *k_min, fn=fn, fp=fp)
    min_model.signal_threshold(alpha)
    min_model.test_classes(M_obs.M)

    x = True
    prev_model = min_model
    while x:
        curr_model = model(N_A, N_B, N_C, prev_model.accepted[-1],
                           *prev_model.accepted[0:-1], fn=fp, fp=fp)
        curr_model.signal_threshold(alpha)
        curr_model.test_classes(M_obs.M)

        if (curr_model.accepted == prev_model.accepted).all():
            x = False
        else:
            prev_model = curr_model

    return curr_model

In [157]:
N_t = 21000
M_obs = model(2800, 481, 2214, N_t, 1800, 78, 1226, 106, 720, 57, 242, fn=0, fp=0)
k_min = (N_t-242, 0, 0, 0, 0, 0, 0, 242)
final = iterate(M_obs, 2800, 481, 2214, k_min)
final.accepted

array([ 1800,     0,  1226,     0,   720,     0,   242, 21000])

In [161]:
def random_set():
    real = np.random.randint(0, 2, 7)
    # prevent all 0's
    while np.sum(real) < 3:
        real = np.random.randint(0, 2, 7)

    classes = np.random.randint(10, 2000, 7)
    classes[np.where(real == 0)] = 0
    N_A = classes[0] + classes[3] + classes[4]  + classes[6]
    N_B = classes[1] + classes[3] + classes[5]  + classes[6]
    N_C = classes[2] + classes[4] + classes[5]  + classes[6]
    N_T = 21000
    sizes = np.append(np.array(N_T), classes)
    true_classes = np.append(classes, [N_T])
    N = [N_A, N_B, N_C]
    
    return N, true_classes, sizes

def noise_model(N, sizes, fn, fp):
    M = model(*N, *sizes, fn=fn, fp=fp)
    M_mat = M.M + M.fps + M.fns
    M_mat[np.where(M_mat < 0)] = 0
    m_vector = (M_mat[0,0,0], M_mat[1,0,0], M_mat[0,1,0],
                M_mat[0,0,1], M_mat[1,1,0], M_mat[1,0,1],
                M_mat[0,1,1], M_mat[1,1,1])
    M_obs = model(*N, *m_vector, fn=0, fp=0)
    return M_obs

def loop(M_obs, N, true_classes, sizes, alpha=5, fp=0.1, fn=0.1):
    k_min = find_k_min(sizes)
    final = iterate(M_obs, *N, k_min, alpha=alpha, fp=fp, fn=fn)

    true_bool = true_classes > 0
    final_bool = final.accepted > 0

    correct = final_bool*true_bool
    
    return correct, final.signal

In [162]:
def run(alpha, fp, fn):
    iters = 10000
    random_classes = np.empty(shape=(iters, 8))
    classifications = np.empty(shape=(iters, 8))
    signals = np.empty(shape=(iters, 8))
    for i in range(iters):
        N, true_classes, sizes = random_set()
        M_obs = noise_model(N, sizes, fp=fp, fn=fn)
        correct, signal = loop(M_obs, N, true_classes, k_min,
                               alpha=alpha, fp=fp, fn=fn)
        random_classes[i, :] = true_classes
        classifications[i] = correct
        signals[i] = signal
    return random_classes, classifications, signals

In [208]:
def noise(sigma=0.05, iters=iters):
    return np.random.normal(0, sigma, iters)

def running_mean(x, N):
    cumsum = np.cumsum(np.insert(x, 0, 0)) 
    return (cumsum[N:] - cumsum[:-N]) / float(N)

def rc(classes, classifications, entry, subset=None):
    if subset:
        index = np.random.randint(0, len(classes[:, entry]), subset)
        r = classes[:, entry][index]
        c = classifications[:, entry][index]
        
        sorter = np.argsort(r)
        r = r[sorter]
        c = c[sorter] + noise(iters=subset)
        return r, c
    else:
        sorter = np.argsort(classes[:,entry])
        r = classes[:,entry][sorter]
        c = classifications[:,entry][sorter] + noise()
        return r, c

def interpol(classes, classifications, entry, s=10):
    r, c = rc(entry)
    f = scipy.interpolate.UnivariateSpline(running_mean(r, 10), running_mean(c, 10), s=s)
    xs = np.linspace(50, 1750, 50)
    return xs, f(xs)

def run_mean(classes, classifications, entry, interval):
    r, c = rc(classes, classifications, entry)
    return running_mean(r, interval), running_mean(c, interval)

In [164]:
random_classes, classifications, signals = run(5, 0.1, 0.1)

In [211]:
def make_pretty_plots(classes, classifications):
    fig, ax = plt.subplots(ncols=3, figsize=(15, 4))
    ax[0].plot(*rc(random_classes, classifications, 1, subset=500), 'o', alpha=0.5)
    ax[0].plot(*run_mean(random_classes, classifications, 1, 50), 'r')
    ax[0].set_title('Singly labelled class')
    ax[0].set_ylabel('Correct classification')

    ax[1].plot(*rc(random_classes, classifications, 5, subset=500), 'o', alpha=0.5)
    ax[1].plot(*run_mean(random_classes, classifications, 5, 50), 'r')
    ax[1].set_title('Doubly labelled class')
    ax[1].set_xlabel('Number of DE genes in class')

    ax[2].plot(*rc(random_classes, classifications, 6, subset=500), 'o', alpha=0.5)
    ax[2].plot(*run_mean(random_classes, classifications, 6, 50), 'r')
    ax[2].set_title('Triply labelled class')
    return fig, ax

In [218]:
def study(alpha, fp, fn):
    random_classes, classifications, signals = run(alpha, fp, fn)
    fig, ax = make_pretty_plots(random_classes, classifications)
    fig.suptitle('Params: $alpha=${0}, $q$={1}, $f$={2}'.format(alpha, fp, fn))
    return fig, ax

In [None]:
fig, ax = study(4, 0.1, 0.1)

In [None]:
fig, ax = study(5, 0.1, 0.1)

In [None]:
fig, ax = study(4, 0.01, 0.1)

In [None]:
fig, ax = study(4, 0.1, 0.3)