In [6]:
import os
import glob
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from matplotlib import pyplot as plt
from gridnext.densenet import DenseNet
from gridnext.plotting import plot_label_tensor
from gridnext.visium_datasets import create_visium_dataset
from gridnext.gridnet_models import GridNetHexMM
from gridnext.multimodal_datasets import MMStackDataset

In [11]:
# Download data from: https://zenodo.org/uploads/10372917

data_dir = '../data/BA44_testdata'

In [12]:
spaceranger_dirs = sorted(glob.glob(os.path.join(data_dir, 'spaceranger', '*')))
fullres_image_files = sorted(glob.glob(os.path.join(data_dir, 'fullres_images', '*.jpg')))
annot_files = sorted(glob.glob(os.path.join(data_dir, 'annotations', '*.csv')))

class_names = ['Layer_%d' % i for i in range(1, 7)] + ['White_matter']

### Dataset instantiation

In [4]:
img_data = create_visium_dataset(spaceranger_dirs, 
                                 use_image=True, use_count=False, spatial=True,
                                 fullres_image_files=fullres_image_files,
                                 annot_files=annot_files)

In [5]:
count_data = create_visium_dataset(spaceranger_dirs, 
                                   use_image=False, use_count=True, spatial=True,
                                   minimum_detection_rate=0.02,
                                   annot_files=annot_files)

In [6]:
mm_gdat = create_visium_dataset(spaceranger_dirs, spatial=True,
                                use_image=True, use_count=True, 
                                fullres_image_files=fullres_image_files,
                                annot_files=annot_files)

'''
for i in range(len(mm_gdat)):
    (xi,xc), y = mm_gdat[i]
    fig, ax = plt.subplots(1, figsize=(2,2))
    plot_label_tensor(y, class_names=class_names, Visium=True, ax=ax)
    print(xi.shape, xi.min(), xi.max())
    print(xc.shape, xc.min(), xc.max())
'''
print(len(mm_gdat))

6


### Smaller dummy dataset

In [21]:
n, aw, imw, ngenes, nclasses = 6, 4, 224, 1000, 7
imdat = torch.rand((n, aw, aw, 3, imw, imw))
cdat = torch.randint(0, 10, (n, ngenes, aw, aw))
ldat = torch.randint(0, nclasses, (n, aw, aw))

mm_gdat = MMStackDataset(TensorDataset(imdat, ldat), TensorDataset(cdat, ldat))

### Model instantiation

In [22]:
(xi,xc), y = mm_gdat[0]
image_shape = xi.shape[2:]
count_shape = xc.shape[0:1]
grid_shape = xi.shape[:2]

print(image_shape)
print(count_shape)
print(grid_shape)

torch.Size([3, 224, 224])
torch.Size([1000])
torch.Size([4, 4])


In [23]:
f_count = nn.Sequential(
    nn.Linear(xc.shape[0], 500),
    nn.Linear(500, 100),
    nn.BatchNorm1d(100),
    nn.ReLU(),

    nn.Linear(100, 100),
    nn.Linear(100, 50),
    nn.BatchNorm1d(50),
    nn.ReLU(),

    nn.Linear(50, len(class_names))
)

In [24]:
f_image = DenseNet(num_classes=len(class_names), small_inputs=False, efficient=False,
                   growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, drop_rate=0)

In [25]:
g = GridNetHexMM(f_image, f_count, image_shape, count_shape, grid_shape, len(class_names))

dl = DataLoader(mm_gdat, batch_size=1, shuffle=False)

for x, y in dl:
    pp = g.patch_predictions(x)
    print(pp.shape)
    out = g(x)
    print(out.shape)

RuntimeError: mat1 and mat2 must have the same dtype, but got Long and Float