In [11]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
%matplotlib inline
import seaborn as sns

In [193]:
import torch
import torchvision
from torchvision import transforms

# Load the MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=None, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=None, download=True)

# Create data loaders to load the data in batches
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

In [194]:
N = 1000

transform = transforms.Compose([
    transforms.ToTensor(),
])

mnist_nines = []
for image, target in train_dataset:
    if target == 9:
        # Convert the PIL image to a PyTorch tensor
        image = transforms.ToTensor()(image)
        mnist_nines.append(image.numpy().squeeze())

# X = train_dataset.data.numpy()[:N]
X =  np.array(mnist_nines)[:N]

X = X - X.mean(axis=0)

X = X.reshape(N, -1)

cov = np.dot(X.T, X) / X.shape[0] # get the data covariance matrix

cov.shape

(784, 784)

In [195]:
px.imshow(mnist_nines[0], color_continuous_scale='RdBu', color_continuous_midpoint=0)

In [196]:
px.imshow((cov @ X[0]).reshape(28, 28), color_continuous_scale='RdBu', color_continuous_midpoint=0)

In [197]:
num_eigenvectors = 5
U,S,V = np.linalg.svd(cov)

U.shape

(784, 784)

In [198]:
U_reduced = U[:,:num_eigenvectors].T # get the first 5 eigenvectors

In [110]:
S[:num_eigenvectors]

array([6.6365595, 3.7891724, 3.230307 , 1.9516282, 1.7772009],
      dtype=float32)

In [199]:
px.imshow(U_reduced[0].reshape(28, 28), color_continuous_scale='RdBu', color_continuous_midpoint=0)

In [200]:
px.imshow(U_reduced[1].reshape(28, 28), color_continuous_scale='RdBu', color_continuous_midpoint=0)

## Toy Dataset

In [156]:
# Set the random seed for reproducibility
np.random.seed(123)

# Generate a random dataset of size 1000 with 3 variables
n = 1000
age = np.random.normal(loc=0.5, scale=0.2, size=n)
smoke = np.random.normal(loc=0.5, scale=0.2, size=n)
cancer = 0.3 * (age / 100) + 0.6 * smoke + np.random.normal(loc=0, scale=0.1, size=n)

# Ensure that age and how much they smoke are between the desired range
# age = np.clip(age, 15, 100)
# smoke = np.clip(smoke, 0, 1)
# cancer = np.clip(cancer, 0, 1)

data = pd.DataFrame({'age': age, 'smoke': smoke, 'cancer': cancer})

In [187]:
px.scatter_3d(data, x="age", y="smoke", z="cancer", size_max=0)

In [158]:
X = np.stack([age, smoke, cancer]).T

X = X - X.mean(axis=0)

cov = np.dot(X.T, X) / X.shape[0] # get the data covariance matrix

cov.shape

(3, 3)

In [159]:
px.imshow(cov, color_continuous_scale='RdBu', color_continuous_midpoint=0, x = ['age', 'smoke', 'cancer'], y = ['age', 'smoke', 'cancer'])

In [161]:
num_eigenvectors = 3
U,S,V = np.linalg.svd(cov)

U_reduced = U[:,:num_eigenvectors].T # get the first 5 eigenvectors

In [162]:
U_reduced

array([[ 0.13263883, -0.80151787, -0.58307465],
       [ 0.99112467,  0.11252567,  0.07078045],
       [ 0.00887907, -0.5872879 ,  0.8093294 ]])

In [163]:
S

array([0.05274767, 0.03983853, 0.0067268 ])

In [191]:
import plotly.graph_objects as go

fig = go.Figure(data=go.Scatter3d(
    x=age - age.mean(),
    y=smoke - smoke.mean(),
    z=cancer - cancer.mean(),
    mode='markers',
    marker=dict(
        size=5,
        color=cancer,
        colorscale='Viridis',
    ),
))

v1 = U_reduced[0] 
v2 = U_reduced[1] 
v3 = U_reduced[2] / 2

fig.add_trace(go.Scatter3d(
    x=[0, v1[0]], y=[0, v1[1]], z=[0, v1[2]],
    mode='lines',
    line=dict(width=5, color='red'),
    name='Vector 1'
))

fig.add_trace(go.Scatter3d(
    x=[0, v2[0]], y=[0, v2[1]], z=[0, v2[2]],
    mode='lines',
    line=dict(width=5, color='blue'),
    name='Vector 2'
))

fig.add_trace(go.Scatter3d(
    x=[0, v3[0]], y=[0, v3[1]], z=[0, v3[2]],
    mode='lines',
    line=dict(width=5, color='green'),
    name='Vector 3'
))

fig.update_layout(
    scene=dict(
        xaxis_title='Age',
        yaxis_title='Smoke',
        zaxis_title='Cancer',
    ),
)
fig.show()