## Create Model
* Create 5mer, 6mer and 7mer models

In [1]:
from signalalign.hiddenMarkovModel import HmmModel
import os 

def create_rrna_model(original_rna_model, output_dir, noise=0, new_variants = "abc",
                     replacement_bases = "AAA"):
    alphabet = "ACGT"
    if noise > 0:  
        tmp_rna_model = os.path.join(output_dir, "rna_r94_5mer_{}_noise_"+str(noise)+".model")
    else:
        tmp_rna_model = os.path.join(output_dir, "rna_r94_5mer_{}.model")

    current_model_path = original_rna_model
    for i, (variant, rep_base) in enumerate(zip(new_variants, replacement_bases)):
        alphabet += variant
        rna_model = HmmModel(current_model_path, rna=True)
        print(alphabet)
        rna_model.write_new_model(tmp_rna_model.format(alphabet), alphabet, rep_base,  noise=noise)
        if i > 0:
            os.remove(current_model_path)
        current_model_path = tmp_rna_model.format(alphabet)


In [2]:
original_rna_model = "small_model/testModelR9p4_5mer_acgt_RNA_180mv.model"
output_dir = "small_model"
create_rrna_model(original_rna_model, output_dir, noise=0, new_variants = "abc", replacement_bases = "AAA")


ACGTa
ACGTab
ACGTabc


  self.transitions_expectations[i + to_state] = self.transitions_expectations[i + to_state] / j


## Smallest Model

Using the same modification for each distribution creates some overlapping kmer distributions which, in aggregate may give us a better picture of the underlying expected current signal. However, when there are only a few kmers contributing, it usually gives bimodal distributions that do not model any of the position's kmer distributions well. So, I want to create a model, specific to the 18S and 25S S.cer rRNA's that has the smallest alphabet so that each position has a set of unique kmers covering that position. 

Hopefully this is smaller than the 21 letter alphabet I am using now

1) Get kmers for each position  
2) Find overlapping kmers  
3) the max set of overlapping kmers is the max number of extra characters I need  


In [3]:
import numpy as np
import pandas as pd
from rrna_analysis.kmer_pos_mapping import KmerPosMapping, get_kmers_from_seq, get_covered_kmers, get_kmer_permutations
from py3helpers.seq_tools import ReferenceHandler
from collections import namedtuple, defaultdict
from signalalign.utils.sequenceTools import CustomAmbiguityPositions, reverse_complement
import itertools
import operator
from itertools import zip_longest

In [4]:
mods_csv = "small_model/mod_file.csv"
reference = "../training/reference/yeast_25S_18S.fa"
alt_c_positions = "small_model/yeast_18S_25S_variants.positions"
# positions = "/Users/andrewbailey/CLionProjects/rrna_analysis/reference/mod_files/yeast_18S_25S_variants.positions"
kpm = KmerPosMapping(reference, alt_c_positions, mods_csv)
rh = ReferenceHandler(reference)

In [5]:
positions_data = CustomAmbiguityPositions.parseAmbiguityFile(alt_c_positions)


In [6]:
def get_mod_positions(mods_csv, k=5):
    mods_data = KmerPosMapping.read_in_mod_data(mods_csv)
    mod_positions = {}
    grouped_positions = {}
    for i, data in mods_data.groupby(["contig", "strand"]):
        contig = i[0]
        strand = i[1]
    #     print(data)
        mod_positions[contig+strand] = data["reference_index"].values
        values = sorted(data["reference_index"].values)
        recent = values[0]
        groups = []
        group = [recent]
        for x in values[1:]:
            if x - recent <= k:
                group.append(x)
            else:
                groups.append(group)
                group = [x]
            recent = x
        grouped_positions[contig+strand] = groups
    return mod_positions, grouped_positions

In [7]:
def get_list_full_sequence(rh, positions_data, edit=True):
    full_sequence = {}
    for i, data in positions_data.groupby(["contig", "strand"]):
        contig = i[0]
        strand = i[1]
        sequence = rh.get_sequence(contig, None, None)
    #     sequence = self.ref_handler.get_sequence(contig, None, None)
        if strand == "-":
            sequence = rc.complement(sequence)
        list_seq = list(sequence)
        if edit:
            for i, row in data.iterrows():
                assert list_seq[row.position] == row.change_from, f"ref base:{list_seq[row.position]} != position change_from:{row.change_from}"             
                list_seq[row.position] = row.change_to
        full_sequence[contig+strand] = list_seq
    return full_sequence

In [8]:
def get_kmers_from_list(sequence, index, k, rna):
    sub_seq = sequence[index-k+1:index+k]
    if rna:
        sub_seq = sub_seq[::-1]
    kmers = get_kmer_permutations(sub_seq)
    all_kmers = []
    for x in kmers:
        all_kmers.append(get_kmers_from_seq(x, k))
    all_kmers = [set(x) for x in zip(*all_kmers)]
    if rna:
        all_kmers = all_kmers[::-1]
    return all_kmers

In [9]:
def get_kmers_from_full_sequence(full_sequence, mod_positions, k=5, rna=True):
    contig_strand_position_kmers = {}
    for contig_strand, ref_indices in mod_positions.items():
        position_kmers = {}
        contig_sequence = full_sequence[contig_strand]
        for index in ref_indices:
            all_kmers = get_kmers_from_list(contig_sequence, index, k, rna)
            position_kmers[str(index)] = all_kmers
        contig_strand_position_kmers[contig_strand] = position_kmers
    return contig_strand_position_kmers

In [10]:
def get_expected_total(group, k=5):
    total = 0
    curr_window = []
    prev_min = group[0]
    for index in range(group[0]-k+1, group[-1]+1):
        if index+(k-1) in group:
            curr_window.append(index+(k-1))
        if len(curr_window) > 0 and index > curr_window[0]:
            curr_window.pop(0)
        kmers = 2**len(curr_window)
        total += kmers
#         print(index, kmers, curr_window)
    return total

In [11]:
def get_contig_strand_group_kmers(contig_strand_position_kmers, grouped_positions, k=5):
    contig_strand_group_kmers = {}
    seen_kmers = {}
    for contig_strand, position_kmers in contig_strand_position_kmers.items():
        gp = grouped_positions[contig_strand]
        group_kmers_list = []
        for group in gp:
    #         print("NEW GROUP", group)
            group_kmers = []
    #         total = 0
            for index in group:
    #             list_of_pos = list(itertools.chain.from_iterable([list(x) for x in position_kmers[str(index)]]))
    #             total += len(list_of_pos)
                group_kmers.extend(set.union(*position_kmers[str(index)]))
            group_kmers = set(group_kmers)
            expected_total = get_expected_total(group, k=k)
    #         print(len(group), group)
    #         print(expected_total, len(group_kmers), len(set(group_kmers)))
    #             print(group_kmers)
            assert expected_total == len(group_kmers), f"expected_total != len(group_kmers), {expected_total} != {len(group_kmers)}"
            group_kmers_list.append(group_kmers)
        contig_strand_group_kmers[contig_strand] = group_kmers_list
    return contig_strand_group_kmers

In [12]:
def get_fix_positions(contig_strand_group_kmers):
    all_kmers_set = set()
    for cstrand in contig_strand_group_kmers.keys():
        all_kmers_set |= set.union(*contig_strand_group_kmers[cstrand])
    all_kmers_dict = { key : [] for key in all_kmers_set }
    fix_positions = defaultdict(list)
    for cstrand in contig_strand_group_kmers.keys():
        for i, group in enumerate(contig_strand_group_kmers[cstrand]):
            for kmer in group:
                all_kmers_dict[kmer].append(cstrand+str(i))
                if len(all_kmers_dict[kmer]) > 1:
                    fix_positions[cstrand+str(i)] += [kmer]
    return fix_positions

In [13]:
def get_max_pos_max_len(fix_positions, skip=0):
    max_pos = -1
    max_len = -1
    prev_bests = []
    for pos, k_list in fix_positions.items():
        k_list_len = len(k_list)
        if k_list_len > max_len:
            max_len = k_list_len
            max_pos = pos
            prev_bests.insert(0, [max_pos, max_len])
    if skip >= len(prev_bests):
        max_pos = prev_bests[-1][0]
        max_len = prev_bests[-1][1]
        over = True
    else:
        max_pos = prev_bests[skip][0]
        max_len = prev_bests[skip][1]
        over = False

    return max_pos, max_len, over

In [14]:
def get_change_base_index(max_pos, fix_positions, grouped_positions, contig_strand_position_kmers, rna=True, k=5, skip=0):
    contig_strand = max_pos[:8]
    pos = int(max_pos[8:])
    kmers = set(fix_positions[max_pos])
#     print(grouped_positions)
#     print(contig_strand)
#     print(grouped_positions[contig_strand])
#     print(max_pos, pos, len(grouped_positions[contig_strand]))
    target_positions = grouped_positions[contig_strand][pos]
    best_pos = None
    best_pos_index = None
    best_kmers = None
    curr_max = 0
    for x in target_positions:
        kmer_sets = contig_strand_position_kmers[contig_strand][str(x)]
        for i, kmer_set in enumerate(kmer_sets):
            overlap_kmers = kmers & kmer_set
            n_kmers = len(overlap_kmers)
            if n_kmers > curr_max:
                curr_max = n_kmers
                best_pos = x
                best_pos_index = i
                best_kmers = overlap_kmers
#     print(best_pos, best_pos_index, best_kmers)
    sub_pos = None
    canonical = {"A", "T", "G", "C"}
    sub_pos_list = []
    for i, x in enumerate(zip_longest(*best_kmers)):
        bases = set(x)
        canonical & bases
        n_bases = len(set(x))
        if n_bases == 1 and len(canonical & bases) == 1:
            sub_pos = i
            sub_pos_list.insert(0, sub_pos)
    
    if skip >= len(sub_pos_list):
        sub_pos = sub_pos_list[-1]
        over = True
    else:
        sub_pos = sub_pos_list[skip]
        over = False
    if rna:
        sub_index = k-sub_pos-1
    change_base_index = (best_pos - k+1) + best_pos_index + sub_index
    return change_base_index, over

In [15]:
def get_kmer_mapping(og_sequence, contig_strand_position_kmers, k=5, rna=True):
    kmer_mapping = []
    for contig_strand, pos_kmer_map in contig_strand_position_kmers.items():
        for pos, kmer_list in pos_kmer_map.items():
    #         print(contig_strand, pos, kmer_list)
            index = int(pos)
            curr_kmers = get_kmers_from_list(og_sequence[contig_strand], index, k, rna)
    #         print(curr_kmers)
            for kmer_pair in zip(curr_kmers, kmer_list):
                k1 = list(kmer_pair[0])[0]
                k2 = list(kmer_pair[1])
    #             print(k1, k2)
                kmer_mapping.append([k1, k2])
    return kmer_mapping

In [16]:
def create_efficient_positions_file(reference, ambig_positions, mods_csv, output_path, k=5, rna=True):
    rh = ReferenceHandler(reference)
    mod_positions, grouped_positions = get_mod_positions(mods_csv, k=k)
    positions_data = CustomAmbiguityPositions.parseAmbiguityFile(ambig_positions)
    positions_data.loc[:, "change_to"] = "ab"
    # positions_data = CustomAmbiguityPositions.parseAmbiguityFile(path)
    full_sequence = get_list_full_sequence(rh, positions_data)

    fix_positions = [1, 1]
    iterration = 0
    kmer_mapping = []

    pos_skip = 0
    base_skip = 0
    while len(fix_positions) > 0 and iterration < 1000:
        print("ANOTHER ITER", len(fix_positions))
        contig_strand_position_kmers = get_kmers_from_full_sequence(full_sequence, mod_positions, k=k, rna=rna)
        contig_strand_group_kmers = get_contig_strand_group_kmers(contig_strand_position_kmers, grouped_positions=grouped_positions, k=k)
        fix_positions = get_fix_positions(contig_strand_group_kmers)
        if len(fix_positions) == 0:
            break
        max_pos, max_len, pos_over = get_max_pos_max_len(fix_positions, skip=pos_skip)
        change_base_index, base_over = get_change_base_index(max_pos, fix_positions, grouped_positions, contig_strand_position_kmers, k=k, rna=rna, skip=base_skip)            


        contig_strand = max_pos[:8]
        contig = max_pos[:7]
        position = change_base_index
        strand = max_pos[7]
        change_from = full_sequence[contig_strand][change_base_index]
        change_to = "c"

        try:
            full_sequence[contig_strand][change_base_index] = change_to
            contig_strand_position_kmers = get_kmers_from_full_sequence(full_sequence, mod_positions, k=k, rna=rna)
            contig_strand_group_kmers = get_contig_strand_group_kmers(contig_strand_position_kmers, grouped_positions=grouped_positions, k=k)
    #  will not run if causes error

            new_row = {"contig": max_pos[:7],
               "position": change_base_index,
               "strand": max_pos[7],
               "change_from": change_from,
               "change_to": change_to}
            positions_data = positions_data.append(new_row, ignore_index=True)
            print("PASS", max_pos, change_base_index, pos_skip, base_skip)
            fix_positions = get_fix_positions(contig_strand_group_kmers)
            if len(fix_positions) == 0:
                print("DONE")


        except AssertionError:
            full_sequence[contig_strand][change_base_index] = change_from
            if pos_over:
                base_skip += 1
            if not pos_over:
                pos_skip += 1
            if pos_over and base_over:
                raise AssertionError("ERROR", contig_strand, change_base_index, pos_skip, base_skip)            
            print("SKIPPING", contig_strand, change_base_index, pos_skip, base_skip)


        iterration += 1
    positions_data.to_csv(output_path, sep="\t", header=False, index=False)
    #     get kmer mapping
    og_sequence = get_list_full_sequence(rh, positions_data, edit=False)
    kmer_mapping = get_kmer_mapping(og_sequence, contig_strand_position_kmers, k=k, rna=rna)

    return kmer_mapping

In [17]:
def edit_model_with_kmer_mapping(model, kmer_mapping, rna=True):
    rna_model = HmmModel(model, rna=rna)
    for k, change in kmer_mapping:
        normal_mean, normal_sd = rna_model.get_event_mean_gaussian_parameters(k)
        for k2 in change:
            rna_model.set_kmer_event_mean_params(k2, normal_mean, normal_sd)
    rna_model.normalized = True
    rna_model.write(model)

In [18]:
mods_csv = "small_model/mod_file.csv"
reference = "../training/reference/yeast_25S_18S.fa"
alt_c_positions = "small_model/yeast_18S_25S_variants.positions"
output_path = "small_model/small_model_yeast_18S_25S_variants.positions"
k=5
rna=True
kmer_mapping = create_efficient_positions_file(reference, alt_c_positions, mods_csv, output_path, k=k, rna=rna)
rna_model = "small_model/rna_r94_5mer_ACGTabc.model"
edit_model_with_kmer_mapping(rna_model, kmer_mapping, rna=rna)


ANOTHER ITER 2
PASS RDN25-1+45 2944 0 0
ANOTHER ITER 39
PASS RDN25-1+21 2129 0 0
ANOTHER ITER 39
PASS RDN25-1+26 2260 0 0
ANOTHER ITER 39
PASS RDN25-1+45 2946 0 0
ANOTHER ITER 39
PASS RDN25-1+21 2130 0 0
ANOTHER ITER 39
PASS RDN25-1+22 2137 0 0
ANOTHER ITER 39
PASS RDN25-1+32 2417 0 0
ANOTHER ITER 39
PASS RDN25-1+32 2418 0 0
ANOTHER ITER 39
PASS RDN25-1+34 2629 0 0
ANOTHER ITER 39
PASS RDN18-1+4 206 0 0
ANOTHER ITER 39
PASS RDN18-1+27 1410 0 0
ANOTHER ITER 39
PASS RDN18-1+29 1573 0 0
ANOTHER ITER 39
PASS RDN25-1+2 774 0 0
ANOTHER ITER 39
PASS RDN25-1+11 982 0 0
ANOTHER ITER 39
PASS RDN25-1+14 1049 0 0
ANOTHER ITER 39
PASS RDN25-1+21 2131 0 0
ANOTHER ITER 38
PASS RDN25-1+22 2139 0 0
ANOTHER ITER 38
PASS RDN25-1+22 2140 0 0
ANOTHER ITER 37
PASS RDN25-1+31 2347 0 0
ANOTHER ITER 36
PASS RDN25-1+33 2617 0 0
ANOTHER ITER 36
PASS RDN25-1+34 2632 0 0
ANOTHER ITER 36
PASS RDN25-1+34 2634 0 0
ANOTHER ITER 35
PASS RDN25-1+37 2730 0 0
ANOTHER ITER 35
PASS RDN25-1+39 2811 0 0
ANOTHER ITER 35
PASS R

In [19]:
kmer_mapping

[['ATACT', ['bTACT', 'aTACT']],
 ['TATAC', ['TbTAC', 'TaTAC']],
 ['GTATA', ['GTbTA', 'GTaTA']],
 ['CGTAT', ['CGTbT', 'CGTaT']],
 ['TCGTA', ['TCGTa', 'TCGTb']],
 ['ACTCG', ['aCTCG', 'bCTCG']],
 ['TACTC', ['TbCTC', 'TaCTC']],
 ['TTACT', ['TTbCT', 'TTaCT']],
 ['ATTAC', ['ATTaC', 'ATTbC']],
 ['AATTA', ['AATTa', 'AATTb']],
 ['TAAAT', ['aAAAT', 'bAAAT']],
 ['CTAAA', ['CbAAA', 'CaAAA']],
 ['ACTAA', ['ACaAA', 'ACbAA']],
 ['GACTA', ['GACbA', 'GACaA']],
 ['TGACT', ['TGACb', 'TGACa']],
 ['TATTT', ['bATTT', 'aATTT']],
 ['TTATT', ['TbATT', 'TaATT']],
 ['TTTAT', ['TTbAT', 'TTaAT']],
 ['GTTTA', ['GTTbA', 'GTTaA']],
 ['AGTTT', ['AGTTb', 'AGTTa']],
 ['TATTT', ['aATcc', 'bATcc']],
 ['TTATT', ['TbATc', 'TaATc']],
 ['ATTAT', ['ATaAT', 'ATbAT']],
 ['GATTA', ['GATbA', 'GATaA']],
 ['AGATT', ['AGATb', 'AGATa']],
 ['TAAAC', ['aAAAC', 'bAAAC']],
 ['TTAAA', ['TaAAA', 'TbAAA']],
 ['TTTAA', ['TTbAA', 'TTaAA']],
 ['CTTTA', ['CTTaA', 'CTTbA']],
 ['TCTTT', ['TCTTa', 'TCTTb']],
 ['CTACA', ['bTACA', 'aTACA']],
 ['CCTAC