### Import

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pydeseq2.dds import DeseqDataSet
from pydeseq2.ds import DeseqStats

In [None]:
counts_raw = pd.read_csv('counts.csv', index_col=0)
conditions = pd.read_csv('conditions.csv', index_col=0)
mapping = pd.read_csv('tx2gene.tsv', sep='\t')

### Exploratory Data Analysis

In [None]:
# Info about counts
print("Shape of dataset:", counts_raw.shape)
print("\nColumn names:", counts_raw.columns.tolist())
print("\nData types:\n", counts_raw.dtypes)
print("\nMissing values:\n", counts_raw.isnull().sum())

In [None]:
# Info about mapping
print("Shape of dataset:", mapping.shape)
print("\nNumber of unique genes in mapping:", mapping['gene_id'].nunique())
print("\nNumber of unique genes in counts:", counts_raw.index.nunique())


In [None]:
# Basic statistics for counts
print("\nSummary statistics:\n", counts_raw.describe())

In [None]:
# Number of genes to consider/filter
print("\nNumber of genes in the dataset(x4):", counts_raw.shape[0]*4)
print("Number of genes with non-zero counts:", (counts_raw > 0).sum().sum())
print("Number of genes with zero counts:", (counts_raw == 0).sum().sum())

In [None]:
# Number of zero counts per sample
print("\nNumber of zeroes sample 1:", (counts_raw.iloc[:, 0] == 0).sum(), "=", f"{(counts_raw.iloc[:, 0] == 0).sum()/33602*100:.2f}%",)
print("\nNumber of zeroes sample 2:", (counts_raw.iloc[:, 1] == 0).sum(), "=", f"{(counts_raw.iloc[:, 1] == 0).sum()/33602*100:.2f}%",)
print("\nNumber of zeroes sample 3:", (counts_raw.iloc[:, 2] == 0).sum(), "=", f"{(counts_raw.iloc[:, 2] == 0).sum()/33602*100:.2f}%",)
print("\nNumber of zeroes sample 4:", (counts_raw.iloc[:, 3] == 0).sum(), "=", f"{(counts_raw.iloc[:, 3] == 0).sum()/33602*100:.2f}%",)

In [None]:
# Number of counts < 1 per sample
print("\nNumber of counts < 1 sample 1:", (counts_raw.iloc[:, 0] < 1).sum(), "=", f"{(counts_raw.iloc[:, 0] < 1).sum()/33602*100:.2f}%",)
print("\nNumber of counts < 1 sample 2:", (counts_raw.iloc[:, 1] < 1).sum(), "=", f"{(counts_raw.iloc[:, 1] < 1).sum()/33602*100:.2f}%",)
print("\nNumber of counts < 1 sample 3:", (counts_raw.iloc[:, 2] < 1).sum(), "=", f"{(counts_raw.iloc[:, 2] < 1).sum()/33602*100:.2f}%",)
print("\nNumber of counts < 1 sample 4:", (counts_raw.iloc[:, 3] < 1).sum(), "=", f"{(counts_raw.iloc[:, 3] < 1).sum()/33602*100:.2f}%",)

In [None]:
# Count out genes with low counts across all samples (keep genes with count > 1 in at least 2 samples).
print("Total number of genes:", counts_raw.shape[0])
print("Number of genes with count > 1 in at least 2 samples:", ((counts_raw > 1).sum(axis=1) >= 2).sum())
print("Number of genes with count > 1 in at least 3 samples:", ((counts_raw > 1).sum(axis=1) >= 3).sum())


In [None]:
# Barplot of counts per sample
sample_sums = counts_raw.sum(axis=0)
colors = ['red' if s in ['Sample3', 'Sample4'] else 'C0' for s in sample_sums.index]
sample_sums.plot(kind='bar', color=colors)
plt.ylabel('Total Counts')
plt.title('Total Counts per Sample')
plt.show()


In [None]:
# Histogram of log-transformed counts per sample (log2(CPM + 1)).
library_sizes = counts_raw.sum(axis=0)
cpm = counts_raw.div(library_sizes, axis=1) * 1e6
log_cpm = np.log2(cpm + 1)

# Use high-contrast colors for each sample
contrast_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']  # blue, orange, green, red

plt.figure(figsize=(8, 5))
for col, color in zip(log_cpm.columns, contrast_colors):
    sns.histplot(log_cpm[col], bins=50, kde=False, label=col, color=color, alpha=0.5)
plt.xlabel('log2(CPM + 1)')
plt.ylabel('Gene count')
plt.title('Histogram of log2(CPM + 1) per sample')
plt.legend()
plt.show()

### Data Prep for DE

In [None]:
# Filter out rows with zero counts
counts = counts_raw[(counts_raw > 1).sum(axis=1) >= 2].astype(int)
counts

In [None]:
# Verify number of filtered rows
print(counts.shape)
print(counts_raw.shape)

In [None]:
# Transpose the counts DataFrame to have genes as rows and samples as columns
counts_t = counts.transpose()
print(counts_t.shape)
counts_t.head()

In [None]:
# Format the conditions DataFrame to match the counts DataFrame by sample names
metadata = pd.DataFrame(zip(counts_t.index, ['KCL', 'KCL', 'KNO3', 'KNO3']), columns=['sample', 'condition'])
metadata.set_index('sample', inplace=True)
metadata

In [None]:
# Create DeseqDataSet object with counts and metadata
dds = DeseqDataSet(
    counts=counts_t, 
    metadata=metadata,
    design_factors="condition")


In [None]:
# Preliminary PCA
import scanpy as sc
sc.tl.pca(dds)

# Get explained variance ratio for PC1 and PC2
explained_var = dds.uns['pca']['variance_ratio'] * 100
pc1_var = explained_var[0]
pc2_var = explained_var[1]

# Create the PCA plot
sc.pl.pca(dds, color='condition', size=200, show=False)

# Add sample names as annotations
pca_coords = dds.obsm['X_pca']
for i, sample in enumerate(dds.obs_names):
    plt.annotate(sample, (pca_coords[i, 0], pca_coords[i, 1]), 
                xytext=(5, 5), textcoords='offset points', fontsize=10, alpha=0.8)

plt.xlabel(f"PC1 ({pc1_var:.1f}%)")
plt.ylabel(f"PC2 ({pc2_var:.1f}%)")
plt.show()

### DeSeq 

In [None]:
# Run DESeq2 analysis
dds.deseq2()
dds

In [None]:
stats = DeseqStats(dds, contrast=['condition', 'KCL', 'KNO3'], n_cpus=8)
stats.summary()

In [None]:
# Extract results DataFrame
res = stats.results_df
res

In [None]:
# Filter results based on baseMean
res = res[res['baseMean'] >= 10]
res

In [None]:
# Add a column for gene names
# First drop the 'transcript_id' column temporarily
mapping_df_no_transcript = mapping.drop(columns=['transcript_id'])

# Drop duplicates based on both 'gene_id' and 'gene_name'
mapping_df_no_duplicates = mapping_df_no_transcript.drop_duplicates(subset=['gene_id', 'gene_name'])

In [None]:
# Then add the column by mmatching gene_id and gene_name
res['Symbol'] = res.index.map(
   mapping_df_no_duplicates.set_index('gene_id')['gene_name']
).fillna('Unknown')
res

In [None]:
#  Write results to CSV file to port to R
#res.to_csv('results.csv', index=True)

In [None]:
# Filter significant results based on adjusted p-value and log2 fold change
sigs = res[(res.padj < 0.05) & (abs(res.log2FoldChange) > 1)]
sigs_symbols = dict(zip(sigs.index, sigs['Symbol']))
sigs.shape

### PCA

In [None]:
sc.tl.pca(dds, layer='normed_counts')

In [None]:
# Get explained variance ratio for PC1 and PC2
explained_var = dds.uns['pca']['variance_ratio'] * 100
pc1_var = explained_var[0]
pc2_var = explained_var[1]

sc.pl.pca(dds, color='condition', size=200, show=False)
plt.xlabel(f"PC1 ({pc1_var:.1f}%)")
plt.ylabel(f"PC2 ({pc2_var:.1f}%)")
plt.show()

### Heatmap

In [None]:
# Add log1p transformed counts to layers
dds.layers['log1p'] = np.log1p(dds.layers['normed_counts'])
dds

In [None]:
# Subset the DeseqDataSet to include only significant genes
dds_sigs = dds[:,  sigs.index]
dds_sigs

In [None]:
grapher = pd.DataFrame(dds_sigs.layers['log1p'].T,
    index=dds_sigs.var_names, columns=dds_sigs.obs_names)
grapher

In [None]:
col_colors = pd.Series(['KCL', 'KCL', 'KNO3', 'KNO3'], index=grapher.columns)
lut = {'KCL': "#D5D837", 'KNO3': "#8028D8"}
col_colors = col_colors.map(lut)

sns.clustermap(grapher,
               z_score=0,
               cmap='RdYlBu_r',
               col_colors=col_colors)

# Add a legend for the conditions
for label in lut:
    plt.scatter([], [], color=lut[label], label=label)
plt.legend(title='Condition', bbox_to_anchor=(1, 1), bbox_transform=plt.gcf().transFigure)

# Add a label "condition" on the right, above the gene symbols
g = plt.gcf().axes[-1]  # colorbar axis
g.annotate('Condition', xy=(18.1, -0.125), xycoords='axes fraction', ha='center', va='bottom', fontsize=12, fontweight='bold', rotation=0)

### Volcano


In [None]:
# Select most significant genes - among the downregulated, pick 10 with lowest padj, and among those, 
# 5 with lowest log2FoldChange. Vice versa for upregulated genes.

# Select genes with padj < 0.05 and log2FC < 1 (downregulated or low upregulation)
down = res[(res['padj'] < 0.05) & (res['log2FoldChange'] < 1)]
down_10 = down.nsmallest(10, 'padj')
down_5 = down_10.nsmallest(5, 'log2FoldChange')

# Select genes with padj < 0.05 and log2FC > 1 (upregulated)
up = res[(res['padj'] < 0.05) & (res['log2FoldChange'] > 1)]
up_10 = up.nsmallest(10, 'padj')
up_5 = up_10.nlargest(5, 'log2FoldChange')

# Concatenate the results
top_genes = pd.concat([down_5, up_5])

gene_selection = top_genes['Symbol'].tolist()
gene_selection


In [None]:
# Create volcano plot data
padj_thresh = 0.05
lfc_thresh = 1.0

volcano_data = res.copy()
volcano_data['-log10(padj)'] = -np.log10(volcano_data['padj'])

# Define significance categories
volcano_data['Category'] = 'Not Significant'
volcano_data.loc[(volcano_data['padj'] < padj_thresh) & (volcano_data['log2FoldChange'] > lfc_thresh), 'Category'] = 'Upregulated'
volcano_data.loc[(volcano_data['padj'] < padj_thresh) & (volcano_data['log2FoldChange'] < -lfc_thresh), 'Category'] = 'Downregulated'
volcano_data.loc[(volcano_data['padj'] < padj_thresh) & (abs(volcano_data['log2FoldChange']) <= lfc_thresh), 'Category'] = 'Significant'

# Highlight top genes
highlight_genes = top_genes.index.tolist()
volcano_data['Highlight'] = volcano_data.index.isin(highlight_genes)

# Create the plot
plt.figure(figsize=(10, 8))

# Plot points by category
categories = volcano_data['Category'].unique()
colors = {'Not Significant': 'lightgray', 'Upregulated': 'red', 'Downregulated': 'blue', 'Significant': 'orange'}

for category in categories:
    subset = volcano_data[volcano_data['Category'] == category]
    plt.scatter(subset['log2FoldChange'], subset['-log10(padj)'], 
               c=colors[category], alpha=0.6, s=20, label=category)

# Highlight top genes
highlight_data = volcano_data[volcano_data['Highlight']]
plt.scatter(highlight_data['log2FoldChange'], highlight_data['-log10(padj)'], 
           c='black', s=100, alpha=0.8, edgecolors='white', linewidth=2, label='Top Genes')

# Add gene labels for highlighted genes
for idx, row in highlight_data.iterrows():
    plt.annotate(row['Symbol'], (row['log2FoldChange'], row['-log10(padj)']), 
                xytext=(5, 5), textcoords='offset points', fontsize=8, alpha=0.8)

# Add threshold lines
plt.axhline(y=-np.log10(padj_thresh), color='gray', linestyle='--', alpha=0.7)
plt.axvline(x=lfc_thresh, color='gray', linestyle='--', alpha=0.7)
plt.axvline(x=-lfc_thresh, color='gray', linestyle='--', alpha=0.7)

# Formatting
plt.xlabel('log2(Fold Change)')
plt.ylabel('-log10(adjusted p-value)')
plt.title('Volcano Plot: KCL vs KNO3')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()


In [None]:
gene_selection_named = {idx: symbol for idx, symbol in sigs['Symbol'].items() if symbol in gene_selection}
gene_selection_named

### Subset Heatmap

In [None]:
dds_selected = dds[:, list(gene_selection_named.keys())]

In [None]:
grapher_selected = pd.DataFrame(dds_selected.layers['log1p'].T,
    index=dds_selected.var_names, columns=dds_selected.obs_names)
grapher_selected

In [None]:
grapher_symbol = grapher_selected.rename(index=gene_selection_named)
grapher_symbol

In [None]:
# Create a heatmap for the selected genes
col_colors = pd.Series(['KCL', 'KCL', 'KNO3', 'KNO3'], index=grapher_symbol.columns)
lut = {'KCL': "#D5D837", 'KNO3': "#8028D8"}
col_colors = col_colors.map(lut)

sns.clustermap(grapher_symbol,
               z_score=0,
               cmap='RdYlBu_r',
               col_colors=col_colors)

# Add a legend for the conditions
for label in lut:
    plt.scatter([], [], color=lut[label], label=label)
plt.legend(title='Condition', bbox_to_anchor=(1, 1), bbox_transform=plt.gcf().transFigure)

# Add a label "condition" on the right, above the gene symbols
g = plt.gcf().axes[-1]  # colorbar axis
g.annotate('Condition', xy=(18.1, -0.125), xycoords='axes fraction', ha='center', va='bottom', fontsize=12, fontweight='bold', rotation=0)

In [None]:
grapher_counts = pd.DataFrame(dds_selected.layers['normed_counts'].T,
    index=dds_selected.var_names, columns=dds_selected.obs_names)
grapher_symbol_counts = grapher_counts.rename(index=gene_selection_named)
grapher_symbol_counts

In [None]:
# Create barplot data for selected genes - 
selected_genes_data = grapher_symbol_counts.T  # Transpose to have samples as rows

# Separate genes into up and downregulated based on log2FoldChange from res
upregulated_genes = []
downregulated_genes = []

for gene_id, symbol in gene_selection_named.items():
    if res.loc[gene_id, 'log2FoldChange'] > 0:
        upregulated_genes.append(symbol)
    else:
        downregulated_genes.append(symbol)

# Create subplots for up and downregulated genes
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))

# Define colors for conditions
condition_colors = {'KCL': "#E93A3A", 'KNO3': "#3450DF"}
sample_conditions = ['KCL', 'KCL', 'KNO3', 'KNO3']
colors = [condition_colors[cond] for cond in sample_conditions]

# Plot upregulated genes
if upregulated_genes:
    up_data = selected_genes_data[upregulated_genes]
    x_up = np.arange(len(upregulated_genes))
    width = 0.2
    
    for i, (sample, color) in enumerate(zip(up_data.index, colors)):
        offset = (i - 1.5) * width
        ax1.bar(x_up + offset, up_data.loc[sample], width, 
               label=sample, color=color, alpha=0.8)

    ax1.set_ylabel('Normalized Counts')
    ax1.set_title('Expression Levels of Upregulated Genes')
    ax1.set_xticks(x_up)
    ax1.set_xticklabels(upregulated_genes, rotation=45, ha='right')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

# Plot downregulated genes
if downregulated_genes:
    down_data = selected_genes_data[downregulated_genes]
    x_down = np.arange(len(downregulated_genes))
    
    for i, (sample, color) in enumerate(zip(down_data.index, colors)):
        offset = (i - 1.5) * width
        ax2.bar(x_down + offset, down_data.loc[sample], width, 
               label=sample, color=color, alpha=0.8)
    
    ax2.set_ylabel('Normalized Counts')
    ax2.set_title('Expression Levels of Downregulated Genes')
    ax2.set_xticks(x_down)
    ax2.set_xticklabels(downregulated_genes, rotation=45, ha='right')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()