In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import jkbc.utils.files as f
import h5py as h5py
import jkbc.utils.preprocessing as prep
import numpy as np
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt

In [3]:
filename = '/mnt/sdb/taiyaki_mapped/mapped_umi16to9.hdf5'
mer_size = 5
pre  = int(mer_size/2)
post = int(mer_size/2+1)
reads_range = (0, 76033)

In [4]:
def convert_to_string(lst):
    str_lst = [str(x) for x in lst]
    return ''.join(str_lst)

def make_mer_dict(lst, length):
    import itertools as iter
    combinations = list(iter.product(lst,repeat=length))
    
    mer_dict = {}
    for mer in combinations:
        mer_str = convert_to_string(mer)
        mer_dict[mer_str] = []
    
    return mer_dict

def save_obj(obj, name ):
    with open('data/'+ name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(name ):
    with open('data/' + name + '.pkl', 'rb') as f:
        return pickle.load(f)
    
def flatten_(lst):
    return [item for sublist in lst for item in sublist]
    

In [5]:
with h5py.File(filename, 'r') as h5file:
    read_idx = list(h5file['Reads'].keys())[reads_range[0]:reads_range[1]]

In [None]:
%%time
mer_dict = make_mer_dict(range(4), mer_size)

for read_id in tqdm(read_idx):
    dac, ref_to_signal, reference = f.get_read_info_from_file(filename, read_id)
    signal = prep._standardize(dac)
    
    for mer in range(len(reference)-post):
        first = mer-pre
        last  = mer+post
        group = convert_to_string(reference[first:last])
        
        if group == '':
            continue
        
        window_range  = (ref_to_signal[first], ref_to_signal[last])
        window_signal = signal[window_range[0]:window_range[1]]
        mer_dict[group].append(window_signal)

  8%|▊         | 6419/76033 [06:20<1:15:22, 15.39it/s] 

## Making tabel for average values

In [None]:
%%time
averaged_mer_dict = {}
for key, value in tqdm(mer_dict.items()):
    mer_sum = sum([sum(inner) for inner in value])
    mer_count = sum([len(inner) for inner in value])
    if mer_count == 0:
        print(f'{key} has no examples')
        continue
    averaged_mer_dict[key] = mer_sum/mer_count

In [None]:
filename = f'{mer_size}mer_values_range{reads_range}'
save_obj(averaged_mer_dict, filename)

# Statistics

### Find the k-mers with highest/lowest distribution

In [None]:
mean_distanct_mer_dict = {}
for key, value in tqdm(mer_dict.items()):
    if key not in averaged_mer_dict:
        continue
    mean = averaged_mer_dict[key]
    flat_lst = flatten_(value)
    mean_distanct_mer_dict[key] = sum([abs(x-mean) for x in flat_lst])

In [None]:
sorted_dict = sorted(mean_distanct_mer_dict, key=mean_distanct_mer_dict.get, reverse=True)

In [None]:
## BEST
key = sorted_dict[0]
lst = mer_dict[key]
flat_list = flatten_(lst)
plt.hist(flat_list)
print(key, averaged_mer_dict[key])

In [None]:
## WORST
key = sorted_dict[-1]
lst = mer_dict[key]
flat_list = flatten_(lst)
plt.hist(flat_list)
print(key, averaged_mer_dict[key])

### Validating whether the reverse k-mer has different output

In [None]:
reverse_diff = {}
for key in averaged_mer_dict.keys():
    reverse = key[::-1]
    if reverse not in averaged_mer_dict:
        continue
    diff = abs(averaged_mer_dict[key]-averaged_mer_dict[reverse])
    reverse_diff[key] = diff
plt.hist(reverse_diff.values())

sorted_reverse_dict = sorted(reverse_diff, key=reverse_diff.get, reverse=True)
key = sorted_reverse_dict[0]
reverse_key = sorted_reverse_dict[1]
print(key, reverse_key, reverse_diff[key])

In [None]:
## BIGGEST DIFF
lst = mer_dict[key]
flat_list = flatten_(lst)
plt.hist(flat_list)
print(key, averaged_mer_dict[key])

In [None]:
lst = mer_dict[reverse_key]
flat_list = flatten_(lst)
plt.hist(flat_list)
print(key, averaged_mer_dict[key])

### Is the throughput dependent on the current k-mer?

In [None]:
lengths = {}
for key, value in tqdm(mer_dict.items()):
    current_lengths = [len(x) for x in value]
    mean = sum(current_lengths)/len(current_lengths)
    lengths[key] = mean
plt.hist(lengths.values())

sorted_lengths_dict = sorted(lengths, key=lengths.get, reverse=True)