In [1]:
import os
import sys
import json
import numpy as np
import pandas as pd
from collections import defaultdict
from tqdm import tqdm

import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader

sys.path.append('./src')
from dataloaders import TileDataset
from models import SimSiam, Triplet

In [2]:
# model_name = 'simsiam'
# epoch = '0999'
# mpath = f'/mnt/data5/spatial/runs/{model_name}-all-slides/checkpoints/{epoch}.pt'
# chkpt = torch.load(mpath)

# model = SimSiam(
#     backbone='resnet50',
#     projector_hidden_dim=2048,
#     predictor_hidden_dim=512,
#     output_dim=2048,
# )

In [2]:
model_name = 'triplet'
epoch = '0999'
mpath = f'/mnt/data5/spatial/runs/{model_name}-all-slides/checkpoints/{epoch}.pt'
chkpt = torch.load(mpath)

model = Triplet(
    backbone='resnet50',
    projector_hidden_dim=2048,
    output_dim=2048,
)

In [3]:
model.load_state_dict(chkpt['state_dict'])
model.to('cuda')
model.eval()
None

In [4]:
ds = TileDataset(
    name='train',
    tile_dirs=[
        f'/mnt/data5/spatial/tiles/slide{slide}/{section}1'
        for slide in [1, 2, 3, 4]
        for section in ['A', 'B', 'C', 'D']
    ],
)

In [5]:
mean, std = ds.get_mean_std()
norm = T.Normalize(mean=mean, std=std)

Computing Train Dataset Norm: 100%|██████████| 36896/36896 [00:31<00:00, 1179.08it/s]


In [6]:
all_embeddings = {}
count = 0
for slide in [1, 2, 3, 4]:
    for section in ['A', 'B', 'C', 'D']:
        eval_section = f'slide{slide}/{section}1'
        cols = [
            "barcode",
            "in_tissue",
            "array_row",
            "array_col",
            "pxl_row_in_fullres",
            "pxl_col_in_fullres",
        ]
        pos_df = pd.read_csv(
            os.path.join('/mnt/data5/spatial/count', eval_section, 'outs/spatial/tissue_positions_list.csv'),
            header=None,
            names=cols,
        )
        pos_df = pos_df[pos_df['in_tissue'] == 1].reset_index(drop=True)
        count += len(pos_df)

        eval_tile_dir = os.path.join('/mnt/data5/spatial/tiles', eval_section)
        eval_ds = TileDataset(
            name=eval_section,
            tile_dirs=[eval_tile_dir],
            transform=lambda ds, idx, x: norm(x),
        )

        # use the ordering of the tiles in the metadata
        new_tile_paths = eval_tile_dir + '/' + pos_df['barcode'] + '.png'
        ntps = new_tile_paths.sort_values().reset_index(drop=True)
        otps = pd.Series(eval_ds.tile_paths).sort_values().reset_index(drop=True)
        # check that actual tiles match metadata
        assert ntps.equals(otps)
        eval_ds.tile_paths = new_tile_paths

        eval_dl = DataLoader(
            eval_ds,
            batch_size=256,
            shuffle=False,
            num_workers=0,
            pin_memory=True,
        )
        embeddings = []
        for eval_step, tiles in enumerate(tqdm(eval_dl)):
            tiles = tiles.to('cuda')
            with torch.no_grad():
                embedding = model.encoder(tiles).to('cpu')
                embeddings.append(embedding)
        all_embeddings[eval_section] = torch.concatenate(embeddings, axis=0)

assert count == len(ds.tile_paths)

os.makedirs(f'/mnt/data5/spatial/embeddings/{model_name}-all-slides-{epoch}', exist_ok=True)
torch.save(all_embeddings, f'/mnt/data5/spatial/embeddings/{model_name}-all-slides-{epoch}/embeddings.pt')

  0%|          | 0/3 [00:00<?, ?it/s]

100%|██████████| 3/3 [00:02<00:00,  1.19it/s]
100%|██████████| 3/3 [00:00<00:00,  4.39it/s]
100%|██████████| 12/12 [00:03<00:00,  3.69it/s]
100%|██████████| 9/9 [00:02<00:00,  4.00it/s]
100%|██████████| 3/3 [00:00<00:00,  4.48it/s]
100%|██████████| 4/4 [00:00<00:00,  4.21it/s]
100%|██████████| 3/3 [00:00<00:00,  3.07it/s]
100%|██████████| 4/4 [00:00<00:00,  4.03it/s]
100%|██████████| 15/15 [00:04<00:00,  3.69it/s]
100%|██████████| 15/15 [00:03<00:00,  3.89it/s]
100%|██████████| 17/17 [00:04<00:00,  4.07it/s]
100%|██████████| 14/14 [00:03<00:00,  4.46it/s]
100%|██████████| 13/13 [00:03<00:00,  4.27it/s]
100%|██████████| 11/11 [00:02<00:00,  4.41it/s]
100%|██████████| 15/15 [00:03<00:00,  4.04it/s]
100%|██████████| 13/13 [00:03<00:00,  4.20it/s]
