In [None]:
import re
import itertools
from pathlib import Path
from typing import Optional
from numpy.random import RandomState
import plotly.express as px
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt

from hmmer import HMMER, read_domtbl, read_domtbl
import gff_io
from gff_io.interval import PyInterval, RInterval
import hmmer_reader
import pandas as pd
from iseq_prof import pfam, sam
import iseq_prof_analysis as analysis
from fasta_reader import read_fasta
import sam_io
from dna_features_viewer import GraphicFeature, GraphicRecord

In [None]:
# random = RandomState(183)
# meta_filepath = Path("/Users/horta/db/pfam/Pfam-A.hmm.meta.pkl.gz")
# dombtbl_filepath = Path("/Users/horta/ebi/chlamydia/output/assembly/domtblout.txt")

In [None]:
analysis.load_config(verbose=True)
EVALUE_THRSHOLD = 1e-9
RNAME = "2"

In [None]:
root = analysis.config.chlamydia.root_dir
hybrid_consensus = analysis.config.chlamydia.hybrid_consensus
output_dir = root / "output_depth49_200targets"

In [None]:
hybrid = list(read_fasta(root / hybrid_consensus))
print(f"# {hybrid_consensus}")
print(f"Number of targets: {len(hybrid)}")
print(f"Target 2: >{hybrid[1].defline}")

In [None]:
clans = pfam.Clans()
def normalize_clan_name(name: Optional[str]) -> str:
    if name is None:
        return "Unclassified"
    return name

In [None]:
assembly_gffs = {}

for assembly_gff in gff_io.read_gff(output_dir / "prokka" / "assembly.gff"):
    ID = assembly_gff.attributes_asdict()["ID"]
    assembly_gffs[ID] = assembly_gff

In [None]:
colors = itertools.cycle(px.colors.qualitative.Plotly)
clan_colors = {}

## Prokka+HMMER3

In [None]:
features = []
for domtbl_row in read_domtbl(output_dir / "assembly" / "domtblout.txt"):
    assembly_gff = assembly_gffs[domtbl_row.query.name]

    interval = domtbl_row.ali_coord.interval
    interval = PyInterval(interval.start * 3, interval.end * 3)
    interval = interval.offset(assembly_gff.interval.start)
    
    profile_name = domtbl_row.target.name
    profile_clan = normalize_clan_name(clans.get(domtbl_row.target.accession))
    if profile_clan not in clan_colors:
        clan_colors[profile_clan] = next(colors)

    strand = int(assembly_gff.strand + "1")
    feature = GraphicFeature(start=interval.start, end=interval.end, strand=strand,
                             color=clan_colors[profile_clan],
                             label=profile_name)
    features.append(feature)

consensus_features = features
# record = GraphicRecord(sequence_length=len(hybrid[1].sequence), features=features)
# ax = record.plot(figure_width=20)[0]
# ax.figure.savefig('prokka_on_hybrid_depth49_consensus.png', bbox_inches='tight')
# record.plot(figure_width=20);

In [None]:
sam_map = sam.SAMMap(output_dir / "alignment.sam")

In [None]:
features_db = defaultdict(list)
# features = []
hybrid_length = len(hybrid[1].sequence)
# mapped_seqids = set()
seqid_interval = {}
for item in gff_io.read_gff(output_dir / "output.gff"):

    atts = item.attributes_asdict()
    if float(atts["E-value"]) > EVALUE_THRSHOLD:
        continue
        
    profile_name = atts["Profile_name"]
    profile_clan = normalize_clan_name(clans.get(atts["Profile_acc"]))

#     assert int(item.start) > 0
#     assert int(item.end) > int(item.start)
#     start, end = map_backs[item.seqid].back(int(item.start) - 1, int(item.end))
#     assert end > start
#     if end <= 0 or start >= hybrid_length:
#         continue
    try:
        interval = sam_map.back_to_query(item.seqid, item.interval)
    except KeyError:
        continue
    if interval is None:
        continue

    if profile_clan not in clan_colors:
        clan_colors[profile_clan] = next(colors)
        
#     if item.seqid not in seqid_colors:
#         if len(seqid_colors) >= 10:
#             continue
#         seqid_colors[item.seqid] = next(colors)
    
#     mapped_seqids.add(item.seqid)
    strand = int(item.strand + "1")
    feature = GraphicFeature(start=interval.start, end=interval.end,
                             strand=strand,
#                              color=seqid_colors[item.seqid],
                             color=clan_colors[profile_clan],
                             label=profile_name)
    features_db[item.seqid].append(feature)
#     features.append(feature)

# record = GraphicRecord(sequence_length=hybrid_length, features=features)
# ax = record.plot(figure_width=20)[0]
# ax.figure.savefig('iseq_on_200_sequences.png', bbox_inches='tight')
# record.plot(figure_width=20);

In [None]:
n = len(features_db) + 1
fig, axs = plt.subplots(
    n, 1, figsize=(24, 3 * n), sharex=True
)

record = GraphicRecord(sequence_length=hybrid_length, features=consensus_features)
record.plot(ax=axs[0]);

for ax, (seqid, features) in zip(axs[1:], features_db.items()):
    record = GraphicRecord(sequence_length=hybrid_length, features=features)
    interval = sam_map.full_query_interval(seqid)
    ax.fill_between(
        [interval.start, interval.end], -1, 1, facecolor="peachpuff", alpha=0.5, zorder=-1
    )
    record.plot(ax=ax);

fig.savefig('iseq_on_200_sequences.png', bbox_inches='tight')