In [1]:
from IPython.core.display import display, HTML
display(HTML('<style>.container { width:100% !important; }</style>'))

In [107]:
import os
import glob
import gzip
import gffutils
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import pysam 

from collections import defaultdict
from plotly.subplots import make_subplots


ref_fasta_file = '../rna_seq_alignment/S288C_reference/S288C_reference_sequence_chr_renamed.fsa'
ref_fasta = pysam.FastaFile(ref_fasta_file)


### Summarize differential expression 
__Using the provided sample-level counts (RNAseq_counts) and sample sheet (calico_data_challenge_samplesheet.tsv) answer the following questions:__
* Which genes are differentially expressed based on,
    * Age
    * Genotype
    * genotype-specific aging.
* What do these genes have in common?
* What is the largest factor affecting variability in this dataset?
* If you had more time, what aspect of this dataset would you explore to further address the motivating question or caveats that may exist.


In [11]:
# Load in Annotation DataBase and Gene size dataframe
annot_db = gffutils.create_db('../rna_seq_alignment/Saccharomyces_cerevisiae.R64-1-1.99.gtf', 
                              ':memory:', 
                              disable_infer_genes=True, 
                              disable_infer_transcripts=True, 
                              merge_strategy='merge')


In [12]:
def read_rnaseq_counts(infile):
    col_name = infile.split('/')[-1][:-len('.counts')]
    df = pd.read_csv(infile, sep='\t', header=None).rename(columns={0: 'gene', 1: col_name})
    # Drop non-applicable rows
    labels_to_drop = ['__no_feature', '__ambiguous', '__too_low_aQual', 
                      '__not_aligned', '__alignment_not_unique']
    return df.set_index('gene').drop(labels_to_drop, axis=0)



### Read RNA-seq counts data

In [66]:
rnaseq_counts_files = sorted(glob.glob('/Users/jganbat/Documents/jupyter_notebooks/rna_seq_analysis/rna_counts/*'))

combined_counts = pd.concat([read_rnaseq_counts(rnaseq_count) for rnaseq_count in rnaseq_counts_files], axis=1).sort_index()

### Collect Sample grouping information

In [18]:
strain_group = defaultdict(list)
age_group = defaultdict(list)
for sample in combined_counts.columns:
    strain_group[sample.split('_')[0]].append(sample)
    age_group[sample.split('_')[1]].append(sample)


age_ordered_cols = []
for k, v in age_group.items():
    age_ordered_cols.extend(v)
    
strain_ordered_cols = []
for k, v in strain_group.items():
    strain_ordered_cols.extend(v)

In [19]:
gene_size_list = []
for idx, gene in enumerate(combined_counts.index.tolist()):
    for feat_idx, feature in enumerate(annot_db.parents(gene)):
        gene_size = feature.stop - feature.start
        gene_size_list.append([gene, gene_size])
        
gene_size_df = pd.DataFrame.from_records(gene_size_list, columns=['gene', 'gene_size']).set_index('gene').sort_index()

In [15]:
# Normalize based on gene size
combined_counts = combined_counts.apply(lambda x: x/gene_size_df.loc[x.name, 'gene_size'], axis=1)
# Normalize counts by total sum
norm_df = combined_counts.apply(lambda x: x/x.sum(), axis=0)

In [149]:
def plot_heatmap(df):
    norm_df = df.apply(lambda x: x/max(x), axis=1)
    return go.Heatmap(x=norm_df.columns, y=norm_df.index, z=norm_df)


def plot_deg_gene_matrix(combined_counts_df, col_order, deg_file, fold_change_threshold=3):
    # Diff Exp Genes from edgeR
    deg_df = pd.read_csv(deg_file, sep='\t').set_index('genes')
    up_deg_df = deg_df[deg_df['logFC'] >= fold_change_threshold].sort_values(by='logFC', ascending=False)
    down_deg_df = deg_df[deg_df['logFC'] <= -fold_change_threshold].sort_values(by='logFC')

    grp1_name, grp2_name = deg_file[:-len('.tsv')].split('_vs_')
    fig = make_subplots(rows=1, cols=2, vertical_spacing=10, 
                        subplot_titles=('Hyper expressed in {}'.format(grp1_name), 
                                        'Hyper expressed in {}'.format(grp2_name)))

    fig.add_trace(plot_heatmap(combined_counts.loc[down_deg_df.index, col_order]), row=1, col=1)
    fig.add_trace(plot_heatmap(combined_counts.loc[up_deg_df.index, col_order]), row=1, col=2)
    fig.update_layout(height=600, width=1200)
    fig.show()


In [150]:
plot_deg_gene_matrix(combined_counts, age_ordered_cols, 'age_0_vs_age_20.tsv', fold_change_threshold=2)

In [151]:
plot_deg_gene_matrix(combined_counts, age_ordered_cols, 'age_0_vs_age_40.tsv', fold_change_threshold=2)

In [152]:
plot_deg_gene_matrix(combined_counts, age_ordered_cols, 'age_20_vs_age_40.tsv', fold_change_threshold=2)