In [1]:
import os
import glob
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from matplotlib import pyplot as plt
from gridnext.densenet import DenseNet
from gridnext.plotting import plot_label_tensor
from gridnext.visium_datasets import create_visium_dataset
from gridnext.gridnet_models import GridNetHexMM
from gridnext.multimodal_datasets import MMStackDataset

In [2]:
# Download data from: https://zenodo.org/uploads/10372917

data_dir = '../data/BA44_testdata'

In [3]:
spaceranger_dirs = sorted(glob.glob(os.path.join(data_dir, 'spaceranger', '*')))
fullres_image_files = sorted(glob.glob(os.path.join(data_dir, 'fullres_images', '*.jpg')))
annot_files = sorted(glob.glob(os.path.join(data_dir, 'annotations', '*.csv')))

class_names = ['Layer_%d' % i for i in range(1, 7)] + ['White_matter']

## AnnData-based workflow for pairing image and count

### Instantiation of AnnData w/image information

In [4]:
from gridnext.visium_datasets import create_visium_anndata_img

# From pre-extracted image patches:
imgpatch_dirs = sorted(glob.glob(os.path.join(data_dir, 'spaceranger', '*', '*_patches100um')))
adata_img1 = create_visium_anndata_img(spaceranger_dirs, imgpatch_dirs=imgpatch_dirs, 
                                       annot_files=annot_files)
print(adata_img1)

  adata_arr.obs['imgpath'] = imfiles
  adata_arr.obs['imgpath'] = imfiles
  adata_arr.obs['imgpath'] = imfiles
  adata_arr.obs['imgpath'] = imfiles
  adata_arr.obs['imgpath'] = imfiles
  adata_arr.obs['imgpath'] = imfiles


AnnData object with n_obs × n_vars = 22088 × 28463
    obs: 'x', 'y', 'x_px', 'y_px', 'array', 'annotation', 'imgpath'


In [5]:
# On-demand image patch extraction:
adata_img2 = create_visium_anndata_img(spaceranger_dirs, fullres_image_files=fullres_image_files, patch_size_um=200,
                                       annot_files=annot_files)
print(adata_img2)

  adata_arr.obs['imgpath'] = imfiles
  adata_arr.obs['imgpath'] = imfiles
  adata_arr.obs['imgpath'] = imfiles
  adata_arr.obs['imgpath'] = imfiles
  adata_arr.obs['imgpath'] = imfiles
  adata_arr.obs['imgpath'] = imfiles


AnnData object with n_obs × n_vars = 22088 × 28463
    obs: 'x', 'y', 'x_px', 'y_px', 'array', 'annotation', 'imgpath'


### Instantiation of Dataset from AnnData

In [6]:
adata_img1.obs

Unnamed: 0,x,y,x_px,y_px,array,annotation,imgpath
V003-CGND-HRA-02744-A_51_5,51,5,7335,2099,V003-CGND-HRA-02744-A,White_matter,../data/BA44_testdata/spaceranger/V003-CGND-HR...
V003-CGND-HRA-02744-A_52_4,52,4,7448,1902,V003-CGND-HRA-02744-A,White_matter,../data/BA44_testdata/spaceranger/V003-CGND-HR...
V003-CGND-HRA-02744-A_53_5,53,5,7561,2099,V003-CGND-HRA-02744-A,White_matter,../data/BA44_testdata/spaceranger/V003-CGND-HR...
V003-CGND-HRA-02744-A_54_4,54,4,7674,1902,V003-CGND-HRA-02744-A,White_matter,../data/BA44_testdata/spaceranger/V003-CGND-HR...
V003-CGND-HRA-02744-A_55_5,55,5,7788,2099,V003-CGND-HRA-02744-A,White_matter,../data/BA44_testdata/spaceranger/V003-CGND-HR...
...,...,...,...,...,...,...,...
V005-CGND-HRA-02751-B_13_71,13,71,2887,15128,V005-CGND-HRA-02751-B,Layer_3,../data/BA44_testdata/spaceranger/V005-CGND-HR...
V005-CGND-HRA-02751-B_14_70,14,70,3000,14931,V005-CGND-HRA-02751-B,Layer_3,../data/BA44_testdata/spaceranger/V005-CGND-HR...
V005-CGND-HRA-02751-B_16_70,16,70,3226,14931,V005-CGND-HRA-02751-B,Layer_3,../data/BA44_testdata/spaceranger/V005-CGND-HR...
V005-CGND-HRA-02751-B_18_70,18,70,3453,14931,V005-CGND-HRA-02751-B,Layer_4,../data/BA44_testdata/spaceranger/V005-CGND-HR...


In [11]:
from gridnext.multimodal_datasets import MMAnnDataset, MMAnnGridDataset

mm_pdat = MMAnnDataset(adata_img1, 'annotation')
print(len(mm_pdat))

(xi, xc), y = mm_pdat[0]
print(xi.shape, xc.shape, y)

22088
torch.Size([3, 227, 227]) torch.Size([28463]) tensor(6)


In [14]:
mm_gdat = MMAnnGridDataset(adata_img1, 'annotation', 'array')
print(len(mm_gdat))

(xi, xc), y = mm_gdat[0]
print(xi.shape, xc.shape, y)

6
torch.Size([78, 64, 3, 227, 227]) torch.Size([28463, 78, 64]) 76


In [9]:
adata_img1.obs

Unnamed: 0,x,y,x_px,y_px,array,annotation,imgpath
V003-CGND-HRA-02744-A_51_5,51,5,7335,2099,V003-CGND-HRA-02744-A,White_matter,../data/BA44_testdata/spaceranger/V003-CGND-HR...
V003-CGND-HRA-02744-A_52_4,52,4,7448,1902,V003-CGND-HRA-02744-A,White_matter,../data/BA44_testdata/spaceranger/V003-CGND-HR...
V003-CGND-HRA-02744-A_53_5,53,5,7561,2099,V003-CGND-HRA-02744-A,White_matter,../data/BA44_testdata/spaceranger/V003-CGND-HR...
V003-CGND-HRA-02744-A_54_4,54,4,7674,1902,V003-CGND-HRA-02744-A,White_matter,../data/BA44_testdata/spaceranger/V003-CGND-HR...
V003-CGND-HRA-02744-A_55_5,55,5,7788,2099,V003-CGND-HRA-02744-A,White_matter,../data/BA44_testdata/spaceranger/V003-CGND-HR...
...,...,...,...,...,...,...,...
V005-CGND-HRA-02751-B_13_71,13,71,2887,15128,V005-CGND-HRA-02751-B,Layer_3,../data/BA44_testdata/spaceranger/V005-CGND-HR...
V005-CGND-HRA-02751-B_14_70,14,70,3000,14931,V005-CGND-HRA-02751-B,Layer_3,../data/BA44_testdata/spaceranger/V005-CGND-HR...
V005-CGND-HRA-02751-B_16_70,16,70,3226,14931,V005-CGND-HRA-02751-B,Layer_3,../data/BA44_testdata/spaceranger/V005-CGND-HR...
V005-CGND-HRA-02751-B_18_70,18,70,3453,14931,V005-CGND-HRA-02751-B,Layer_4,../data/BA44_testdata/spaceranger/V005-CGND-HR...


## Dataset instantiation directly from Spaceranger

In [4]:
img_data = create_visium_dataset(spaceranger_dirs, 
                                 use_image=True, use_count=False, spatial=True,
                                 fullres_image_files=fullres_image_files,
                                 annot_files=annot_files)

In [5]:
count_data = create_visium_dataset(spaceranger_dirs, 
                                   use_image=False, use_count=True, spatial=True,
                                   minimum_detection_rate=0.02,
                                   annot_files=annot_files)

In [6]:
mm_gdat = create_visium_dataset(spaceranger_dirs, spatial=True,
                                use_image=True, use_count=True, 
                                fullres_image_files=fullres_image_files,
                                annot_files=annot_files)

'''
for i in range(len(mm_gdat)):
    (xi,xc), y = mm_gdat[i]
    fig, ax = plt.subplots(1, figsize=(2,2))
    plot_label_tensor(y, class_names=class_names, Visium=True, ax=ax)
    print(xi.shape, xi.min(), xi.max())
    print(xc.shape, xc.min(), xc.max())
'''
print(len(mm_gdat))

6


### Smaller dummy dataset

In [19]:
n, aw, imw, ngenes, nclasses = 6, 4, 224, 1000, 7
imdat = torch.rand((n, aw, aw, 3, imw, imw)).float()
cdat = torch.randint(0, 10, (n, ngenes, aw, aw)).float()
ldat = torch.randint(0, nclasses, (n, aw, aw)).long()

mm_gdat = MMStackDataset(TensorDataset(imdat, ldat), TensorDataset(cdat, ldat))

## Model instantiation

In [20]:
(xi,xc), y = mm_gdat[0]
image_shape = xi.shape[2:]
count_shape = xc.shape[0:1]
grid_shape = xi.shape[:2]

print(image_shape)
print(count_shape)
print(grid_shape)

torch.Size([3, 224, 224])
torch.Size([1000])
torch.Size([4, 4])


In [21]:
f_count = nn.Sequential(
    nn.Linear(xc.shape[0], 500),
    nn.Linear(500, 100),
    nn.BatchNorm1d(100),
    nn.ReLU(),

    nn.Linear(100, 100),
    nn.Linear(100, 50),
    nn.BatchNorm1d(50),
    nn.ReLU(),

    nn.Linear(50, len(class_names))
)

In [22]:
f_image = DenseNet(num_classes=len(class_names), small_inputs=False, efficient=False,
                   growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, drop_rate=0)

In [23]:
g = GridNetHexMM(f_image, f_count, image_shape, count_shape, grid_shape, len(class_names))

dl = DataLoader(mm_gdat, batch_size=1, shuffle=False)

for x, y in dl:
    pp = g.patch_predictions(x)
    print(pp.shape)
    out = g(x)
    print(out.shape)

torch.Size([1, 14, 4, 4])
torch.Size([1, 7, 4, 4])
torch.Size([1, 14, 4, 4])
torch.Size([1, 7, 4, 4])
torch.Size([1, 14, 4, 4])
torch.Size([1, 7, 4, 4])
torch.Size([1, 14, 4, 4])
torch.Size([1, 7, 4, 4])
torch.Size([1, 14, 4, 4])
torch.Size([1, 7, 4, 4])
torch.Size([1, 14, 4, 4])
torch.Size([1, 7, 4, 4])
