This notebook visualizes how the obtained models compare, and visualizes the rules obtained to create them.

In [None]:
import pandas as pd

def load_rules(motif,mod):
    all_rules = dict()
    all_rules_list = []

    for p in [5,10,25,50,75,90]:
        for i in range(5):
            # print(p,i)
            with open("./data/rules_"+motif+"_"+mod+"/p"+str(p)+'_'+str(i)+".txt", "r") as f:
                lines = f.readlines()
                rules = [x.strip().split('\t') for x in lines[3:]]
                rules = dict([(x[0],float(x[1])) for x in rules])
                all_rules[(p,i)] = rules
                all_rules_list.extend([[p,i,x[0],x[1]] for x in rules.items()])

    # print(all_rules)
    # print(all_rules_list)

    df_rules = pd.DataFrame(all_rules_list,columns=['p','i','match','value'])
    df_rules.replace({'p': {90:85}},inplace=True) 

    df_rules['match']

    df_rules[['mod_pos', 'base_tmp']] = df_rules['match'].str.split('M', n=1, expand=True)

    # df_rules_base = df_rules[df_rules['base_tmp']!='']
    # df_rules['base_tmp']

    df_rules['base_pos'] = df_rules['base_tmp'].str[0]
    df_rules['base'] = df_rules['base_tmp'].str[1]

    df_rules.groupby(['mod_pos','base_pos','base']).describe()
    df_rules['value'] = -df_rules['value']
    return(df_rules)



In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

base_colors = {
    'A':'C2',
    'C':'C0',
    'G':'C1',
    'T':'C3',
    'M':'C7'}


def plot_rules(df_rules,motif,mod):
    sns.set(rc={'figure.figsize':(24,24)})
    fig, axs = plt.subplots(ncols=6,nrows=6,sharey='row',sharex='col',gridspec_kw={'width_ratios': [1]+[3]*5})

    # Sort using custom order
    df_rules.base=pd.Categorical(df_rules.base,categories=['A','C','G','T'])
    df_rules=df_rules.sort_values(['base_pos','base','mod_pos'])
    df_base = df_rules#[df_rules['base_tmp']!='']

    for i,p in enumerate([5,10,25,50,75,85]):
        for j in range(5):
            my_ax = axs[i,1+j]
            mod_pos = j+1 if motif == 'GpC' else j
            sns.boxplot(data=df_base[(df_base['p']==p) & (df_base['mod_pos']==str(mod_pos))], x="base_pos", y="value",hue='base',ax=my_ax, palette=base_colors)
            
            # Add vertical lines for visual guidance
            for k in range(1,4):
                axs[i, j+1].axvline(x=k-0.5, color='white', linestyle='-', linewidth=1)
            axs[i, j+1].axvline(x=j-0.5, color='k', linestyle='--', linewidth=1)
            

            my_ax.title.set_text(None)
            my_ax.set_xlabel(None)
            my_ax.set_ylabel(None)

            if i == 0:
                my_ax.title.set_text('Base effect, mod position='+str(mod_pos))
            if i == 5:
                my_ax.set_xlabel('Base position')
            # my_ax.set_ylabel('Effect')
            my_ax.set_xlim((-0.55,3.55))

    plt.tight_layout()

    # sns.set(rc={'figure.figsize':(4,18)})
    # fig, axs = plt.subplots(ncols=1,nrows=6,sharey='col',sharex='col')
    df_mod = df_rules[df_rules['base_tmp']=='']

    for i,p in enumerate([5,10,25,50,75,85]):
        my_ax = axs[i,0]
        sns.boxplot(data=df_mod[df_mod['p']==p], x="mod_pos", y="value",ax=my_ax,color='C7')
        # my_ax.title.set_text('T='+str(p))

        my_ax.title.set_text(None)
        my_ax.set_xlabel(None)
        my_ax.set_ylabel(None)

        if i == 0:
            my_ax.title.set_text('Mod effect')
            
        if i == 5:
            my_ax.set_xlabel('Mod position')
        my_ax.set_ylabel(r'$\Delta$'+'C rules at T='+str(p))

    plt.suptitle('Rules for: '+motif+' '+mod)
    plt.tight_layout(rect=[0, 0.0, 1, 0.985])
    plt.savefig('./data/rules_'+motif+'_'+mod+'.pdf')

    # df_base

In [None]:
rule_sets = {}
all_sets = [('CpG','meth'),('GpC','meth'),('CpG','gluc')]
for motif,mod in all_sets:
    rule_sets[(motif,mod)] = load_rules(motif,mod)

In [None]:
# df_rules.groupby(['mod_pos','base_pos','base']).mean()

# sns.set(rc={'figure.figsize':(24,24),'text.usetex':False})
for motif,mod in all_sets:
    plot_rules(rule_sets[(motif,mod)],motif,mod)


In [None]:
all_sets = [('CpG','meth'),('GpC','meth'),('CpG','gluc')]

sns.set(rc={'figure.figsize':(18,6)})
fig, axs = plt.subplots(ncols=3,nrows=1,sharey='row',sharex='col')
for i,(motif,mod) in enumerate(all_sets):
    my_ax = axs[i]
    matches = pd.DataFrame(rule_sets[(motif,mod)].groupby(['p','i'])['match'].count()).reset_index()
    # print(matches.columns)
    sns.boxplot(data=matches, x="p", y="match",ax=my_ax)

    my_ax.title.set_text('Number of rules for '+motif+' '+mod+' models')
    my_ax.set_xlabel('Percentage of 6-mers in training')
    my_ax.set_ylabel('Count')
plt.tight_layout()

plt.savefig('./data/rules_count.pdf')

In [None]:
import numpy as np

order = 'ACGT'

def load_one_ruleset(file_path, mer_size=6):
    with open(file_path, "r") as f:
        lines = f.readlines()
        rules = [x.strip().split('\t') for x in lines[3:]]
        rules = dict([(x[0],float(x[1])) for x in rules])
    print(len(rules))

    a = np.empty((mer_size-1,4*mer_size+1,))
    a[:] = np.nan
    
    for key,value in rules.items():
        row = int(key[0])
        col = 0
        if len(key) > 2:
            bpos = int(key[2])
            base = key[3]
            col = bpos * 4 + order.index(base)+1
        # print(row,col)
        a[row,col] = -value

    colnames = ['M']+[str(p)+b for p in range(1,mer_size+1) for b in order]
    
    return(pd.DataFrame(a,columns=colnames))


def plot_ruleset(ruleset, set_name, mer_size=6):
    sns.set(rc={'axes.facecolor':'white'})
    background = np.zeros(np.array(ruleset).shape)
    background[:] = np.nan
    for nan_pos in np.argwhere(np.isnan(np.array(ruleset).T)):
        if not 1+nan_pos[1]*4 <= nan_pos[0] < 1+nan_pos[1]*4+8:
            background[nan_pos[1],nan_pos[0]] = 0

    # background = 1-background

    sns.heatmap(background,center = 0,annot=False,cbar=False, cmap='coolwarm', vmin=-1, vmax=1)
    sns.heatmap(ruleset,center = 0,annot=True, fmt=".2f",cmap='coolwarm')
    plt.title(set_name + ' ruleset')
    plt.xlabel('Base position and base')
    plt.ylabel('Modification position')
    plt.yticks(np.arange(mer_size-1)+0.5,range(1,mer_size))

    for i in range(1,9*4+1,4): #[1,5,9,13,17,21]:
        plt.axvline(x=i, color='gray',linewidth=4)

    for m_pos in range(mer_size-1):
        plt.text(m_pos*4 + 3, m_pos + 0.5, 'M',
            horizontalalignment='center',
            verticalalignment='center',
            fontsize=16,
            color="black",
            weight="bold",
            rotation=0)
    for m_pos in range(mer_size-1):
        plt.text(m_pos*4 + 7, m_pos + 0.5, 'G',
            horizontalalignment='center',
            verticalalignment='center',
            fontsize=16,
            color="black",
            weight="bold",
            rotation=0)

    # print(np.array(ruleset))
    for nan_pos in np.argwhere(np.isnan(np.array(ruleset).T)):
        if not 1+nan_pos[1]*4 <= nan_pos[0] < 1+nan_pos[1]*4+8:
            plt.text(nan_pos[0]+0.5,nan_pos[1]+0.5, '<',
                horizontalalignment='center',
                verticalalignment='center',
                fontsize=12,
                color="black",
                # weight="bold",
                rotation=0)

def mark_kmer(kmer):
    row = kmer.index('MG')
    for i,b in enumerate(kmer):
        col = 0
        if i == row+1:
            continue
        elif i == row:
            col = 0
        else:
            col = i*4+order.index(b)+1
        
        plt.text(col + 0.5, row + 0.3, '*',
            horizontalalignment='center',
            verticalalignment='center',
            fontsize=16,
            color="white",
            # weight="bold",
            rotation=0)

sns.set(rc={'figure.figsize':(24,3.8)})
nanopolish_ruleset = load_one_ruleset("data/rules_CpG_r9_nanopolish")
plot_ruleset(nanopolish_ruleset, set_name='Nanopolish R9 CpG')


plt.savefig('./data/nanopolish_ruleset.pdf')
# annotate_row(cols=[0,3,6,9,],row=3)

        
# mark_kmer('GTAMGC')


# dict_main, dict_alt, dict_main_sd, list_diff = load_polishmodel('../nanopolish/nanopolish/etc/r9-models/r9.4_450bps.cpg.6mer.template.model')
# list_kmers = [x[0] for x in list_diff]
# list_diff = [x[1] for x in list_diff]
# list_alt = [dict_alt[x] for x in list_kmers]
# list_main = [dict_main[x.replace('M','C')] for x in list_kmers]
# df_model = pd.DataFrame({'k-mer':list_kmers,'Canonical':list_main,'Modified':list_alt,'Difference':list_diff})
# df_model.set_index('k-mer',inplace=True)
# print(df_model)
# df_model.loc[['MGAAAA','MGATGG']+sorted(['AACMGC','ACTMGA','CTGMGT','TTGMGT','GTAMGT','TGCMGC'])+['GATCMG','TTGTMG']]
# # plt.figure()
# # full_lambda_ruleset = load_one_ruleset("data/rules_CpG_r9_full_lambda")
# # plot_ruleset(full_lambda_ruleset, set_name='Full Lambda R9 CpG')
# df_model.loc[['GTAMGC']]

# nanopolish_ruleset

In [None]:

nanopolish_ruleset_R10 = load_one_ruleset("./data/f5c/rq_rerun.rules",mer_size=9)
nanopolish_ruleset_R10
plot_ruleset(nanopolish_ruleset_R10, set_name='Nanopolish R10 CpG', mer_size=9)
plt.savefig('./data/nanopolish_ruleset_r10.pdf')

In [None]:

nanopolish_ruleset_R10_85 = load_one_ruleset("./data/f5c/rq_rerun_85.rules",mer_size=9)
plot_ruleset(nanopolish_ruleset_R10_85, set_name='Nanopolish R10 CpG, 85 rules', mer_size=9)

In [None]:
# # df_model
# less_order = 'ATC'
# small_model = {}
# for fb in less_order:
#     for sb in less_order:
#         if fb+sb == 'CG':
#             continue
#         # small_model['M'+fb+sb] = df_model[df_model.index.str[2:6] == 'MG'+fb+sb].median()
#         small_model[fb+'M'+sb] = df_model[df_model.index.str[2:6] == fb+'MG'+sb].median()
#         # small_model[fb+sb+'M'] = df_model[df_model.index.str[2:6] == fb+sb+'MG'].median()

# small_model = pd.DataFrame(small_model).T

# small_model['mod_index'] = small_model.index.str.index('M')

# small_model.sort_values('mod_index',inplace=True)
# pd.set_option('display.precision', 1)
# small_model


In [None]:
# Model loading functions

def load_polishmodel(model_path,mod_motif='CG',mod_motif_M='MG'):
    dict_main = {}
    dict_main_sd = {}
    dict_alt = {}
    with open(model_path) as fp:
        for line in fp:
            if line[0] == '#':
                continue
            splitline = line.split()
            if splitline[0] == 'kmer':
                continue
            kmer = splitline[0]

            # Does this kmer have a modified base
            if mod_motif_M in kmer:
                # That kmer wasn't trained as it should have both a modified and an unmodified base
                if kmer.count(mod_motif) >= 1:
                    continue

                # Replace all mod motif occurrences with canonical motifs
                canon = kmer.replace(mod_motif_M,mod_motif)
                # kmer still has an M somewhere so it wasn't really a proper modification motif
                if 'M' in canon:
                    continue

                # Looks like a properly trained modified kmer, keep it
                # dict_alt[canon] = float(splitline[1])
                dict_alt[kmer] = float(splitline[1])

            # This is a canonical kmer, other dict
            else:
                dict_main[kmer] = float(splitline[1])
            dict_main_sd[kmer] = float(splitline[2])

    list_diff = []
    for kmer,value in dict_alt.items():
        list_diff.append((kmer,value-dict_main[kmer.replace(mod_motif_M,mod_motif)]))

    return dict_main, dict_alt, dict_main_sd, sorted(list_diff)


def load_all_models(impute='replaced'):
    all_models = []
    tag = 'fully-imputed' if impute == 'replaced' else 'missing-imputed'

    dict_main, dict_alt, dict_main_sd, list_diff = load_polishmodel('../../nanopolish/nanopolish/etc/r9-models/r9.4_450bps.cpg.6mer.template.model')
    all_models.append(['nanopolish, not-imputed',list_diff])

    if impute == 'replaced':
        dict_main, dict_alt, dict_main_sd, list_diff = load_polishmodel('./data/debug_model_out.nanopolish.'+impute+'.model')
        all_models.append(['nanopolish, '+tag,list_diff])

    dict_main, dict_alt, dict_main_sd, list_diff = load_polishmodel('./data/output_CpG_meth/event_align/lambda_phage_methylated_cpg/nanopolish_train/r9.4_450bps.cpg.6mer.template.round4.model')
    all_models.append(['full lambda, not-imputed',list_diff])

    if impute == 'replaced':
        dict_main, dict_alt, dict_main_sd, list_diff = load_polishmodel('./data/output_CpG_meth/event_align/lambda_phage_methylated_cpg/requant/full_lambda.'+impute+'.model')
        all_models.append(['full lambda, '+tag,list_diff])

    # Short check if something changed in the way models are made seeing a little discrepancy in contrast in the output figures
    # Answer: no. That's just in the data I guess.
    # if impute == 'replaced':
    #     dict_main, dict_alt, dict_main_sd, list_diff = load_polishmodel('./data/debug_cpg_models.'+impute+'.model')
    #     all_models.append(['rerun p85_0, '+tag,list_diff])

    for p in reversed([5, 10, 25, 50, 75, 90]):
        _p = p if p != 90 else 85
        for i in range(5):
            _,_,_,list_diff = load_polishmodel('./data/refcuts_CG_test_2/p'+str(p)+'_'+str(i)+'/requant/p'+str(p)+'_'+str(i)+'.'+impute+'.model')
            all_models.append(['T='+str(_p)+':'+str(i)+' '+tag, list_diff])
            # print(p,i,len(list_diff))
    return all_models


def load_np_models():
    all_models = []

    dict_main, dict_alt, dict_main_sd, list_diff = load_polishmodel('../../nanopolish/nanopolish/etc/r9-models/r9.4_450bps.cpg.6mer.template.model')
    all_models.append(['nanopolish, not-imputed',list_diff])

    # dict_main, dict_alt, dict_main_sd, list_diff = load_polishmodel('./data/debug_model_out.nanopolish.'+impute+'.model')
    # all_models.append(['nanopolish, fully-imputed',list_diff])

    dict_main, dict_alt, dict_main_sd, list_diff = load_polishmodel('./data/output_CpG_meth/event_align/lambda_phage_methylated_cpg/nanopolish_train/r9.4_450bps.cpg.6mer.template.round4.model')
    all_models.append(['full lambda, not-imputed',list_diff])

    # dict_main, dict_alt, dict_main_sd, list_diff = load_polishmodel('./data/output_CpG_meth/event_align/lambda_phage_methylated_cpg/requant/full_lambda.'+impute+'.model')
    # all_models.append(['full lambda, fully-imputed',list_diff])


    for p in reversed([5, 10, 25, 50, 75, 90]):
        _p = p if p != 90 else 85
        for i in range(5):
            _,_,_,list_diff = load_polishmodel('./data/refcuts_CG_test_2/p'+str(p)+'_'+str(i)+'/nanopolish_train/r9.4_450bps.cpg.6mer.template.round4.model')
            all_models.append(['T='+str(_p)+':'+str(i)+' not-imputed', list_diff])
            # print(p,i,len(list_diff))
    return all_models

In [None]:
# Heatmap for all fully-imputed models, and full default nanopolish and fully trained lambda model
all_models_dfs = [pd.DataFrame(x[1],columns=['kmer',x[0]]) for x in load_all_models()]
dfs = [df.set_index('kmer') for df in all_models_dfs]
sns.set(rc={'figure.figsize':(12,24)})
sns.heatmap(pd.concat(dfs, axis=1),cmap="bwr",center=0)
plt.savefig('./data/all_models_heatmap.pdf')

In [None]:
# Heatmap for all missing-imputed models, and full default nanopolish and fully trained lambda model
all_models_dfs = [pd.DataFrame(x[1],columns=['kmer',x[0]]) for x in load_all_models(impute='added')]
dfs = [df.set_index('kmer') for df in all_models_dfs]
sns.set(rc={'figure.figsize':(12,24)})
sns.heatmap(pd.concat(dfs, axis=1),cmap="bwr",center=0)

In [None]:
# Heatmap for all not-imputed models, and full default nanopolish and fully trained lambda model
all_models_dfs = [pd.DataFrame(x[1],columns=['kmer',x[0]]) for x in load_np_models()]
dfs = [df.set_index('kmer') for df in all_models_dfs]
sns.set(rc={'figure.figsize':(12,24)})
sns.heatmap(pd.concat(dfs, axis=1),cmap="bwr",center=0)

In [None]:
# Make a correlation plot of the fully imputed models

import numpy as np

all_models_dfs = [pd.DataFrame(x[1],columns=['kmer',x[0]]) for x in load_all_models()]
dfs = [df.set_index('kmer') for df in all_models_dfs]

# Compute the correlation matrix
corr = pd.concat(dfs, axis=1).corr()

# Generate a mask for the upper triangle
mask = np.triu(np.ones_like(corr, dtype=bool))

# Set up the matplotlib figure
f, ax = plt.subplots(figsize=(11, 9))

# Generate a custom diverging colormap
# cmap = sns.diverging_palette(230, 20, as_cmap=True)

# Draw the heatmap with the mask and correct aspect ratio
sns.heatmap(corr, mask=mask, #cmap='plasma_r', #vmax=1, center=0.75,
            square=True, linewidths=.5, cbar_kws={"shrink": .5})

# corr
plt.tight_layout()
plt.savefig('./data/all_models_correlation.pdf')

In [None]:
df_fullimp = pd.concat(dfs, axis=1)
sns.clustermap(df_fullimp,cmap="bwr",center=0)#,row_cluster=False, col_cluster=False)

