In [None]:
import math
import numpy as np
import pandas as pd
import seaborn as sns
from pathlib import Path
import matplotlib.pyplot as plt
from collections import defaultdict
import matplotlib as mpl
import matplotlib.cm as cm

In [None]:
def splitSampleName(x):
    """ Extract patient number, treatment, response and 
        replicate from sample name. """
    responsivePatients = [5, 126,  62]
    x = x.split('_')
    patient = int(x[1])
    treatment = x[2].split('-')[0]
    try:
        rep = x[2].split('-')[1]
    except IndexError:
        rep = '1'
    if patient in responsivePatients:
        response = 'responsive'
    else:
        response = 'resistant'
    return patient, treatment, response, rep


def defaultPlotting(size=9, width=180, ratio=0.5):
    """ Preconfigured figure settings """
    mm = 1 / 25.4 # mm in an inch
    colour = '#444444'
    sns.set(rc={
        'font.size': size, 'axes.titlesize': size, 'axes.labelsize': size,
        'xtick.labelsize': size, 'ytick.labelsize': size,
        'legend.fontsize': size, 'legend.title_fontsize': size,
        'font.family': 'sans-serif', 'lines.linewidth': 1.5,
        'axes.labelcolor': colour, 'xtick.color': colour,
        'ytick.color': colour,
        'figure.figsize': (width * mm, width * ratio * mm),
        'axes.spines.top': False, 'axes.spines.right': False,
    }, style='white')


def formatP(p):
    """ Return formatted p for title """
    if p > 0.999:
        pformat = '> 0.999'
    elif p < 0.001:
        pformat = '< .001'
    else:
        pformat = '= ' + f'{p:.3f}'[1:]
    if p < 0.001:
        return pformat + ' ***'
    elif p < 0.01:
        return pformat + ' **'
    elif p < 0.05:
        return pformat + ' *'
    else:
        return pformat
    
def getColour(x):
    if x < 0:
        return '0,0,255'
    else:
        return '255,0,0'

In [None]:
# Set alpha for significance threshold of q-values
alpha = 0.01

In [None]:
dtypes = ({
    'chrom': str, 'chromStart': int, 'chromEnd': int, 'name': str, 'score': float,
    'strand': str, 'thickStart': str, 'thickEnd': str, 'itemRgb': str, 'blockCount': int,
    'hlockSizes': str, 'blockStarts': str
})
genes = pd.read_csv('../analysis/annotation/GRCh38.bed12', dtype=dtypes, names=dtypes.keys(), sep='\t')

validChroms = [str(i) for i in range(1,23)] + ['X']

genes = genes.loc[genes['chrom'].isin(validChroms)]

In [None]:
# Read edgeR gene level abundances
geneCounts = (
    pd.read_csv('../analysis/featureCounts/allCounts-TMM.tsv.gz', sep='\t')
    .reset_index()
    .melt(id_vars='index', value_name='TMM', var_name='sample')
    .rename({'index': 'ens_gene'}, axis=1))
geneCounts['ens_gene'] = geneCounts['ens_gene'].apply(lambda x: x.split('.')[0])

geneCounts['patient'], geneCounts['treatment'], geneCounts['response'], geneCounts['rep'] = (
    zip(*geneCounts['sample'].apply(splitSampleName)))
geneCounts['group'] = geneCounts['patient'].astype(str) + geneCounts['rep'].astype(str)

In [None]:
data = defaultdict(dict)
DEGlist = set()
for model in ['treatment', 'response']:
    # Load gene-level data
    diffGene = pd.read_csv(f'{model}/{model}-gene.csv')
    diffGene['DEG'] = diffGene['qval'] < alpha
    # Retrieve gene IDs associated with multiple gene symbols (ambiguous)
    multiMapped = diffGene.groupby('ext_gene')['target_id'].nunique() > 1
    multiMapped = multiMapped[multiMapped].index
    # Remove ambiguous gene IDs
    diffGene = diffGene.loc[~diffGene['target_id'].isin(multiMapped)]
    
    data[model]['diffGene'] = diffGene
    
    # Load transcript-level data
    diffTx = pd.read_csv(f'{model}/{model}-Tx.csv')
    diffTx['DEG'] = diffTx['qval'] < alpha
    # Remove ambiguous gene IDs
    diffTx = diffTx.loc[~diffTx['ens_gene'].isin(multiMapped)]
    data[model]['diffTx'] = diffTx
    
    # Get mean absolute logFC for each gene
    diffTx['abs(FC)'] = diffTx['b'].abs()
    absFC = diffTx.groupby('ens_gene')['abs(FC)'].max()
    diffGene = pd.merge(diffGene, absFC, left_on='target_id', right_index=True)
    
    # Load count data for all conditions
    counts = pd.read_csv(f'{model}/{model}-obsNormCounts.csv.gz')
    # Label transcript ID in count data with gene ID and gene symbol
    counts = pd.merge(
        counts, diffTx[['target_id', 'ens_gene', 'ext_gene']], 
        left_on='target_id', right_on='target_id')
    # Recover experimental group information encoded in sample name
    counts['patient'], counts['treatment'], counts['response'], counts['rep'] = (
        zip(*counts['sample'].apply(splitSampleName)))
    counts['group'] = counts['patient'].astype(str) + counts['rep'].astype(str)
    data[model]['counts'] = counts

    # Get top N DEG for plotting
    n = 10
    # Retrieve top DEGs ranked by logFC (max by transcript)
    DEGs = diffGene.loc[diffGene['DEG']].sort_values('qval', ascending=True).head(n)['target_id'].tolist()
    DEGlist.update(DEGs)
    
    df = pd.merge(genes, data[model]['diffTx'], left_on='name', right_on='target_id')
    
    # Add custom colour map of beta value (fold change)
    df['itemRgb'] = df['b'].apply(getColour)
    df['score'] = 100 * (-np.log10(df['qval'])).astype(int)
    df['score'] = df['score'].apply(lambda x: min(x, 1000))
    df.loc[df['chrom'].isin(validChroms), dtypes.keys()].to_csv(
        f'{model}-diffTx.bed12', header=False, index=False, sep='\t')

In [None]:
defaultPlotting()
for name in DEGlist:
    
    fig, axes = plt.subplots(1, 2, sharey=True)
    for i, model in enumerate(['response', 'treatment']):
        geneData = geneCounts.loc[geneCounts['ens_gene'] == name]   
        if geneData.empty:
            axes[i].axis('off')
            continue
        p_adj = float(data[model]['diffGene'].loc[data[model]['diffGene']['target_id'] == name, 'qval'])
        symbol = data[model]['diffGene'].loc[data[model]['diffGene']['target_id'] == name, 'ext_gene'].values[0]

        sns.violinplot(
            x=model, y='TMM', inner=None, color='White',  cut=0,
            linewidth=2, data=geneData, ax=axes[i])
        g = sns.stripplot(
            x=model, y='TMM', hue='patient', alpha=0.8, 
            data=geneData, ax=axes[i])
        if i == 1:
            c1 = g.collections[2].get_offsets()
            c2 = g.collections[3].get_offsets()
            for a, b, in zip(c1, c2):
                axes[i].plot(
                    [a.data[0], b.data[0]], [a.data[1], b.data[1]], 
                    color='grey', linewidth=0.5, ls='--', zorder=1)
        axes[i].set_xlabel('Condition')
        axes[i].set_title(f'{symbol}, $p_{{adj}}$ {formatP(p_adj)}', loc='left')
        axes[i].get_legend().remove()
    axes[0].set_ylabel('Trimmed Means of M values (TMM)')
    axes[1].set_ylabel('')
    
    handles, labels = axes[1].get_legend_handles_labels()
    lgd = fig.legend(handles, labels, loc='right', ncol=1, title='Patient', bbox_to_anchor=(1.10, 0.50))

    outName = symbol if isinstance(symbol, str) else name
    #fig.tight_layout()
    Path(f'plots/{outName}/transcripts/').mkdir(parents=True, exist_ok=True)
    fig.tight_layout()
    fig.savefig(f'plots/{outName}/{outName}-GeneLevel.svg',
                bbox_extra_artists=(lgd,), bbox_inches='tight')

    allTranscripts = data[model]['diffTx'].loc[data[model]['diffTx']['ens_gene'] == name, 'target_id'].unique()
    for _, transcript in enumerate(allTranscripts):
        fig, axes = plt.subplots(1, 2, sharey=True)
        for i, model in enumerate(['response', 'treatment']):
            countsTx = data[model]['counts'].loc[data[model]['counts']['target_id'] == transcript]
            if countsTx.empty:
                axes[i].axis('off')
                continue
            p_adj = float(data[model]['diffTx'].loc[data[model]['diffTx']['target_id'] == transcript, 'qval'])
            sns.violinplot(
                x=model, y='tpm', inner=None, color='White', cut=0, alpha=0.8,
                linewidth=2, data=countsTx, ax=axes[i])
            
            g = sns.stripplot(
                x=model, y='tpm', hue='patient', data=countsTx, ax=axes[i])
            if i == 1:
                c1 = g.collections[2].get_offsets()
                c2 = g.collections[3].get_offsets()
                for a, b, in zip(c1, c2):
                    axes[i].plot(
                        [a.data[0], b.data[0]], [a.data[1], b.data[1]], 
                        color='grey', linewidth=0.5, ls='--', zorder=1)
            axes[i].set_ylabel('Transcripts Per Million (TPM)')
            axes[i].set_xlabel('Condition')
            axes[i].set_title(f'{symbol} ({transcript}), $p_{{adj}}$ {formatP(p_adj)}', loc='left')
            axes[i].get_legend().remove()
        axes[0].set_ylabel('Transcripts Per Million (TPM)')
        axes[1].set_ylabel('')

        handles, labels = axes[1].get_legend_handles_labels()
        lgd = fig.legend(handles, labels, loc='right', ncol=1, title='Patient', bbox_to_anchor=(1.10, 0.50))
        fig.tight_layout()
        fig.savefig(f'plots/{outName}/transcripts/{transcript}.svg',
                    bbox_extra_artists=(lgd,), bbox_inches='tight')
