In [1]:
# the representative gene_id is available at:
# https://www.arabidopsis.org/download_files/Genes/Araport11_genome_release/Araport11_blastsets/Araport11_seq_20220914_representative_gene_model.gz

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.patches as mp
import pandas as pd
import numpy as np
import pysam
%config InlineBackend.figure_format = 'retina'

araport11_isoform_path = './FLEPSeq/genome_lib/araport11.representative.gene_model.bed'
araport11_isoform = pd.read_csv(araport11_isoform_path, sep='\t', 
                                names=['chrom', 'chromStart', 'chromEnd', 'name', 
                                       'score', 'strand', 'thickStart', 'thickEnd', 
                                       'itemRgb', 'blockCount', 'blockSizes', 'blockStarts']
                               )
araport11_isoform['gene_id'] = araport11_isoform['name'].map(lambda x: x.split('.')[0])
araport11_isoform.set_index('gene_id', inplace=True)

# Class for alignment visualization

In [25]:
class IGV(object):
    '''IGV(gene_id)
    
    A class for bam alignment visualization, base on the bam file and the annotation bed12 file

    Attributes:
        gene_id: A string type of gene_id list in bed12 file.
    '''

    def __init__(self, gene_id):
        self.gene_id = gene_id
        self._get_gene_info()
        self.max = self.end
        self.min = self.start
        self.bam_list = {}
        self.bedgraph_list = []
        self.labels_bam = []
        self.labels_bedgraph = []

    
    def _get_gene_info(self):
        '''
        Get gene info from bed file
        '''
        gene_info = araport11_isoform.loc[self.gene_id]  # araport11_isoform loaded from pandas
        self.chrom = gene_info.chrom
        self.start = gene_info.chromStart
        self.end = gene_info.chromEnd
        self.strand = gene_info.strand
        self.thickStart = gene_info.thickStart
        self.thickEnd = gene_info.thickEnd
        self.blockCount = gene_info.blockCount
        self.strand_boo = False if self.strand == '+' else True
        self.blockSizes = np.fromstring(gene_info.blockSizes, sep=',', dtype='int')
        self.blockStarts = np.fromstring(gene_info.blockStarts, sep=',', dtype='int') + self.start # 0 base

    def _plot_gene_model(self, ax, gene_color='k'):
        # plot TSS
        small_relative = 0.05 * (self.max-self.min) 
        arrowprops = dict(arrowstyle="-|>", connectionstyle="angle", color=gene_color)
        if self.strand == '+':
            ax.annotate('', xy=(self.start+small_relative, .6), xytext=(self.start, 0), arrowprops=arrowprops)
        else:
            ax.annotate('', xy=(self.end-small_relative, .6), xytext=(self.end, 0), arrowprops=arrowprops)
        ax.plot([self.start, self.end], [0, 0], color=gene_color) 

        height = .3 # the height of gene model

        for exonstart, size in zip(self.blockStarts, self.blockSizes):
            if exonstart == self.start and exonstart+size == self.end:
                utr_size = self.thickStart-self.start
                utr = mp.Rectangle((exonstart, 0-height/2), utr_size, height, color=gene_color, linewidth=0)
                ax.add_patch(utr)
                utr_size = self.end-self.thickEnd
                utr = mp.Rectangle((self.thickEnd, 0-height/2), utr_size, height, color=gene_color, linewidth=0)
                ax.add_patch(utr)
                exon = mp.Rectangle((self.thickStart, 0-height), self.thickEnd-self.thickStart, height*2, color=gene_color, linewidth=0)
                ax.add_patch(exon)
            elif exonstart + size <= self.thickStart:
                # only 5'/ 3'UTR
                utr = mp.Rectangle((exonstart, 0-height/2), size, height, color=gene_color, linewidth=0)
                ax.add_patch(utr)
            elif exonstart < self.thickStart and exonstart + size > self.thickStart:
                # exon with 5' / 3' UTR 
                utr_size = self.thickStart-exonstart
                utr = mp.Rectangle((exonstart, 0-height/2), utr_size, height, color=gene_color, linewidth=0)
                exon = mp.Rectangle((exonstart+utr_size, 0-height), size-utr_size, height*2, color=gene_color, linewidth=0)
                ax.add_patch(utr)
                ax.add_patch(exon)
            elif exonstart >= self.thickStart and exonstart + size <= self.thickEnd:
                # regular exon
                exon = mp.Rectangle((exonstart, 0-height), size, height*2, color=gene_color, linewidth=0)
                ax.add_patch(exon)
            elif exonstart < self.thickEnd and exonstart + size > self.thickEnd:
                # exon with 3' / 5' UTR
                utr_size = exonstart + size - self.thickEnd
                utr = mp.Rectangle((self.thickEnd, 0-height/2), utr_size, height, color=gene_color, linewidth=0)
                exon = mp.Rectangle((exonstart, 0-height), size-utr_size, height*2, color=gene_color, linewidth=0)
                ax.add_patch(utr)
                ax.add_patch(exon)
            elif exonstart >= self.thickEnd:
                # only 3'/ 5'UTR
                utr = mp.Rectangle((exonstart, 0-height/2), size, height, color=gene_color, linewidth=0)
                ax.add_patch(utr)
            
        ax.annotate(self.gene_id, xy=((self.start+self.end)/2, 0.8), ha='center')
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.yaxis.set_major_locator(ticker.NullLocator())
        ax.xaxis.set_major_locator(ticker.NullLocator())
        ax.xaxis.set_ticks_position('none')
        ax.set_ylim(-.5, 1)
        
        
    def _find_exon(self, read):
        BAM_CREF_SKIP = 3 #BAM_CREF_SKIP
        blockStart = []
        blockSize = []
        match_or_deletion = {0, 2, 7, 8} # only M/=/X (0/7/8) and D (2) are related to genome position
        exon_start = read.reference_start
        length = 0
        for op, nt in read.cigartuples:
            if op in match_or_deletion:
                length += nt
            elif op == BAM_CREF_SKIP:
                blockStart.append(exon_start)
                blockSize.append(length)
                exon_start += length+nt
                length = 0
        blockStart.append(exon_start)
        blockSize.append(length)
        return zip(blockStart, blockSize)
        
    def _plot_bam(self, ax, read_list, bam_num, read_color='#5D93C4'):
        ypos = 0
        height = .6
        for read in read_list:
            line = mp.Rectangle((read.reference_start, ypos-height/4), read.reference_length, height/2, color='#A6A6A6', linewidth=0)
            ax.add_patch(line)
            for block_start, block_size in self._find_exon(read):
                exon = mp.Rectangle((block_start, ypos-height), block_size, height*2, color=read_color, linewidth=0)
                ax.add_patch(exon)
            ypos += -1

        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.yaxis.set_major_locator(ticker.NullLocator())
        ax.set_ylim(ypos-int(len(read_list)*.1), 1+int(len(read_list)*.25))
   

        if bam_num != len(self.bam_list) + len(self.bedgraph_list):
            ax.xaxis.set_major_locator(ticker.NullLocator())
            ax.xaxis.set_ticks_position('none')
        else:
            ax.set_xlabel('Length (nt)')

         


    def _filter_reads(self, read, five_prime_threshold=150):
        # discard the antisense read
        if read.is_reverse is not self.strand_boo:
            return True
        # discard reads generated from upstream gene
        if read.is_reverse and read.reference_end-self.end > five_prime_threshold:
            return True
        elif not read.is_reverse and self.start - read.reference_start > five_prime_threshold:
            return True

        return False

    def _sort_bam(self, read_list, sort_method):
        if sort_method == '3':
            if self.strand == '+':
                read_list.sort(key=lambda read: read.reference_end)
            else:
                read_list.sort(key=lambda read: read.reference_start, reverse=True)
        elif sort_method == '5':
            if self.strand == '+':
                read_list.sort(key=lambda read: read.reference_start)
            else:
                read_list.sort(key=lambda read: read.reference_end, reverse=True)
        
          

    def add_bam(self, *bam_paths, sort_method='3', label = ""):
        '''add_bam(self, *bam_paths, read_type={'polya', 'elongating'}, sort_method='3',)

        Add read object to the self.bam_list

        Args:
            *bam_paths: An the bam file path. default: {'polya', 'elongating'}.
            
            read_type: A set contained any of {'polya', 'elongating', 'elongating_5lost', 
                'polya_5lost', 'splicing_intermediate', 'elongating_3_mapping_low_accuracy', 
                'polya_3_not_in_last_exon'}
                Determines what kinds of reads to plot.
                
            sort_method: {'3', '5', 'ir'}, default '3'
                Sort reads strategy. -'3': sort by 3' end position of reads.
                -'5': sort by 5' end position of reads. 'ir': sort by retained
                intron.
        '''
        self.labels_bam.append(label)

        for bam_path in bam_paths:
           
            with pysam.AlignmentFile(bam_path, 'rb') as inbam:
                non_polya_read = []
                
                # extend search range downstream 1000nt
                if self.strand == '-':
                    start = self.start-1000 if self.start > 1000 else 1
                    end = self.end
                else:
                    end = self.end+1000 if self.end+1000 < inbam.get_reference_length(self.chrom) else inbam.get_reference_length(self.chrom)
                    start = self.start
                    
                for read in inbam.fetch(self.chrom, start, end):
                    # filter read
                    if self._filter_reads(read):
                        continue
                    
                   
                    
                    else:
                        non_polya_read.append(read)

            # sort reads
            self._sort_bam(non_polya_read, sort_method)
            # merge
            self.bam_list[bam_path] = non_polya_read
    

    def add_bedgraph(self, bedgraph_path):
        complete_bedgraph = pd.read_csv(bedgraph_path, sep='\t', header=None, names=['chrom', 'start', 'end', 'value'], low_memory=False)
        self.bedgraph_list.append(complete_bedgraph)
               
                 


    def _plot_bedgraph(self, ax, bedgraph,  extend_3_prime_by, bedgraph_color = "darkgrey"):

               

        if self.strand == '-':
            subset_bedgraph = bedgraph[(bedgraph['chrom'] == self.chrom) & (bedgraph['start'] >= self.start - extend_3_prime_by) & (bedgraph['end'] <= self.end)]
            subset_bedgraph = subset_bedgraph.iloc[::-1]
     
        else:
            subset_bedgraph = bedgraph[(bedgraph['chrom'] == self.chrom) & (bedgraph['start'] >= self.start) & (bedgraph['end'] <= self.end + extend_3_prime_by)]


        
      
        values = subset_bedgraph['value']
        max_value = max(values)
        ymax = 1.1
         
      
        for index, line in subset_bedgraph.iterrows():
            start = line["start"]
            end = line["end"]
            value = line["value"]

            width = end - start + 2
            height = value/max_value
            
            
            rectangle = mp.Rectangle((start, 0), width,  height , color=bedgraph_color, linewidth=0)
            

            ax.add_patch(rectangle)
           
      

        
        ax.set_ylim(0, ymax)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
            

        


        #add xaxis to each plot


    def plot(self, height=3, height_bedgraph = 3, width=6, extend_3_prime_by = 0):
        '''plot(self, height=3, width=6)

        Plot the alignment

        Args:
            height: the height of the figure, default 3.
            width: the width of the figure, default 6.
        '''
        nrows = len(self.bam_list) + len(self.bedgraph_list) +1
       
        height_ratios = []

        

        
        
        for i in self.bedgraph_list:
                height_ratios.append(height_bedgraph)

        height_ratios.append(0.6)
        
        for bam_path in self.bam_list:
                height_ratios.append(height)
      

        fig, ax = plt.subplots(
            nrows=nrows, 
            gridspec_kw={'height_ratios': height_ratios},
            figsize=(width, sum(height_ratios)),
            sharex = True,
        )


        for i, sub_graph in enumerate(self.bedgraph_list):
    
            self._plot_bedgraph(ax[i], sub_graph, extend_3_prime_by)
            



        # plot gene_model
        if len(self.bam_list) == 0 and len(self.bedgraph_list) == 0:
            self._plot_gene_model(ax)
        else:
            self._plot_gene_model(ax[len(self.bedgraph_list)])



        ## plot bedgraph files
      
        ### plot bam files
        for i, bam_path in enumerate(self.bam_list, 1):
            ax[i].set_title(self.labels_bam[i-1], y = 0.95, pad = 3, fontsize='medium', loc = "left")
            i = i + len(self.bedgraph_list)
            self._plot_bam(ax[i], self.bam_list[bam_path], i)
            
        

        ax_ = ax[i]
        step = (self.max-self.min)//400*100  # ticklabels of x_axis


        
        if self.strand == '+':
            #xticks = np.arange(self.start, self.end+step, step)
            xticks = np.arange(self.min, self.max+step, step)
            ax_.set_xticks(xticks)
            ax_.set_xticklabels(xticks-self.min)
            ax_.set_xlim(self.min, self.max + extend_3_prime_by)
        else:
            #xticks = np.arange(self.end, self.start-step, -step)
            xticks = np.arange(self.max, self.min-step, -step)
            ax_.set_xticks(xticks)
            ax_.set_xticklabels(self.max-xticks) 
            ax_.set_xlim(self.min - extend_3_prime_by, self.max)
            ax_.invert_xaxis()

           
        plt.subplots_adjust(hspace=0.2)
        plt.savefig(f'{self.gene_id}.igv.svg', format='svg', bbox_inches='tight')
