# Codes for model construction

In [None]:
import copy
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # v1 fa docker gpu 0

import random
import shutil
from functools import partial
from glob import glob
from multiprocessing import Pool, cpu_count
from pathlib import Path

import cupy as cp
import cv2
import fastai
import matplotlib
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import pandas as pd
import PIL
import SimpleITK
import SimpleITK as sitk
import sklearn.metrics as skm
import torchvision
import torchvision.models as models
from cupyx.scipy import ndimage as cu_ndimage
from fastai.callback.mixup import MixUp
from fastai.callback.tracker import SaveModelCallback
from fastai.data.core import DataLoaders, Tensor, explode_types, fastuple, show_image, tensor, typedispatch
from fastai.optimizer import SGD, Adam, QHAdam
from fastai.torch_basics import flatten_check, is_listy, itertools, nn, set_seed, to_np, torch
from fastai.torch_core import defaults
from fastai.vision.all import (
    Callback,
    Categorize,
    CategoryBlock,
    ClassificationInterpretation,
    CropPad,
    CrossEntropyLossFlat,
    DataBlock,
    Datasets,
    FileGetter,
    GrandparentSplitter,
    Hook,
    Image,
    ImageBlock,
    ImageDataLoaders,
    Interpretation,
    IntToFloatTensor,
    Learner,
    Module,
    MultiCategoryBlock,
    Normalize,
    PILBase,
    PILImage,
    PILImageBW,
    Pipeline,
    RandomCrop,
    RandomResizedCrop,
    RandTransform,
    Recorder,
    Resize,
    ResizeMethod,
    RocAuc,
    RocAucBinary,
    TensorImage,
    TensorImageBase,
    TfmdLists,
    ToTensor,
    Transform,
    TransformBlock,
    accuracy,
    aug_transforms,
    cast,
    create_body,
    create_head,
    detuplify,
    doc,
    first,
    get_files,
    get_grid,
    imagenet_stats,
    model_meta,
    params,
    parent_label,
    resnet34,
    resnet50,
    resnet101,
    resnet152,
    show_at,
    skm_to_fastai,
    store_attr,
    to_image,
    uniqueify,
    xresnet18,
    xresnet34,
    xresnet50,
    xresnet101,
    xresnet152,
    xse_resnet152,
    xse_resnext18,
    xse_resnext50,
    xse_resnext101,
)
from fastcore.foundation import L
from PIL import Image
from scipy import ndimage
from scipy.ndimage.interpolation import zoom
from termcolor import colored
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm_notebook

plt.style.use("seaborn-deep")

plt.rc("font", family="Times New Roman")
%reload_ext autoreload
%autoreload 2
%matplotlib inline

# Config & Args

In [None]:
def random_seed(seed_value):
    """
    Global random seed for CPU/GPU. 
    """
    assert isinstance(seed_value, int)

    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


class Args:
    exp_name = "ZYF1207benchmark3-test2auc77"
    seed = 2
    dpi = 200

    ww = 1200
    wl = -500
    blacklist = [
        "345092",
        "353437",
    ]
    data_path = Path("/Projects/data/DJJ")
    # part 1 (train + validation + test0)
    excel_f_p1 = data_path / "part1/SPH-20191215.xlsx"
    ct_path_p1 = data_path / "part1"

    # test 1
    excel_f_t1 = data_path / "test1/test1_new.xlsx"
    ct_path_t1 = data_path / "test1"

    # test 2
    excel_f_t2 = data_path / "test2/test2_new.xlsx"
    ct_path_t2 = data_path / "test2"

    output_path = data_path / "processed_ct_r75"
    spacing = [0.6, 0.6, 0.6]

    fixed_half_edge = 75  # half the length of nodule 3D-Patch
    use_existing = True
    split_p = (0.8, 0.9)  # Split points for training, validation, and test sets
    shuffle_dataset = False

    rotation = 180
    r_pos = 1.1 # ratio of class1 : class0 during training
    radius_middle = 66  # half the length of crop
    resize_size = 112  # size of network's input
    n_in = 3  # number of input channel

    batch_size = 64

    num_workers = 0  # Number of CPU threads in DataLoader
    ps = 0.5
    lr = 1e-3
    wd = 1e-3
    mom = 0.9
    epoch = 150
    freeze_epochs = 70
    pct_start = 0.3
    loss_weight = [1.0, 1.1]

    test_code = True
    pre_processing = False  # whether need to pre-process raw CT data
    plot_validate_result = True

    vmin = -0.025
    vmax = 0.15


args = Args()
random_seed(args.seed)

defaults.callbacks

## Packages' Version

In [None]:
print(f"fastai: {fastai.__version__}")
print(f"pytorch: {torch.__version__}")
print(f"numpy: {np.__version__}")
print(f"cupy: {cp.__version__}")
print(f"pandas: {pd.__version__}")
print(f"matplotlib: {matplotlib.__version__}")
print(f"opencv: {cv2.__version__}")
print(f"torchvision: {torchvision.__version__}")

# Data Statistics

In [None]:
def get_statistic(args: tuple, new_spacing: np.ndarray) -> np.ndarray:
    """
    Extracting the regularized radius information of one case
    """
    assert isinstance(args, tuple) and len(args) == 2
    case_path, label = args
    assert isinstance(label, int) and label in (0, 1)
    nii_img_path = case_path / f"{case_path.name}.nii"
    assert isinstance(case_path, Path) and case_path.exists()
    assert nii_img_path.exists()
    assert isinstance(new_spacing, np.ndarray) and new_spacing.ndim == 1 and new_spacing.size == 3

    nii_img = nib.load(str(nii_img_path))

    img_arr = nii_img.get_fdata()
    assert img_arr.ndim == 3
    assert -32768 < img_arr.min() < img_arr.max() < 32767
    img_arr = img_arr.astype(np.int16)  # HWD

    affine = nii_img.header.get_best_affine()

    spacing = np.concatenate([np.abs(affine[0, :3][np.nonzero(affine[0, :3])]), np.abs(affine[1, :3][np.nonzero(affine[1, :3])]), np.abs(affine[2, :3][np.nonzero(affine[2, :3])])])  # WHD
    spacing = spacing[::-1]  # DHW

    annotation_file = case_path / "R.acsv"
    assert isinstance(annotation_file, Path) and annotation_file.exists()

    center_and_length = []
    for line in open(annotation_file):
        if line.startswith("point"):
            center_and_length.append(np.array(line.split("|")[1:4]).astype(np.float))  # WHD
    center, length = center_and_length

    # Convert 'center' to pixel coordinate system
    center = center - affine[:3, -1]
    center = np.matmul(np.linalg.inv(affine[:3, :3]), center)  # HWD

    # Convert 'length' to pixel coordinate system
    length = np.matmul(np.linalg.inv(affine[:3, :3]), length)
    length = np.abs(length)

    new_shape = np.round(img_arr.transpose(2, 1, 0).shape * spacing / new_spacing)
    new_spacing = spacing * img_arr.shape / new_shape
    resize_factor = new_shape / img_arr.transpose(2, 1, 0).shape  # DHW
    length *= resize_factor[2]

    return spacing.min(), length.max(), label


if args.pre_processing and args.test_code:
    min_spacing, max_length, label = get_statistic(args=(Path("/Projects/data/DJJ/part1/410759"), 1), new_spacing=np.array(args.spacing))
    print(f"min_spacing = {min_spacing}, max_length = {max_length}, label = {label}")

In [None]:
def hist_one_dataset(dataset_path: Path, excel_path: Path, title: str, bins: int = 50):
    """
    Plotting some information about one data set
    """
    assert isinstance(dataset_path, Path) and dataset_path.exists()
    assert isinstance(excel_path, Path) and excel_path.exists()
    assert isinstance(title, str) and title in ("PART1", "TEST1", "TEST2")
    assert isinstance(bins, int) and bins > 0

    label_pd = pd.read_excel(excel_path)[["住院号", "N2"]]
    label_pd = label_pd.set_index("住院号")
    label_dict = label_pd.to_dict()["N2"]

    label_dict_norm = dict()
    for k, v in label_dict.items():
        label_dict_norm[str(k)] = int(v)
    l = [(dataset_path / i, label_dict_norm[i.name]) for i in dataset_path.iterdir() if i.is_dir() and i.name in label_dict_norm.keys() and i.name not in args.blacklist]
    p = Pool(32)
    p_get_statistic = partial(get_statistic, new_spacing=np.array(args.spacing))
    rslt = p.map(p_get_statistic, l)
    rslt_arr = np.array(rslt)

    plt.figure(figsize=(4, 4), dpi=args.dpi)
    plt.hist(rslt_arr[:, 0], bins=bins, alpha=0.7, density=True)  # edgecolor="black",  facecolor="blue",
    plt.xlabel("spacing (in mm)")
    plt.ylabel("freq")
    plt.title(f"{title} Spacing Min Value")
    plt.show()

    plt.figure(figsize=(4, 4), dpi=args.dpi)
    plt.hist([rslt_arr[rslt_arr[:, 2] == 0, 1], rslt_arr[rslt_arr[:, 2] == 1, 1]], label=["class 0", "class1"], bins=bins, density=True)
    plt.xlabel("length (in pixels)")
    plt.ylabel("freq")
    plt.title(f"{title} Nodule Length Max Value")
    plt.legend(loc="upper right")
    plt.show()

# Pre-Processing (if Needed)

In [None]:
def load_scan_nii(case_path: Path) -> [np.ndarray, np.ndarray]:
    """
    Load the numpy matrix and spacing information from a patient's nii file
    """
    assert isinstance(case_path, Path) and case_path.exists()

    nii_img_path = case_path / f"{case_path.name}.nii"
    assert nii_img_path.exists()

    nii_img = nib.load(str(nii_img_path))
    header = nii_img.header

    img_arr = nii_img.get_fdata()
    img_arr = np.squeeze(img_arr).astype(np.int16)

    affine = header.get_best_affine()

    assert isinstance(img_arr, np.ndarray)
    assert img_arr.ndim == 3 and img_arr.dtype == np.int16
    assert isinstance(affine, np.ndarray)
    assert affine.ndim == 2 and affine.dtype == np.float64
    return img_arr, affine

In [None]:
def re_sample(ct_array: np.ndarray, x: [int, float], y: [int, float], z: [int, float], r_mm: [int, float], spacing: np.ndarray, new_spacing: np.ndarray, order: int = 3) -> tuple:
    """
    Use the new_spacing parameter to resample the initial CT matrix with spatial resolution, while 
    mapping the coordinates, size, and other parameters to the pixel coordinate system at the 
    new resolution.
    """
    assert isinstance(ct_array, np.ndarray) and ct_array.ndim == 3 and ct_array.shape > (1, 1, 1) and ct_array.dtype == np.int16
    assert isinstance(x, (int, float)) and x >= 0
    assert isinstance(y, (int, float)) and y >= 0
    assert isinstance(z, (int, float)) and z >= 0
    assert (isinstance(r_mm, (int, float)) and r_mm > 0) or r_mm is None
    assert isinstance(spacing, np.ndarray) and spacing.size == 3
    assert isinstance(new_spacing, np.ndarray) and new_spacing.size == 3
    assert isinstance(order, int) and 0 <= order <= 5
    assert (spacing > 0).all() and (new_spacing > 0).all()

    new_shape = np.round(ct_array.shape * spacing / new_spacing)
    new_spacing = spacing * ct_array.shape / new_shape
    resize_factor = new_shape / ct_array.shape
    ct_array = zoom(ct_array, resize_factor, mode="nearest", order=order)
    z *= resize_factor[0]
    y *= resize_factor[1]
    x *= resize_factor[2]
    if r_mm is not None:
        r_mm *= resize_factor[2]

    assert isinstance(ct_array, np.ndarray) and ct_array.ndim == 3
    assert ct_array.shape > (1, 1, 1) and ct_array.dtype == np.int16
    assert isinstance(new_spacing, np.ndarray) and new_spacing.size == 3
    assert isinstance(x, (int, float)) and x >= 0
    assert isinstance(y, (int, float)) and y >= 0
    assert isinstance(z, (int, float)) and z >= 0
    assert (isinstance(r_mm, (int, float)) and r_mm > 0) or r_mm is None

    return ct_array, new_spacing, x, y, z, r_mm

In [None]:
def get_nodule(ct_array: np.ndarray, x: [int, float], y: [int, float], z: [int, float], radius_mm: [int, float] = None, fixed_radius: int = None) -> np.ndarray:
    """
    The 3D-patch is cropped from the CT matrix of a case based on the nodule annotation.
    """
    assert isinstance(ct_array, np.ndarray) and ct_array.ndim == 3 and ct_array.dtype == np.int16 and ct_array.shape > (1, 1, 1)
    assert isinstance(x, (int, float)) and x >= 0
    assert isinstance(y, (int, float)) and y >= 0
    assert isinstance(z, (int, float)) and z >= 0

    if fixed_radius is None:
        assert isinstance(radius_mm, (int, float)) and radius_mm > 0
    if radius_mm is None:
        assert isinstance(fixed_radius, int) and fixed_radius > 0

    x, y, z = int(x), int(y), int(z)
    if fixed_radius is not None and radius_mm is None:
        ct_array = np.pad(ct_array, pad_width=fixed_radius, mode="constant", constant_values=ct_array.min())
        ct_array = ct_array[z : z + fixed_radius * 2, y : y + fixed_radius * 2, x : x + fixed_radius * 2]  # DHW
        assert ct_array.shape == (fixed_radius * 2, fixed_radius * 2, fixed_radius * 2), "abnormal nodule size" + ct_array.shape
    elif radius_mm is not None and fixed_radius is None:
        radius_mm = int(radius_mm)
        ct_array = np.pad(ct_array, pad_width=radius_mm, mode="constant", constant_values=ct_array.min())
        ct_array = ct_array[
            z : z + radius_mm * 2, y : y + radius_mm * 2, x : x + radius_mm * 2,
        ]
        assert ct_array.shape == (radius_mm * 2, radius_mm * 2, radius_mm * 2), "abnormal nodule size"
    else:
        raise ValueError

    assert ct_array.dtype == np.int16
    return ct_array

In [None]:
def get_3d_roi(case_path: Path, diameter_size: int, new_spacing: np.ndarray, fixed_radius: int) -> np.ndarray:
    """
    Extract 3D-ROI.
    """
    nii_img_path = case_path / f"{case_path.name}.nii"

    assert isinstance(case_path, Path) and case_path.exists()
    assert nii_img_path.exists()

    assert (isinstance(diameter_size, int) and diameter_size > 0 and diameter_size % 2 == 0) or diameter_size is None
    assert isinstance(new_spacing, np.ndarray)
    assert isinstance(fixed_radius, int) and fixed_radius > 0
    assert new_spacing.ndim == 1 and new_spacing.size == 3

    img_arr, affine = load_scan_nii(case_path=case_path)

    old_spacing = np.concatenate([np.abs(affine[0, :3][np.nonzero(affine[0, :3])]), np.abs(affine[1, :3][np.nonzero(affine[1, :3])]), np.abs(affine[2, :3][np.nonzero(affine[2, :3])]),])  # WHD

    annotation_file = case_path / "R.acsv"
    assert isinstance(annotation_file, Path) and annotation_file.exists()

    center_and_length = []
    for line in open(annotation_file):
        if line.startswith("point"):
            center_and_length.append(np.array(line.split("|")[1:4]).astype(np.float))  # WHD

    center, length = center_and_length

    center = center - affine[:3, -1]
    center = np.matmul(np.linalg.inv(affine[:3, :3]), center)

    length = np.matmul(np.linalg.inv(affine[:3, :3]), length)
    length = np.abs(length)

    img_arr, new_spacing, x, y, z, r_mm = re_sample(
        ct_array=img_arr.transpose(2, 1, 0), x=center[0], y=center[1], z=center[2], r_mm=diameter_size // 2 if diameter_size is not None else None, spacing=old_spacing[::-1], new_spacing=new_spacing,
    )

    patch_3d = get_nodule(ct_array=img_arr, x=x, y=y, z=z, radius_mm=None, fixed_radius=fixed_radius)

    assert patch_3d.shape == (fixed_radius * 2, fixed_radius * 2, fixed_radius * 2)
    assert patch_3d.dtype == np.int16
    return patch_3d

In [None]:
def get_case_path_and_label(ct_root_path: Path, label_info_lst: list) -> list:
    """
    Check data integrity.
    """
    assert isinstance(ct_root_path, Path) and ct_root_path.exists() and ct_root_path.is_dir()
    assert isinstance(label_info_lst, list) and len(label_info_lst) > 0

    cases_path_lst = [ct_root_path / sub_dir for sub_dir in ct_root_path.iterdir() if sub_dir.is_dir()]

    r = []
    for label in label_info_lst:
        found = False
        case_path = None
        for case_path in cases_path_lst:
            if label["ID"] == case_path.stem:
                found = True
                break
        if found:
            label["path"] = case_path
            r.append(label)
        else:
            print(f'Corresponding folder not found: case {label["ID"]}')
            pass
    assert len(r) > 0
    return r

In [None]:
def worker_for_one_case(para: dict, processed_path: Path, category_names: list, fixed_radius: int, spacing: np.ndarray) -> None:
    """
    Worker function that extracts node from a patient and saves the .npy file to the hard drive
    """
    assert isinstance(para, dict)
    assert isinstance(processed_path, Path) and processed_path.is_dir()
    assert isinstance(category_names, list)
    assert isinstance(category_names[0], str)
    assert isinstance(fixed_radius, int) and fixed_radius > 0
    assert isinstance(spacing, np.ndarray) and spacing.size == 3 and (spacing > 0).all()

    npy_folder = processed_path / category_names[para["c"]]
    npy_file = npy_folder / f'{para["ID"]}.npy'

    if not npy_folder.exists():
        npy_folder.mkdir(parents=True, exist_ok=True)

    if args.use_existing and npy_file.exists():
        print(f"{npy_file} file existed, skip")
    else:
        try:
            patch_3d = get_3d_roi(case_path=para["path"], diameter_size=None, new_spacing=spacing, fixed_radius=fixed_radius)
            assert patch_3d.dtype == np.int16
            np.save(str(npy_file), patch_3d)
        except IOError:
            print(f"Bug in {npy_file}")
        else:
            print(f"file {npy_file} saved")

In [None]:
def final_pre_processing(ct_root: Path, excel_f: Path, dataset_name: str) -> None:
    """
    Process all patient data from the dataset into *.npy files.
    """

    assert isinstance(ct_root, Path) and ct_root.exists() and ct_root.is_dir()
    assert isinstance(excel_f, Path) and excel_f.exists() and excel_f.is_file()
    assert isinstance(dataset_name, str) and dataset_name in ("part1", "test1", "test2")

    label_pd = pd.read_excel(excel_f)[["住院号", "N2"]]
    label_pd = label_pd.set_index("住院号")

    label_lst = [{"ID": str(index), "c": int(item["N2"]),} for index, item in label_pd.iterrows() if str(index) not in args.blacklist]
    case_path_label_list = get_case_path_and_label(label_info_lst=label_lst, ct_root_path=ct_root)

    output_path = args.output_path / ct_root.stem

    if not output_path.exists():
        output_path.mkdir(parents=True, exist_ok=True)

    partial_worker = partial(worker_for_one_case, processed_path=output_path, category_names=["negative", "positive"], fixed_radius=args.fixed_half_edge, spacing=np.array(args.spacing))

    p = Pool(cpu_count())
    print(f"Start processing {dataset_name} in parallel.")
    p.map(partial_worker, case_path_label_list)
    print(f"dataset {dataset_name} is processed.")

## Extract 3D Nodule Patches from All Cases

In [None]:
if args.pre_processing:
    final_pre_processing(ct_root=args.ct_path_p1, excel_f=args.excel_f_p1, dataset_name="part1")
    final_pre_processing(ct_root=args.ct_path_t1, excel_f=args.excel_f_t1, dataset_name="test1")
    final_pre_processing(ct_root=args.ct_path_t2, excel_f=args.excel_f_t2, dataset_name="test2")

## Shuffle Dataset

In [None]:
def shuffle_dataset(path: Path, r: tuple, seed: int) -> None:
    """
    Shuffle the *.npy files
    """
    assert isinstance(path, Path) and path.exists() and path.is_dir()
    assert isinstance(r, tuple) and len(r) == 2 and 0 < r[0] < r[1] < 1
    assert isinstance(seed, int)

    categories = ["negative", "positive"]
    original = [[], []]
    for child in path.glob("**/**/*"):
        if child.is_file() and child.suffix == ".npy":
            if "negative" in str(child):
                original[0].append(child)
            elif "positive" in str(child):
                original[1].append(child)
            else:
                raise ValueError

    print(len(original[0]))
    print(len(original[1]))
    original[0] = sorted(original[0], key=lambda x: x.stem)
    original[1] = sorted(original[1], key=lambda x: x.stem)

    random.seed(seed)
    random.shuffle(original[0])
    random.shuffle(original[1])

    for cate_i in range(2):
        for i, npy_f in enumerate(original[cate_i]):
            if 0 <= i < int(r[0] * len(original[cate_i])):  # train
                npy_f_new = path / "train" / categories[cate_i] / npy_f.name
            elif int(r[0] * len(original[cate_i])) <= i < int(r[1] * len(original[cate_i])):  # valid
                npy_f_new = path / "valid" / categories[cate_i] / npy_f.name
            elif int(r[1] * len(original[cate_i])) <= i < len(original[cate_i]):  # test
                npy_f_new = path / "test" / categories[cate_i] / npy_f.name
            else:
                raise ValueError
            if not npy_f_new.parent.exists():
                npy_f_new.parent.mkdir(parents=True, exist_ok=True)
            print(npy_f, "=>", npy_f_new)
            shutil.move(npy_f, npy_f_new)


if args.shuffle_dataset:
    shuffle_dataset(path=args.output_path / "part1", r=args.split_p, seed=7)

## Plot Extracted 3D RoI in Test 0 DataSet

In [None]:
def plot_and_check_test_set(test_set_path: Path, k=100):

    assert isinstance(test_set_path, Path) and test_set_path.exists() and test_set_path.is_dir()
    assert isinstance(k, int) and k > 0

    plt.figure(figsize=(15, 15), dpi=args.dpi)
    plt.subplots_adjust(wspace=0.5, hspace=0.5)

    count = 1
    for npy_file in random.sample(list(test_set_path.glob("**/*.npy")), k=k):
        file_name, c = npy_file.stem, npy_file.parent.stem

        x = np.load(str(npy_file))  # .astype(np.int16)
        x = np.clip(x, a_min=args.wl - args.ww / 2, a_max=args.wl + args.ww / 2)
        x = (x - (args.wl - args.ww / 2)) / args.ww
        x = x[x.shape[0] // 2]

        plt.subplot(10, 10, count)
        plt.imshow(x, "gray")
        plt.title(file_name + "/" + str(c), fontsize=10)
        count += 1
    plt.show()


if args.plot_validate_result:
    plot_and_check_test_set(test_set_path=args.output_path / "part1" / "test")

## Plot Extracted 3D RoI in Test 1 DataSet

In [None]:
if args.plot_validate_result:
    plot_and_check_test_set(test_set_path=args.output_path / "test1")

## Plot Extracted 3D RoI in Test 2 DataSet

In [None]:
if args.plot_validate_result:
    plot_and_check_test_set(test_set_path=args.output_path / "test2")

# Build Fastai Model

## CaseImage (Customized Class for Loading CT Case Data)

In [None]:
def random_rotation(volume, rotation, length):
    """
    3D Affine Transform
    """
    assert isinstance(volume, np.ndarray) and volume.ndim == 3
    assert isinstance(rotation, int) and rotation > 0

    theta_x = np.pi / 180 * np.random.uniform(-rotation, rotation)
    theta_y = np.pi / 180 * np.random.uniform(-rotation, rotation)
    theta_z = np.pi / 180 * np.random.uniform(-rotation, rotation)

    rotation_matrix_x = np.array([[1, 0, 0], [0, np.cos(theta_x), -np.sin(theta_x)], [0, np.sin(theta_x), np.cos(theta_x)],])
    rotation_matrix_y = np.array([[np.cos(theta_y), 0, np.sin(theta_y)], [0, 1, 0], [-np.sin(theta_y), 0, np.cos(theta_y)],])
    rotation_matrix_z = np.array([[np.cos(theta_z), -np.sin(theta_z), 0], [np.sin(theta_z), np.cos(theta_z), 0], [0, 0, 1],])
    transform_matrix = rotation_matrix_x @ rotation_matrix_y @ rotation_matrix_z

    center_in = 0.5 * np.array(volume.shape)
    center_out = 0.5 * np.array(volume.shape)

    if cp.cuda.is_available():  # affine transform on GPU
        volume_rotated = cu_ndimage.affine_transform(
            input=cp.asarray(volume), matrix=cp.asarray(transform_matrix), offset=cp.asarray(center_in - center_out.dot(transform_matrix.T)), order=1, mode="nearest", cval=0
        )
        volume_rotated = volume_rotated.get()
    else:  # affine transform on CPU
        volume_rotated = affine_transform(volume, transform_matrix, mode="constant", cval=0, offset=center_in - center_out.dot(transform_matrix.T), order=3)

    assert isinstance(volume_rotated, np.ndarray) and volume_rotated.dtype == np.int16
    return volume_rotated


class CaseImage(fastuple):
    def show(self, ctx=None, **kwargs):
        img, target = self
        assert isinstance(img, (PIL.Image.Image, Tensor))
        assert target is None or isinstance(target, (str, int)), f"target data type：{type(target)}"

        if not isinstance(img, Tensor):
            img = tensor(img).permute(2, 0, 1)

        img = img.byte()

        return show_image(torch.cat([img[0], img[1], img[2],], dim=1,), cmap="gray", figsize=(5, 15), title=target, ctx=ctx, **kwargs,)

## CaseTransform (Customized Transform for Loading CT Case Data)

In [None]:
class CaseTransform(Transform):
    """
    Input a npy path and output a CaseImage object
    """

    def __init__(self, files, label_func, splits):
        assert isinstance(files, L)
        assert hasattr(label_func, "__call__")
        assert isinstance(splits, tuple)

        self.vocab, self.label2id = uniqueify(list(map(label_func, files)), sort=True, bidir=True,)
        self.label_func = label_func
        self.train_lst = [f for f in files[splits[0]]]

    def encodes(self, fn):
        assert isinstance(fn, Path) and fn.is_file(), f"Got {type(fn)}"

        f = np.load(str(fn)).astype(np.int16)

        _, h, w = f.shape
        i_middle = f.shape[0] // 2

        if fn in self.train_lst:
            f = random_rotation(volume=f, rotation=args.rotation, length=h)
        f = np.stack([f[i_middle, :, :], f[:, i_middle, :], f[:, :, i_middle],], axis=2)

        assert f.shape == (h, w, 3)

        f = np.uint8(np.clip((f - (args.wl - args.ww / 2)) / args.ww * 255, a_min=0, a_max=255))
        assert f.dtype == np.uint8

        cls = self.label2id[self.label_func(fn)]
        return CaseImage(PILImage.create(f), cls)

    def decodes(self, x: PILImage) -> matplotlib.axes.SubplotBase:
        assert isinstance(x, PILImage), f"Wrong data type: {type(x)}"

        if not isinstance(x, Tensor):
            img = tensor(x).permute(2, 0, 1)

        img = img.byte()
        return show_image(torch.cat([img[0], img[1], img[2],], dim=1), cmap="gray", figsize=(5, 15))

## Customized Dataloaders

In [None]:
@typedispatch
def show_batch(x: CaseImage, y, samples, ctxs=None, max_n=6, nrows=None, ncols=2, figsize=None, **kwargs):
    assert isinstance(x, CaseImage)
    assert y is None
    assert ctxs is None
    assert samples is None
    assert isinstance(max_n, int) and max_n > 0
    assert nrows is None or isinstance(nrows, int) and nrows > 0
    assert isinstance(ncols, int) and ncols > 0
    assert figsize is None or isinstance(figsize, tuple)

    if figsize is None:
        figsize = (ncols * 4, max_n // ncols * 2)
    if ctxs is None:
        ctxs = get_grid(min(x[0].shape[0], max_n), nrows=None, ncols=ncols, figsize=figsize)
    for i, ctx in enumerate(ctxs):
        CaseImage(x[0][i], x[1][i].item()).show(ctx=ctx)


def get_balance_files(files: L, train_name: str, valid_name: str) -> L:
    """
    Down-sampling negative class during training.
    """
    assert isinstance(files, L)
    assert isinstance(train_name, str)
    assert isinstance(valid_name, str)

    new_files = L()
    tmp_train_negative = L()
    positive_count = 0
    for i in files.items:
        if i.parent.parent.stem == train_name:
            if i.parent.stem == "positive":
                new_files.append(i)
                positive_count += 1
            elif i.parent.stem == "negative":
                tmp_train_negative.append(i)
            else:
                raise ValueError
        else:
            new_files.append(i)
    random.shuffle(tmp_train_negative)
    for i in tmp_train_negative[: int(positive_count / args.r_pos)]:
        new_files.append(i)
    return new_files


class Sharpen(RandTransform):
    """
    Sharpen image using opencv.
    """

    def __init__(self, **kwargs):
        self.kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])

        super().__init__(**kwargs)

    def encodes(self, x: PIL.Image.Image):
        x = cv2.filter2D(np.array(x), -1, self.kernel)

        x = PIL.Image.fromarray(x)
        return x


class Smoothing(RandTransform):
    """
    Add Gaussian Smoothing
    """

    def __init__(self, kernel_size=2, denominator=9, **kwargs):
        self.kernel_size = kernel_size
        self.denominator = denominator
        self.kernel = np.array([[1, 2, 1], [2, 4, 2], [1, 2, 1]]) / self.denominator
        super().__init__(**kwargs)

    def encodes(self, x: PIL.Image.Image):
        x = cv2.filter2D(np.array(x), -1, self.kernel)

        x = PIL.Image.fromarray(x)
        return x


def get_dls(path, train_name, valid_name, verbose=True, balance=True):
    """
    Get Balanced DataLoaders
    """
    assert isinstance(path, Path) and path.is_dir()
    assert isinstance(train_name, str)
    assert isinstance(valid_name, str)
    assert isinstance(balance, bool)

    files = get_files(path=path, extensions=".npy", recurse=True,)

    if balance:
        files = get_balance_files(files=files, train_name=train_name, valid_name=valid_name)

    splits = GrandparentSplitter(train_name=train_name, valid_name=valid_name,)(files)
    tfm = CaseTransform(files=files, label_func=parent_label, splits=splits)
    tls = TfmdLists(files, tfms=tfm, splits=splits, verbose=verbose)

    if verbose:
        print("Building Dataloaders")
    dataloaders = tls.dataloaders(
        after_item=[
            RandomCrop(size=args.radius_middle * 2),
            RandomResizedCrop(size=args.resize_size, min_scale=0.4, ratio=(0.75, 1.3333333333333333), resamples=(2, 0),),
            ToTensor,
            Sharpen(p=0.5),
            Smoothing(kernel_size=2, denominator=16, p=0.5),
        ],
        after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)],
        verbose=verbose,
        bs=args.batch_size,
        num_workers=args.num_workers,
    )

    return dataloaders


if args.test_code:
    dls = get_dls(path=args.output_path / "part1", train_name="train", valid_name="valid")
    len(dls.train_ds)

### Build Training and Validation Dataset

In [None]:
dls = get_dls(path=args.output_path / "part1", train_name="train", valid_name="valid")
if args.test_code:
    f"----------Plot 25 Cases in Training Dataset----------"
    dls.show_batch(max_n=25, nrows=5, ncols=5, figsize=(30, 10), dpi=args.dpi)

In [None]:
if args.test_code:
    f"----------Plot 25 Cases in Validation Dataset----------"
    dls.valid.show_batch(max_n=25, nrows=5, ncols=5, figsize=(30, 10), dpi=args.dpi)

### Build Test0 Dataset

In [None]:
test0_files = get_files(path=args.output_path / "part1" / "test", extensions=".npy", recurse=True,)
dl_t0 = dls.test_dl(test0_files, shuffle=False)
if args.test_code:
    dl_t0.show_batch(max_n=25, nrows=5, ncols=5, figsize=(30, 10), dpi=args.dpi)

### Build Test1 Dataset

In [None]:
test1_files = get_files(path=args.output_path / "test1", extensions=".npy", recurse=True,)
dl_t1 = dls.test_dl(test1_files, shuffle=False)
if args.test_code:
    dl_t1.show_batch(max_n=25, nrows=5, ncols=5, figsize=(30, 10), dpi=args.dpi)

### Build Test2 Dataset

In [None]:
test2_files = get_files(path=args.output_path / "test2", extensions=".npy", recurse=True,)
dl_t2 = dls.test_dl(test2_files, shuffle=False)
if args.test_code:
    dl_t2.show_batch(max_n=25, nrows=5, ncols=5, figsize=(30, 10), dpi=args.dpi)

## Performance Metrics

In [None]:
def sensitivity(y_pred, y_true, thresh=0.5):
    """
    sensitivity metric function.
    """
    y_pred = F.softmax(y_pred, dim=1)[:, 1]
    TP = ((y_pred > thresh) * (y_true.data)).float().sum()
    # 预测为negative实际为positive的总数
    FN = ((y_pred <= thresh) * (y_true.data)).float().sum()
    return (TP / (TP + FN)).item()


def specificity(y_pred, y_true, thresh=0.5):
    """
    specificity metric function.
    """
    y_pred = F.softmax(y_pred, dim=1)[:, 1]
    FP = ((y_pred > thresh).float() * (1 - y_true.data).float()).sum()
    TN = ((y_pred <= thresh).float() * (1 - y_true.data).float()).sum()
    return (TN / (FP + TN)).item()


def Precision(axis=-1, labels=None, pos_label=1, average="binary", sample_weight=None):
    """
    Precision for single-label classification
    """
    return skm_to_fastai(skm.precision_score, axis=axis, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight, zero_division=0)  # 对原API的补充


def Recall(axis=-1, labels=None, pos_label=1, average="binary", sample_weight=None):
    """
    Recall for single-label classification
    """
    return skm_to_fastai(skm.recall_score, axis=axis, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight, zero_division=0)  # 对原API的补充

## Build Neural Network

In [None]:
class MyNN(Module):
    def __init__(self, encoder, head):
        self.encoder, self.head = encoder, head

    def forward(self, x):
        ftrs = self.encoder(x)
        return self.head(ftrs)

In [None]:
encoder = create_body(arch=resnet152, n_in=args.n_in, pretrained=True, cut=-2)
encoder[0]

In [None]:
head = create_head(512 * 4 * 2, n_out=2, ps=args.ps)
head

In [None]:
model = MyNN(encoder, head)
model


def my_nn_splitter(model):
    return [params(model.encoder), params(model.head)]

## Loss Function

In [None]:
loss_func = CrossEntropyLossFlat(weight=torch.tensor(args.loss_weight).cuda())

## Optimizer

In [None]:
opt_func = partial(Adam, mom=args.mom, sqr_mom=0.999, wd=args.wd, decouple_wd=False, eps=1e-08)

## Balance Sample Callback

In [None]:
class DataBalanceCallback(Callback):
    def before_train(self):
        self.learn.dls = get_dls(path=args.output_path / "part1", train_name="train", valid_name="valid", verbose=False)

## Learner

In [None]:
learn = Learner(
    dls=dls,
    model=model,
    loss_func=loss_func,
    opt_func=opt_func,
    splitter=my_nn_splitter,
    path=Path("."),
    model_dir=Path("weights"),
    metrics=[RocAucBinary(), accuracy, Recall(), Precision()],
    wd_bn_bias=False,
    cbs=[MixUp(), DataBalanceCallback()],
)

learn.summary()
learn.show_training_loop()
print(learn.cbs)

# Model Training

## Training （head fine-tuning + full network finetuning）

In [None]:
learn.freeze()
learn.summary()
learn.lr_find()

In [None]:
learn.fit(
    n_epoch=args.freeze_epochs, lr=args.lr, wd=args.wd, cbs=[SaveModelCallback(monitor="roc_auc_score", fname=f"bestmodel-stage1-{args.exp_name}", with_opt=False)],
)

learn.load(file=f"bestmodel-stage1-{args.exp_name}")
learn.unfreeze()
learn.fit(
    n_epoch=args.epoch, lr=args.lr / 10, wd=args.wd, cbs=[SaveModelCallback(monitor="roc_auc_score", fname=f"bestmodel-stage2-{args.exp_name}", with_opt=False)],
)

In [None]:
learn.recorder.plot_loss(with_valid=True)

# Plot the Results on the Five Datasets

## Results on Training Set

In [None]:
def plot_top_losses(x, y, samples, outs, raws, losses, nrows=None, ncols=None, figsize=None, **kwargs):
    axs = get_grid(len(samples), nrows=nrows, ncols=ncols, add_vert=1, figsize=figsize, title="Prediction/Actual/Loss/Probability", dpi=args.dpi)
    for ax, s, o, r, l in zip(axs, samples, outs, raws, losses):
        s[0].show(ctx=ax, **kwargs)
        ax.set_title(f"{o[0].argmax(-1)}/{s[1]} / {l.item():.2f} / {r.max().item():.2f}")


class MyInterpretation(Interpretation):
    "Interpretation base class, can be inherited for task specific Interpretation classes"

    def __init__(self, dl, inputs, preds, targs, decoded, losses):
        super().__init__(dl, inputs, preds, targs, decoded, losses)

    def _pre_show_batch(self, b, max_n=9):
        "Decode `b` to be ready for `show_batch`"
        b = self.dl.decode(b)
        its = L()
        for x1, y1 in zip(b[0], b[1]):
            its.append((CaseImage(x1, None), y1))
        if not is_listy(b):
            b, its = [b], L((o,) for o in its)
        return detuplify(b[: self.dl.n_inp]), detuplify(b[self.dl.n_inp :]), its

    @classmethod
    def from_learner(cls, learn, ds_idx=-1, dl=None, act=None):
        "Construct interpretation object from a learner"
        if ds_idx in (0, 1):
            return cls(dl, *learn.get_preds(ds_idx=ds_idx, with_input=True, with_loss=True, with_decoded=True, act=act))
        else:
            return cls(dl, *learn.get_preds(dl=dl, with_input=True, with_loss=True, with_decoded=True, act=act))

    def plot_top_losses(self, k, largest=True, **kwargs):
        losses, idx = self.top_losses(k, largest)
        if not isinstance(self.inputs, tuple):
            self.inputs = (self.inputs,)

        if isinstance(self.inputs[0], Tensor):
            inps = tuple(o[idx] for o in self.inputs)
        else:
            inps = self.dl.create_batch(self.dl.before_batch([tuple(o[i] for o in self.inputs) for i in idx]))

        b = inps + tuple(o[idx] for o in (self.targs if is_listy(self.targs) else (self.targs,)))

        x, y, its = self._pre_show_batch(b, max_n=k)

        b_out = inps + tuple(o[idx] for o in (self.decoded if is_listy(self.decoded) else (self.decoded,)))

        x1, y1, outs = self._pre_show_batch(b_out, max_n=k)
        if its is not None:
            plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), self.preds[idx], losses, **kwargs)


class MyClassificationInterpretation(MyInterpretation):
    "Interpretation methods for classification models."

    def __init__(self, dl, inputs, preds, targs, decoded, losses):
        super().__init__(dl, inputs, preds, targs, decoded, losses)
        self.vocab = self.dl.vocab

    def confusion_matrix(self):
        "Confusion matrix as an `np.ndarray`."
        x = torch.arange(0, len(self.vocab))
        d, t = flatten_check(self.decoded.argmax(-1), self.targs)
        cm = ((d == x[:, None]) & (t == x[:, None, None])).long().sum(2)
        return to_np(cm)

    def plot_confusion_matrix(self, normalize=False, title="Confusion matrix", cmap="Blues", norm_dec=2, plot_txt=True, **kwargs):
        "Plot the confusion matrix, with `title` and using `cmap`."
        cm = self.confusion_matrix()
        if normalize:
            cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
        fig = plt.figure(**kwargs)
        plt.imshow(cm, interpolation="nearest", cmap=cmap)
        plt.title(title)
        tick_marks = np.arange(len(self.vocab))
        plt.xticks(tick_marks, self.vocab, rotation=90)
        plt.yticks(tick_marks, self.vocab, rotation=0)

        if plot_txt:
            thresh = cm.max() / 2.0
            for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
                coeff = f"{cm[i, j]:.{norm_dec}f}" if normalize else f"{cm[i, j]}"
                plt.text(j, i, coeff, horizontalalignment="center", verticalalignment="center", color="white" if cm[i, j] > thresh else "black")

        ax = fig.gca()
        ax.set_ylim(len(self.vocab) - 0.5, -0.5)

        plt.tight_layout()
        plt.ylabel("Actual")
        plt.xlabel("Predicted")
        plt.grid(False)

    def most_confused(self, min_val=1):
        "Sorted descending list of largest non-diagonal entries of confusion matrix, presented as actual, predicted, number of occurrences."
        cm = self.confusion_matrix()
        np.fill_diagonal(cm, 0)
        res = [(self.vocab[i], self.vocab[j], cm[i, j]) for i, j in zip(*np.where(cm >= min_val))]
        return sorted(res, key=itemgetter(2), reverse=True)

    def print_classification_report(self):
        "Print scikit-learn classification report"
        d, t = flatten_check(self.decoded, self.targs)
        print(skm.classification_report(t, d, labels=list(self.vocab.o2i.values()), target_names=[str(v) for v in self.vocab]))


def validate_one_set(learn, dl=None, ds_idx=-1, dataset_name=None):
    assert isinstance(dataset_name, str)
    res = learn.validate(ds_idx=ds_idx, dl=dl,)

    print(f"{dataset_name}: loss = {res[0]}, auc = {res[1]}, accuracy = {res[2]}, sensitivity(recall) = {res[3]}, precision = {res[4]}")
    interp = MyClassificationInterpretation.from_learner(learn=learn, dl=dl, ds_idx=ds_idx, act=partial(F.softmax, dim=-1))
    interp.plot_confusion_matrix(title=f"Confusion Matrix of {dataset_name} Set", dpi=args.dpi)
    interp.plot_top_losses(k=25, largest=True, figsize=(30, 10))


class Hook:
    def __init__(self, m):
        self.hook = m.register_forward_hook(self.hook_func)

    def hook_func(self, m, i, o):
        self.stored = o.detach().clone()

    def __enter__(self, *args):
        return self

    def __exit__(self, *args):
        self.hook.remove()


class HookBwd:
    def __init__(self, m):
        self.hook = m.register_backward_hook(self.hook_func)

    def hook_func(self, m, gi, go):
        self.stored = go[0].detach().clone()

    def __enter__(self, *args):
        return self

    def __exit__(self, *args):
        self.hook.remove()


def generate_hotmap(cls, x_idx, layer_idx, ax, x, y, x_dec):
    assert isinstance(cls, int) and cls in (0, 1)
    assert isinstance(x_idx, int) and x_idx >= 0
    assert isinstance(layer_idx, int)
    assert isinstance(x, Tensor)
    assert isinstance(y, Tensor)
    assert isinstance(x_dec, Tensor)

    with HookBwd(learn.model.encoder[layer_idx]) as hookg:
        with Hook(learn.model.encoder[layer_idx]) as hook:
            output = learn.model.eval()(x.cuda())
            print("prediction: ", F.softmax(output, dim=-1))
            act = hook.stored
        output[0, cls].backward()
        grad = hookg.stored

    w = grad[0].mean(dim=[1, 2], keepdim=True)
    cam_map = (w * act[0]).sum(0)

    cam_map = cam_map.detach().cpu()
    cam_map = tensor(cv2.resize(cam_map.numpy(), dsize=x_dec.shape[-2:], interpolation=cv2.INTER_CUBIC))
    im = ax.imshow(torch.cat([cam_map, cam_map, cam_map], dim=1), alpha=0.5, cmap="jet", vmin=args.vmin, vmax=args.vmax)

    return im


def interpret_one_case(dl, layer_idx, sample_index, title):
    assert isinstance(dl, fastai.data.core.TfmdDL)
    assert isinstance(layer_idx, int)
    assert isinstance(sample_index, int)
    assert isinstance(title, str)

    fig, axs = plt.subplots(nrows=3, ncols=1, figsize=(30, 10), dpi=args.dpi)

    xs, ys = first(dl)
    x = xs[sample_index][None]
    y = ys[sample_index]
    print("target =", y)
    x_dec = dl.decode((x,))[0][0]
    case_image = CaseImage(x_dec, None)

    for i, ax in enumerate(axs.reshape(-1)):
        if i == 0:
            case_image.show(ctx=ax)
        elif i == 1:
            case_image.show(ctx=ax)
            im = generate_hotmap(cls=0, x_idx=sample_index, layer_idx=layer_idx, ax=ax, x=x, y=y, x_dec=x_dec)
        elif i == 2:
            case_image.show(ctx=ax)
            im = generate_hotmap(cls=1, x_idx=sample_index, layer_idx=layer_idx, ax=ax, x=x, y=y, x_dec=x_dec)

    cbar_ax = fig.add_axes([0.65, 0.15, 0.025, 0.7])
    cbar = fig.colorbar(im, cax=cbar_ax,)

    cbar.ax.tick_params(labelsize=0, axis="both", which="both", length=0)

    plt.show()

In [None]:
learn.load(file=f"bestmodel-stage2-{args.exp_name}")
learn.dls = get_dls(path=args.output_path / "part1", train_name="train", valid_name="valid", verbose=False, balance=False)

if args.plot_validate_result:
    validate_one_set(learn=learn, dl=learn.dls.train, ds_idx=0, dataset_name="Train")

## Results on Validation Set

In [None]:
if args.plot_validate_result:
    validate_one_set(learn=learn, dl=learn.dls.valid, ds_idx=1, dataset_name="Validation")

## Results on Test0 Set

In [None]:
if args.plot_validate_result:
    validate_one_set(learn=learn, dl=dl_t0, dataset_name="Test0")

    random.seed(8)
    sample_index = random.choice(range(min(args.batch_size, len(dl_t0.items))))
    interpret_one_case(dl=dl_t0, layer_idx=-1, sample_index=sample_index, title="Test0")
    random.seed(24)
    sample_index = random.choice(range(min(args.batch_size, len(dl_t0.items))))
    interpret_one_case(dl=dl_t0, layer_idx=-1, sample_index=sample_index, title="Test0")

## Results on Test1 Set

In [None]:
if args.plot_validate_result:
    validate_one_set(learn=learn, dl=dl_t1, dataset_name="Test1")

    random.seed(8)
    sample_index = random.choice(range(min(args.batch_size, len(dl_t1.items))))
    interpret_one_case(dl=dl_t1, layer_idx=-1, sample_index=sample_index, title="Test1")
    random.seed(24)
    sample_index = random.choice(range(min(args.batch_size, len(dl_t1.items))))
    interpret_one_case(dl=dl_t1, layer_idx=-1, sample_index=sample_index, title="Test1")

## Results on Test2 Set

In [None]:
if args.plot_validate_result:
    validate_one_set(learn=learn, dl=dl_t2, dataset_name="Test2")

    random.seed(8)
    sample_index = random.choice(range(min(args.batch_size, len(dl_t2.items))))
    interpret_one_case(dl=dl_t2, layer_idx=-1, sample_index=sample_index, title="Test2")
    random.seed(24)
    sample_index = random.choice(range(min(args.batch_size, len(dl_t2.items))))
    interpret_one_case(dl=dl_t2, layer_idx=-1, sample_index=sample_index, title="Test2")

# Save Prediction to Disk

In [None]:
def save_prediction_excel(ids, pred, targs, excel_path):
    """
    Output prediction results of train, valid and test dataset to a excel file.
    """
    df = pd.DataFrame(np.concatenate((ids[:, np.newaxis], pred, targs[:, np.newaxis]), axis=-1))
    df.to_excel(excel_path, index=False)


if not (args.output_path / args.exp_name).exists():
    (args.output_path / args.exp_name).mkdir(parents=True, exist_ok=True)

# train set
preds, targets = learn.get_preds(ds_idx=0, reorder=True)
ids = np.array([str(i).split("/")[-1].split(".")[0] for i in learn.dls.train.items])

save_prediction_excel(ids, F.softmax(preds, dim=-1).cpu().numpy(), targets, args.output_path / args.exp_name / "predict_train.xlsx")

# validation set
preds, targets = learn.get_preds(ds_idx=1)
ids = np.array([str(i).split("/")[-1].split(".")[0] for i in learn.dls.valid.items])

save_prediction_excel(ids, F.softmax(preds, dim=-1).cpu().numpy(), targets, args.output_path / args.exp_name / "predict_valid.xlsx")

# test0 set
preds, targets = learn.get_preds(dl=dl_t0)
ids = np.array([str(i).split("/")[-1].split(".")[0] for i in dl_t0.items])

save_prediction_excel(ids, F.softmax(preds, dim=-1).cpu().numpy(), targets, args.output_path / args.exp_name / "predict_test0.xlsx")

# test1 set
preds, targets = learn.get_preds(dl=dl_t1)
ids = np.array([str(i).split("/")[-1].split(".")[0] for i in dl_t1.items])

save_prediction_excel(ids, F.softmax(preds, dim=-1).cpu().numpy(), targets, args.output_path / args.exp_name / "predict_test1.xlsx")

# test2 set
preds, targets = learn.get_preds(dl=dl_t2)
ids = np.array([str(i).split("/")[-1].split(".")[0] for i in dl_t2.items])

save_prediction_excel(ids, F.softmax(preds, dim=-1).cpu().numpy(), targets, args.output_path / args.exp_name / "predict_test2.xlsx")

## Plot ROC Curve

In [None]:
fpr = dict()
tpr = dict()
roc_auc = dict()
# Train Set
preds, targets = learn.get_preds(ds_idx=0)
fpr[0], tpr[0], _ = skm.roc_curve(y_true=targets, y_score=F.softmax(preds, dim=-1).cpu().numpy()[:, 1])
roc_auc[0] = skm.auc(fpr[0], tpr[0])

# Valid Set
preds, targets = learn.get_preds(ds_idx=1)
fpr[1], tpr[1], _ = skm.roc_curve(y_true=targets, y_score=F.softmax(preds, dim=-1).cpu().numpy()[:, 1])
roc_auc[1] = skm.auc(fpr[1], tpr[1])

# Test0 Set
preds, targets = learn.get_preds(dl=dl_t0)
fpr[2], tpr[2], _ = skm.roc_curve(y_true=targets, y_score=F.softmax(preds, dim=-1).cpu().numpy()[:, 1])
roc_auc[2] = skm.auc(fpr[2], tpr[2])

# Test1 Set
preds, targets = learn.get_preds(dl=dl_t1)
fpr[3], tpr[3], _ = skm.roc_curve(y_true=targets, y_score=F.softmax(preds, dim=-1).cpu().numpy()[:, 1])
roc_auc[3] = skm.auc(fpr[3], tpr[3])

# Test2 Set
preds, targets = learn.get_preds(dl=dl_t2)
fpr[4], tpr[4], _ = skm.roc_curve(y_true=targets, y_score=F.softmax(preds, dim=-1).cpu().numpy()[:, 1])
roc_auc[4] = skm.auc(fpr[4], tpr[4])

plt.figure(dpi=args.dpi, figsize=(5, 5))
lw = 2
plt.plot(fpr[0], tpr[0], color="blue", lw=lw, label="Train Set ROC curve (AUC = %0.2f)" % roc_auc[0])
plt.plot(fpr[1], tpr[1], color="green", lw=lw, label="Validation Set ROC curve (AUC = %0.2f)" % roc_auc[1])
plt.plot(fpr[2], tpr[2], color="darkorange", lw=lw, label="Test0 Set ROC curve (AUC = %0.2f)" % roc_auc[2])
plt.plot(fpr[3], tpr[3], color="yellow", lw=lw, label="Test1 Set ROC curve (AUC = %0.2f)" % roc_auc[3])
plt.plot(fpr[4], tpr[4], color="red", lw=lw, label="Test2 Set ROC curve (AUC = %0.2f)" % roc_auc[4])

plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel(
    "1 - Specificity", fontsize=11,
)
plt.ylabel(
    "Sensitivity", fontsize=11,
)
plt.title(
    "Receiver Operating Characteristic", fontsize=11,
)
plt.legend(loc="lower right")
plt.show()