In [22]:
import torch
import torchvision
import numpy as np
import umap.umap_ as umap
import plotly.express as px

# Load MNIST dataset
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
X = train_dataset.data.float() / 255.0  # Shape (60000, 28, 28)
X = X.view(-1, 784)  # Flatten to (60000, 784)
y = train_dataset.targets

# Subsample to 10,000 points
n_samples = 10000
indices = torch.randperm(X.shape[0])[:n_samples]
X_sub = X[indices].numpy()
y_sub = y[indices].numpy()

# Apply UMAP
umap_model = umap.UMAP(n_components=2, random_state=42, n_neighbors=30, min_dist=0.1)
projections = umap_model.fit_transform(X_sub)

# Interactive plot with Plotly
fig = px.scatter(
    x=projections[:, 0],
    y=projections[:, 1],
    color=y_sub.astype(str),
    title='UMAP of MNIST Digits',
    labels={'x': 'UMAP Component 1', 'y': 'UMAP Component 2'},
    render_mode='webgl'
)
fig.update_traces(marker=dict(size=4, opacity=0.6))
fig.update_layout(
    width=900,
    height=700,
    title={'x': 0.5, 'xanchor': 'center'},
    legend=dict(title='Digit', orientation='h', yanchor='bottom', y=1.02, xanchor='center', x=0.5)
)
fig.show(renderer="browser")


n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.

