[Thanks to Tiffany Tang for the original R version of this]

For our exploration of PCA, we will be looking at gene expression data, obtained from The Cancer Genome Atlas, for various patients with breast cancer (BRCA). The genomic basis of breast cancer has been extensively studied in the scientific literature, and in particular, scientists have classified breast cancer occurrences into four different subtypes - each with its own defining characteristics and clinical implications (The Cancer Genome Atlas Research Network 2012).

Below, I have gathered the TCGA BRCA gene expression data for 244 patients along with their cancer subtype information and survival status. Given that there are 17814 genes in this dataset, there is no possible way that we can visualize all possible marginals or pairs of features at once, but perhaps, performing dimension reduction via PCA provides a good starting point for visualization.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from scipy.linalg import svd
from time import time

In [None]:
# load .Rdata file
from pyreadr import read_r
brca_data = read_r('data/tcga_brca.Rdata')

In [None]:
X = brca_data['X']  # Replace with actual feature data
Y = brca_data['Y']  # Replace with actual labels
# X = StandardScaler().fit_transform(X)  # Standardize the data

# PCA using sklearn
pca = PCA(n_components=0.95)  # retain 95% variance
X_pca = pca.fit_transform(X)

In [None]:
#| label: fig1
# Create a DataFrame for scores
brca_scores = pd.DataFrame(X_pca, columns=[f'PC{i+1}' for i in range(X_pca.shape[1])])
brca_scores['Subtype'] = Y['BRCA_Subtype_PAM50']  # Assuming Y has the subtype information

# Plot PC scores
plt.figure(figsize=(10, 6))
sns.scatterplot(data=brca_scores, x='PC1', y='PC2', hue='Subtype')
plt.title('PCA of BRCA Data')
plt.show()

In [None]:
#| label: fig2
# Cumulative variance explained
cum_var_explained = np.cumsum(pca.explained_variance_ratio_)

# Plot cumulative variance explained
plt.figure(figsize=(10, 6))
plt.plot(cum_var_explained, marker='o')
plt.axhline(0.75, color='red', linestyle='--')
plt.title('Cumulative Proportion of Variance Explained')
plt.xlabel('Number of PCs')
plt.ylabel('Cumulative Proportion of Variance Explained')
plt.grid()
plt.show()

In [None]:
# How many PC's would you need to explain 75% of the variance?
n_pcs = np.argmax(cum_var_explained >= 0.75) + 1
print(f'Number of PCs to explain 75% of variance: {n_pcs}')

In [None]:
# SVD for PCA
U, s, Vt = svd(X, full_matrices=False)
brca_scores_svd = pd.DataFrame(U @ np.diag(s), columns=[f'PC{i+1}' for i in range(len(s))])
brca_scores_svd['Subtype'] = Y['BRCA_Subtype_PAM50']

In [None]:
#| label: fig3
# Compare plots from SVD and sklearn PCA
plt.figure(figsize=(10, 6))
sns.scatterplot(data=brca_scores_svd, x='PC1', y='PC2', hue='Subtype')
plt.title('PCA of BRCA Data using SVD')
plt.show()

In [None]:

# Timing comparisons
npcs = 5

# Timing PCA with sklearn
start_time = time()
pca_out = PCA(n_components=npcs).fit(X)
time_sklearn = time() - start_time

# Timing SVD
start_time = time()
U, s, Vt = svd(X, full_matrices=False)
time_svd = time() - start_time

print(f"Time taken by sklearn PCA: {time_sklearn:.4f} seconds")
print(f"Time taken by SVD: {time_svd:.4f} seconds")

# Variance explained
var_ex1 = (s[:npcs]**2) / np.sum(s**2)
var_ex2 = pca.explained_variance_ratio_[:npcs]

print("Variance explained using sklearn PCA (first 5 PCs):", var_ex2)
print("Variance explained using SVD (first 5 PCs):", var_ex1)

Why is the result from SVD different? Note that scikit-learn's PCA centers the data, while we have not centered it for SVD. We would need to make a judgment call in this dataset on whether or not to center and/or scale it.

In [None]:
#| label: fig4
# Pair plot using seaborn
pair_data = brca_scores_svd.iloc[:, :npcs]  # Select only first n PCs
pair_data['Subtype'] = brca_scores_svd['Subtype']
sns.pairplot(pair_data, hue='Subtype')
plt.suptitle('Pair Plots of First 5 PCs', y=1.02)
plt.show()