In [1]:
import math
import numpy as np
import pandas as pd
import seaborn as sns
from pathlib import Path
import matplotlib.pyplot as plt
from collections import defaultdict
import matplotlib as mpl
import matplotlib.cm as cm

In [2]:
def splitSampleName(x):
    """ Extract patient number, treatment, response and 
        replicate from sample name. """
    responsivePatients = [5, 126,  62]
    x = x.split('_')
    patient = int(x[1])
    treatment = x[2].split('-')[0]
    try:
        rep = x[2].split('-')[1]
    except IndexError:
        rep = '1'
    if patient in responsivePatients:
        response = 'responsive'
    else:
        response = 'resistant'
    return patient, treatment, response, rep


def defaultPlotting(size=9, width=180, ratio=0.5):
    """ Preconfigured figure settings """
    mm = 1 / 25.4 # mm in an inch
    colour = '#444444'
    sns.set(rc={
        'font.size': size, 'axes.titlesize': size, 'axes.labelsize': size,
        'xtick.labelsize': size, 'ytick.labelsize': size,
        'legend.fontsize': size, 'legend.title_fontsize': size,
        'font.family': 'sans-serif', 'lines.linewidth': 1.5,
        'axes.labelcolor': colour, 'xtick.color': colour,
        'ytick.color': colour,
        'figure.figsize': (width * mm, width * ratio * mm),
        'axes.spines.top': False, 'axes.spines.right': False,
    }, style='white')


def formatP(p):
    """ Return formatted p for title """
    if p > 0.999:
        pformat = '> 0.999'
    elif p < 0.001:
        pformat = '< .001'
    else:
        pformat = '= ' + f'{p:.3f}'[1:]
    if p < 0.001:
        return pformat + ' ***'
    elif p < 0.01:
        return pformat + ' **'
    elif p < 0.05:
        return pformat + ' *'
    else:
        return pformat
    
def getColour(x):
    if x < 0:
        return '0,0,255'
    else:
        return '255,0,0'

In [3]:
# Set alpha for significance threshold of q-values
alpha = 0.01

In [4]:
dtypes = ({
    'chrom': str, 'chromStart': int, 'chromEnd': int, 'name': str, 'score': float,
    'strand': str, 'thickStart': str, 'thickEnd': str, 'itemRgb': str, 'blockCount': int,
    'hlockSizes': str, 'blockStarts': str
})
genes = pd.read_csv('../analysis/annotation/GRCh38.bed12', dtype=dtypes, names=dtypes.keys(), sep='\t')

validChroms = [str(i) for i in range(1,23)] + ['X', 'M']

genes = genes.loc[genes['chrom'].isin(validChroms)]

In [5]:
# Read edgeR gene level abundances
geneCounts = (
    pd.read_csv('../analysis/featureCounts/allCounts-TMM.tsv.gz', sep='\t')
    .reset_index()
    .melt(id_vars='index', value_name='TMM', var_name='sample')
    .rename({'index': 'ens_gene'}, axis=1))
geneCounts['ens_gene'] = geneCounts['ens_gene'].apply(lambda x: x.split('.')[0])

geneCounts['Patient'], geneCounts['Treatment'], geneCounts['Response'], geneCounts['rep'] = (
    zip(*geneCounts['sample'].apply(splitSampleName)))
geneCounts['group'] = geneCounts['Patient'].astype(str) + geneCounts['rep'].astype(str)

In [11]:
data = defaultdict(dict)
with pd.ExcelWriter('allDE-GeneResults.xlsx') as writer:  
    for model in ['Response', 'Treatment', 5, 15, 17, 49, 62, 69, 126]:
        data[model]['genesToPlot'] = set(['ENSG00000141526'])

        if isinstance(model, int):
            prefix = f'patient/{model}/{model}'
        else:
            prefix = f'{model.lower()}/{model.lower()}'
        data[model]['prefix'] = prefix

        # Load gene-level data
        diffGene = pd.read_csv(f'{prefix}-gene.csv')
        diffGene[f'DE'] = diffGene['qval'] < alpha

        # Retrieve gene IDs associated with multiple gene symbols (ambiguous)
        multiMapped = diffGene.groupby('ext_gene')['target_id'].nunique() > 1
        multiMapped = multiMapped[multiMapped].index

        # Remove ambiguous gene IDs
        diffGene = diffGene.loc[~diffGene['target_id'].isin(multiMapped)]

        # Load transcript-level data
        diffTx = pd.read_csv(f'{prefix}-Tx.csv')
        diffTx[f'DE'] = diffTx['qval'] < alpha

        # Remove ambiguous gene IDs
        diffTx = diffTx.loc[~diffTx['ens_gene'].isin(multiMapped)]
        data[model]['diffTx'] = diffTx

        # Get mean absolute logFC for each gene
        medianBeta = diffTx.groupby('ens_gene')['b'].median()
        data[model]['diffGene'] = pd.merge(diffGene, medianBeta, left_on='target_id', right_index=True)
        data[model]['diffGene']['abs(b)'] = data[model]['diffGene']['b'].abs()
        data[model]['diffGene'].to_excel(writer, sheet_name=f'{model}', index=False)

        # Load count data for all conditions
        counts = pd.read_csv(f'{prefix}-obsNormCounts.csv.gz')
        # Label transcript ID in count data with gene ID and gene symbol
        counts = pd.merge(
            counts, diffTx[['target_id', 'ens_gene', 'ext_gene']], 
            left_on='target_id', right_on='target_id')
        # Recover experimental group information encoded in sample name
        counts['Patient'], counts['Treatment'], counts['Response'], counts['rep'] = (
            zip(*counts['sample'].apply(splitSampleName)))
        counts['group'] = counts['Patient'].astype(str) + counts['rep'].astype(str)
        data[model]['counts'] = counts

        # Get top N DEG for plotting
        n = 100
        # Retrieve top DEGs ranked
        DEGs = diffGene.loc[diffGene['DE']].sort_values('qval', ascending=True).head(n)['target_id'].tolist()

        df = pd.merge(genes, data[model]['diffTx'], left_on='name', right_on='target_id')

        # Add custom colour map of beta value (fold change)
        df['itemRgb'] = df['b'].apply(getColour)
        df['score'] = 100 * (-np.log10(df['qval'])).astype(int)
        df['score'] = df['score'].apply(lambda x: min(x, 1000))
        df.loc[df['chrom'].isin(validChroms), dtypes.keys()].to_csv(
            f'{prefix}-diffTx.bed12', header=False, index=False, sep='\t')

        data[model]['genesToPlot'].update(DEGs)

In [8]:
defaultPlotting(width=90, ratio=1)
order = ['CTRL', 'BEVA']
for pt in [5, 15, 17, 49, 62, 69, 126]:
    print(f'Processing PDX{pt}')
    for gene in data[pt]['genesToPlot']:
        pathPrefix = data[pt]['prefix'].removesuffix(f'/{pt}')
        geneData = geneCounts.loc[(geneCounts['ens_gene'] == gene) & (geneCounts['Patient'] == pt)]
        p_adj = float(data[pt]['diffGene'].loc[data[pt]['diffGene']['target_id'] == gene, 'qval'])
        symbol = data[pt]['diffGene'].loc[data[pt]['diffGene']['target_id'] == gene, 'ext_gene'].values[0]
        outName = symbol if isinstance(symbol, str) else gene
        Path(f'{pathPrefix}/plots/{outName}/transcripts/').mkdir(parents=True, exist_ok=True)

        fig, ax = plt.subplots()
        sns.violinplot(
            x='Treatment', y='TMM', inner=None, color='White', cut=0,
            linewidth=2, order=order, data=geneData, ax=ax)
        sns.stripplot(
            x='Treatment', y='TMM', alpha=0.75, 
            order=order, data=geneData, ax=ax)
        ax.set_title(f'PDX{pt} - {outName}, $p_{{adj}}$ {formatP(p_adj)}', loc='left')
        ax.set_ylabel('Trimmed Means of M values (TMM)')
        fig.tight_layout()
        fig.savefig(f'{pathPrefix}/plots/{outName}/PDX{pt}-{outName}-GeneLevel.svg')
        plt.close()
        
        allTranscripts = data[pt]['diffTx'].loc[data[pt]['diffTx']['ens_gene'] == gene, 'target_id'].unique()
        for _, transcript in enumerate(allTranscripts):
            countsTx = data[pt]['counts'].loc[data[pt]['counts']['target_id'] == transcript]
            p_adj = float(data[pt]['diffTx'].loc[data[pt]['diffTx']['target_id'] == transcript, 'qval'])
            fig, ax = plt.subplots()
            sns.violinplot(
                x='Treatment', y='tpm', inner=None, color='White', cut=0,
                linewidth=2, order=order, data=countsTx, ax=ax)
            sns.stripplot(
                x='Treatment', y='tpm', alpha=0.75, 
                order=order, data=countsTx, ax=ax)
            ax.set_ylabel('Transcripts Per Million (TPM)')
            ax.set_title(f'PDX{pt} - {outName} ({transcript})\n$p_{{adj}}$ {formatP(p_adj)}', loc='left')
            fig.tight_layout()
            fig.savefig(f'{pathPrefix}/plots/{outName}/transcripts/{transcript}.svg')
            plt.close()

Processing PDX5
Processing PDX15
Processing PDX17
Processing PDX49
Processing PDX62
Processing PDX69
Processing PDX126


In [9]:
defaultPlotting(width=180, ratio=0.5)
order = [49, 17, 69, 15, 5, 126, 62]
info = data['Response']
for gene in info['genesToPlot']:
    geneData = geneCounts.loc[(geneCounts['ens_gene'] == gene)]
    p_adj = float(info['diffGene'].loc[info['diffGene']['target_id'] == gene, 'qval'])
    symbol = info['diffGene'].loc[info['diffGene']['target_id'] == gene, 'ext_gene'].values[0]
    outName = symbol if isinstance(symbol, str) else gene
    Path(f'response/plots/{outName}/transcripts/').mkdir(parents=True, exist_ok=True)
    
    fig, ax = plt.subplots()
    sns.violinplot(
        x='Patient', y='TMM', inner=None, color='White', cut=0,
        linewidth=2, data=geneData, order=order, ax=ax)
    sns.stripplot(
        x='Patient', y='TMM', hue='Treatment', alpha=0.75, order=order,
        data=geneData, ax=ax)
    ax.set_title(f'{outName}, $p_{{adj}}$ {formatP(p_adj)}', loc='left')
    ax.set_ylabel('Trimmed Means of M values (TMM)')
    fig.tight_layout()
    fig.savefig(f'response/plots/{outName}/Response-{outName}-GeneLevel.svg')
    plt.close()
    
    allTranscripts = info['diffTx'].loc[info['diffTx']['ens_gene'] == gene, 'target_id'].unique()
    for _, transcript in enumerate(allTranscripts):
        countsTx = info['counts'].loc[info['counts']['target_id'] == transcript]
        p_adj = float(info['diffTx'].loc[info['diffTx']['target_id'] == transcript, 'qval'])
        fig, ax = plt.subplots()
        sns.violinplot(
            x='Patient', y='tpm', inner=None, color='White', cut=0,
            linewidth=2, order=order, data=countsTx, ax=ax)
        sns.stripplot(
            x='Patient', y='tpm', hue='Treatment', alpha=0.75, 
            order=order, data=countsTx, ax=ax)
        ax.set_ylabel('Transcripts Per Million (TPM)')
        ax.set_title(f'{outName}, $p_{{adj}}$ {formatP(p_adj)}', loc='left')
        fig.tight_layout()
        fig.savefig(f'response/plots/{outName}/transcripts/{transcript}.svg')
        plt.close()