In [None]:
import anndata as ad
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
from sklearn.neighbors import KNeighborsTransformer
from torchvision.datasets import MNIST
from umap import UMAP

from txtox.utils import get_paths

paths = get_paths()

raw_data = MNIST("mnist", download=True)
(images, labels) = zip(*raw_data)
images = np.asarray(images) / 255
X = images.reshape(len(images), -1)

Y = UMAP(n_components=2, min_dist=0.8, n_neighbors=15, random_state=0).fit_transform(X)
Y = (Y - Y.mean(axis=0)) / Y.std(axis=0)
Y = Y * 5
Y = np.concatenate((Y, np.ones((Y.shape[0], 1))), axis=1)

labels = np.asarray(labels)

obs = pd.DataFrame(columns=["x_section", "y_section", "z_section"], data=Y)
obs["subclass"] = labels
obs["subclass"] = obs["subclass"].astype("category")
display(obs.head(3))

In [None]:
obs.describe()

In [None]:
obs.index = obs.index.astype(str)

In [None]:
adata

In [None]:
adata = ad.AnnData(X=X, obs=obs)
adata.obsm["spatial"] = adata.obs[["x_section", "y_section", "z_section"]].values
sc.pp.neighbors(
    adata,
    n_neighbors=15,
    use_rep="X",
    knn=True,
    transformer=None,
    metric="euclidean",
    random_state=0,
    key_added=f"spatial",
    copy=False,
)

unique_subclasses = adata.obs["subclass"].unique()
pastel_palette = sns.color_palette("pastel", len(unique_subclasses))
subclass_color_map = dict(zip(unique_subclasses, pastel_palette))
adata.obs["subclass_color"] = adata.obs["subclass"].astype(int).map(subclass_color_map)

In [None]:
adata.write_h5ad(paths["data_root"] + "mnist.h5ad")
adata