## finalizingReadPlotting_matplotlib.ipynb
### Marcus Viscardi     June 16, 2022

So I tried to just run what was in testingReadPlotting_matplotlib.py with tba-1 and tba-2 b/c josh wanted those for grants. But the script didn't work. I am fairly sure it is due to these genes being negative strand... I was also having a hard time deciphering what was happening in the other code because it was really hard-coded to plot the stuff for ubl-1 and my RNA society poster.

The new plan is to keep that code function and rewrite stuff here to be a little more reusable. Then one I have troubleshooted the hell out of all of this I'll drop it into a real script that I can package and give to folks. Potentially with the addition of the coverage plotting on top of this? Would be cool!

In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
pd.set_option('display.width', 100)
pd.set_option('display.max_columns', None)

import re

import seaborn as sea
import matplotlib.pyplot as plt
import sys

import warnings
sys.path.insert(0, '/data16/marcus/scripts/nanoporePipelineScripts')
from nanoporePipelineCommon import *

print("imports done")

In [None]:
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle

def _make_rectangle_patch(genome_start, length, y_center, thickness, color='gray'):
    return Rectangle((genome_start, y_center-(thickness/2)), length, thickness,
                     facecolor=color,
                     edgecolor=color,
                     fill=True,
                     lw=0)


def _add_patches_from_cigars_and_gen_pos(axes, cigar, gen_start, y, strand, color='black', plot_introns=True,tail_length=None):
    # Parse the cigar string
    parsed_cigar = re.findall(rf'(\d+)([MDNSIX])', cigar)
    mdn_nums = [int(num) for num, char in parsed_cigar if char in "MDN"]
    gen_end = gen_start + sum(mdn_nums)
    genomic_read_length = gen_end - gen_start
    
    genome_loc = gen_start
    
    rectangle_patch_list = []
    first_n_length = 0
    for length, code in parsed_cigar:
        length = int(length)
        if code == 'S':
            pass
        elif code == 'M':
            rectangle_patch_list.append(_make_rectangle_patch(genome_loc, length, y, thickness=0.8))
            genome_loc += length
        elif code == 'D':
            if length < 50:
                rectangle_patch_list.append(_make_rectangle_patch(genome_loc, length, y, thickness=0.8))
            else:
                if plot_introns:
                    rectangle_patch_list.append(_make_rectangle_patch(genome_loc, length, y, thickness=0.001))
            genome_loc += length
        elif code == 'I':
            pass
        elif code == 'N':
            if plot_introns:
                rectangle_patch_list.append(_make_rectangle_patch(genome_loc, length, y, thickness=0.001))
            genome_loc += length
    axes.add_collection(PatchCollection(rectangle_patch_list, color=color))
    if isinstance(tail_length, float):
        if strand == "+":
            axes.add_patch(_make_rectangle_patch(genome_loc, tail_length, y, thickness=0.4, color='green'))
            genome_loc += tail_length
        else:
            axes.add_patch(_make_rectangle_patch(gen_start, -tail_length, y, thickness=0.4, color='green'))
    return genomic_read_length


def _row_apply_plot_cigar(row, axes, plot_introns=True):
    index = row.name
    cigar = row.cigar
    gen_start = row.original_chr_pos
    is_adapted = row.t5
    polya_length = row.polya_length
    strand = row.strand
    
    if is_adapted == '-':
        color='black'
    else:
        color='red'
    return _add_patches_from_cigars_and_gen_pos(axes, cigar, gen_start, index, strand,
                                               color=color,
                                               plot_introns=plot_introns,
                                               tail_length=polya_length)


def _get_gene_coordinates(
        gene_id=None, gene_name=None,
        parsed_gtf_path="/data16/marcus/genomes/elegansRelease100/Caenorhabditis_elegans.WBcel235.100.gtf.parquet"
) -> (str, str, int, int):
    # First make sure we got something to look up:
    gene_id_bool = isinstance(gene_id, str)
    gene_name_bool = isinstance(gene_name, str)
    if not gene_id_bool and not gene_name_bool:
        raise NotImplementedError(f"Please pass a gene_id or a gene_name!")
    # Load the parsed gtf_file
    try:
        gtf_df = pd.read_parquet(parsed_gtf_path)[["gene_id",
                                                   "gene_name",
                                                   "feature",
                                                   "chr",
                                                   "start",
                                                   "end",
                                                   "strand"]].query("feature == 'gene'")
    except FileNotFoundError:
        raise FileNotFoundError(f"Please make sure there is a parsed gtf file at: {parsed_gtf_path}")

    # Get the gene of interest!
    try:
        if gene_id_bool:
            entry_of_interest = gtf_df.query(f"gene_id == '{gene_id}'").reset_index(drop=True).iloc[0].to_dict()
            gene_name = entry_of_interest["gene_name"]
        else:  # if gene_name_bool
            entry_of_interest = gtf_df.query(f"gene_name == '{gene_name}'").reset_index(drop=True).iloc[0].to_dict()
            gene_id = entry_of_interest["gene_id"]
    except IndexError:
        raise IndexError(f"Gene of interest (gene_id: {gene_id} / gene_name: {gene_name}) not found!")
    chromosome = entry_of_interest["chr"]
    start = entry_of_interest["start"]
    end = entry_of_interest["end"]
    strand = entry_of_interest["strand"]
    print(f"Found entry for {gene_name} ({gene_id}) on chromosome {chromosome:>5} at ({start}, {end}) on the '{strand}' strand")
    return gene_name, chromosome, start, end, strand


def plot_reads(reads_df, gene_id_to_plot=None, gene_name_to_plot=None,
               save_dir=None, save_suffix="", plot_width_and_height=(25,5),
               subsample_fraction=None, subsample_number=None,
               t5_pos_count=None, t5_neg_count=None,
               pad_x_axis_bounds_by=None, only_keep_reads_matched_to_gene=True):
    
    gene_name, chromosome, genomic_start, genomic_end, gene_strand = _get_gene_coordinates(gene_name=gene_name_to_plot, gene_id=gene_id_to_plot)
    
    if isinstance(subsample_fraction, float):
        subsampled_reads_df = reads_df.sample(frac=subsample_fraction)
    elif isinstance(subsample_number, int):
        subsampled_reads_df = reads_df.sample(n=subsample_number)
    else:
        subsampled_reads_df = reads_df  # Just to have the same variable name!
    if only_keep_reads_matched_to_gene:
        all_gene_reads = subsampled_reads_df.query(f"gene_name == '{gene_name}'")
    else:
        raise NotImplementedError(f"This doesn't currently work...")
    gene_df_t5_pos = all_gene_reads.query("t5 == '+'")
    if isinstance(t5_pos_count, int):
        gene_df_t5_pos = gene_df_t5_pos.sample(t5_pos_count)
    gene_df_t5_neg = all_gene_reads.query("t5 == '-'")
    if isinstance(t5_neg_count, int):
        gene_df_t5_neg = gene_df_t5_neg.sample(t5_neg_count)
    gene_df = pd.concat([gene_df_t5_pos, gene_df_t5_neg])
    
    plt.style.use('default')
    # fig, ax = plt.subplots()
    # tqdm.pandas(desc="First pass to extract the length of the first intron called")
    # gene_df[['genomic_read_length', 'first_n_length']] = gene_df.progress_apply(lambda row: _row_apply_plot_cigar(row, ax), axis=1, result_type='expand')
    # fig, ax = None, None
    
    plt.style.use('default')
    fig, ax = plt.subplots(figsize=plot_width_and_height)
    
    if gene_strand == "-":
        sort_order = ["t5", "chr_pos", "original_chr_pos", "read_length"]
        sort_order_ascending = [False, True, False, False]
    else:  # gene_strand == "+":
        sort_order = ["t5", "original_chr_pos", "chr_pos", "read_length"]
        sort_order_ascending = [False, False, False, False]
    tqdm.pandas(desc="Plotting Reads...")
    gene_df = gene_df.sort_values(sort_order, ascending=sort_order_ascending).reset_index(drop=True)
    gene_df.progress_apply(lambda row: _row_apply_plot_cigar(row, ax), axis=1)

    number_of_plotted_reads = gene_df.shape[0]
    ax.set_ylim(-1, number_of_plotted_reads+1)
    
    if isinstance(pad_x_axis_bounds_by, int):
        ax.set_xlim(genomic_start - pad_x_axis_bounds_by,
                    genomic_end + pad_x_axis_bounds_by)
    else:
        ax.set_xlim(genomic_start, genomic_end)
    
    plt.xticks([])
    plt.yticks([])
    if isinstance(save_dir, str):
        save_path = f"{save_dir}/{get_dt(for_file=True)}_readPlotting_{gene_name}{save_suffix}"
        print(f"Saving plot to {save_path} + .png/.svg...")
        plt.savefig(save_path + ".svg")
        plt.savefig(save_path + ".png")
    # plt.show()
    return gene_df

In [None]:
reads_df_genes_raw, compressed_df_genes_raw = load_and_merge_lib_parquets(["xrn-1-5tera", "xrn-1-5tera-smg-6"], drop_sub_n=1, add_tail_groupings=False, drop_failed_polya=False, group_by_t5=True)
print("library load done.")

In [None]:
reads_df = reads_df_genes_raw.copy()
compressed_df = compressed_df_genes_raw.copy()

In [None]:
gene_to_plot = "zip-1"

# plot_reads(reads_df.query("lib == 'xrn-1-5tera-smg-6'"), gene_name_to_plot=gene_to_plot,
#            # t5_pos_count=1, t5_neg_count=30,
#            pad_x_axis_bounds_by=100, save_dir=f"./outputDir", save_suffix="_smg-6-KO_allReads")
plot_reads(reads_df.query("lib == 'xrn-1-5tera'"), gene_name_to_plot=gene_to_plot,
           # t5_pos_count=10, t5_neg_count=80,
           pad_x_axis_bounds_by=100, save_dir=f"./outputDir", save_suffix="_WT_allReads")