In [5]:
import argparse
import os

import torch
import torch.optim as optim

from hover_net.dataloader import get_dataloader
from hover_net.models import HoVerNetExt
from hover_net.process import proc_valid_step_output, train_step, valid_step
from hover_net.tools.coco import coco_evaluation_pipeline
from hover_net.tools.utils import (dump_yaml, read_yaml, update_accumulated_output)

## Set configurations

In [6]:
config = {
    "DATA": {
        "TRAIN_COCO_JSON": "./data/PanNuke_COCO/Fold 1/fold1.json",
        "VALID_COCO_JSON": "./data/PanNuke_COCO/Fold 1/fold1.json",
        "CLASSES": ["neoplastic", "inflammatory", "softtissue", "dead", "epithelial"],
        "NUM_TYPES": 5,
        "PATCH_SIZE": 256
    },
    "TRAIN": {
        "DEVICE": "cpu",
        "EPOCHS": 160,
        "BATCH_SIZE": 4
    },
    "EVAL": {
        "COCO_EVAL_STEP": 5,
        "COCO_EVAL_CAT_IDS": [1, 2, 3, 4, 5]
    },
    "MODEL": {
        "BACKBONE": "resnet",
        "PRETRAINED": "./pretrained/resnet50-0676ba61.pth",
        "NUM_TYPES": 3
    },
    "LOGGING": {
        "SAVE_STEP": 5,
        "SAVE_PATH": "./experiments/initial/",
        "VERBOSE": True
    }
}

In [7]:
train_dataloader = get_dataloader(
        dataset_type="coco",
        ann_file=config["DATA"]["TRAIN_COCO_JSON"],
        classes=config["DATA"]["CLASSES"],
        input_shape=(
            config["DATA"]["PATCH_SIZE"],
            config["DATA"]["PATCH_SIZE"]
        ),
        mask_shape=(
            config["DATA"]["PATCH_SIZE"],
            config["DATA"]["PATCH_SIZE"]
        ),
        batch_size=config["TRAIN"]["BATCH_SIZE"],
        run_mode="train",
    )

val_dataloader = get_dataloader(
        dataset_type="coco",
        ann_file=config["DATA"]["VALID_COCO_JSON"],
        classes=config["DATA"]["CLASSES"],
        input_shape=(
            config["DATA"]["PATCH_SIZE"],
            config["DATA"]["PATCH_SIZE"]
        ),
        mask_shape=(
            config["DATA"]["PATCH_SIZE"],
            config["DATA"]["PATCH_SIZE"]
        ),
        batch_size=config["TRAIN"]["BATCH_SIZE"],
        run_mode="val",
    )

loading annotations into memory...
Done (t=0.99s)
creating index...
index created!
loading annotations into memory...
Done (t=0.97s)
creating index...
index created!


## Define model

In [8]:
model = HoVerNetExt(
        backbone_name=config["MODEL"]["BACKBONE"],
        pretrained_backbone=config["MODEL"]["PRETRAINED"],
        num_types=config["MODEL"]["NUM_TYPES"],
    )

Loading: ./pretrained/resnet50-0676ba61.pth


In [9]:
optimizer = optim.Adam(model.parameters(), lr=1.0e-4, betas=(0.9, 0.999))

In [10]:
model.to(config["TRAIN"]["DEVICE"])

HoVerNetExt(
  (backbone): ResNetExt(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0):

In [11]:
os.makedirs(config["LOGGING"]["SAVE_PATH"], exist_ok=True)

In [12]:
dump_yaml(
        os.path.join(
            config["LOGGING"]["SAVE_PATH"],
            "config.yaml"
        ),
        config
    )

## Train model

In [None]:
for epoch in range(config["TRAIN"]["EPOCHS"]):
    accumulated_output = {}
    for step_idx, data in enumerate(train_dataloader):
        train_result_dict = train_step(
            epoch,
            step_idx,
            batch_data=data,
            model=model,
            optimizer=optimizer,
            device=config["TRAIN"]["DEVICE"],
            show_step=1,
            verbose=config["LOGGING"]["VERBOSE"],
        )

    for step_idx, data in enumerate(val_dataloader):
        valid_result_dict = valid_step(
            epoch, step_idx,
            batch_data=data,
            model=model,
            device=config["TRAIN"]["DEVICE"]
        )
        update_accumulated_output(accumulated_output, valid_result_dict)

    out_dict = proc_valid_step_output(accumulated_output)

    print(
        f"[Epoch {epoch + 1} / {config['TRAIN']['EPOCHS']}] Val || "
        f"ACC={out_dict['scalar']['np_acc']:.3f} || "
        f"DICE={out_dict['scalar']['np_dice']:.3f} || "
        f"MSE={out_dict['scalar']['hv_mse']:.3f}"
    )

    if (epoch + 1) % config["LOGGING"]["SAVE_STEP"] == 0:
        torch.save(
            model.state_dict(),
            os.path.join(
                config["LOGGING"]["SAVE_PATH"],
                f"epoch_{epoch + 1}.pth"
            )
        )

    if (epoch + 1) % config["EVAL"]["COCO_EVAL_STEP"] == 0:
        coco_evaluation_pipeline(
            dataloader=val_dataloader,
            model=model,
            device=config["TRAIN"]["DEVICE"],
            nr_types=config["DATA"]["NUM_TYPES"],
            cat_ids=config["EVAL"]["COCO_EVAL_CAT_IDS"]
        )

In [None]:
torch.save(
    model.state_dict(),
    os.path.join(config["LOGGING"]["SAVE_PATH"], "latest.pth")
)