In [None]:
import anndata as ad
import numpy as np
import pandas as pd
import scanpy as sc
from sklearn.neighbors import KNeighborsTransformer
from torchvision.datasets import MNIST
from umap import UMAP
import seaborn as sns

from txtox.utils import get_paths

paths = get_paths()

raw_data = MNIST("mnist", download=True)
(images, labels) = zip(*raw_data)
images = np.asarray(images) / 255
X = images.reshape(len(images), -1)

Y = UMAP(n_components=2, min_dist=0.8, n_neighbors=15).fit_transform(X)
Y = (Y - Y.mean(axis=0)) / Y.std(axis=0)
Y = Y * 5
Y = np.concatenate((Y, np.ones((Y.shape[0], 1))), axis=1)

labels = np.asarray(labels)

obs = pd.DataFrame(columns=["x_section", "y_section", "z_section"], data=Y)
obs["subclass"] = labels
obs["subclass"] = obs["subclass"].astype("category")
display(obs.head(3))

Unnamed: 0,x_section,y_section,z_section,subclass
0,0.289193,-1.356731,1.0,5
1,10.618962,0.060141,1.0,0
2,0.270434,9.131705,1.0,4


In [2]:
obs.describe()

Unnamed: 0,x_section,y_section,z_section
count,60000.0,60000.0,60000.0
mean,-5e-06,-3e-06,1.0
std,5.000069,5.000065,0.0
min,-8.996947,-12.954169,1.0
25%,-3.733734,-3.192264,1.0
50%,-0.696206,0.26352,1.0
75%,2.760081,3.411531,1.0
max,12.853412,10.896443,1.0


In [3]:
obs.index=obs.index.astype(str)

In [None]:
adata = ad.AnnData(X=X, obs=obs)
adata.obsm["spatial"] = adata.obs[["x_section", "y_section", "z_section"]].values
sc.pp.neighbors(
    adata,
    n_neighbors=15,
    n_pcs=None,
    use_rep="spatial",
    knn=True,
    transformer=KNeighborsTransformer(n_neighbors=15, metric="minkowski", p=2),
    key_added="spatial",
    copy=False,
)

unique_subclasses = adata.obs["subclass"].unique()
pastel_palette = sns.color_palette("pastel", len(unique_subclasses))
subclass_color_map = dict(zip(unique_subclasses, pastel_palette))
adata.obs["subclass_color"] = adata.obs["subclass"].astype(int).map(subclass_color_map)


adata.write_h5ad(paths["data_root"] + "mnist.h5ad")

In [5]:
adata

AnnData object with n_obs × n_vars = 60000 × 784
    obs: 'x_section', 'y_section', 'z_section', 'subclass'
    uns: 'spatial'
    obsm: 'spatial'
    obsp: 'spatial_distances', 'spatial_connectivities'