In [None]:
import re
import itertools
from pathlib import Path
from typing import Optional
from numpy.random import RandomState
import plotly.express as px
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt

from hmmer import HMMER, read_domtbl, read_domtbl
import gff_io
from gff_io.interval import PyInterval, RInterval
import hmmer_reader
import pandas as pd
from iseq_prof import pfam, sam
import iseq_prof_analysis as analysis
from fasta_reader import read_fasta
import sam_io
from dna_features_viewer import GraphicFeature, GraphicRecord

In [None]:
analysis.load_config(verbose=True)
EVALUE_THRSHOLD = 1e-10
assembly_name = "hybrid-49depth-assembly"

In [None]:
root = analysis.config.chlamydia.root_dir
hybrid_consensus = analysis.config.chlamydia.hybrid_consensus
output_dir = root / "output_depth49_251targets"
sam_map = sam.SAMMap(output_dir / "alignment" / "alignment.sam")

In [None]:
hybrid = list(read_fasta(root / hybrid_consensus))
HYBRID_LENGTH = len(hybrid[1].sequence)
print(f"# {hybrid_consensus}")
print(f"Number of consensus: {len(hybrid)}")
print(f"Target 2: >{hybrid[1].defline}")

In [None]:
ntargets = len(read_fasta(output_dir / "alignment" / "alignment.fasta").read_items())
print(f"Number of raw targets: {ntargets}")

In [None]:
clans = pfam.Clans()
def normalize_clan_name(name: Optional[str]) -> str:
    if name is None:
        return "Unclassified"
    return name

In [None]:
def get_prokka_map(gff_filepath):
    prokka_map = {}
    for item in gff_io.read_gff(gff_filepath):
        ID = item.attributes_asdict()["ID"]
        assert ID not in prokka_map
        prokka_map[ID] = item
    return prokka_map

In [None]:
colors = itertools.cycle(px.colors.qualitative.Plotly)
clan_colors = {}

## Prokka+HMMER3 on assembly

In [None]:
def get_assembly_features(output_dir, clans, clan_colors):
    features = []
    prokka_map = get_prokka_map(output_dir / "prokka_assembly" / "assembly.gff")
    for domtbl_row in read_domtbl(output_dir / "hmmscan_assembly" / "domtbl.txt"):
        prokka = prokka_map[domtbl_row.query.name]

        interval = domtbl_row.ali_coord.interval
        interval = PyInterval(interval.start * 3, interval.end * 3)
        interval = interval.offset(prokka.interval.start)

        profile_name = domtbl_row.target.name
        profile_clan = normalize_clan_name(clans.get(domtbl_row.target.accession))
        if profile_clan not in clan_colors:
            clan_colors[profile_clan] = next(colors)

        strand = int(prokka.strand + "1")
        feature = GraphicFeature(start=interval.start, end=interval.end, strand=strand,
                                 color=clan_colors[profile_clan],
                                 label=profile_name)
        features.append(feature)
    
    return features

assembly_features = get_assembly_features(output_dir, clans, clan_colors)

In [None]:
assembly_record = GraphicRecord(sequence_length=HYBRID_LENGTH, features=assembly_features)
ax, _ = assembly_record.plot(figure_width=20);
ax.set_title(f"{assembly_name} ground-truth");

In [None]:
def get_iseq_features(output_dir, clans, clan_colors, sam_map):
    features_db = defaultdict(list)
    for item in gff_io.read_gff(output_dir / "iseq_scan_targets" / "output.gff"):

        atts = item.attributes_asdict()
        if float(atts["E-value"]) > EVALUE_THRSHOLD:
            continue

        profile_name = atts["Profile_name"]
        profile_clan = normalize_clan_name(clans.get(atts["Profile_acc"]))

        try:
            interval = sam_map.back_to_query(item.seqid, item.interval)
        except KeyError:
            continue
        if interval is None:
            continue

        if profile_clan not in clan_colors:
            clan_colors[profile_clan] = next(colors)

        strand = int(item.strand + "1")
        feature = GraphicFeature(start=interval.start, end=interval.end,
                                 strand=strand,
                                 color=clan_colors[profile_clan],
                                 label=profile_name)
        features_db[item.seqid].append(feature)
    return features_db

iseq_features_db = get_iseq_features(output_dir, clans, clan_colors, sam_map)

In [None]:
def sort_features_db(features_db):
    def key_order(key):
        if key == "2":
            return "0"
        return key
    keys = sorted(features_db.keys(), key=key_order)
    
    return {key: features_db[key] for key in keys}  

def seqid2label(key):
    if key == "2":
        return assembly_name
    return key

In [None]:
def get_figure(features_db, assembly_features, assembly_name):
    features_db = sort_features_db(features_db)
    n = len(features_db) + 1
    fig, axs = plt.subplots(n, 1, figsize=(24, 5 * n), sharex=True)

    record = GraphicRecord(sequence_length=HYBRID_LENGTH, features=assembly_features)
    axs[0].text(0, -0.86, f"{assembly_name} ground-truth")
    record.plot(ax=axs[0]);

    for ax, (seqid, features) in zip(axs[1:], features_db.items()):
        record = GraphicRecord(sequence_length=HYBRID_LENGTH, features=features)
        interval = sam_map.full_query_interval(seqid)
        x = [interval.start, interval.end]
        ax.fill_between(x, -1, 1, facecolor="peachpuff", alpha=0.5, zorder=-1)
        ax.text(interval.start, -0.8, seqid2label(seqid))
        record.plot(ax=ax);
    
    return fig, axs

# ISEQ on raw reads

In [None]:
fig, axs = get_figure(iseq_features_db, assembly_features, assembly_name)
axs[0].set_title("ISEQ");
fig.savefig(f"iseq_on_{ntargets}_sequences.png", bbox_inches='tight')

In [None]:
def get_hmmer_features(output_dir, clans, clan_colors, sam_map):
    features_db = defaultdict(list)
    prokka_map = get_prokka_map(output_dir / "prokka_targets" / "targets.gff")
    for domtbl_row in read_domtbl(output_dir / "hmmscan_targets" / "domtbl.txt"):
        
        prokka = prokka_map[domtbl_row.query.name]

        interval = domtbl_row.ali_coord.interval
        interval = PyInterval(interval.start * 3, interval.end * 3)
        interval = interval.offset(prokka.interval.start)

        profile_name = domtbl_row.target.name
        profile_clan = normalize_clan_name(clans.get(domtbl_row.target.accession))
        
        try:
            interval = sam_map.back_to_query(prokka.seqid, interval)
        except KeyError:
            continue
        if interval is None:
            continue

        if profile_clan not in clan_colors:
            clan_colors[profile_clan] = next(colors)

        strand = int(prokka.strand + "1")
        feature = GraphicFeature(start=interval.start, end=interval.end,
                                 strand=strand,
                                 color=clan_colors[profile_clan],
                                 label=profile_name)
        features_db[prokka.seqid].append(feature)
    return features_db

hmmer_features_db = get_hmmer_features(output_dir, clans, clan_colors, sam_map)

# PROKKA+HMMER3 on raw reads

In [None]:
fig, axs = get_figure(hmmer_features_db, assembly_features, assembly_name)
axs[0].set_title("PROKKA+HMMER3");
fig.savefig(f"hmmer3_on_{ntargets}_sequences.png", bbox_inches='tight')

# ISEQ & PROKKA+HMMER3 on raw reads

In [None]:
keys = set(iseq_features_db.keys()) & set(hmmer_features_db.keys())

In [None]:
def get_figure2(features_db1, features_db2, title1, title2, assembly_features, assembly_name):
    features_dbs = [sort_features_db(features_db1), sort_features_db(features_db2)]
    n = len(features_db1) + 1
    fig, axs = plt.subplots(n, 2, figsize=(24, 5 * n), sharex=True)

    record = GraphicRecord(sequence_length=HYBRID_LENGTH, features=assembly_features)
    axs[0][0].text(0, -0.86, f"{assembly_name} ground-truth")
    axs[0][1].text(0, -0.86, f"{assembly_name} ground-truth")
    record.plot(ax=axs[0][0]);
    record.plot(ax=axs[0][1]);

    for ax, (seqid1, features1), (seqid2, features2) in zip(axs[1:], features_db1.items(), features_db2.items()):
        assert seqid1 == seqid2
        record1 = GraphicRecord(sequence_length=HYBRID_LENGTH, features=features1)
        interval1 = sam_map.full_query_interval(seqid1)
        x1 = [interval1.start, interval1.end]
        ax[0].fill_between(x1, -1, 1, facecolor="peachpuff", alpha=0.5, zorder=-1)
        ax[0].text(interval1.start, -0.8, seqid2label(seqid1))
        record1.plot(ax=ax[0]);

        record2 = GraphicRecord(sequence_length=HYBRID_LENGTH, features=features2)
        interval2 = sam_map.full_query_interval(seqid2)
        x2 = [interval2.start, interval2.end]
        ax[1].fill_between(x2, -1, 1, facecolor="peachpuff", alpha=0.5, zorder=-1)
        ax[1].text(interval2.start, -0.8, seqid2label(seqid1))
        record2.plot(ax=ax[1]);
    
    return fig, axs

In [None]:
fig, axs = get_figure2({key: iseq_features_db[key] for key in keys},
                       {key: hmmer_features_db[key] for key in keys},
                       "ISEQ", "PROKKA+HMMER3",
                        assembly_features,
                        assembly_name)
axs[0][0].set_title("ISEQ");
axs[0][1].set_title("PROKKA+HMMER3");
fig.savefig(f"iseq_vs_prokka_hmmer3_{ntargets}_sequences.png", bbox_inches='tight')