## Step 2: Training a simple model

In [1]:
# install pathology-whole-slide-data if needed
!pip3 install git+https://github.com/DIAGNijmegen/pathology-whole-slide-data@main

Collecting git+https://github.com/DIAGNijmegen/pathology-whole-slide-data@main
  Cloning https://github.com/DIAGNijmegen/pathology-whole-slide-data (to revision main) to /tmp/pip-req-build-0bowo24x
  Running command git clone --filter=blob:none --quiet https://github.com/DIAGNijmegen/pathology-whole-slide-data /tmp/pip-req-build-0bowo24x
  Resolved https://github.com/DIAGNijmegen/pathology-whole-slide-data to commit 1e4c6ca939c5e372a0b626739c24443f31bff505
  Preparing metadata (setup.py) ... [?25ldone


In [2]:
# install detectron2 if needed
!pip3 install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.9/index.html

Looking in links: https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.9/index.html


In [3]:
import os
import time
import numpy as np
from pathlib import Path
from matplotlib import pyplot as plt

from wholeslidedata.interoperability.detectron2.iterator import WholeSlideDetectron2Iterator
from wholeslidedata.interoperability.detectron2.trainer import WholeSlideDectectron2Trainer
from wholeslidedata.interoperability.detectron2.predictor import Detectron2DetectionPredictor
from wholeslidedata.iterators import create_batch_iterator
from wholeslidedata.visualization.plotting import plot_boxes

from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.modeling import build_model

Setting up the training configuration and parameters (can also be defined in a separate yaml file).

In [None]:
user_config = {
    'wholeslidedata': {
        'default': {
            'yaml_source': "./configs/training_sample.yml",
#             "seed": 42,
            "image_backend": "asap",
            'labels': {
                "ROI": 0,
                "lymphocytes": 1
            },
        
            
            'batch_shape': {
                'batch_size': 10,
                'spacing': 0.5,
                'shape': [128,128,3],
                'y_shape': [1000, 6],
            },
            
            
            
            "annotation_parser": {
                "sample_label_names": ['roi'],
            },
            
            'point_sampler_name': "RandomPointSampler",
            'point_sampler': {
                "buffer": {'spacing': "${batch_shape.spacing}", 'value': -64},
            },
            
            'patch_label_sampler_name': 'DetectionPatchLabelSampler',
            'patch_label_sampler': {
                "max_number_objects": 1000,
                "detection_labels": ['lymphocytes'],
                    
            },
            
        }
    }
}

Creating the batch generator.

In [None]:
training_batch_generator = create_batch_iterator(
    user_config=user_config,
#     user_config=r'./configs/training_config.yml',
    mode='training',
    cpus=1,
    iterator_class=WholeSlideDetectron2Iterator,
)

Visualizing a sample batch.

In [None]:
batch_dicts = next(training_batch_generator)

for _ in range(20):
    batch_dicts = next(training_batch_generator)
    fig, ax = plt.subplots(1,8, figsize=(20,10))
    for i in range(8):
        patch = batch_dicts[i]['image'].cpu().detach().numpy().transpose(1,2,0).astype('uint8')
        _boxes =  batch_dicts[i]['instances'].gt_boxes.tensor.cpu().detach().numpy()
        boxes = np.ones((len(_boxes), 6))
        boxes[..., :4] = _boxes
        max_width, max_height = batch_dicts[i]['instances'].image_size
        ax[i].imshow(patch)
        plot_boxes(boxes, max_width=max_width, max_height=max_height, axes=ax[i])
    plt.show()

Creating the output folder for saving the model and results.

In [None]:
output_folder = Path('./outputs')
if not(os.path.isdir(output_folder)): os.mkdir (output_folder) 
cpus = 4

# Train the model

In [None]:
cfg = get_cfg()
# using faster rcnn architecture
cfg.merge_from_file(
    model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml")
)


cfg.DATASETS.TRAIN = ("detection_dataset2",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 1

cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[16, 24, 32]]

cfg.SOLVER.IMS_PER_BATCH = 10
cfg.SOLVER.BASE_LR = 0.001  # pick a good LR
cfg.SOLVER.MAX_ITER = 2000  # 2000 iterations seems good enough for this toy dataset; you may need to train longer for a practical dataset
cfg.SOLVER.STEPS = (10, 100, 250)
cfg.SOLVER.WARMUP_ITERS = 0
cfg.SOLVER.GAMMA = 0.5

cfg.OUTPUT_DIR = str(output_folder)
output_folder.mkdir(parents=True, exist_ok=True)


model = build_model(cfg)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Parameter Count:\n" + str(pytorch_total_params))

trainer = WholeSlideDectectron2Trainer(cfg, user_config=user_config, cpus=cpus)
trainer.resume_or_load(resume=False)
trainer.train()

Evaluation.

In [None]:
cfg = get_cfg()

cfg.merge_from_file(
    model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml")
)


cfg.DATASETS.TRAIN = ("detection_dataset2",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 1

cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 256  
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[16, 24, 32]]

cfg.SOLVER.IMS_PER_BATCH = 10
cfg.SOLVER.BASE_LR = 0.001  # pick a good LR
cfg.SOLVER.MAX_ITER = 2000  # 300 iterations seems good enough for this toy dataset; you may need to train longer for a practical dataset
cfg.SOLVER.WARMUP_ITERS = 0
cfg.SOLVER.GAMMA = 0.5

cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.1

cfg.OUTPUT_DIR = str(output_folder)
output_folder.mkdir(parents=True, exist_ok=True)

cfg.MODEL.WEIGHTS = os.path.join(output_folder, "model_final.pth")

model = build_model(cfg)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Parameter Count:\n" + str(pytorch_total_params))

In [None]:
predictor = Detectron2DetectionPredictor(cfg)
with create_batch_iterator(
    user_config=user_config,
    mode='training',
    cpus=4,
) as training_batch_generator:
    for _ in range(10):
        fig, ax = plt.subplots(1,10, figsize=(20,10))
        batch_x, batch_y, info = next(training_batch_generator)
        predicted_batch = predictor.predict_on_batch(batch_x)
        for i in range(10):
            patch = batch_x[i]
            boxes =  predicted_batch[i]['boxes']
            confidences = predicted_batch[i]['confidences']
            filtered_boxes = []
            for box, confidence in zip(boxes, confidences):
                if confidence > 0.3:
                    filtered_boxes.append(box)
            ax[i].imshow(patch)
            plot_boxes(filtered_boxes, max_width=64, max_height=64, axes=ax[i])
        plt.show()