In [1]:
import os
import json
import numpy as np
import pandas as pd
from collections import defaultdict
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from scipy.io import mmread

import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader
from sklearn.metrics import pairwise_distances
from scipy.sparse.csgraph import shortest_path

from datasets import TileDataset
from models import SimSiam

%matplotlib widget

In [2]:
errs = defaultdict(list)
class ErrorTileDataset(TileDataset):
    def __getitem__(self, idx):
        try:
            tile = super().__getitem__(idx)
        except Exception as e:
            global errs
            errs[self.name].append(idx)
            print(f'Check {self.name} {idx=} {e}')
            tile = torch.ones((3, 86, 86))
        return tile

In [3]:
mpath = '/mnt/data5/spatial/runs/all-slides/checkpoints/0999.pt'
chkpt = torch.load(mpath)

model = SimSiam(
    backbone='resnet50',
    projector_hidden_dim=2048,
    predictor_hidden_dim=512,
    output_dim=2048,
)
model.load_state_dict(chkpt['state_dict'])
model.to('cuda')
model.eval()

SimSiam(
  (encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          

In [4]:
all_embeddings = {}
for slide in [1, 2, 3, 4]:
    for section in ['A', 'B', 'C', 'D']:
        eval_section = f'slide{slide}/{section}1'

        cols = [
            "barcode",
            "in_tissue",
            "array_row",
            "array_col",
            "pxl_row_in_fullres",
            "pxl_col_in_fullres",
        ]

        cpath = os.path.join("/mnt/data5/spatial/count", eval_section, "outs/spatial")
        pos_df = pd.read_csv(
            os.path.join(cpath, "tissue_positions_list.csv"),
            header=None,
            names=cols,
        )
        pos_df = pos_df.loc[pos_df["in_tissue"] == 1].reset_index(drop=True)
        with open(os.path.join(cpath, "scalefactors_json.json")) as f:
            scale_factors = json.loads(f.read())
        spot_radius = scale_factors["spot_diameter_fullres"] / 2

        mean = chkpt['data_mean']
        std = chkpt['data_std']

        eval_tile_dir = os.path.join('/mnt/data5/spatial/tiles', eval_section)
        eval_ds = ErrorTileDataset(
            tile_dir=eval_tile_dir,
            file_ext='.png',
            transform=T.Normalize(mean=mean, std=std),
            name=eval_section,
        )

        # use the ordering of the tiles in the metadata
        new_tile_paths = eval_tile_dir + '/' + pos_df['barcode'] + '.png'
        ntps = new_tile_paths.sort_values().reset_index(drop=True)
        otps = pd.Series(eval_ds.tile_paths).sort_values().reset_index(drop=True)
        # check that actual tiles match metadata
        # if it doesnt, it's because the tile at the
        # spot was skipped for being non-square 
        # i.e. its on the border of the slide
        # assert ntps.equals(otps)
        eval_ds.tile_paths = new_tile_paths

        eval_dl = DataLoader(
            eval_ds,
            batch_size=256,
            shuffle=False,
            num_workers=0,
            pin_memory=True,
        )
        embeddings = []
        for eval_step, tiles in enumerate(tqdm(eval_dl)):
            tiles = tiles.to('cuda')
            with torch.no_grad():
                embedding = model.encoder(tiles).to('cpu')
                embeddings.append(embedding)
        all_embeddings[eval_section] = torch.concatenate(embeddings, axis=0)

# for spots which didn't have tiles, replace with dummy embedding that will be maximally far
for eval_section, idx in errs.items():
    all_embeddings[eval_section][idx] = torch.full((2048,), fill_value=999, dtype=torch.float32)

os.makedirs('/mnt/data5/spatial/embeddings/all-slides-model', exist_ok=True)
torch.save(all_embeddings, f'/mnt/data5/spatial/embeddings/all-slides-model/embeddings.pt')

  0%|          | 0/3 [00:00<?, ?it/s]

100%|██████████| 3/3 [00:02<00:00,  1.20it/s]
100%|██████████| 3/3 [00:00<00:00,  4.45it/s]
100%|██████████| 12/12 [00:03<00:00,  3.71it/s]
100%|██████████| 9/9 [00:02<00:00,  4.25it/s]
100%|██████████| 3/3 [00:00<00:00,  4.63it/s]
100%|██████████| 4/4 [00:00<00:00,  5.05it/s]
100%|██████████| 3/3 [00:00<00:00,  3.97it/s]
100%|██████████| 4/4 [00:00<00:00,  4.51it/s]
100%|██████████| 15/15 [00:03<00:00,  4.32it/s]
100%|██████████| 15/15 [00:03<00:00,  4.29it/s]
100%|██████████| 17/17 [00:04<00:00,  4.24it/s]
100%|██████████| 14/14 [00:03<00:00,  4.25it/s]
100%|██████████| 13/13 [00:03<00:00,  4.11it/s]
  0%|          | 0/11 [00:00<?, ?it/s]

Check slide4/B1 idx=0 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/B1/CTGGGATCGCCCAGAT-1.png'
Check slide4/B1 idx=1 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/B1/TTCCACATTTCTCGTC-1.png'
Check slide4/B1 idx=2 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/B1/AGACGGGATTGGTATA-1.png'
Check slide4/B1 idx=3 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/B1/TACAAATTGCGGAGGT-1.png'
Check slide4/B1 idx=4 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/B1/CATTGCAAAGCATAAT-1.png'
Check slide4/B1 idx=5 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/B1/TTCTTCGCAATAGAGC-1.png'
Check slide4/B1 idx=6 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/B1/ATTCAGGATCGCCTCT-1.png'
Check slide4/B1 idx=7 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/B1/GCTGTCTGTGATCGAC-1.png'
Check slide4/B1 idx=8 [Errno 2] No such file or 

100%|██████████| 11/11 [00:02<00:00,  4.00it/s]
100%|██████████| 15/15 [00:03<00:00,  3.98it/s]


Check slide4/C1 idx=3786 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/C1/CCGCTCCGGATAAGCT-1.png'
Check slide4/C1 idx=3787 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/C1/ACGGACGCAGCGACAA-1.png'
Check slide4/C1 idx=3788 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/C1/ATTTATACTGGTAAAG-1.png'
Check slide4/C1 idx=3789 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/C1/CGGTATGGGCACTCTG-1.png'
Check slide4/C1 idx=3790 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/C1/ACCACACGGTTGATGG-1.png'


  0%|          | 0/13 [00:00<?, ?it/s]

Check slide4/D1 idx=0 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/D1/TGCATGGCAGTCTTGC-1.png'
Check slide4/D1 idx=1 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/D1/AGCTGCATTTGAGGTG-1.png'
Check slide4/D1 idx=2 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/D1/CCATCATAAGAACAGG-1.png'
Check slide4/D1 idx=3 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/D1/ATTCAGATGAATCCCT-1.png'
Check slide4/D1 idx=4 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/D1/AGCTGCTGTGCCGAAT-1.png'
Check slide4/D1 idx=5 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/D1/GCTCGACCGAACTGAA-1.png'
Check slide4/D1 idx=6 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/D1/CGGCTGAAGGTTACGC-1.png'
Check slide4/D1 idx=7 [Errno 2] No such file or directory: '/mnt/data5/spatial/tiles/slide4/D1/ATACGACAGATGGGTA-1.png'


100%|██████████| 13/13 [00:02<00:00,  4.36it/s]
