In [1]:
import os
import glob
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader, TensorDataset
from matplotlib import pyplot as plt
from gridnext.densenet import DenseNet
from gridnext.plotting import plot_label_tensor
from gridnext.visium_datasets import create_visium_dataset, read_feature_names, read_feature_matrix, visium_get_positions
from gridnext.gridnet_models import GridNetHexMM
from gridnext.multimodal_datasets import MMStackDataset

In [2]:
# Download data from: https://zenodo.org/uploads/10372917

data_dir = '../data/BA44_testdata'

In [3]:
spaceranger_dirs = sorted(glob.glob(os.path.join(data_dir, 'spaceranger', '*')))
fullres_image_files = sorted(glob.glob(os.path.join(data_dir, 'fullres_images', '*.jpg')))
annot_files = sorted(glob.glob(os.path.join(data_dir, 'annotations', '*.csv')))

class_names = ['Layer_%d' % i for i in range(1, 7)] + ['White_matter']

In [4]:
# Image transforms to be applied prior to input to DenseNet image classifier (f)
ppx = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
patch_size = (3,224,224)

## AnnData-based workflow for pairing image and count

### Instantiation of AnnData w/image information

In [5]:
from gridnext.visium_datasets import create_visium_anndata_img

# From pre-extracted image patches:
imgpatch_dirs = sorted(glob.glob(os.path.join(data_dir, 'spaceranger', '*', '*_patches100um')))
adata_img1 = create_visium_anndata_img(spaceranger_dirs, imgpatch_dirs=imgpatch_dirs, 
                                       annot_files=annot_files)
print(adata_img1)

AnnData object with n_obs × n_vars = 22088 × 28463
    obs: 'x', 'y', 'x_px', 'y_px', 'array', 'annotation', 'imgpath'
    var: 'gene_symbol'


In [6]:
adata_img1.obs

Unnamed: 0,x,y,x_px,y_px,array,annotation,imgpath
V003-CGND-HRA-02744-A_51_5,51,5,7335,2099,V003-CGND-HRA-02744-A,White_matter,../data/BA44_testdata/spaceranger/V003-CGND-HR...
V003-CGND-HRA-02744-A_52_4,52,4,7448,1902,V003-CGND-HRA-02744-A,White_matter,../data/BA44_testdata/spaceranger/V003-CGND-HR...
V003-CGND-HRA-02744-A_53_5,53,5,7561,2099,V003-CGND-HRA-02744-A,White_matter,../data/BA44_testdata/spaceranger/V003-CGND-HR...
V003-CGND-HRA-02744-A_54_4,54,4,7674,1902,V003-CGND-HRA-02744-A,White_matter,../data/BA44_testdata/spaceranger/V003-CGND-HR...
V003-CGND-HRA-02744-A_55_5,55,5,7788,2099,V003-CGND-HRA-02744-A,White_matter,../data/BA44_testdata/spaceranger/V003-CGND-HR...
...,...,...,...,...,...,...,...
V005-CGND-HRA-02751-B_13_71,13,71,2887,15128,V005-CGND-HRA-02751-B,Layer_3,../data/BA44_testdata/spaceranger/V005-CGND-HR...
V005-CGND-HRA-02751-B_14_70,14,70,3000,14931,V005-CGND-HRA-02751-B,Layer_3,../data/BA44_testdata/spaceranger/V005-CGND-HR...
V005-CGND-HRA-02751-B_16_70,16,70,3226,14931,V005-CGND-HRA-02751-B,Layer_3,../data/BA44_testdata/spaceranger/V005-CGND-HR...
V005-CGND-HRA-02751-B_18_70,18,70,3453,14931,V005-CGND-HRA-02751-B,Layer_4,../data/BA44_testdata/spaceranger/V005-CGND-HR...


In [None]:
# On-demand image patch extraction:
adata_img2 = create_visium_anndata_img(spaceranger_dirs, fullres_image_files=fullres_image_files, patch_size_um=200,
                                       annot_files=annot_files)
print(adata_img2)

### Instantiation of Dataset from AnnData

In [7]:
from gridnext.multimodal_datasets import MMAnnDataset, MMAnnGridDataset

mm_pdat = MMAnnDataset(adata_img1, 'annotation', img_transforms=ppx)
print(len(mm_pdat))

(xi, xc), y = mm_pdat[0]
print(xi.shape, xc.shape, y)

22088
torch.Size([3, 224, 224]) torch.Size([28463]) tensor(6)


In [8]:
mm_gdat = MMAnnGridDataset(adata_img1, 'annotation', 'array', img_transforms=ppx)
print(len(mm_gdat))

(xi, xc), y = mm_gdat[0]
print(xi.shape, xc.shape, y)

6
torch.Size([78, 64, 3, 224, 224]) torch.Size([28463, 78, 64]) 76


In [9]:
dl = DataLoader(mm_gdat, batch_size=2)
x, y = next(iter(dl))
print(x[0].shape, x[1].shape)
print(y.shape)

torch.Size([2, 78, 64, 3, 224, 224]) torch.Size([2, 28463, 78, 64])
torch.Size([2])


## Dataset instantiation directly from Spaceranger

In [15]:
img_data = create_visium_dataset(spaceranger_dirs, 
                                 use_image=True, use_count=False, spatial=True,
                                 fullres_image_files=fullres_image_files,
                                 annot_files=annot_files, 
                                 img_transforms=ppx)

In [16]:
count_data = create_visium_dataset(spaceranger_dirs, 
                                   use_image=False, use_count=True, spatial=True,
                                   minimum_detection_rate=0.02,
                                   annot_files=annot_files)

In [17]:
mm_gdat = create_visium_dataset(spaceranger_dirs, spatial=True,
                                use_image=True, use_count=True, 
                                fullres_image_files=fullres_image_files,
                                annot_files=annot_files, img_transforms=ppx)

'''
for i in range(len(mm_gdat)):
    (xi,xc), y = mm_gdat[i]
    fig, ax = plt.subplots(1, figsize=(2,2))
    plot_label_tensor(y, class_names=class_names, Visium=True, ax=ax)
    print(xi.shape, xi.min(), xi.max())
    print(xc.shape, xc.min(), xc.max())
'''
print(len(mm_gdat))

6


In [19]:
dl = DataLoader(mm_gdat, batch_size=2)
x, y = next(iter(dl))
print(x[0].shape, x[1].shape)
print(y[0].shape, y[1].shape)

torch.Size([2, 78, 64, 3, 224, 224]) torch.Size([2, 78, 64])
torch.Size([2, 10711, 78, 64]) torch.Size([2, 78, 64])


### Smaller dummy dataset

In [10]:
n, aw, imw, ngenes, nclasses = 6, 4, 224, 1000, 7
imdat = torch.rand((n, aw, aw, 3, imw, imw)).float()
cdat = torch.randint(0, 10, (n, ngenes, aw, aw)).float()
ldat = torch.randint(0, nclasses, (n, aw, aw)).long()

mm_gdat = MMStackDataset(TensorDataset(imdat, ldat), TensorDataset(cdat, ldat))

In [12]:
dl = DataLoader(mm_gdat, batch_size=2)
x, y = next(iter(dl))
print(x[0].shape, x[1].shape)
print(y.shape)

torch.Size([2, 4, 4, 3, 224, 224]) torch.Size([2, 1000, 4, 4])
torch.Size([2, 4, 4])


## Model instantiation

In [13]:
(xi,xc), y = mm_gdat[0]
image_shape = xi.shape[2:]
count_shape = xc.shape[0:1]
grid_shape = xi.shape[:2]

print(image_shape)
print(count_shape)
print(grid_shape)

torch.Size([3, 224, 224])
torch.Size([1000])
torch.Size([4, 4])


In [14]:
f_count = nn.Sequential(
    nn.Linear(xc.shape[0], 500),
    nn.Linear(500, 100),
    nn.BatchNorm1d(100),
    nn.ReLU(),

    nn.Linear(100, 100),
    nn.Linear(100, 50),
    nn.BatchNorm1d(50),
    nn.ReLU(),

    nn.Linear(50, len(class_names))
)

In [15]:
f_image = DenseNet(num_classes=len(class_names), small_inputs=False, efficient=False,
                   growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, drop_rate=0)

In [16]:
g = GridNetHexMM(f_image, f_count, image_shape, count_shape, grid_shape, len(class_names))

dl = DataLoader(mm_gdat, batch_size=2, shuffle=False)

for x, y in dl:
    pp = g.patch_predictions(x)
    print(pp.shape)
    out = g(x)
    print(out.shape)

torch.Size([2, 14, 4, 4])
torch.Size([2, 7, 4, 4])
torch.Size([2, 14, 4, 4])
torch.Size([2, 7, 4, 4])
torch.Size([2, 14, 4, 4])
torch.Size([2, 7, 4, 4])


## Model training

In [17]:
from gridnext.training import train_gridwise

# Train g network
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(g.corrector.parameters(), lr=1e-3)

# Fixing the parameters of the patch classifier allows slightly faster training, even when only
# optimizing the parameters of the corrector. 
for param in g.patch_classifier.parameters():
    param.requires_grad = False

In [18]:
train_gridwise(g, {'train':dl, 'val':dl}, loss, optimizer)

Epoch 0/9
----------
train Loss: 2.4764 Acc: 0.1279
val Loss: 4.2891 Acc: 0.0930

Epoch 1/9
----------
train Loss: 1.9258 Acc: 0.2558
val Loss: 2.5540 Acc: 0.1512

Epoch 2/9
----------
train Loss: 1.6701 Acc: 0.3605
val Loss: 1.9742 Acc: 0.2907

Epoch 3/9
----------
train Loss: 1.5157 Acc: 0.4070
val Loss: 1.7174 Acc: 0.3256

Epoch 4/9
----------
train Loss: 1.4113 Acc: 0.4419
val Loss: 1.5540 Acc: 0.3953

Epoch 5/9
----------
train Loss: 1.3374 Acc: 0.5233
val Loss: 1.4580 Acc: 0.4186

Epoch 6/9
----------
train Loss: 1.2819 Acc: 0.5116
val Loss: 1.3821 Acc: 0.4419

Epoch 7/9
----------
train Loss: 1.2334 Acc: 0.5000
val Loss: 1.3188 Acc: 0.4419

Epoch 8/9
----------
train Loss: 1.1938 Acc: 0.5581
val Loss: 1.2677 Acc: 0.5233

Epoch 9/9
----------
train Loss: 1.1586 Acc: 0.5698
val Loss: 1.2214 Acc: 0.5814

Training complete in 2m 33s
Best val loss: 1.221394


(GridNetHexMM(
   (patch_classifier): DenseNet(
     (features): Sequential(
       (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
       (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu0): ReLU(inplace=True)
       (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
       (denseblock1): _DenseBlock(
         (denselayer1): _DenseLayer(
           (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
           (relu1): ReLU(inplace=True)
           (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
           (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
           (relu2): ReLU(inplace=True)
           (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
         )
         (denselayer2): _DenseLayer(
           (norm1): BatchNo