### Programming for Biomedical Informatics
#### Week 6 - Differential Gene Expression Analysis

We're going to perform some differential expression analysis using the PyDESeq2 package using an RNA-Seq dataset from NCBI-GEO

In [None]:
'''
Sources of Data

Original Publication
Tomaiuolo P, Piras IS, Sain SB, Picinelli C, Baccarin M, Castronovo P, Morelli MJ, Lazarevic D, Scattoni ML, Tonon G, Persico AM.
RNA sequencing of blood from sex- and age-matched discordant siblings supports immune and transcriptional dysregulation in autism spectrum disorder.
Sci Rep. 2023 Jan 16;13(1):807. doi: 10.1038/s41598-023-27378-w. PMID: 36646776; PMCID: PMC9842630.

GEO Entry: GSE212645
https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE212645

meta-data file
https://ftp.ncbi.nlm.nih.gov/geo/series/GSE212nnn/GSE212645/matrix/GSE212645_series_matrix.txt.gz

raw data file
https://www.ncbi.nlm.nih.gov/geo/download/?type=rnaseq_counts&acc=GSE212645&format=file&file=GSE212645_filtered_counts_GRCh38.p13_NCBI.tsv.gz

normalised data file
https://www.ncbi.nlm.nih.gov/geo/download/?type=rnaseq_counts&acc=GSE212645&format=file&file=GSE212645_norm_counts_FPKM_GRCh38.p13_NCBI.tsv.gz

genome annotation file
https://www.ncbi.nlm.nih.gov/geo/download/?format=file&type=rnaseq_counts&file=Human.GRCh38.p13.annot.tsv.gz
'''

In [None]:
# fetch the experimental data and meta-data

#setup
import urllib.request
import os
import pandas as pd
import matplotlib.pyplot as plt

# #fetch the count data
# counts_url = 'https://www.ncbi.nlm.nih.gov/geo/download/?type=rnaseq_counts&acc=GSE212645&format=file&file=GSE212645_filtered_counts_GRCh38.p13_NCBI.tsv.gz'
# urllib.request.urlretrieve(counts_url, './data/GSE212645_filtered_counts_GRCh38.p13_NCBI.tsv.gz')
raw_counts = pd.read_csv('./data/GSE212645_raw_counts_GRCh38.p13_NCBI.tsv.gz', sep='\t', index_col=0)

# #fetch the meta-data
# metadata_url = 'https://ftp.ncbi.nlm.nih.gov/geo/series/GSE212nnn/GSE212645/matrix/GSE212645_series_matrix.txt.gz'
# urllib.request.urlretrieve(metadata_url, './data/GSE212645_series_matrix.txt.gz')
#read in the meta-data
# we need to skip the first 38 rows as they contain project rather than sample meta-data
metadata = pd.read_csv('./data/GSE212645_series_matrix.txt.gz', sep='\t', skiprows=38, header=None)


# we now tidy this up and retain the information we need
# keep the rows we need
row_numbers = [0,8,10,11,12]
metadata = metadata.iloc[row_numbers]

# replace column 0 values with the list below
new_feature_names = ['number','gender', 'status', 'family', 'treatment']
metadata.iloc[:,0] = new_feature_names

# make the first row the column names and remove the first row
metadata.columns = metadata.iloc[0]
metadata = metadata.iloc[1:]
metadata.set_index('number', append=False, inplace=True)

# # transpose the data frame
metadata = metadata.T

# # reset the index and rename the first column
metadata.reset_index(inplace=True)
metadata.rename(columns={0: 'sample_no'}, inplace=True)

# tidy up the column contents
metadata['gender'] = metadata['gender'].str.replace('Sex: ', '')
metadata['status'] = metadata['status'].str.replace('genotype: ', '')
metadata['family'] = metadata['family'].str.replace('family: ', '')
metadata['treatment'] = metadata['treatment'].str.replace('treatment: ', '')

metadata.set_index('sample_no', inplace=True)
metadata.index.name = None

#change index name to sample_id
metadata.index.name = 'sample_id'

In [None]:
# Keep only genes where all samples have counts >= 10
filtered_counts = raw_counts[(raw_counts >= 10).all(axis=1)]

# keep the mean counts of the removed genes to check later
removed_counts = raw_counts[(raw_counts <= 10).all(axis=1)]

# calculate the mean expression for each gene
removed_means = removed_counts.mean(axis=1)
removed_means = removed_means.sort_values(ascending=False)

#keep this for later

In [None]:
# look at the count data
print(filtered_counts.shape)
filtered_counts.head()

In [None]:
# look at the meta-data
print(metadata.shape)
metadata.head()

In [None]:
# plot a boxplot of the raw counts by sample
plt.figure(figsize=(10,6))

filtered_counts.boxplot(rot=90)
plt.yscale('log')
plt.title('Raw Gene Counts by Sample')
plt.ylabel('log(Raw Counts)')
plt.show()

In [None]:
# plot an MA plot of the raw counts from scratch

import numpy as np

# find all the columns in filtered_counts that have column headers matching the metadata index and have staus of 'ASD'ArithmeticError
asd_samples = metadata[metadata['status'] == 'ASD'].index

# find all the columns in filtered_counts that have column headers matching the metadata index and have staus of 'SIB'
sib_samples = metadata[metadata['status'] == 'SIB'].index

# calculate the mean of the raw counts for each gene in the ASD samples
asd_mean = filtered_counts[asd_samples].mean(axis=1)

# calculate the mean of the raw counts for each gene in the SIB samples
sib_mean = filtered_counts[sib_samples].mean(axis=1)

# create a new data frame with these mean values and the gene names as the index
ma_data = pd.DataFrame({'ASD': asd_mean, 'SIB': sib_mean})

# calculate the M value with log2
ma_data['M'] = ma_data['ASD'].apply(np.log2) - ma_data['SIB'].apply(np.log2)

# calculate the A value with log2
ma_data['A'] = (ma_data['ASD'].apply(np.log2) + ma_data['SIB'].apply(np.log2)) / 2

# plot the data
import matplotlib.pyplot as plt
import numpy as np

plt.scatter(ma_data['A'], ma_data['M'], s=1)
plt.xlabel('A')
plt.ylabel('M')
plt.title('MA Plot of Raw Counts')
plt.show()

In [None]:
# Simple PCA analysis

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

transposed_filtered_counts = filtered_counts.T

# scale features, run PCA on 2-dimensions
X = transposed_filtered_counts.values
X_scaled = StandardScaler().fit_transform(X)
pca = PCA(n_components=2)
pcs = pca.fit_transform(X_scaled)
sample_pca = pd.DataFrame(pcs, columns=['PC1', 'PC2'], index=transposed_filtered_counts.index)

sample_pca = sample_pca.join(metadata['status'])

#visualise
sample_pca.index.name = 'sample_id'

print(sample_pca)

In [None]:
# plot the pca
from matplotlib.patches import Patch

# map statuses to distinct colors
statuses = list(sample_pca['status'].unique())
cmap = plt.get_cmap('Set1')
color_map = {s: cmap(i % cmap.N) for i, s in enumerate(statuses)}
colors = sample_pca['status'].map(color_map)

# create figure and axes, scatter and attach colorbar to that axes
fig, ax = plt.subplots(figsize=(7,5))
sc = ax.scatter(sample_pca['PC1'], sample_pca['PC2'], s=50, edgecolor='k',c=list(colors))

# simple legend with colored patches
legend_elements = [Patch(facecolor=col, edgecolor='k', label=label) 
                   for label, col in color_map.items()]
ax.legend(handles=legend_elements, title='status')

# labels and colorbar
ax.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)")
ax.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)")
ax.set_title('PCA')
plt.tight_layout()
plt.show()

In [None]:
# remove the outlier samples
samples_to_remove = ['GSM6542253', 'GSM6542277']

filtered_counts_good_samples = filtered_counts.drop(columns=samples_to_remove, errors='ignore')
print(filtered_counts_good_samples.shape)

metadata_good_samples = metadata.drop(index=samples_to_remove, errors='ignore')
print(metadata_good_samples.shape)

In [None]:
# now lets use DESeq2 to perform differential expression

import os
import pickle as pkl

from pydeseq2.dds import DeseqDataSet
from pydeseq2.default_inference import DefaultInference
from pydeseq2.ds import DeseqStats
from pydeseq2.utils import *

SAVE = False  # whether to save the outputs of this notebook

if SAVE:
    # Replace this with the path to directory where you would like results to be
    # saved
    OUTPUT_PATH = "./data/asd_deseq2_results"
    os.makedirs(OUTPUT_PATH, exist_ok=True)  # Create path if it doesn't exist

In [None]:
## set up the DESeq2 object
## Note we are setting up a two-factor (paired analysis) where disease status is assessed in sibling pairs affected:unaffected

inference = DefaultInference(n_cpus=8)
dds = DeseqDataSet(
    counts=filtered_counts_good_samples.T,
    metadata=metadata_good_samples,
    design="~status+family",
    refit_cooks=True,
    inference=inference,
)

#ignore the error, this is just because the gene_ids are numbers

print(dds)

In [None]:
# compute normalisation factors
dds.fit_size_factors()

# extract the size factors to look at
size_factors = dds.obs["size_factors"]

# sort by value descending
size_factors.sort_values()

size_factors.columns = ['sizefactor']

In [None]:
# fit gene-wise dispersion estimates
dds.fit_genewise_dispersions()

# fit dispersion priors
dds.fit_dispersion_prior()

In [None]:
# have a look at the DESeq2 object after the fitting
print(dds)

In [None]:
# plot the fitted dispersion trend
plt.figure(figsize=(10,6))
plt.scatter(dds.var['genewise_dispersions'], dds.var['fitted_dispersions'])
# plt.scatter(dds.var['genewise_dispersions'], dds.var['fitted_dispersions'])
plt.xlabel('Gene-wise Dispersion')
plt.ylabel('Fitted Dispersion')
plt.title('Fitted Dispersion vs Gene-wise Dispersion')
plt.show()

# OK this looks very problematic, remember fitted vs gene dispersion should be a linear relationship - this is a red flag

In [None]:
# fit log fold changes
dds.fit_LFC()
lfcs = dds.varm["LFC"]

In [None]:
# Improved plot for log fold changes using seaborn style
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

# flatten lfcs if it's an array-like
lfcs_arr = np.array(lfcs).ravel()

# compute robust x-limits (1st to 99th percentiles) to avoid extreme outliers
vmin = np.nanpercentile(lfcs_arr, 1)
vmax = np.nanpercentile(lfcs_arr, 99)

plt.figure(figsize=(10,6))
sns.set_theme(style='whitegrid')

# plot histogram with KDE overlay, clipped to the percentile range for display
sns.histplot(lfcs_arr, bins=200, kde=True, stat='density', color='C0')
plt.xlim(vmin, vmax)

# annotate mean and median
mean_val = np.nanmean(lfcs_arr)
median_val = np.nanmedian(lfcs_arr)
plt.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'Mean = {mean_val:.3f}')
plt.axvline(median_val, color='green', linestyle='-.', linewidth=2, label=f'Median = {median_val:.3f}')

plt.xlabel('Log Fold Change')
plt.ylabel('Density')
plt.title('Distribution of Log Fold Changes (clipped 1st-99th percentiles)')
plt.legend()

# add a small text box with summary stats
props = dict(boxstyle='round', facecolor='white', alpha=0.8)
plt.text(0.95, 0.95, f'n={lfcs_arr.size}\nmean={mean_val:.3f}\nmedian={median_val:.3f}', transform=plt.gca().transAxes, fontsize=10, verticalalignment='top', horizontalalignment='right', bbox=props)

plt.tight_layout()
plt.show()

In [None]:
# calculate cooks distanes and refit

# this step aims to identify outliers that would adversely affect the differential expresison analysis and filters them out
dds.calculate_cooks()
if dds.refit_cooks:
    # Replace outlier counts
    dds.refit()

In [None]:
#let's look at the dispersion plot
dds.plot_dispersions()

# this looks OK, but is very flat, an unusually low amount of dispersion especially at the high count end

In [None]:
# save the results so far
if SAVE:
    with open(os.path.join(OUTPUT_PATH, "dds_detailed_pipe.pkl"), "wb") as f:
        pkl.dump(dds, f)

In [None]:
# this is the main step where the differential expression is calculated
stat_res = DeseqStats(dds, contrast=['status','ASD','SIB'], inference=inference)

In [None]:
# create the summary stats
# runs the Wald test
# this test effectively calculates the robustness of the beta value estaimation in the main DESeq2 GLM and then calculates the p-values based on the aussume
# assumed normal distribution of the beta values

stat_res.summary()

In [None]:
results = stat_res.results_df

sorted_results = results.sort_values(by='pvalue', ascending=True)

#convert the GeneID column to integers
sorted_results.index = sorted_results.index.map(int)
sorted_results.reset_index(inplace=True)

sorted_results.head()

# it's clear here that the p-value adjustment is very heavily penalising the p-values

In [None]:
# optionally save the results
# if SAVE:
#     with open(os.path.join(OUTPUT_PATH, "stat_results_detailed_pipe.pkl"), "wb") as f:
#         pkl.dump(stat_res, f)

In [None]:
# load up the genome annotation file so that we can look at the gene names

annotation_url = 'https://www.ncbi.nlm.nih.gov/geo/download/?format=file&type=rnaseq_counts&file=Human.GRCh38.p13.annot.tsv.gz'

# download the file and save in the ./data directory
import urllib.request
import os

urllib.request.urlretrieve(annotation_url, './data/Human.GRCh38.p13.annot.tsv.gz')

#read directly into a data frame
annotation = pd.read_csv('./data/Human.GRCh38.p13.annot.tsv.gz', sep='\t', index_col=0, low_memory=False)

#drop all columns except Symbol and Description
annotation = annotation[['Symbol', 'Description']]

annotation.head()

In [None]:
# merge the annotation with the results on the GeneID column
results = sorted_results.merge(annotation, left_on='GeneID', right_on='GeneID')

results.head(20)

In [None]:
# now lets go back and pick up the results for the genes cited in the paper to see what the p-values are
target_genes = ['EGR1', 'EGR2', 'IGKV6D-21', 'IGKV3D-15', 'S100B', 'CD80']

target_gene_results = results[results['Symbol'].isin(target_genes)]

target_gene_results

In [None]:
#lets go back and find the log normlised counts for these genes
# merge the ma_data with the target_gene_results on the GeneID column

combined = target_gene_results.merge(ma_data, left_on='GeneID', right_index=True)

combined.head(10)

In [None]:
# we can only display data from genes that made it through our earlier QC
surviving_target_genes = set(target_genes) & set(combined['Symbol'])
missing_genes = set(target_genes) - set(combined['Symbol'])

In [None]:
# out of interest here are the mean counts for the target genes that were excluded earlier!

# Filter annotation to find genes in missing_genes
missing_gene_info = annotation[annotation['Symbol'].isin(list(missing_genes))]
missing_gene_counts = pd.merge(missing_gene_info,raw_counts,on='GeneID')

# Calculate mean of columns from the 3rd column onwards (columns 2+, 0-indexed)
missing_gene_counts['Mean'] = missing_gene_counts.iloc[:, 2:].mean(axis=1)

# Display the results
print(missing_gene_counts[['Symbol', 'Mean']])

In [None]:
# next we plot paired barcharts for each Symbol showing the log normalised counts for each gene in the ASD and SIB samples using matplotlib
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, len(surviving_target_genes), figsize=(20, 6))

for i, gene in enumerate(surviving_target_genes):
    #take the log2 values of the counts
    combined.loc[combined['Symbol'] == gene, ['ASD', 'SIB']] = combined.loc[combined['Symbol'] == gene, ['ASD', 'SIB']].apply(np.log2)
    ax[i].bar(['ASD', 'SIB'], combined.loc[combined['Symbol'] == gene, ['ASD', 'SIB']].values[0])
    ax[i].set_title(gene)
    ax[i].set_ylabel('log normalised counts')
    ax[i].set_xlabel('Status')

plt.show()

# the trends are the same as in the paper, but the p-values are not significant - NB in the paper only EGR1 and IGKV3D-15 were significant after correction.