# Profile HMMs

---
## Learning Objectives

1. Apply HMMs to describe PSSMs
* Develop model structure for Profile HMM
* Build and apply model



---

Today we will be reviewing Profile HMMs in class including a demonstration of how we can implement profile HMMs using our existing framework. 

This is a diagram of Hidden Markov Model used in HMMER (from the HMMER User Guide by Sean Eddy). The chain of match (M), insert (I), and deletion (D) states can be extended to match the length of the multiple sequence alignment that is used as the training set to produce a model. Individual sequences may then be aligned to the model and scored based on the probability that the model would emit that sequence.

<center><img src='./figures/HMM_Diagram.PNG'/></center>


We will be implementing a Profile HMM using the BAR domain discussed in the slides.

To accomplish this, we will implement two functions. First, `get_valid_states()` will provide a list of states that meet our heuristic threshold in the model (denoted as *s in the slides). Second, we will implement `build_profileHMM()` that will use our existing HMM class structure (inlcuded here in HMM.py) to develop a model in the above structure.

A few caveats: Our HMM implementation requires that all possible emissions and transitions exist in the dictionary. That is, any hidden state must have probabilities of emiting the entire alphabet and the transition matrix must have a probability for every state going to every other state in the model. These probabilites can be set to 0 to create the profile HMM structure, but they must be set.

In [None]:
# Imports
from data_readers import get_fasta
from collections import Counter, defaultdict
from HMM import HMM
import json

In [None]:
def get_valid_states(fasta_file, threshold=0.7):
    ''' 
    Function to determine which positions in an alignment are valid states in the profile HMM given a threshold
    
    Args: 
        fasta_file (str): fasta file containing alignments
        threshold (float): the treshold of allowed gap characters (default = 0.7)

    Returns:
        valid_states (list of bools): list of booleans (True/False) if each position is above the threshold
        
    Example:
        >>> get_valid_states("data/BAR.fa", 0.7) #doctest: +ELLIPSIS
        [True, True, True, True, False, ...]
    '''
    # Returns list of valid states positions

    state_counter = None
    valid_states = []
    
    #Estimate the states in the alignment that are match versus indel states
    for name, seq in get_fasta(fasta_file):
        if state_counter is None:
            state_counter = [Counter() for _ in range(0, len(seq))]
            
        for i, aa in enumerate(seq):
            state_counter[i].update(aa)
    
    
    for i, position_counter in enumerate(state_counter):
        if position_counter['-']/sum(position_counter.values()) > (1 - threshold):
            valid_states.append(False)
        else:
            valid_states.append(True)

    return valid_states

As mentioned in the slides, we need to update the states for the HMM using the following equations:

$a_{kl} = A_{kl} / \sum_{l'}A_{kl'}$

$e_{k}(a) = E_{k}(a) / \sum_{a'}E_{k}(a')$

Where $k$ and $l$ represent state indices, $a_{kl}$ and $e_{k}$ are transition and emission probabilities, respectively, and $A_{kl}$ and $E_{k}$ are the corresponding frequencies.


In [None]:
#Train model using the BAR domain data in data/BAR.fa

# In order to build our model, we will need to set default paramters in an 
# initialized HMM using pseudocounts and then update these values with the 
# information in the fasta file
def build_profileHMM(alphabet, valid_states, fasta_file, pseudocount=0.01):
    ''' 
    Function to initialize a Profile HMM structure
    
    Args: 
        alphabet (list): alphabet characters for the model
        valid_states (list of bools): all positions in the alignments that are in match states
        fasta_file (str): fasta file containing alignments
        pseudocount (float): value to set as initial probabilities

    Returns:
        profile_HMM (HMM): HMM object
        
    Pseudocode:
    Initialize full initial, emission, and transition matrix to 0s (our HMM object requires all transitions and emissions are set to at least 0)
    Initialize possible emissions and transitons to pseudocount
    Calculate probabilities at each position given valid_states and fasta_file
        
    '''
    
    #Initialize empty matrices A, E, and I
    A = defaultdict(dict)
    E = defaultdict(dict)
    I = {}
    hidden_states = []
    
    #generate all valid states
    #Initialize I0 state as insert at start
    hidden_states.append('I0')

    #Initialize our states based on the number of match states
    for i in range(1, sum(valid_states)+1):
        hidden_states.append('I'+str(i))
        hidden_states.append('M'+str(i))
        hidden_states.append('D'+str(i))
    
    # Initialize all transitions to 0
    for item in hidden_states:
        for next_item in hidden_states:
            A[item][next_item] = 0

            
    # Give all valid state transitions a pseudocount
    A['I0']['I0'] = pseudocount
    
    #Initialize our states based on the number of match states
    for i in range(1, sum(valid_states)+1):
        # Transitions into insert state
        A['I'+str(i)]['I'+str(i)] = pseudocount # self
        A['D'+str(i)]['I'+str(i)] = pseudocount # from deletion      
        A['M'+str(i)]['I'+str(i)] = pseudocount # from match
        
        # Transitions into deletion state
        A['I'+str(i-1)]['D'+str(i)] = pseudocount # from previous insert
        if i > 1: #special case for first location so ignore these
            A['D'+str(i-1)]['D'+str(i)] = pseudocount # from previous deletion      
            A['M'+str(i-1)]['D'+str(i)] = pseudocount # from previous match
        
        # Transitions into match state
        A['I'+str(i-1)]['M'+str(i)] = pseudocount # from previous insert
        if i > 1: #special case for first location so ignore these
            A['D'+str(i-1)]['M'+str(i)] = pseudocount # from previous deletion      
            A['M'+str(i-1)]['M'+str(i)] = pseudocount # from previous match
        
    # Initialize start state similarly
    for item in hidden_states:
        I[item] = 0
    I['I0'] = pseudocount
    I['D1'] = pseudocount
    I['M1'] = pseudocount
    
    # Initialize emissions from Insert and Match states
    for i in range(0, sum(valid_states)+1):
        for j, aa in enumerate(alphabet):
            E['D' + str(i)][aa] = 0
            E['I' + str(i)][aa] = pseudocount
            if i > 0:
                E['M' + str(i)][aa] = pseudocount
            else:
                E['M' + str(i)][aa] = 0

    
    # now iterate through the file and train all of the states 
    for name, seq in get_fasta(fasta_file):
        last_state = None
        state_number = 1
        for i, aa in enumerate(seq):
            if i == 0: # first position
                if valid_states[i]: # we are in a match or deletion
                    if aa == '-': # we are in a deletion
                        I['D1'] += 1
                        last_state = 'D1'
                        state_number += 1
                    else: # we are in a match
                        I['M1'] += 1
                        E['M1'][aa] += 1
                        last_state = 'M1'
                        state_number += 1
                        
                else: # We are in an insert or a deletion
                    if aa == '-': # we are in a valid location so ignore this
                        pass
                    else: #we are in an insert
                        I['I0'] += 1
                        E['I0'][aa] += 1
                        last_state = 'I0'
                    
            else: # same except after the first position
                if valid_states[i]: # we are in a match or deletion
                    if aa == '-': # we are in a deletion
                        A[last_state]['D' + str(state_number)] += 1
                        last_state = 'D' + str(state_number)
                        state_number += 1
                    else: # we are in a match
                        A[last_state]['M' + str(state_number)] += 1
                        E['M' + str(state_number)][aa] += 1
                        last_state = 'M' + str(state_number)
                        state_number += 1
                        
                else: # We are in an insert or a deletion
                    if aa == '-': # we are in a valid location so ignore this
                        pass
                    else: #we are in an insert
                        A[last_state]['I' + str(state_number-1)] += 1
                        E['I' + str(state_number-1)][aa] += 1
                        last_state = 'I' + str(state_number-1)               

    # now we normalize to a sum of 1
    for items in A:
        A_sum = sum(A[items].values())
        if A_sum > 0:
            for next_item in A[items]:
                A[items][next_item] = A[items][next_item] / A_sum

    for items in E:
        E_sum = sum(E[items].values())
        if E_sum > 0:
            for next_item in E[items]:
                E[items][next_item] = E[items][next_item] / E_sum
            
    I_sum = sum(I.values())
    for items in I:
        I[items] = I[items] / I_sum
    
    return HMM(alphabet, hidden_states, A, E, I)

In [None]:
valid_states = get_valid_states("data/BAR_Short.fa", 0.5)
alphabet = list('GALMFWKQESPVICYHRNDT')
profile = build_profileHMM(alphabet, valid_states, "data/BAR_Short.fa")

In [None]:
# Exact example from slides
sequence = "TKLDDDFKE"

print ("Forward:")
f_Px, f_matrix = profile.forward(list(sequence))
print (f_Px)


Expected output:
Forward:
9.605647218365268e-05

In [None]:
print(profile)