In [3]:
import torch
import torch.nn as nn
from math import floor
import os
import random
import numpy as np
import pandas as pd
import pdb
import time
from datasets.dataset_h5 import Dataset_All_Bags, Whole_Slide_Bag
from torch.utils.data import DataLoader
from models.resnet_custom import resnet50_baseline
import argparse
from utils.utils import print_network, collate_features
from utils.file_utils import save_hdf5
from PIL import Image
import h5py

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [2]:
def compute_w_loader(file_path, output_path, model, batch_size = 8, verbose = 0, 
	  				 print_every=20, pretrained=True, target_patch_size=-1):
	"""
	args:
		file_path: directory of bag (.h5 file)
		output_path: directory to save computed features (.h5 file)
		model: pytorch model
		batch_size: batch_size for computing features in batches
		verbose: level of feedback
		pretrained: use weights pretrained on imagenet
	"""
	dataset = Whole_Slide_Bag(file_path=file_path, pretrained=pretrained, 
							  target_patch_size=target_patch_size)
	# x, y = dataset[0]
	kwargs = {'num_workers': 4, 'pin_memory': True} if device.type == "cuda" else {}
	loader = DataLoader(dataset=dataset, batch_size=batch_size, **kwargs, collate_fn=collate_features)

	if verbose > 0:
		print('processing {}: total of {} batches'.format(file_path,len(loader)))

	mode = 'w'
	for count, (batch, coords) in enumerate(loader):
		with torch.no_grad():	
			if count % print_every == 0:
				print('batch {}/{}, {} files processed'.format(count, len(loader), count * batch_size))
			batch = batch.to(device, non_blocking=True)
			mini_bs = coords.shape[0]
			
			features = model(batch)
			
			features = features.cpu().numpy()

			asset_dict = {'features': features, 'coords': coords}
			
			save_hdf5(output_path, asset_dict, attr_dict= None, mode=mode)
			mode = 'a'
	
	return output_path

def create_folder(folder_path, clean_folder = True):
    if not os.path.isdir(folder_path):
        os.mkdir(folder_path)
    else: # remove all files in folder
        if clean_folder:
            for f in os.listdir(folder_path):
                os.remove(os.path.join(folder_path, f))

```
RESULTS_DIRECTORY_ST/ (e.g RESULTS_test_prostate)   
	├── h5
    		├── prostate_adenocarcinoma_1.3.0.h5
    		└── ...
	├── csv
    		├── slide_id.csv

```

```
slide_id.csv:   
slide_id   
prostate_adenocarcinoma_1.3.0   
```

In [4]:
# SET UP FILE PATHS
print('initializing dataset')
data_dir = './RESULTS_test_prostate'
csv_path = f'{data_dir}/csv/slide_id.csv'
feat_dir = './FEATURE_DIRECTORY'
slide_ext = '.tif'
batch_size = 256
target_patch_size = 224
pretrained = True

bags_dataset = Dataset_All_Bags(csv_path)
total = len(bags_dataset)

os.makedirs(feat_dir, exist_ok=True)
dest_files = os.listdir(feat_dir)


slide_id = bags_dataset[0].split(slide_ext)[0]
bag_name = slide_id + '.h5'
os.path.join(data_dir, 'h5', f'{slide_id}.h5')

# LOAD MODEL
print('loading model checkpoint')
model = resnet50_baseline(pretrained=pretrained)
model = model.to(device)

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    
model.eval()

initializing dataset
loading model checkpoint


In [5]:
for bag_candidate_idx in range(total):
    slide_id = bags_dataset[bag_candidate_idx].split(slide_ext)[0]
    bag_name = slide_id + '.h5'
    bag_candidate = os.path.join(data_dir, 'h5', bag_name)
    
    print('\nprogress: {}/{}'.format(bag_candidate_idx, total))
    print(bag_name)
    # if not no_auto_skip and slide_id+'.pt' in dest_files:
    #     print('skipped {}'.format(slide_id))
    #     continue 

    output_folder = os.path.join(feat_dir, 'h5_files')
    output_path = os.path.join(output_folder, bag_name)
    
    create_folder(output_folder, clean_folder = True)

    pt_folder = os.path.join(feat_dir, 'pt_files')
    create_folder(pt_folder, clean_folder = True)
    
    file_path = bag_candidate
    time_start = time.time()
    output_file_path = compute_w_loader(file_path, output_path, 
                                        model = model, batch_size = batch_size, 
                                        verbose = 1, print_every = 20,
                                        target_patch_size=target_patch_size)
    time_elapsed = time.time() - time_start
    print('\ncomputing features for {} took {} s'.format(output_file_path, time_elapsed))
    file = h5py.File(output_file_path, "r")

    features = file['features'][:]
    print('features size: ', features.shape)
    print('coordinates size: ', file['coords'].shape)
    features = torch.from_numpy(features)
    bag_base, _ = os.path.splitext(bag_name)
    torch.save(features, os.path.join(feat_dir, 'pt_files', bag_base+'.pt'))


progress: 0/1
prostate_adenocarcinoma_1.3.0.h5
pretrained: True
transformations: <torchvision.transforms.Compose object at 0x7fc9f1bbe7d0>
target_size:  (224, 224)
processing ./RESULTS_test_prostate/h5/prostate_adenocarcinoma_1.3.0.h5: total of 18 batches
batch 0/18, 0 files processed

computing features for ./FEATURE_DIRECTORY/h5_files/prostate_adenocarcinoma_1.3.0.h5 took 6.6557183265686035 s
features size:  (4371, 1024)
coordinates size:  (4371, 2)
