In [161]:
import pandas as pd
import numpy as np

In [162]:
hap_f_path = 'examples/CCDG_14151_B01_GRM_WGS_2020-08-05_chr19.filtered.shapeit2-duohmm-phased.2504.43-47Mb.ALL.maf01.haps.gz'
sample_f_path = 'examples/CCDG_14151_B01_GRM_WGS_2020-08-05_chr19.filtered.shapeit2-duohmm-phased.2504.43-47Mb.ALL.maf01.sample.gz'

In [163]:
total_numSNPs = 26246
examined_numSNPs = 50

In [164]:
import random
start_row = random.randint(0,total_numSNPs-examined_numSNPs)

In [165]:
import pickle

In [166]:
df = pd.read_csv(hap_f_path,header=None,sep=' ',skiprows=start_row,nrows=examined_numSNPs)

In [167]:
genetic_pos = df[2]

In [168]:
# make sure sort by position
df = df.sort_values(by=2)

In [169]:
# with open(f'training_data_{examined_numSNPs}.pkl','wb') as f:
#     pickle.dump((df,start_row,examined_numSNPs),f)

In [170]:
haps = df.loc[:,5:].values

In [171]:
labels = []
for i in range(haps.shape[1]//2):
    labels.append((haps[:,i],haps[:,haps.shape[1]//2+i]))
labels = np.array(labels)

In [172]:
genos = np.full(shape=(haps.shape[0],haps.shape[1]//2),fill_value=-1,dtype=int)
## create genotypes by combining each pair of haplotypes
for i in range(genos.shape[1]):
    genos[:,i] = haps[:,2*i] + haps[:,2*i+1]
genos = genos.T

In [173]:
genos.shape

(2504, 50)

In [174]:
B = 3

In [175]:
def construct_possible_haps(genotypes):
    all_possible_haps = set()
    all_possible_dips = set()
    num_samples,num_SNPs = genotypes.shape
    for i in range(num_samples):
        this_sample_possible_haps = [()]
        for j in range(num_SNPs):
            if genotypes[i,j] == 2:
                for k in range(len(this_sample_possible_haps)):
                    this_sample_possible_haps[k]+=(1,)
            elif genotypes[i,j] == 0:
                for k in range(len(this_sample_possible_haps)):
                    this_sample_possible_haps[k]+=(0,)
            else:
                new_this_sample_possible_haps = []
                previous_all_haps = this_sample_possible_haps.copy()
                for k in range(len(previous_all_haps)):
                    previous_all_haps[k]+=(1,)
                new_this_sample_possible_haps += previous_all_haps
                previous_all_haps = this_sample_possible_haps.copy()
                for k in range(len(previous_all_haps)):
                    previous_all_haps[k]+=(0,)
                new_this_sample_possible_haps += previous_all_haps
                this_sample_possible_haps = new_this_sample_possible_haps
        if len(this_sample_possible_haps) == 1:
            hap = this_sample_possible_haps[0]
            all_possible_haps.add(hap)
            all_possible_dips.add((hap,hap))
        else:
            for l in range(len(this_sample_possible_haps)//2):
                hap_1 = this_sample_possible_haps[l] 
                hap_2 = this_sample_possible_haps[len(this_sample_possible_haps)-l-1] 
                all_possible_haps.add(hap_1)
                all_possible_haps.add(hap_2)
                if hap_1<= hap_2:
                    all_possible_dips.add((hap_1,hap_2))
                else:
                    all_possible_dips.add((hap_2,hap_1))
#     all_possible_dips = set(all_possible_dips)
#     reverse_all_possible_haps_index = {}
#     for hap,index in all_possible_haps_index.items():
#         reverse_all_possible_haps_index[index] = hap
    return list(all_possible_dips),list(all_possible_haps)

In [176]:
class haplotypeSegmentGraph(object):
    '''
    H_g
    '''
    def __init__(self,genos,genetic_pos,B):
        '''
        B: number of hetero markers in each segment
        '''
        self.nodes = {}
        self.genetic_pos = genetic_pos
        self.B = B
        self.total_num_haplotypes = None
        self.build_haplotype_graph(genos)
    def __str__(self):
        nodes_count = 0
        for marker,nodes in self.nodes.items():
            nodes_count += len(nodes)
        log_num_haps = math.log(self.total_num_haplotypes)/math.log(2)
        output = f'''===========================
Number of haplotypes: {self.total_num_haplotypes} (~2^{log_num_haps})
Number of markers: {len(self.nodes)}
Number of nodes (# segment haplotypes(~=B) x # markers): {nodes_count}
==========================='''
        return output
    def __repr__(self):
        return self.__str__()
    def build_haplotype_graph(self,genos):
        B = self.B
        masking = genos.copy()
        masking[masking!=1] =0
        snp_hetero = masking.max(axis=0)
        cumsum_snp_hetero = np.cumsum(snp_hetero)
        bins = np.array([i*B for i in range(int(cumsum_snp_hetero[-1]//B)+1)])
        inds = np.digitize(cumsum_snp_hetero, bins,right=True)
        
        # for each segment
        marker_id = genos.shape[1]
        after_marker_nodes = []
        node_id = 0
        for unique_ind in np.unique(inds)[::-1]:
            segment_genos = genos[:,inds==unique_ind]
            dips,haps = construct_possible_haps(segment_genos)
            last_marker = True
            for i in list(range(segment_genos.shape[1]))[::-1]:
                marker_id -= 1
                marker_nodes = []
                for hap in haps:
                    assert len(hap) == segment_genos.shape[1]
                    pos = self.genetic_pos[marker_id]
                    node = haplotypeSegmentNode(node_id,marker_id,hap,hap[i],pos)
                    node_id += 1
                    marker_nodes.append(node)
#                     if len(after_marker_nodes) == 0:
#                         print(marker_id+1)
                    if last_marker:
                        node.type = 'inter'
                        node.outer_nodes = after_marker_nodes
                        for outer_node in node.outer_nodes:
                            outer_node.inner_nodes.append(node)
                    else:
                        node.type = 'intra'
                        node.outer_nodes = [n for n in after_marker_nodes if n.haplotype==node.haplotype]
                        for outer_node in node.outer_nodes:
                            outer_node.inner_nodes.append(node)
                after_marker_nodes = marker_nodes
                last_marker = False
                self.nodes[marker_id] = after_marker_nodes
        total_num_haplotypes = 0
        for node in self.nodes[0]:
            total_num_haplotypes += self.forward(node)
        self.total_num_haplotypes = total_num_haplotypes
        for node in self.nodes[genos.shape[1]-1]:
            self.backward(node)
        for node in self.nodes[0]:
            self.update_weight(node)
        
    def forward(self,node):
        if len(node.outer_nodes) == 0:
            node.outer_weight = 1
            return 1
        else:
            all_num_outer_haplotypes  = 0
            for outer_node in node.outer_nodes:
                if outer_node.outer_weight == None:
                    all_num_outer_haplotypes += self.forward(outer_node)
                else:
                    all_num_outer_haplotypes += outer_node.outer_weight
            node.outer_weight = all_num_outer_haplotypes
            return all_num_outer_haplotypes
    def backward(self,node):
        if len(node.inner_nodes) == 0:
            node.inner_weight = 1
            return 1
        else:
            all_num_inner_haplotypes  = 0
            for inner_node in node.inner_nodes:
                if inner_node.inner_weight == None:
                    all_num_inner_haplotypes += self.backward(inner_node)
                else:
                    all_num_inner_haplotypes += inner_node.inner_weight
            node.inner_weight = all_num_inner_haplotypes
            return all_num_inner_haplotypes
    def update_weight(self,node):
        if node.weight == None:
            node.weight = node.inner_weight * node.outer_weight
            for outer_node in node.outer_nodes:
                node.outer_weights[outer_node.id] = node.inner_weight*outer_node.outer_weight
                self.update_weight(outer_node)
        
class haplotypeSegmentNode(object):
    def __init__(self,node_id,marker,haplotype,allele,pos):
        self.id = node_id
        self.marker =  marker
        self.haplotype = haplotype
        self.allele = allele
        self.weight = None
        self.inner_weight = None
        self.outer_weight = None
        self.type = None
        self.inner_nodes = []
        self.outer_nodes = []
        self.outer_weights = {}
        self.pos = pos
    def __str__(self):
        log_weight = math.log(self.weight)/math.log(2)
        log_inner_weight = math.log(self.inner_weight)/math.log(2)
        log_outer_weight = math.log(self.outer_weight)/math.log(2)
        output = f'''===========================
Haplotype segment Node: represents a possible haplotype state for this marker in the whole dataset
--------------------------
Node id: {self.id}
Marker id: {self.marker}
Haplotype: {self.haplotype}
Allele: {self.allele}
Type(it connects to another segment[inter] or connects to the node in the same segment[intra]): {self.type}
Weight (# haplotypes going through this node): {self.weight}(~2^{log_weight})
Inner weight(# haplotypes ending at this node): {self.inner_weight} (~2^{log_inner_weight})
Outer weight weight(# haplotypes starting from this node): {self.outer_weight} (~2^{log_outer_weight})
# inner nodes (# nodes connect to it): {len(self.inner_nodes)}
# outer nodes (# nodes it connects to): {len(self.outer_nodes)}
Genetic position: {self.pos}
==========================='''
        return output
    def __repr__(self):
        return self.__str__()
    def dist(self,another_node):
        return np.abs(self.pos - another_node.pos)
    

In [177]:
class genoSegmentGraph(object):
    '''
    S_g graph
    '''
    def __init__(self,geno,B):
        self.nodes = {}
        self.B = B
        self.build_geno_graph(geno)
    def __str__(self):
        nodes_count = 0
        for marker,nodes in self.nodes.items():
            nodes_count += len(nodes)
        
        output = f'''===========================
Number of markers: {len(self.nodes)}
Number of nodes (# segment haplotypes(~=B) x # markers): {nodes_count}
==========================='''
        return output
    def __repr__(self):
        return self.__str__()
    def build_geno_graph(self,geno):
        '''
        B: number of hetero markers in each segment
        geno: genotype (num_of_markers)
        '''
        B = self.B
        # split the genotype by segments
        splitted_geno = np.split(geno, np.where(geno == 1)[0][:-1]+1)
        segments = [np.concatenate(splitted_geno[i:i+B]) for i in range(0,len(splitted_geno),B)]
        marker_id = geno.shape[0]
        after_marker_nodes = []
        node_id = 0
        for segment in segments[::-1]:
            segment_geno = np.array([segment])
            _,haps = construct_possible_haps(segment_geno)
            last_marker = True
            for i in list(range(segment_geno.shape[1]))[::-1]:
                marker_id -= 1
                marker_nodes = []
                for hap in haps:
                    node = genoSegmentNode(node_id,marker_id,hap,hap[i],segment)
                    node_id += 1
                    marker_nodes.append(node)
                    if last_marker:
                        node.type = 'inter'
                        node.outer_nodes = after_marker_nodes
                        for outer_node in node.outer_nodes:
                            outer_node.inner_nodes.append(node)
                    else:
                        node.type = 'intra'
                        node.outer_nodes = [n for n in after_marker_nodes if n.haplotype==node.haplotype]
                        for outer_node in node.outer_nodes:
                            outer_node.inner_nodes.append(node)
                if last_marker and after_marker_nodes == []:
                    leaves_nodes = marker_nodes
                after_marker_nodes = marker_nodes
                last_marker = False
                self.nodes[marker_id] = after_marker_nodes
class genoSegmentNode(object):
    def __init__(self,node_id,marker,haplotype,allele,segment):
        self.id = node_id
        self.marker =  marker
        self.haplotype = haplotype
        self.allele = allele
        self.segment = segment
        self.type = None
        self.inner_nodes = []
        self.outer_nodes = []
    def __str__(self):
        output = f'''===========================
Geno segment Node: represents a possible haplotype state for this marker in this sample genotype
--------------------------
Node id: {self.id}
Marker id: {self.marker}
Haplotype: {self.haplotype}
Segment genotype: {self.segment}
Allele: {self.allele}
Type(it connects to another segment[inter] or connects to the node in the same segment[intra]): {self.type}
# inner nodes (# nodes connect to it): {len(self.inner_nodes)}
# outer nodes (# nodes it connects to): {len(self.outer_nodes)}
==========================='''
        return output
    def __repr__(self):
        return self.__str__()

In [178]:
from scipy.special import logsumexp
import time
import math
from pathlib import Path
import pickle
from joblib import Parallel, delayed
from multiprocessing import Pool
class haplotypeHMM(object):
    '''
    hidden_states -> mosaic state (X,Y), indicating states of diploids from haplotypes X and Y
    transition -> transition prob for mosaic state
    emission -> prob of genotype condition on mosaic state
    inital -> prob initial mosaic state
    
    '''
    def __init__(self, hap_graph,savedir='./checkpoints/',pseudocount=1e-100,seed=42):
        self.hap_graph = hap_graph
        self.B = self.hap_graph.B
        self.total_num_haplotypes = self.hap_graph.total_num_haplotypes
        self.seed = seed
        self.pseudocount = pseudocount
        K = self.total_num_haplotypes
        self.theta = 1/(math.log(K) + 0.5772)
        self.population_size = 15000
        self.savedir = savedir
#     def emit_prob(self, this_state, haplotype):
#         return self.emissions[this_state][haplotype]
    
#     def transition_prob(self, this_state, next_state):
#         return self.transitions[this_state][next_state]
    
#     def init_prob(self, this_state):
#         return self.initial[this_state]

    def emit_prob(self, this_node, geno_graph_node):
        K = self.total_num_haplotypes
        theta = self.theta
        if geno_graph_node.allele == this_node.allele:
            delta = 1
        else:
            delta = 0
        prob = K/(K+theta) * delta + theta/(K+theta)/2
        assert prob >= 0 and prob <= 1
        if prob == 0:
            prob = self.pseudocount
        if prob == 1:
            prob -= self.pseudocount
        return prob
#         return self.emissions[this_node][genotype]
    
    def transition_prob(self, this_node, next_node):
        K = self.total_num_haplotypes
        rho = 1 - np.exp(-4*self.population_size*this_node.dist(next_node)/K)
        if next_node.id in this_node.outer_weights:
            edge_weight = this_node.outer_weights[next_node.id]
        else:
            edge_weight = 0
        prob = (1-rho) * edge_weight / this_node.weight + rho * next_node.weight / self.total_num_haplotypes
        assert prob >= 0 and prob <= 1
        if prob == 0:
            prob = self.pseudocount
        if prob == 1:
            prob -= self.pseudocount
        return prob
    
    def init_prob(self, node):
        prob = node.weight / self.total_num_haplotypes
        assert prob >= 0 and prob <= 1
        if prob == 0:
            prob = self.pseudocount
        if prob == 1:
            prob -= self.pseudocount
        return prob
    
    def logsumexp(self,lst):
        arr = np.array(lst)
#         log_arr =np.log(arr)
#         if np.any(arr<=0):
#             print(arr)
#             raise Exception()
        return logsumexp(arr)
    def forward(self, geno_graph):
        log_left_list = []
        log_left = np.empty(shape=(len(geno_graph.nodes[0]),len(self.hap_graph.nodes[0])))
        # for first marker
        for i,geno_graph_node in enumerate(geno_graph.nodes[0]):
            for j,hap_graph_node in enumerate(self.hap_graph.nodes[0]):
                log_left[i,j] = np.log(self.init_prob(hap_graph_node)) + np.log(self.emit_prob(hap_graph_node, geno_graph_node))
        log_left_list.append(log_left)
        all_marker_id = list(sorted(hap_graph.nodes.keys()))
        for marker_id in all_marker_id[1:]:
            log_left = np.empty(shape=(len(geno_graph.nodes[marker_id]),len(self.hap_graph.nodes[marker_id])))
            for i,geno_graph_node in enumerate(geno_graph.nodes[marker_id]):
                for j,hap_graph_node in enumerate(self.hap_graph.nodes[marker_id]):
                    all_temp_log_left = []
                    for u,prev_geno_graph_node in enumerate(geno_graph.nodes[marker_id-1]):
                        if prev_geno_graph_node in geno_graph_node.inner_nodes:
                            for v,prev_hap_graph_node in enumerate(self.hap_graph.nodes[marker_id-1]):
                                all_temp_log_left.append(log_left_list[marker_id-1][u,v] +  np.log(self.transition_prob(prev_hap_graph_node,hap_graph_node)))
                    
                    log_left[i,j] = np.log(self.emit_prob(hap_graph_node, geno_graph_node)) + self.logsumexp(all_temp_log_left)
            log_left_list.append(log_left)
        return log_left_list
    def backward(self, geno_graph):
        log_right_list = []
        log_right = np.empty(shape=(len(geno_graph.nodes[0]),len(self.hap_graph.nodes[0])))
        # for last marker
        all_marker_id = list(sorted(hap_graph.nodes.keys()))
        marker_id = all_marker_id[-1]
        for i,geno_graph_node in enumerate(geno_graph.nodes[marker_id]):
            for j,hap_graph_node in enumerate(self.hap_graph.nodes[marker_id]):
                log_right[i,j] = 0
        log_right_list.insert(0,log_right)
        for marker_id in all_marker_id[-2::-1]:
            log_right = np.empty(shape=(len(geno_graph.nodes[marker_id]),len(self.hap_graph.nodes[marker_id])))
            for i,geno_graph_node in enumerate(geno_graph.nodes[marker_id]):
                for j,hap_graph_node in enumerate(self.hap_graph.nodes[marker_id]):
                    all_temp_log_right = []
                    for u,next_geno_graph_node in enumerate(geno_graph.nodes[marker_id+1]):
                        if next_geno_graph_node in geno_graph_node.outer_nodes:
                            for v,next_hap_graph_node in enumerate(self.hap_graph.nodes[marker_id+1]):
                                all_temp_log_right.append(log_right_list[0][u,v] +  np.log(self.transition_prob(hap_graph_node,next_hap_graph_node)))
                    
                    log_right[i,j] = np.log(self.emit_prob(hap_graph_node, geno_graph_node)) + self.logsumexp(all_temp_log_right)
            log_right_list.insert(0,log_right)
        return log_right_list
    def expectation(self,geno_graph):
        log_alpha_list = self.forward(geno_graph)
        log_beta_list = self.backward(geno_graph)
        
        all_marker_id = list(sorted(hap_graph.nodes.keys()))
        marker_id = all_marker_id[0]
        init_log_marginal = np.empty(shape=(len(geno_graph.nodes[marker_id])))
        for i,geno_graph_node in enumerate(geno_graph.nodes[marker_id]):
            init_log_marginal[i] = self.logsumexp(log_beta_list[0][i,:]) - self.logsumexp(log_beta_list[0])
        
        log_marginal_list = [init_log_marginal]
        for marker_id in all_marker_id[0:-1]:
            log_marginal = np.empty(shape=(len(geno_graph.nodes[marker_id]),len(geno_graph.nodes[marker_id+1])))
            temp_log_marginal = []
            for i1,geno_graph_node_1 in enumerate(geno_graph.nodes[marker_id]):
                for i2,geno_graph_node_2 in enumerate(geno_graph.nodes[marker_id+1]):
                    for j1,hap_graph_node_1 in enumerate(self.hap_graph.nodes[marker_id]):
                        for j2,hap_graph_node_2 in enumerate(self.hap_graph.nodes[marker_id+1]):
                            temp_log_marginal.append(log_alpha_list[marker_id][i1,j1] + np.log(self.transition_prob(hap_graph_node_1,hap_graph_node_2))+\
                            np.log(self.emit_prob(hap_graph_node_2,geno_graph_node_2)) + log_beta_list[marker_id+1][i2,j2])
            index = 0
            for i1,geno_graph_node_1 in enumerate(geno_graph.nodes[marker_id]):
                for i2,geno_graph_node_2 in enumerate(geno_graph.nodes[marker_id+1]):
                    log_marginal[i1,i2] = temp_log_marginal[index] - self.logsumexp(temp_log_marginal)
                    index += 1
            log_marginal_list.append(log_marginal)
        return log_marginal_list
                            
                            
            
    
#     def check_convergence(self,this_px,previous_px,eps):
#         if np.linalg.norm(np.array(this_px) - np.array(previous_px))/np.linalg.norm(np.array(previous_px)) < eps:
#             return True
#         else:
#             return False
#     def check_compatible_with_genotype(self,genotype,hap_1,hap_2):
#         assert len(hap_1) == len(hap_2) == len(genotype) 
#         incompatible_snps = 0
#         for i in range(len(genotype)):
#             if  hap_1[i] + hap_2[i] != genotype[i]:
#                 incompatible_snps += 1
#         return incompatible_snps
    
#     def predict(self,geno_graph_nodes):
#         initials = np.empty((num_states))
#         transitions = np.empty((num_states,num_states))
        
    def predict(self,genos,choice='Viterbi',threads=1):
        if choice == 'Viterbi':
            genos_list = [geno for geno in genos]
#             with Pool(threads) as p:
#                 predictions = p.map(self.predict_Viterbi, genos_list)
            
            predictions = Parallel(n_jobs=threads)(delayed(self.predict_Viterbi)(geno) for geno in genos_list)
        return np.array(predictions)

#     def save_checkpoint(self,check_point_folder,it):
#         Path(check_point_folder).mkdir(exist_ok=True,parents=True)
#         timestr = time.strftime("%Y%m%d-%H%M%S")
#         with open(check_point_folder+f'/{it}_{timestr}.pt','wb') as f:
#             pickle.dump(self,f)
#     def predict_Viterbi(self, geno,geno_graph,log_marginal_list):
#         # Predict by Viterbi
#         previous_col_probs = {}
#         for i in range(log_marginal_list[0].shape[0]):
#             previous_col_probs[i] = log_marginal_list[0][i]
#         traceback = []

#         for t in range(1, len(log_marginal_list)): 
#             marker_id = t - 1
#             traceback_next = {}
#             previous_col_probs_next = {}
#             for i1,geno_graph_node in enumerate(geno_graph.nodes[marker_id]):
#                 k = previous_col_probs[i1] + log_marginal_list[t][i1,:]
#                 for i2,geno_graph_node in enumerate(geno_graph.nodes[marker_id+1]):
# #                     if geno[marker_id] >= 
#                     argmax_i2 = np.argmax(k)
#                 traceback_next[i1] = argmax_i2
#                 previous_col_probs_next[i1] = k[argmax_i2]
#             traceback.append(traceback_next)
#             previous_col_probs = previous_col_probs_next
            

#         max_final_state = None
#         max_final_prob = -np.inf
#         for state,prob in previous_col_probs.items():
#             if prob > max_final_prob:
#                 max_final_prob = prob
#                 max_final_state = state
#         all_marker_id = list(sorted(hap_graph.nodes.keys()))
#         result = [geno_graph.nodes[len(all_marker_id)-1][max_final_state]]
#         for t in range(len(all_marker_id)-2,-1,-1):
#             marker_id = all_marker_id[t]
#             result.append(geno_graph.nodes[marker_id][traceback[t][max_final_state]])
#             max_final_state = traceback[t][max_final_state]

#         return result[::-1]
    def predict_Viterbi(self, geno):
        geno_graph = genoSegmentGraph(geno,self.B)
        log_marginal_list = hmm.expectation(geno_graph)
        
        # Predict by Viterbi
        previous_col_probs = {}
        for i in range(log_marginal_list[0].shape[0]):
            for j in range(log_marginal_list[0].shape[0]):
                previous_col_probs[(i,j)] = log_marginal_list[0][i] + log_marginal_list[0][j]
        traceback = []
        incompatible_penalty = 0.1
        all_marker_id = list(sorted(geno_graph.nodes.keys()))
        for t in range(1, len(log_marginal_list)): 
            marker_id = all_marker_id[t-1]
            traceback_next = {}
            previous_col_probs_next = {}
            # next state
            for hap1_i2,hap1_geno_graph_node_i2 in enumerate(geno_graph.nodes[marker_id+1]):
                for hap2_i2,hap2_geno_graph_node_i2 in enumerate(geno_graph.nodes[marker_id+1]):
                        
                    best_prob = -np.inf
                    best_haps = None
                    # prev state
                    for hap1_i1,hap1_geno_graph_node_i1 in enumerate(geno_graph.nodes[marker_id]):
                        for hap2_i1,hap2_geno_graph_node_i1 in enumerate(geno_graph.nodes[marker_id]):

                            num_mismatches = np.abs(geno[marker_id] - (hap1_geno_graph_node_i1.allele + hap2_geno_graph_node_i1.allele)) + np.abs(geno[marker_id+1] - (hap1_geno_graph_node_i2.allele + hap2_geno_graph_node_i2.allele))
                            prob = previous_col_probs[(hap1_i1,hap2_i1)] + log_marginal_list[t][hap1_i1,hap1_i2] + \
                            log_marginal_list[t][hap2_i1,hap2_i2]
                            prob *= 1+incompatible_penalty * num_mismatches
                            if prob > best_prob:
                                best_prob = prob
                                best_haps = (hap1_i1,hap2_i1)
                    if best_haps != None:
                        traceback_next[(hap1_i2,hap2_i2)] = best_haps
                        previous_col_probs_next[(hap1_i2,hap2_i2)] = best_prob
            previous_col_probs = previous_col_probs_next
            traceback.append(traceback_next)
            

        max_final_state = None
        max_final_prob = -np.inf
        for state,prob in previous_col_probs.items():
            if prob > max_final_prob:
                max_final_prob = prob
                max_final_state = state
        
        nodes = geno_graph.nodes[all_marker_id[-1]]
#         if max_final_state[0] >= len(nodes) or max_final_state[1] >= len(nodes):
#             return geno,geno_graph,log_marginal_list
        result = [(nodes[max_final_state[0]],nodes[max_final_state[1]])]
        for t in range(len(all_marker_id)-2,-1,-1):
            marker_id = all_marker_id[t]
            nodes = geno_graph.nodes[marker_id]
            max_final_state = traceback[t][max_final_state]
            result.append((nodes[max_final_state[0]],nodes[max_final_state[1]]))
        results = result[::-1]
        hap1 = []
        hap2 = []
        for (hap1_node,hap2_node) in results:
            if hap1_node.type == 'inter':
                hap1 += hap1_node.haplotype
            if hap2_node.type == 'inter':
                hap2 += hap2_node.haplotype
        return [hap1,hap2]


In [179]:
hap_graph = haplotypeSegmentGraph(genos,genetic_pos,B)

In [180]:
hap_graph

Number of haplotypes: 471415610408960 (~2^48.74399286106018)
Number of markers: 50
Number of nodes (# segment haplotypes(~=B) x # markers): 374

In [181]:
hap_graph.nodes[0][0]

Haplotype segment Node: represents a possible haplotype state for this marker in the whole dataset
--------------------------
Node id: 366
Marker id: 0
Haplotype: (1, 0, 1)
Allele: 1
Type(it connects to another segment[inter] or connects to the node in the same segment[intra]): intra
Weight (# haplotypes going through this node): 58926951301120(~2^45.74399286106018)
Inner weight(# haplotypes ending at this node): 1 (~2^0.0)
Outer weight weight(# haplotypes starting from this node): 58926951301120 (~2^45.74399286106018)
# inner nodes (# nodes connect to it): 0
# outer nodes (# nodes it connects to): 1
Genetic position: 43959977

In [182]:
hmm = haplotypeHMM(hap_graph)

In [183]:
# for i in range(10):
#     geno_graph = genoSegmentGraph(genos[:,10])
#     geno_graph.nodes[0]

In [184]:
results = hmm.predict(genos,threads=20)

In [189]:
def get_switch_error(prediction,truth):
    SERs = []
    hetero_masking = truth[0] + truth[1] == 1
    hetero_prediction = prediction[0][hetero_masking]
    hetero_truth = truth[0][hetero_masking]
    if hetero_truth.shape[0] == 0:
        return 0
    hap_type = hetero_prediction == hetero_truth
    num_switch = 0
    for i in range(len(hap_type)-1):
        if hap_type[i] != hap_type[i+1]:
            num_switch += 1
    SERs.append(num_switch/hap_type.shape[0])
    
    hetero_prediction = prediction[1][hetero_masking]
    hetero_truth = truth[0][hetero_masking]
    if hetero_truth.shape[0] == 0:
        return 0
    hap_type = hetero_prediction == hetero_truth
    num_switch = 0
    for i in range(len(hap_type)-1):
        if hap_type[i] != hap_type[i+1]:
            num_switch += 1
    SERs.append(num_switch/hap_type.shape[0])
    
    return np.min(SERs)

In [190]:
def calculate_switch_error_rate(results,labels):
    all_SER = []
    for i in range(len(results)):
        SER = get_switch_error(results[i],labels[i])
        all_SER.append(SER)
    return all_SER

In [191]:
all_SER = calculate_switch_error_rate(results,labels)

In [192]:
np.median(all_SER)

0.125

In [None]:
#     def initialize_HMM_parameters_MACH(self, alphabet, states,seed):

#         theta = epsi = 0.01
#         H = len(self.haplotypes)
#         transitions = {}
#         emissions = {}
#         initial = {}
#         np.random.seed(seed=seed)
#         initial_rand = np.random.dirichlet(np.ones(len(self.diplotypes)))
#         for i, state in enumerate(self.get_diplotypes()):
#             transitions[state] = {}
#             emissions[state] = {}
#             initial[state] = initial_rand[i]
#             # update emission matrix
#             for j,geno in enumerate(self.get_genotypes()):
#                 if geno ==1:
#                     emissions[state][geno] = emissions_rand[j]
                
#             # update transition matrix
#             state_x, state_y = state
#             for j, next_state in enumerate(self.get_diplotypes()):
#                 next_state_x,next_state_y = next_state
#                 if (state_X != next_state_x) and (state_Y != next_state_y):
#                     transitions[state][next_state] =  theta**2/H**2
#                 elif (state_X != next_state_x) or (state_Y != next_state_y):
#                     transitions[state][next_state] =  (1-theta) * theta/H + theta**2/(H**2)
#                 else:
#                     transitions[state][next_state] =  (1-theta) ** 2+ 2*(1-theta)*theta/H + theta**2/(H**2)
#         return transitions,emissions,initial

In [23]:
def construct_possible_haps(genotypes):
    all_possible_haps_index = {}
    all_possible_dips = []
    num_samples,num_SNPs = genotypes.shape
    for i in range(num_samples):
        this_sample_possible_haps = [()]
        for j in range(num_SNPs):
            if genotypes[i,j] == 2:
                for k in range(len(this_sample_possible_haps)):
                    this_sample_possible_haps[k]+=(1,)
            elif genotypes[i,j] == 0:
                for k in range(len(this_sample_possible_haps)):
                    this_sample_possible_haps[k]+=(0,)
            else:
                new_this_sample_possible_haps = []
                previous_all_haps = this_sample_possible_haps.copy()
                for k in range(len(previous_all_haps)):
                    previous_all_haps[k]+=(1,)
                new_this_sample_possible_haps += previous_all_haps
                previous_all_haps = this_sample_possible_haps.copy()
                for k in range(len(previous_all_haps)):
                    previous_all_haps[k]+=(0,)
                new_this_sample_possible_haps += previous_all_haps
                this_sample_possible_haps = new_this_sample_possible_haps
        if len(this_sample_possible_haps) == 1:
            hap = this_sample_possible_haps[0]
            if hap not in all_possible_haps_index:
                all_possible_haps_index[hap] = len(all_possible_haps_index)
            all_possible_dips.append((all_possible_haps_index[hap],all_possible_haps_index[hap]))
        else:
            for l in range(len(this_sample_possible_haps)//2):
                hap_1 = this_sample_possible_haps[l] 
                hap_2 = this_sample_possible_haps[len(this_sample_possible_haps)-l-1] 
                if hap_1 not in all_possible_haps_index:
                    all_possible_haps_index[hap_1] = len(all_possible_haps_index)
                if hap_2 not in all_possible_haps_index:
                    all_possible_haps_index[hap_2] = len(all_possible_haps_index)
                if all_possible_haps_index[hap_1]<= all_possible_haps_index[hap_2]:
                    all_possible_dips.append((all_possible_haps_index[hap_1],all_possible_haps_index[hap_2]))
                else:
                    all_possible_dips.append((all_possible_haps_index[hap_2],all_possible_haps_index[hap_1]))
    all_possible_dips = set(all_possible_dips)
    reverse_all_possible_haps_index = {}
    for hap,index in all_possible_haps_index.items():
        reverse_all_possible_haps_index[index] = hap
    return all_possible_haps_index,all_possible_dips,reverse_all_possible_haps_index

In [None]:
import time
st = time.time()
training_data = genos[:,:]
possible_genotypes = [0,1,2]
all_possible_haps_index,all_possible_dips,reverse_all_possible_haps_index = construct_possible_haps(training_data)
model = haplotypeHMM(possible_genotypes, list(all_possible_dips),range(len(all_possible_haps_index)),reverse_all_possible_haps_index,savedir='./40_SNPs/',seed=70,pseudocount=1e-100)

model.fit(training_data,100,1e-3)
print(time.time()-st)

In [None]:
model.predict(training_data).shape

In [None]:
# class haplotypeHMM(object):
#     '''
#     hidden_states -> mosaic state (X,Y), indicating states of diploids from haplotypes X and Y
#     transition -> transition prob for mosaic state
#     emission -> prob of genotype condition on mosaic state
#     inital -> prob initial mosaic state
    
#     '''
#     def __init__(self, genotypes, diplotypes, possible_haplotypes,seed=42,pseudocount=1e-100):
#         self.genotypes = list(set(genotypes))
#         self.diplotypes = list(set(hidden_states))
#         self.haplotypes = possible_haplotypes
#         self.seed = seed
#         self.pseudocount = pseudocount
#         self.transitions,self.emissions,self.initial = self.initialize_HMM_parameters_randomly(self.genotypes, self.diplotypes, self.seed)
    
#     def emit_prob(self, this_state, haplotype):
#         return self.emissions[this_state][haplotype]
    
#     def transition_prob(self, this_state, next_state):
#         return self.transitions[this_state][next_state]
    
#     def init_prob(self, this_state):
#         return self.initial[this_state]

#     def get_diplotypes(self):
#         for state in self.diplotypes:
#             yield state
#     def get_genotypes(self):
#         for genotype in self.genotypes:
#             yield genotype
#     def initialize_HMM_parameters_randomly(self, alphabet, states,seed):
        
#         transitions = {}
#         emissions = {}
#         initial = {}
#         np.random.seed(seed=seed)
#         initial_rand = np.random.dirichlet(np.ones(len(self.diplotypes)))
#         for i, state in enumerate(self.get_diplotypes()):
#             transitions[state] = {}
#             emissions[state] = {}
#             initial[state] = initial_rand[i]
#             emissions_rand = np.random.dirichlet(np.ones(len(self.genotypes)))
#             transitions_rand = np.random.dirichlet(np.ones(len(self.diplotypes)))
#             for j,geno in enumerate(self.get_genotypes()):
#                 emissions[state][geno] = emissions_rand[j]
#             for j, next_state in enumerate(self.get_diplotypes()):
#                 transitions[state][next_state] = transitions_rand[j]
                
#         return transitions,emissions,initial
#     def log_sum_exp(self,x):
#         x_arr = np.array(x)
#         x_log_arr = np.log(x_arr)
#         x_log_max = x_log_arr.max()
#         return x_log_max + np.log(np.sum(np.e**(x_log_arr-x_log_max)))
#     def sum_normalize(self,x):
#         return np.exp(np.log(np.array(x))-self.log_sum_exp(x))
#     def calculate_s_value(self, seq_pos, previous_vars,single_pos_genotype):
#         """Calculate the next scaling variable for a sequence position (PRIVATE).
#         This utilizes the approach of choosing s values such that the
#         sum of all of the scaled f values is equal to 1.
#         Arguments:
#          - seq_pos -- The current position we are at in the sequence.
#          - previous_vars -- All of the forward or backward variables
#            calculated so far.
#         Returns:
#          - The calculated scaling variable for the sequence item.
#         """
#         # all of the different letters the state can have
#         state_letters = self.get_diplotypes()

#         # loop over all of the possible states
#         s_value = 0
#         for main_state in state_letters:
#             emission = self.emit_prob(main_state,single_pos_genotype)

#             # now sum over all of the previous vars and transitions
#             trans_and_var_sum = 0
#             for second_state in state_letters:
#                 # the value of the previous f or b value
#                 var_value = previous_vars[seq_pos - 1][second_state]

#                 # the transition probability
#                 trans_value = self.transition_prob(main_state,second_state)

#                 trans_and_var_sum += var_value * trans_value

#             s_value += emission * trans_and_var_sum

#         return s_value
#     def forward(self, genotype):
#         left_list = []
#         left = {}
#         for state in self.get_diplotypes():
#             left[state] = self.init_prob(state) * self.emit_prob(state, genotype[0])
#         left_list.append(left)
        
#         for i in range(1, len(genotype)):
#             s_value = self.calculate_s_value(i,left_list,genotype[i])
#             left = {}
#             for next_state in self.get_diplotypes(): 
#                 left[next_state] = 0
#                 for this_state in self.get_diplotypes():
#                     left[next_state] += left_list[i-1][this_state] * self.transition_prob(this_state, next_state) 
#                 left[next_state] =  self.emit_prob(next_state, genotype[i]) * left[next_state]
#             for next_state in self.get_diplotypes():
#                 left[next_state] /= s_value
#             left_list.append(left)
#         # rescale left
# #         print('LEFT')
# #         print(left_list)
# #         print(scale_list)
# #         for i in range(len(left_list)):
# #             for state in self.get_diplotypes():
# #                 left_list[i][state] *= np.sum(scale_list[:i+1])
#         posterior = 0
#         for state in self.get_diplotypes():
#             posterior += left_list[-1][state]

#         return posterior, left_list


#     def backward(self, genotype):
# #         scale_list = [1]
#         right_list = [] 
#         right = {}
#         for state in self.get_diplotypes():
#             right[state] = 1
#         right_list.append(right)

#         for i in range(len(genotype)-2, -1, -1):
#             s_value = self.calculate_s_value(len(genotype)-1-i,right_list[::-1],genotype[i])

#             right = {} 
# #             scale = 0 
#             for state in self.get_diplotypes():
#                 right[state] = 0
#                 for next_state in self.get_diplotypes():
#                     right[state] += right_list[0][next_state] * self.transition_prob(state, next_state) * self.emit_prob(next_state, genotype[i])
# #                     scale += right[state]
# #             for state in self.get_diplotypes():
# #                 right[state] /= (scale_list[i+1]-scale_list[i])
# #             scale_list.insert(0,scale)
#             right_list.insert(0,right)
        
#         # rescale left
# #         print('RIGHT')
# #         print(right_list)
# #         print(scale_list)
# #         scaled_right_list = right_list.copy()
# #         for i in range(len(right_list)-1,-1,-1):
# #             for state in self.get_diplotypes():
# #                 right_list[i][state] *= np.sum(scale_list[i:])
        
#         posterior = 0
#         for state in self.get_diplotypes():
#             posterior += right_list[0][state] * self.init_prob(state) * self.emit_prob(state, genotype[0])

#         return posterior, right_list
#     def check_convergence(self,all_iters_total_likelihood,eps):
#         if np.linalg.norm(np.array(all_iters_total_likelihood[-1]) - np.array(all_iters_total_likelihood[-2])) < eps:
#             return True
#         else:
#             return False
#     def fit(self, genotypes, max_it,eps=1e-10):
#         pseudocount = self.pseudocount
#         # Train by Baum-Weltch
#         # Inititalization
#         all_iters_KL_divergence = []
#         transitions,emissions,initial = self.initialize_HMM_parameters_randomly(self.genotypes, self.diplotypes, self.seed)
#         for it in range(max_it):
#             print(f'------------------Iter:{it}------------------')
#              # Expectation
#             sum_Px = 0
#             # get the sum over Px first
#             for j in range(len(genotypes)):
#                 genotype = genotypes[j]
#                 # forward and backward
#                 f_Px, _ = self.forward(genotype)
#     #             r_Px, r_matrix = self.backward(genotype)
#                 sum_Px += 1/f_Px
# #             if it >= 1 and self.check_convergence(all_iters_total_likelihood,eps):
# #                 print(all_iters_total_likelihood)
# #                 break
#             for m in range(len(genotypes)):
#                 genotype = genotypes[m]
#                 # forward and backward
#                 f_Px, f_matrix = self.forward(genotype)
#                 r_Px, r_matrix = self.backward(genotype)
#                 for k in self.get_diplotypes():
#                     # Update transition matrix A
#                     for l in self.get_diplotypes():
#                         A = 0
#                         for i in range(len(genotype)-1):
#                             A += f_matrix[i][k] * self.transition_prob(k,l) *  self.emit_prob(l, genotype[i+1]) * r_matrix[i+1][l]
#                         transitions[k][l] = pseudocount + sum_Px * A

#                     # Update emission matrix E
#                     for j, sigma in enumerate(self.get_genotypes()):
#                         E = 0
#                         for i in range(len(genotype)):
#                             if genotype[i] == sigma:
#                                 E += f_matrix[i][k] * r_matrix[i][k]

#                         emissions[k][sigma] = pseudocount + sum_Px * E

#                     # Update initial state matrix B
#                     initial[k] = sum_Px * f_matrix[0][k] * r_matrix[0][k] 
#              # Maximization
#             for k in self.get_diplotypes():
#                 sum_A = 0
#                 for l in self.get_diplotypes():
#                     sum_A += transitions[k][l]
#                 for l in self.get_diplotypes():
#                     transitions[k][l] = transitions[k][l]/sum_A


#                 sum_E = 0
#                 for j, sigma in enumerate(self.get_genotypes()):
#                     sum_E += emissions[k][sigma]
#                 for j, sigma in enumerate(self.get_genotypes()):
#                     emissions[k][sigma] = emissions[k][sigma]/sum_E
# #             print('Param')
# #             print(transitions[k])
# #             print(emissions[k])
#             sum_B = 0
#             for k in self.get_diplotypes():
#                 sum_B += initial[k]
#             for k in self.get_diplotypes():
#                 initial[k] = initial[k]/sum_B
# #             # Maximization
            
# #             for k in self.get_diplotypes():
# #                 all_a = []
# #                 for l in self.get_diplotypes():
# #                     all_a.append(transitions[k][l])
# #                 all_normalized_a = self.sum_normalize(all_a)
# #                 for l_index,l in enumerate(self.get_diplotypes()):
# #                     transitions[k][l] = all_normalized_a[l_index]
                
# #                 all_e = []
# #                 for j, sigma in enumerate(self.get_genotypes()):
# #                     all_e.append(emissions[k][sigma])
                
# #                 all_normalized_e = self.sum_normalize(all_e)
# #                 for sigma_index, sigma in enumerate(self.get_genotypes()):
# #                     emissions[k][sigma] = all_normalized_e[sigma_index]

# #             all_b = []
# #             for k in self.get_diplotypes():
# #                 all_b.append(initial[k])
# #             all_normalized_b = self.sum_normalize(all_b)
            
# #             for k_index,k in enumerate(self.get_diplotypes()):
# #                 initial[k] = all_normalized_b[k_index]
# #             if it >= 1 and self.check_convergence(transitions,eps):
# #                 print(all_iters_total_likelihood)
# #                 break
#             self.transitions = transitions
#             self.emissions = emissions
#             self.initial = initial
#     def predict(self, genotype):
#         # Predict by Viterbi
#         previous_col_probs = {} 
#         traceback = []
#         for state in self.get_diplotypes():
#             previous_col_probs[state] = np.log(self.init_prob(state)) + np.log(self.emit_prob(state, genotype[0]))

#         for t in range(1, len(genotype)): 
#             previous_col_probs_next = {}
#             traceback_next = {}

#             for next_state in self.get_diplotypes():  
#                 k = {}
#                 for this_state in self.get_diplotypes():
#                     k[this_state] = previous_col_probs[this_state] + np.log(self.transition_prob(this_state, next_state)) 
#                 max_k = -np.inf
#                 argmax_k = None
#                 for state,val in k.items():
#                     if val > max_k:
#                         argmax_k = state
#                         max_k = val
#                 previous_col_probs_next[next_state] =  np.log(self.emit_prob(next_state, genotype[t])) + k[argmax_k]
#                 traceback_next[next_state] = argmax_k

#             previous_col_probs = previous_col_probs_next
#             traceback.append(traceback_next)

#         max_final_state = None
#         max_final_prob = -np.inf
#         for state,prob in previous_col_probs.items():
#             if prob > max_final_prob:
#                 max_final_prob = prob
#                 max_final_state = state

#         result = [max_final_state]
#         for t in range(len(genotype)-2,-1,-1):
#             result.append(traceback[t][max_final_state])
#             max_final_state = traceback[t][max_final_state]

#         return result[::-1]


In [362]:
# class haplotypeHMM_fast(object):
#     '''
#     hidden_states -> mosaic state (X,Y), indicating states of diploids from haplotypes X and Y
#     transition -> transition prob for mosaic state
#     emission -> prob of genotype condition on mosaic state
#     inital -> prob initial mosaic state
    
#     '''
#     def __init__(self, alphabet, hidden_states, possible_haplotypes, seed=None):
#         self._alphabet = alphabet
#         self._hidden_states = hidden_states
#         self._haplotypes = possible_haplotypes
#         self._seed = seed
#         self._initialize_theta()
#         self._transitions, self._emissions, self._initial = self._initialize_HMM(self._seed)
# #         if(self._transitions == None):
# #             self._initialize_random(self._alphabet, self._hidden_states, self._seed)
    
#     def _emit(self, cur_state, symbol):
#         return self._emissions[cur_state][symbol]
    
#     def _transition(self, cur_state, next_state):
#         return self._transitions[cur_state][next_state]
    
#     def _init(self, cur_state):
#         return self._initial[cur_state]

#     def _states(self):
#         for k in self._hidden_states:
#             yield k
#     def _emissions_by_theta(self,theta,state,prev_state):
#         H = len(self._haplotypes)
#         state_X,state_Y = state
#         prev_state_X,prev_state_Y = prev_state
#         if (state_X != prev_state_X) and (state_Y != prev_state_Y):
#             return theta**2/H**2
#         elif (state_X != prev_state_X) or (state_Y != prev_state_Y):
#             return (1-theta) * theta/H + theta**2/(H**2)
#         else:
#             return (1-theta) ** 2+ 2*(1-theta)*theta/H + theta**2/(H**2)
#     def _initialize_theta(self,):
#         self._theta = np.ones(len(self._hidden_states))//100
#     def _initialize_HMM(self,seed):
#         transitions = {}
#         emissions = {}
#         initial = {}
#         # initailize the inital prob by dirichlet distribution
#         np.random.seed(seed=seed)
#         initial_rand = np.random.dirichlet(np.ones(len(self._hidden_states)))
#         for i, state in enumerate(self._states()):
#             emissions[state] = {}
#             transitions[state] = {}
#             initial[state] = initial_rand
#             E_rand = np.random.dirichlet(np.ones(len(self._alphabet)))
#             for j, sigma in enumerate(self._get_alphabet()):
#                 emissions[state][sigma] = E_rand[j]
#             for j, next_state in enumerate(self._states()):
#                 transitions[state][next_state] = self._emissions_by_theta(self._theta[j],next_state,state)
#         return transitions,emissions,initial
# #     def _initialize_random(self, alphabet, states, seed):
# #         alphabet = list(set(alphabet))
# #         alphabet.sort()
# #         states = list(set(states))
# #         states.sort()
# #         self._alphabet = alphabet
# #         self._hidden_states = states

# #         #Initialize empty matrices A and E with pseudocounts
# #         A = {}
# #         E = {}
# #         I = {}
# #         np.random.seed(seed=seed)
# #         I_rand = np.random.dirichlet(np.ones(len(self._hidden_states)))
# #         for i, state in enumerate(self._states()):
# #             E[state] = {}
# #             A[state] = {}
# #             I[state] = I_rand[i]
# #             E_rand = np.random.dirichlet(np.ones(len(self._alphabet)))
# #             A_rand = np.random.dirichlet(np.ones(len(self._hidden_states)))
# #             for j, sigma in enumerate(self._get_alphabet()):
# #                 E[state][sigma] = E_rand[j]
# #             for j, next_state in enumerate(self._states()):
# #                 A[state][next_state] = A_rand[j]
                
# #         self._transitions = A
# #         self._emissions = E
# #         self._initial = I
# #         return
        
#     def _get_alphabet(self):
#         for sigma in self._alphabet:
#             yield sigma

#     def _Ca(self,hap_a,previous_left_chain,previous_geno):
#         Ca = 0
#         for hap_b in self._haplotypes:
#             state = (hap_a,hap_b)
#             Ca += previous_left_chain[state] * self._emit(state,previous_geno)
#         return Ca
#     def _C(self,previous_left_chain,previous_geno):
#         C = 0
#         for hap_a in self._haplotypes:
#             C += self._Ca(hap_a,previous_left_chain,previous_geno)
#         return C
#     def forward(self, sequence):
#         H = len(self._haplotypes)
#         # calculate left chain prob
#         left_list = [] 
#         left = {}
#         for state in self._states():
#             left[state] = 1
#         left_list.append(left)

#         for j in range(1, len(sequence)):  # For each position in the sequence
#             left = {}
#             for state in self._states(): # For each state
#                 (x,y) = state
#                 # refer to MACH paper
#                 left[state] = left_list[j-1][state] * self._emit(state,sequence[j-1]) * (1-self._theta[j]) ** 2 + \
#                 self._Ca(x,left_list[j-1],sequence[j-1]) * (1-self._theta[j]) * self._theta[j] / H + \
#                 self._Ca(y,left_list[j-1],sequence[j-1]) * (1-self._theta[j]) * self._theta[j] / H + \
#                 self._C(left_list[j-1],sequence[j-1]) * self._theta[j]**2 / H**2

#             left_list.append(left)
#         Px = 0
#         for state in self._states():
#             Px += left_list[-1][state]

#         return Px, left_list
#     def backward(self, sequence):
#         H = len(self._haplotypes)
#         # calculate right chain prob
#         right_list = [] 
#         right = {}
#         for state in self._states():
#             right[state] = 1
#         right_list.append(right)

#         for j in range(len(sequence)-2,-1,-1):  # For each position in the sequence
#             right = {}
#             for state in self._states(): # For each state
#                 (x,y) = state
#                 # refer to MACH paper
#                 right[state] = right_list[0][state] * self._emit(state,sequence[j+1]) * (1-self._theta[j]) ** 2 + \
#                 self._Ca(x,right_list[0],sequence[j+1]) * (1-self._theta[j]) * self._theta[j] / H + \
#                 self._Ca(y,right_list[0],sequence[j+1]) * (1-self._theta[j]) * self._theta[j] / H + \
#                 self._C(right_list[0],sequence[j+1]) * self._theta[j]**2 / H**2

#             right_list.insert(0,right)
#         Px = 0
#         for state in self._states():
#             Px += right_list[0][state] * self._init(state) * self._emit(state, sequence[0])

#         return Px, right_list
    
#     def baum_welch(self, sequences, pseudocount=1e-100):
#         """ The baum-welch algorithm for unsupervised HMM parameter learning

#         Args:
#             sequence (list): a list of sequences containing valid emissions from the HMM
#             pseudocount (float): small pseudocount value (default: 1e-100)

#         Returns:
#             None but updates the current HMM model parameters:
#              self._transitions, self._emissions, self._initial
        
#         """   
#         # Inititalization
#         transition,emissions,initial = self._initialize_HMM(self._seed)

#         # set the max iteration to 1 here to print the first iteration
#         max_it = 1
#         for it in range(max_it):
#              # Expectation
#             sum_Px = 0
#             # get the sum over Px first
#             for j in range(len(sequences)):
#                 sequence = sequences[j]
#                 # forward and backward
#                 f_Px, _ = self.forward(sequence)
#     #             r_Px, r_matrix = self.backward(sequence)
#                 sum_Px += 1/f_Px
#             for j in range(len(sequences)):
#                 sequence = sequences[j]
#                 # forward and backward
#                 f_Px, f_matrix = self.forward(sequence)
#                 r_Px, r_matrix = self.backward(sequence)
#                 for k in self._states():
#                     # Update transition matrix A
#                     for l in self._states():
#                         A = 0
#                         for i in range(len(sequence)-1):
#                             A += f_matrix[i][k] * self._transition(k,l) *  self._emit(l, sequence[i+1]) * r_matrix[i+1][l]
#                         transition[k][l] = sum_Px * A

#                     # Update emission matrix E
#                     for j, sigma in enumerate(self._get_alphabet()):
#                         E = 0
#                         for i in range(len(sequence)):
#                             if sequence[i] == sigma:
#                                 E += f_matrix[i][k] * r_matrix[i][k]

#                         emissions[k][sigma] = sum_Px * E

#                     # Update initial state matrix B
#                     initial[k] = sum_Px * f_matrix[0][k] * r_matrix[0][k]

#             # Maximization
#             for k in self._states():
#                 sum_A = 0 
#                 for l in self._states():
#                     sum_A += transition[k][l]
#                 for l in self._states():
#                     transition[k][l] = transition[k][l]/sum_A


#                 sum_E = 0
#                 for j, sigma in enumerate(self._get_alphabet()):
#                     sum_E += emissions[k][sigma]
#                 for j, sigma in enumerate(self._get_alphabet()):
#                     emissions[k][sigma] = emissions[k][sigma]/sum_E

#             sum_B = 0
#             for k in self._states():
#                 sum_B += initial[k]
#             for k in self._states():
#                 initial[k] = initial[k]/sum_B

#     #         self.__init__(self._get_alphabet, self._states, A=None, E=None, B=None, seed=None):
#             self._transitions = transition
#             self._emissions = emissions
#             self._initial = initial
# #             print(self)
#         pass
#     def viterbi(self, sequence):
#         """ The viterbi algorithm for decoding a string using a HMM

#         Args:
#             sequence (list): a list of valid emissions from the HMM

#         Returns:
#             result (list): optimal path through HMM given the model parameters
#                            using the Viterbi algorithm
        
#         Pseudocode for Viterbi:
#             Initialization (𝑖=0): 𝑣𝑘(𝑖)=𝑒𝑘(𝜎)𝑏𝑘.
#             Recursion (𝑖=1…𝑇): 𝑣𝑙(𝑖)=𝑒𝑙(𝑥𝑖) max𝑘(𝑣𝑘(𝑖−1)𝑎𝑘𝑙); 
#                                 ptr𝑖(𝑙)= argmax𝑘(𝑣𝑘(𝑖−1)𝑎𝑘𝑙).
#             Termination: 𝑃(𝑥,𝜋∗)= max𝑘(𝑣𝑘(𝑙)𝑎𝑘0); 
#                              𝜋∗𝑙= argmax𝑘(𝑣𝑘(𝑙)𝑎𝑘0).
#             Traceback: (𝑖=𝑇…1): 𝜋∗𝑖−1= ptr𝑖(𝜋∗𝑖).
#         """

#         # Initialization (𝑖=0): 𝑣𝑘(𝑖)=𝑒𝑘(𝜎)𝑏𝑘.
#         # Initialize trellis and traceback matrices
#         # trellis will hold the vi data as defined by Durbin et al.
#         # and trackback will hold back pointers
#         trellis = {} # This only needs to keep the previous column probabilities
#         traceback = [] # This will need to hold all of the traceback data so will be an array of dicts()
#         for state in self._states():
#             trellis[state] = np.log10(self._init(state)) + np.log10(self._emit(state, sequence[0])) # b * e(0) for all k
            
#         # Next we do the recursion step:
#         # Recursion (𝑖=1…𝑇): 𝑣𝑙(𝑖)=𝑒𝑙(𝑥𝑖) max𝑘(𝑣𝑘(𝑖−1)𝑎𝑘𝑙); 
#         #                 ptr𝑖(𝑙)= argmax𝑘(𝑣𝑘(𝑖−1)𝑎𝑘𝑙).
#         for t in range(1, len(sequence)):  # For each position in the sequence
#             trellis_next = {}
#             traceback_next = {}

#             for next_state in self._states():    # Calculate maxk and argmaxk
#                 k={}
#                 for cur_state in self._states():
#                     k[cur_state] = trellis[cur_state] + np.log10(self._transition(cur_state, next_state)) # k(t-1) * a
#                 argmaxk = max(k, key=k.get)
#                 trellis_next[next_state] =  np.log10(self._emit(next_state, sequence[t])) + k[argmaxk] # k * e(t)
#                 traceback_next[next_state] = argmaxk
                
#             #Overwrite trellis 
#             trellis = trellis_next
#             #Keep trackback pointer matrix
#             traceback.append(traceback_next)
            
#         # Termination: 𝑃(𝑥,𝜋∗)= max𝑘(𝑣𝑘(𝑙)𝑎𝑘0); 
#         #                  𝜋∗𝑙= argmax𝑘(𝑣𝑘(𝑙)𝑎𝑘0).
#         max_final_state = max(trellis, key=trellis.get)
#         max_final_prob = trellis[max_final_state]
                
#         # Traceback: (𝑖=𝑇…1): 𝜋∗𝑖−1= ptr𝑖(𝜋∗𝑖).
#         result = [max_final_state]
#         for t in reversed(range(len(sequence)-1)):
#             result.append(traceback[t][max_final_state])
#             max_final_state = traceback[t][max_final_state]

#         return result[::-1]
    

In [180]:
-np.inf

-inf

In [None]:
import time
st = time.time()
hidden_states =  list(all_possible_dips)
alphabet = [0,1,2] # DNA Alphabet

model = haplotypeHMM(alphabet, hidden_states,range(len(all_possible_haps_index)),seed=70,pseudocount=1e-100)

model.fit(genos[:5,:10].T,10,eps=1e-6)
print(time.time()-st)

In [39]:
genos[:5,:10]

array([[2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
       [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
       [2, 0, 1, 0, 1, 1, 2, 1, 1, 2],
       [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
       [2, 0, 1, 0, 1, 1, 2, 1, 1, 2]])

In [38]:
baum_welch_scaling(genos[:5,:10].T,len(hidden_states),len(alphabet),len(genos[:5,:10].T))

ValueError: operands could not be broadcast together with shapes (37,) (37,5) 