In [1]:
import os
import sys
import json
import numpy as np
import pandas as pd
from collections import defaultdict
from tqdm import tqdm

import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader

sys.path.append('./src')
from dataloaders import TileDataset
from models import SimSiam, Triplet

import os
os.environ['CUDA_VISIBLE_DEVICES']='1'

In [2]:
# model_name = 'simsiam-all-slides'
# epoch = '0999'
# mpath = f'/mnt/data5/spatial/runs/{model_name}/checkpoints/{epoch}.pt'
# chkpt = torch.load(mpath)

# model = SimSiam(
#     backbone='resnet50',
#     projector_hidden_dim=2048,
#     predictor_hidden_dim=512,
#     output_dim=2048,
# )

In [3]:
model_name = 'triplet-gi'
epoch = '0999'
mpath = f'/mnt/data5/spatial/runs/{model_name}/checkpoints/{epoch}.pt'
chkpt = torch.load(mpath)

model = Triplet(
    backbone='resnet50',
    projector_hidden_dim=2048,
    output_dim=2048,
)

In [4]:
model.load_state_dict(chkpt['state_dict'])
model.to('cuda')
model.eval()
None

In [5]:
tile_dirs=[
    '/mnt/data5/spatial/data/colon/CD/A/tiles',
    '/mnt/data5/spatial/data/colon/CD/B/tiles',
    '/mnt/data5/spatial/data/colon/CD/C/tiles',
    '/mnt/data5/spatial/data/colon/CD/D/tiles',
    '/mnt/data5/spatial/data/colon/UC/A/tiles',
    '/mnt/data5/spatial/data/colon/UC/B/tiles',
    '/mnt/data5/spatial/data/colon/UC/C/tiles',
    '/mnt/data5/spatial/data/colon/UC/D/tiles',
    '/mnt/data5/spatial/data/colon/normal/A/tiles',
    '/mnt/data5/spatial/data/colon/C.diff/A/tiles',
    '/mnt/data5/spatial/data/colon/C.diff/B/tiles',
    '/mnt/data5/spatial/data/colon/C.diff/C/tiles',
    '/mnt/data5/spatial/data/stomach/normal/A/tiles',
    '/mnt/data5/spatial/data/stomach/H.pylori/A/tiles',
    '/mnt/data5/spatial/data/stomach/H.pylori/B/tiles',
    '/mnt/data5/spatial/data/stomach/H.pylori/C/tiles',
]

In [6]:
ds = TileDataset(
    name='train',
    tile_dirs=tile_dirs,
)
mean, std = ds.get_mean_std()
norm = T.Normalize(mean=mean, std=std)

Computing Train Dataset Norm: 100%|██████████| 36896/36896 [00:34<00:00, 1077.99it/s]


In [7]:
count = 0
for tile_dir in tile_dirs:
    section_dir = tile_dir[:-5]
    ppath = os.path.join(section_dir, 'outs/spatial/tissue_positions_list.csv')
    cols = [
        "barcode",
        "in_tissue",
        "array_row",
        "array_col",
        "pxl_row_in_fullres",
        "pxl_col_in_fullres",
    ]
    pos_df = pd.read_csv(
        ppath,
        header=None,
        names=cols,
    )
    pos_df = pos_df[pos_df['in_tissue'] == 1].reset_index(drop=True)
    count += len(pos_df)

    eval_ds = TileDataset(
        name='eval',
        tile_dirs=[tile_dir],
        transform=lambda ds, idx, x: norm(x),
    )

    # use the ordering of the tiles in the metadata
    new_tile_paths = tile_dir + '/' + pos_df['barcode'] + '.png'
    ntps = new_tile_paths.sort_values().reset_index(drop=True)
    otps = pd.Series(eval_ds.tile_paths).sort_values().reset_index(drop=True)
    # check that actual tiles match metadata
    assert ntps.equals(otps)
    eval_ds.tile_paths = new_tile_paths

    eval_dl = DataLoader(
        eval_ds,
        batch_size=256,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
    )
    embeddings = []
    for eval_step, tiles in enumerate(tqdm(eval_dl)):
        tiles = tiles.to('cuda')
        with torch.no_grad():
            embedding = model.encoder(tiles).to('cpu')
            embeddings.append(embedding)
    embeddings = torch.concatenate(embeddings, axis=0)

    embedding_dir = os.path.join(section_dir, 'embeddings')
    os.makedirs(embedding_dir, exist_ok=True)
    torch.save(embeddings, os.path.join(embedding_dir, f'{model_name}-{epoch}.pt'))

assert count == len(ds.tile_paths)

  0%|          | 0/15 [00:00<?, ?it/s]

100%|██████████| 15/15 [00:05<00:00,  2.79it/s]
100%|██████████| 11/11 [00:02<00:00,  4.23it/s]
100%|██████████| 17/17 [00:03<00:00,  4.38it/s]
100%|██████████| 13/13 [00:02<00:00,  4.34it/s]
100%|██████████| 13/13 [00:03<00:00,  3.97it/s]
100%|██████████| 15/15 [00:03<00:00,  4.36it/s]
100%|██████████| 15/15 [00:03<00:00,  4.10it/s]
100%|██████████| 14/14 [00:03<00:00,  4.37it/s]
100%|██████████| 3/3 [00:00<00:00,  3.99it/s]
100%|██████████| 3/3 [00:00<00:00,  4.71it/s]
100%|██████████| 12/12 [00:02<00:00,  4.44it/s]
100%|██████████| 9/9 [00:02<00:00,  4.40it/s]
100%|██████████| 3/3 [00:00<00:00,  4.68it/s]
100%|██████████| 4/4 [00:00<00:00,  4.67it/s]
100%|██████████| 3/3 [00:00<00:00,  3.65it/s]
100%|██████████| 4/4 [00:00<00:00,  4.70it/s]
