In [1]:
from tqdm import tqdm
from workflow.fastani.remove_gunc_failed_contigs_by_contamination_sp_cluster import RemoveGuncFailedContigsByContaminationSpCluster
from workflow.config import PCT_VALUES
from workflow.external.gtdb_metadata import GtdbMetadataR207Full, GtdbMetadataR207
from workflow.external.gtdb_sp_clusters import GtdbSpClustersR207
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import cm

In [2]:
DF_SP_CLUSTERS = GtdbSpClustersR207().output().read_cached()

In [3]:
DF_SP_CLUSTERS.head()

Unnamed: 0_level_0,rep_genome,taxonomy,ani_radius,ani_mean,ani_min,af_mean,af_min,n_genomes
species,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
s__0-14-0-10-38-17 sp002774085,GB_GCA_002774085.1,d__Bacteria;p__Patescibacteria;c__Microgenomat...,95.0,,,,,1
s__0-14-0-20-30-16 sp002779075,GB_GCA_002779075.1,d__Archaea;p__Iainarchaeota;c__Iainarchaeia;o_...,95.0,,,,,1
s__0-14-0-20-30-16 sp903916665,GB_GCA_903916665.1,d__Archaea;p__Iainarchaeota;c__Iainarchaeia;o_...,95.0,,,,,1
s__0-14-0-20-34-12 sp002779065,GB_GCA_002779065.1,d__Archaea;p__Iainarchaeota;c__Iainarchaeia;o_...,95.0,,,,,1
s__0-14-0-20-40-13 sp002774285,GB_GCA_002774285.1,d__Bacteria;p__Patescibacteria;c__WWE3;o__0-14...,95.0,,,,,1


In [4]:
DF = RemoveGuncFailedContigsByContaminationSpCluster().output().read_cached()
print(DF.shape)

UNQ_GIDS = set(DF.index.get_level_values(0))

print(f'{len(UNQ_GIDS):,} failed gids')
DF.head()

(578539, 5)
35,723 failed gids


Unnamed: 0_level_0,Unnamed: 1_level_0,new_sp_rep,ani,af,type,same
gid,pct,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
GCA_000143435.1,1,GCF_001435955.1,97.327,0.875591,sp_rep,True
GCA_000143435.1,5,GCF_001435955.1,97.327,0.875591,sp_rep,True
GCA_000143435.1,15,GCF_001435955.1,97.6349,0.829228,sp_rep,True
GCA_000153745.1,5,,,,no_ani,True
GCA_000155005.1,1,GCF_003697165.2,96.5203,0.827668,sp_rep,True


In [5]:
DF_META = GtdbMetadataR207Full().output().read_cached()
DF_META.head()

KeyboardInterrupt: 

types:
no_af = no alignment fraction <0.5
no_ani = novel species cluster
sp_rep = found species representative within ANI radius


In [None]:
from workflow.util.paths import get_gid_root
from collections import defaultdict
import os
import pandas as pd
import numpy as np



PCT_VALUES = [0, 1, 5, 10, 15, 20, 30, 40, 50]

cmap  = sns.color_palette('bright', n_colors=len(PCT_VALUES), as_cmap=False)


def get_data_for_gid(gid):
    gid_root = get_gid_root(gid)

    try:
        base_ani = pd.read_hdf(os.path.join(gid_root, 'fastani_gunc_failed_pct_0.h5'))
        pct_ani = pd.read_hdf(os.path.join(gid_root, 'fastani_gunc_failed_by_contamination.h5'))

    except FileNotFoundError:
        print(f"missing: {gid}")
        return

    df = pd.concat([base_ani, pct_ani], ignore_index=True)


    return df



MIN_ANI = 93


def xform(x):
    return 1.5 * np.pi * (x-MIN_ANI)/(100-MIN_ANI)


def do_single_plot(gid, df):

    df_old = df.copy()

    meta_row = DF_META.loc[gid]
    cur_sp_radius = DF_SP_CLUSTERS.loc[meta_row['species'], 'ani_radius']

    d_rep_to_values = defaultdict(list)
    d_rep_to_pct = defaultdict(list)

    is_rep = meta_row['gtdb_genome_representative'][3:] == gid

    df = df[df['reference'] != gid]

    df = df[df['pct'].isin(PCT_VALUES)]


    df = df[(df['af'] >= 0.5) & (df['ani'] >= MIN_ANI)]

    if len(df) == 0:
        print(f'no info left for: {gid}')
        return

    df = df.sort_values(by=['reference', 'pct'])

    reps_that_exceed_radius = frozenset(df[df['ani'] >= cur_sp_radius]['reference'])

    pct_to_rep_point = dict()
    pct_to_rep_point_correct = dict()
    for _, row in df.iterrows():
        cur_ani = row['ani']
        cur_pct = row['pct']
        d_rep_to_values[row['reference']].append(cur_ani)
        d_rep_to_pct[row['reference']].append(cur_pct)

        if cur_ani >= cur_sp_radius:
            prev_pct_to_rep_point = pct_to_rep_point.get(cur_pct, 0)
            if cur_ani > prev_pct_to_rep_point:
                pct_to_rep_point_correct[cur_pct] = DF_META.loc[row['reference'], 'species'] == meta_row['species']
                pct_to_rep_point[cur_pct] = cur_ani


    fig, ax = plt.subplots(1, 1, figsize=(10, 10), subplot_kw={'projection': 'polar'})
    plt.rcParams.update({'font.size': 12})
    plt.rcParams['svg.fonttype'] = 'none'
    next_colour_i = 0
    for i, (rep, ani) in enumerate(d_rep_to_values.items()):
        pct = d_rep_to_pct[rep]


        # ax.plot([0, np.pi/5, np.pi/2, np.pi], [0, 10, 20, 30] )
        #
        # cur_ani = min_max(ani, 0, np.pi * 1.5)


        cur_ani = [xform(x) for x in ani]

        if rep in reps_that_exceed_radius:
            if DF_META.loc[rep, 'species'] == meta_row['species']:
                marker = 'o'
            else:
                marker = 'D'
            ax.plot(cur_ani, pct, '-', marker=marker, label=DF_META.loc[rep, 'species'],
                    color=cmap[next_colour_i % len(cmap)], zorder=20)
            next_colour_i += 1
        else:
            ax.plot(cur_ani, pct, '-', marker='o', alpha=0.1, color='gray')

        # break
        #
        # if i > 0:
        #     break



        # for cur_ani, cur_pct in zip(ani, pct):
        #     print(cur_ani, cur_pct)
        #     rep_sp = DF_META.loc[rep, 'species']
        #     ax.scatter(cur_ani, rep_sp, color=cmap[PCT_VALUES.index(cur_pct)], alpha=0.4)

    # ax.set_xlim([X_LIM[0] - 1, X_LIM[1] + 1])

    # Put a special marker at each representative point
    for cur_pct, cur_ani in pct_to_rep_point.items():
        is_correct = pct_to_rep_point_correct[cur_pct]
        if is_correct:
            ax.scatter(xform(cur_ani), cur_pct, color='g', marker='x', zorder=100, alpha=1, s=120)
        else:
            ax.scatter(xform(cur_ani), cur_pct, color='r', marker='x', zorder=100, alpha=1, s=120)

    ax.vlines(x=xform(cur_sp_radius), ymin=-10, ymax=70, color='r',
              linestyle='--', alpha=0.4, label=f'ANI Radius: {cur_sp_radius:.2f}%')

    new_xticks, new_xlabels = list(), list()
    for i in range(MIN_ANI, 101):
        new_xticks.append(xform(i))
        new_xlabels.append(str(i))

    new_yticks, new_ylabels = list(), list()
    for i in [-10] + PCT_VALUES:
        new_yticks.append(i)
        new_ylabels.append(str(i) if i not in {-10, 0} else '')

    ax.grid(True)
    ax.set_xlim([0, np.pi * 1.5])
    ax.set_ylim([-10, max(PCT_VALUES) + 5])
    ax.set_theta_direction(-1)

    ax.set_xticklabels(new_xlabels)
    ax.set_xticks(new_xticks)
    # ax.set_xlabel('% of genome removed')

    ax.set_yticks(new_yticks)
    ax.set_yticklabels(new_ylabels)
    ax.set_ylabel('% ANI')

    ax.text(x=-0.35, y=12, s='% of genome removed')

    if is_rep:
        title_text = f'{gid} (species representative)\n{meta_row["species"]}'
    else:
        title_text = f'{gid}\n{meta_row["species"]}'

    plt.title(title_text)

    plt.legend()



    # plt.show()

    os.makedirs('/tmp/guncplots', exist_ok=True)
    plt.savefig(f"/tmp/guncplots/{gid}_rep_{is_rep}.svg")
    plt.close()


    print(gid)
    print('^^^^^^')
    return




def gen_data():

    # Some were originally run with more pct values, but just reduce to the analysis set
    df_changed = DF[DF['same'] == False]
    df_changed = df_changed[df_changed.index.get_level_values(1).isin(PCT_VALUES)]

    d_pct_to_changes = defaultdict(list)

    gids_to_check = frozenset(df_changed.index.get_level_values(0))
    print(f'{len(gids_to_check):,} gids to check')

    GIDS_TO_KEEP ={'GCF_015643835.1', 'GCA_015257755.1', 'GCF_900509435.1', 'GCA_013213925.1', 'GCA_902560595.1', 'GCA_017852475.1', 'GCA_900761595.1', 'GCA_018239885.1', 'GCA_900759145.1', 'GCA_903931905.1', 'GCA_017515185.1', 'GCA_008668795.1', 'GCA_018056875.1', 'GCA_902593295.1', 'GCA_018662785.1', 'GCA_002731855.1', 'GCA_009619015.1', 'GCF_000698005.1', 'GCA_017394825.1', 'GCA_902528895.1', 'GCA_900763945.1', 'GCA_008668585.1', 'GCA_900759525.1', 'GCA_903846615.1', 'GCA_002722235.1', 'GCA_905201125.1', 'GCA_001509115.1', 'GCA_900760075.1', 'GCA_905214645.1', 'GCA_905200745.1', 'GCA_017465765.1', 'GCA_905208535.1', 'GCA_900555595.1', 'GCA_009493725.1', 'GCA_900761055.1', 'GCA_011523145.1', 'GCA_900765305.1', 'GCF_002883995.1'}

    for gid in gids_to_check:

        # if gid not in GIDS_TO_KEEP:
        #     continue

        gid_df = get_data_for_gid(gid)

        if gid_df is None:
            continue

        do_single_plot(gid, gid_df)

        # break


# data = gen_data()


In [None]:
from workflow.util.paths import get_gid_root
from collections import defaultdict
import os
import pandas as pd
import numpy as np



PCT_VALUES = [0, 1, 5, 10, 15, 20, 30, 40, 50]

cmap  = sns.color_palette('bright', n_colors=len(PCT_VALUES), as_cmap=False)


def get_data_for_gid(gid):
    gid_root = get_gid_root(gid)

    try:
        base_ani = pd.read_hdf(os.path.join(gid_root, 'fastani_gunc_failed_pct_0.h5'))
        pct_ani = pd.read_hdf(os.path.join(gid_root, 'fastani_gunc_failed_by_contamination.h5'))

    except FileNotFoundError:
        print(f"missing: {gid}")
        return

    df = pd.concat([base_ani, pct_ani], ignore_index=True)


    return df



MIN_ANI = 93


def xform(x):
    return 1.5 * np.pi * (x-MIN_ANI)/(100-MIN_ANI)


def do_single_plot(gid, df):

    df_old = df.copy()

    meta_row = DF_META.loc[gid]
    cur_sp_radius = DF_SP_CLUSTERS.loc[meta_row['species'], 'ani_radius']

    d_rep_to_values = defaultdict(list)
    d_rep_to_pct = defaultdict(list)

    is_rep = meta_row['gtdb_genome_representative'][3:] == gid

    df = df[df['reference'] != gid]

    df = df[df['pct'].isin(PCT_VALUES)]


    df = df[(df['af'] >= 0.5) & (df['ani'] >= MIN_ANI)]

    if len(df) == 0:
        print(f'no info left for: {gid}')
        return

    df = df.sort_values(by=['reference', 'pct'])

    reps_that_exceed_radius = frozenset(df[df['ani'] >= cur_sp_radius]['reference'])

    pct_to_rep_point = dict()
    pct_to_rep_point_correct = dict()
    for _, row in df.iterrows():
        cur_ani = row['ani']
        cur_pct = row['pct']
        d_rep_to_values[row['reference']].append(cur_ani)
        d_rep_to_pct[row['reference']].append(cur_pct)

        if cur_ani >= cur_sp_radius:
            prev_pct_to_rep_point = pct_to_rep_point.get(cur_pct, 0)
            if cur_ani > prev_pct_to_rep_point:
                pct_to_rep_point_correct[cur_pct] = DF_META.loc[row['reference'], 'species'] == meta_row['species']
                pct_to_rep_point[cur_pct] = cur_ani


    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    plt.rcParams.update({'font.size': 12})

    plt.rcParams['svg.fonttype'] = 'none'
    next_colour_i = 0
    for i, (rep, ani) in enumerate(d_rep_to_values.items()):
        pct = d_rep_to_pct[rep]


        # ax.plot([0, np.pi/5, np.pi/2, np.pi], [0, 10, 20, 30] )
        #
        # cur_ani = min_max(ani, 0, np.pi * 1.5)

        if rep in reps_that_exceed_radius:
            if DF_META.loc[rep, 'species'] == meta_row['species']:
                marker = 'o'
            else:
                marker = 'D'
            ax.plot(pct, ani, '-', marker=marker, label=DF_META.loc[rep, 'species'],
                    color=cmap[next_colour_i % len(cmap)], zorder=20)
            next_colour_i += 1


            print(pct)
            print(ani)
            print(f'----------^^^  {DF_META.loc[rep, "species"]} ^^^-----------')


        else:
            ax.plot(pct, ani, '-', marker='o', alpha=0.1, color='gray')

        # break
        #
        # if i > 0:
        #     break



        # for cur_ani, cur_pct in zip(ani, pct):
        #     print(cur_ani, cur_pct)
        #     rep_sp = DF_META.loc[rep, 'species']
        #     ax.scatter(cur_ani, rep_sp, color=cmap[PCT_VALUES.index(cur_pct)], alpha=0.4)

    # ax.set_xlim([X_LIM[0] - 1, X_LIM[1] + 1])

    # Put a special marker at each representative point
    for cur_pct, cur_ani in pct_to_rep_point.items():
        is_correct = pct_to_rep_point_correct[cur_pct]
        if is_correct:
            ax.scatter(cur_pct,cur_ani , color='g', marker='o', zorder=1, s=120)
        else:
            ax.scatter(cur_pct, cur_ani, color='r', marker='o', zorder=1,  s=120)

    ax.hlines(y=cur_sp_radius, xmin=-10, xmax=70, color='r',
              linestyle='--', alpha=0.4, label=f'ANI Radius: {cur_sp_radius:.2f}%')

    new_xticks, new_xlabels = list(), list()
    for i in PCT_VALUES:
        new_xticks.append(i)
        new_xlabels.append(str(i))

    new_yticks, new_ylabels = list(), list()
    for i in [-10] + PCT_VALUES:
        new_yticks.append(i)
        new_ylabels.append(str(i) if i not in {-10, 0} else '')

    ax.grid(True)
    ax.set_xlim(-1, 51)
    # ax.set_ylim([-10, max(PCT_VALUES) + 5])

    ax.set_ylim([94, 100])

    ax.set_xticklabels(new_xlabels)
    ax.set_xticks(new_xticks)
    ax.set_xlabel('% of genome removed')

    # ax.set_yticks(new_yticks)
    # ax.set_yticklabels(new_ylabels)
    ax.set_ylabel('% ANI')

    # ax.text(x=-0.35, y=12, s='% of genome removed')

    if is_rep:
        title_text = f'{gid} (species representative)\n{meta_row["species"]}'
    else:
        title_text = f'{gid}\n{meta_row["species"]}'

    plt.title(title_text)

    plt.legend()

    ncbi_cat = DF_META.loc[gid, 'ncbi_genome_category']

    # plt.show()



    os.makedirs('/tmp/guncplots', exist_ok=True)
    plt.savefig(f"/tmp/guncplots/{gid}_rep_{is_rep}_{ncbi_cat}.svg")
    plt.close()


    print(gid)
    print(ncbi_cat)
    print('^^^^^^')
    return




def gen_data():

    # Some were originally run with more pct values, but just reduce to the analysis set
    df_changed = DF[DF['same'] == False]
    df_changed = df_changed[df_changed.index.get_level_values(1).isin(PCT_VALUES)]



    # df_changed = df_changed[df_changed.index.get_level_values(1).isin({0, 1, 5, 10})]

    d_pct_to_changes = defaultdict(list)

    gids_to_check = frozenset(df_changed.index.get_level_values(0))
    print(f'{len(gids_to_check):,} gids to check')

    GIDS_TO_KEEP ={'GCF_002026585.1', 'GCA_900751995.1', 'GCA_900759445.1'}

    for gid in gids_to_check:

        if gid not in GIDS_TO_KEEP:
            continue

        gid_df = get_data_for_gid(gid)


        do_single_plot(gid, gid_df)


        # break


data = gen_data()
# DF[DF['same'] == False]