In [1]:
import os
import glob
import torch
import numpy as np
import scanpy as sc
import anndata as ad
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader

from gridnext.gridnet_models import GridNetHexOddr
from gridnext.visium_datasets import create_visium_dataset, create_visium_anndata
from gridnext.count_datasets import anndata_to_tensordataset, anndata_arrays_to_tensordataset
from gridnext.training import train_spotwise, train_gridwise

In [2]:
data_dir = '../data/BA44_testdata'

# TODO: instructions for downloading example data from Zenodo

In [3]:
spaceranger_dirs = sorted(glob.glob(os.path.join(data_dir, 'spaceranger', '*')))
annot_files = sorted(glob.glob(os.path.join(data_dir, 'annotations', '*.csv')))

### 1.1 Load spot data

### a) Load large dataset (low memory usage, slow accession)
Map-style PyTorch dataset for lazy loading of spots from large datasets. Use when full data cannot fit into memory.

In [4]:
# First, create a CountDataset using all arrays to generate unified countfiles:
# (once this has been done once, you can skip straight to train/test set creation)
pdat = create_visium_dataset(spaceranger_dirs, annot_files=annot_files, 
                             use_count=True, use_image=False, spatial=False)

# Create training and validation datasets
n_val = 1
train_srd = spaceranger_dirs[n_val:]
train_ann = annot_files[n_val:]

val_srd = spaceranger_dirs[:n_val]
val_ann = annot_files[:n_val]

train_pdat = create_visium_dataset(train_srd, annot_files=train_ann,
                                   use_count=True, use_image=False, spatial=False)
val_pdat = create_visium_dataset(val_srd, annot_files=val_ann,
                                 use_count=True, use_image=False, spatial=False)
print('%d training and %d validation spots' % (len(train_pdat), len(val_pdat)))

assert np.array_equal(train_pdat.classes, val_pdat.classes), "Classes in train/val data do not match!"
class_names = train_pdat.classes
print('%d classes' % len(class_names))

1920 un-annotated spots
228 un-annotated spots
1692 un-annotated spots
19540 training and 2548 validation spots
7 classes


### b) Load small dataset (high memory usage, fast accession)
Load full dataset into memory for fast accession

In [6]:
# Create AnnData representation of full Visium data (all arrays)
destfile = os.path.join(data_dir, 'adata_tutorial.h5ad')

#adata = create_visium_anndata(spaceranger_dirs, annot_files=annot_files, destfile=destfile)
# Use for future accession after running previous line once:
adata = ad.read_h5ad(destfile)  # add "backed='r'" to avoid reading full AnnData into memory

In [7]:
# Perform desired preprocessing (normalization, log-transform, HVG selection, etc...)
sc.pp.normalize_total(adata, 1e4)
sc.pp.log1p(adata)

n_hvgs = 2000
cvar = np.abs(adata.X.std(axis=0) / adata.X.mean(axis=0))
cvar[adata.X.mean(axis=0) == 0] = 0  # disregard genes with 0 mean as HVGs
thresh_val = np.sort(cvar)[-n_hvgs]
adata = adata[:, cvar >= thresh_val]

  cvar = np.abs(adata.X.std(axis=0) / adata.X.mean(axis=0))


In [9]:
# Create training and validation datasets
n_val = 1
val_arrays = adata.obs.array.unique()[:n_val]
train_arrays = adata.obs.array.unique()[n_val:]
adata_val = adata[adata.obs.array.isin(val_arrays)]
adata_train = adata[adata.obs.array.isin(train_arrays)]

train_pdat, train_classes  = anndata_to_tensordataset(adata_train, 'annotation')
val_pdat, val_classes = anndata_to_tensordataset(adata_val, 'annotation')
print('%d training and %d validation spots' % (len(train_pdat), len(val_pdat)))

assert np.array_equal(train_classes, val_classes), "Classes in train/val data do not match!"
class_names = train_classes
print('%d classes' % len(class_names))

19540 training and 2548 validation spots
7 classes


### 1.2. Train spot classifier

In [10]:
# Create data loaders for training loop
dataloader_spots = {
    'train': DataLoader(train_pdat, batch_size=128, shuffle=True),
    'val': DataLoader(val_pdat, batch_size=128)
}

In [13]:
# Instantiate fully-connected network to be used as spot classifier (f)
x, _ = train_pdat[0]
input_size = x.shape[0]

f = nn.Sequential(
    nn.Linear(input_size, 500),
    nn.Linear(500, 100),
    nn.BatchNorm1d(100),
    nn.ReLU(),

    nn.Linear(100, 100),
    nn.Linear(100, 50),
    nn.BatchNorm1d(50),
    nn.ReLU(),

    nn.Linear(50, len(class_names))
)

In [14]:
# Perform model training and save parameters
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(f.parameters(), lr=1e-4)

output_dir = '../models'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
outfile = os.path.join(output_dir, 'tutorial_f_count.pth')

f, f_val_hist, f_train_hist = train_spotwise(f, dataloader_spots, loss, optimizer, 
                                             num_epochs=10, display=False, outfile=outfile)

Epoch 0/9
----------
train Loss: 1.9473 Acc: 0.2692
val Loss: 1.9068 Acc: 0.3206

Epoch 1/9
----------
train Loss: 1.7212 Acc: 0.3662
val Loss: 1.9032 Acc: 0.3257

Epoch 2/9
----------
train Loss: 1.5985 Acc: 0.3854
val Loss: 1.9400 Acc: 0.3277

Epoch 3/9
----------
train Loss: 1.5220 Acc: 0.3908
val Loss: 1.9986 Acc: 0.3289

Epoch 4/9
----------
train Loss: 1.4742 Acc: 0.3946
val Loss: 2.0482 Acc: 0.3285

Epoch 5/9
----------
train Loss: 1.4407 Acc: 0.3951
val Loss: 2.0861 Acc: 0.3281

Epoch 6/9
----------
train Loss: 1.4170 Acc: 0.3961
val Loss: 2.1462 Acc: 0.3277

Epoch 7/9
----------
train Loss: 1.4017 Acc: 0.3964
val Loss: 2.1860 Acc: 0.3269

Epoch 8/9
----------
train Loss: 1.3903 Acc: 0.3960
val Loss: 2.2122 Acc: 0.3250

Epoch 9/9
----------
train Loss: 1.3801 Acc: 0.3983
val Loss: 2.2560 Acc: 0.3285

Training complete in 0m 24s
Best val loss: 1.903159


### 2.1. Load grid data

### a) Large dataset

In [15]:
# Create training and validation datasets
train_gdat = create_visium_dataset(train_srd, annot_files=train_ann,
                                   use_count=True, use_image=False, spatial=True)
val_gdat = create_visium_dataset(val_srd, annot_files=val_ann,
                                 use_count=True, use_image=False, spatial=True)
print('%d training and %d validation arrays' % (len(train_gdat), len(val_gdat)))

class_names = train_gdat.classes
print('%d classes' % len(class_names))

5 training and 1 validation arrays
7 classes


### b) Small dataset

In [18]:
train_pdat, train_classes = anndata_arrays_to_tensordataset(adata_train, 'annotation', 'array')
val_pdat, val_classes = anndata_arrays_to_tensordataset(adata_val, 'annotation', 'array')
print('%d training and %d validation grids' % (len(train_pdat), len(val_pdat)))

assert np.array_equal(train_classes, val_classes), "Classes in train/val data do not match!"
class_names = train_classes
print('%d classes' % len(class_names))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [18:03<00:00, 216.72s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [01:30<00:00, 90.60s/it]

5 training and 1 validation grids
7 classes





### 2.2. Training spatial corrector network

In [16]:
# Create data loaders for training loop
dataloader_grids = {
    'train': DataLoader(train_gdat, batch_size=1, shuffle=True),
    'val': DataLoader(val_gdat, batch_size=1)
}

In [17]:
# Instantiate g network
H_VISIUM = 78
W_VISIUM = 64

g = GridNetHexOddr(f, (input_size,), (H_VISIUM, W_VISIUM), n_classes=len(class_names), use_bn=True)

In [None]:
# Train g network
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(g.corrector.parameters(), lr=1e-3)

# Fixing the parameters of the patch classifier allows slightly faster training, even when only
# optimizing the parameters of the corrector. 
for param in g.patch_classifier.parameters():
    param.requires_grad = False

outfile = os.path.join(output_dir, 'tutorial_g_count')

g, g_val_hist, g_train_hist = train_gridwise(g, dataloader_grids, loss, optimizer, 
                                             num_epochs=10, outfile=outfile)

### 3. Visualizing train/validation performance