In [None]:
from utils.datasets import WSIDataset
from utils.wsi_infer import predict_over_wsi, wsi_heatmap
from torch.utils.data import DataLoader
import os

In [4]:
# load wsi data
wsi_dir = 'data/exmaple_wsis'
wsi_dataset = WSIDataset(img_dir= wsi_dir)
wsi_dataloader = DataLoader(wsi_dataset, batch_size = 1, shuffle = False)

In [9]:
# load feature extractors and scores models
from rl_benchmarks.models import iBOTViT
import pickle

weights_path = 'data/ibot_vit_base_pancan.pth'
ibot_base_pancancer = iBOTViT(architecture="vit_base_pancan", encoder="teacher", weights_path=weights_path)
ibot_base_pancancer.eval()
transform=ibot_base_pancancer.transform

# unpickle skorch NN classifier Chowder model
scores_model_file = 'data/example_model.pkl'
with open(scores_model_file, 'rb') as model_file:
    example_model = pickle.load(model_file)

# extract extreme layer scoring portion of Chowder model
# for torch models, do not include "['model'].module_"
scores_model = example_model['model'].module_.extreme_layer

[32m2024-02-26 14:51:00.024[0m | [1mINFO    [0m | [36mrl_benchmarks.models.feature_extractors.ibot_vit[0m:[36m__init__[0m:[36m78[0m - [1mPretrained weights found at data/ibot_vit_base_pancan.pth and loaded with msg: _IncompatibleKeys(missing_keys=[], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v', 'head.last_layer2.weight_g', 'head.last_layer2.weight_v'])[0m


In [None]:
# make 
outdir = 'data/example_out'
# patch size and magnification needed are specific
#  to feature extraction method
wsi_patch_size = (224,224)
req_mag = 20 

# if out_dir does not exist, create
if os.path.exists(outdir) == False:
    os.makedirs(outdir)

# predict using above feature extractor on example wsis
predict_over_wsi(outdir = outdir, 
                    wsi_patch_size = wsi_patch_size,
                    req_mag = 20,
                    feature_extractor = ibot_base_pancancer,
                    feature_extractor_transforms = transform,
                    feature_extractor_device:'cpu',
                    scorer_model = scores_model,
                    scorer_model_device = 'cpu',
                    wsi_dataloader = wsi_dataloader, 
                    postive_tissue_threshold = 1.0,
                    num_patches = 'all',
                    min_patch_number: int = 10,
                    feature_batch_size: int = 100,
                    cache_dir: str = 'cache'):

In [None]:
# make and save heatmaps from inference result file
def wsi_heatmap(wsi_dataloader,
                scores_df,
                col_to_show:str,
                file_col: str,
                outdir:str, 
                coord_col_1: str = 'start_coord_0',
                coord_col_2: str = 'start_coord_1',
                patch_size_col:str = 'patch_size',
                display_mag:int = 2,
                cmap:str = 'viridis'):