In [3]:
import os
import random
from tqdm import tqdm
import numpy as np
import torch
import cv2
import pandas as pd
from torch.utils.data import DataLoader
from dataset import VolumeDataset
from scnas import ScNas
from glob import glob


In [4]:
class CFG:
    seed          = 42
    debug         = False # set debug=False for Full Training
    exp_name      = 'baseline'
    output_dir    = './checkpoint/'
    model_name    = 'scnas'

    valid_bs      = 8
    volume_shape  = [128, 128, 128]
    overlap_size  = 64 

    device        = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

    gt_df = "../data/gt.csv"
    data_root = "../data/blood-vessel-segmentation/"

    valid_groups = ["train/kidney_1_dense", "train/kidney_2"]
    #valid_groups = ["test/kidney_5", "test/kidney_6"]



In [5]:
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
set_seed(CFG.seed)

# DataLoader

In [6]:
DATASET_FOLDER = "/kaggle/input/blood-vessel-segmentation"
ls_images = glob(os.path.join(DATASET_FOLDER, "test", "*", "*", "*.tif"))
print(f"found images: {len(ls_images)}")

found images: 0


In [7]:
inference_path = [os.path.join(CFG.data_root, group) for group in CFG.valid_groups]

kidney_5 = VolumeDataset(inference_path[0], subvol_size=CFG.volume_shape[0], overlap=CFG.overlap_size)
#kidney_6 = VolumeDataset(inference_path[1], subvol_size=CFG.volume_shape[0], overlap=CFG.overlap_size)


  0%|          | 0/2279 [00:00<?, ?it/s]

100%|██████████| 2279/2279 [00:14<00:00, 158.60it/s]


In [9]:
def infer_segmentation(model, dataset, subvol_size, overlap, batch_size, device):
    """
    Perform batched inference on subvolumes and reconstruct the segmented volume.

    Parameters:z
    model (torch.nn.Module): Trained PyTorch model for segmentation.
    dataset (VolumeDataset): The dataset containing the subvolumes.
    subvol_size (int): Size of the subvolumes.
    overlap (int): Size of the overlap between subvolumes.
    batch_size (int): The size of each batch.
    device (str): Device to perform inference on ('cpu' or 'cuda').

    Returns:
    numpy.ndarray: The reconstructed segmented volume.
    """
    model.to(device)
    model.eval()

    C, Z, Y, X = dataset.volume.shape
    segmented_volume = np.zeros_like(dataset.volume, dtype=np.float32)
    count_volume = np.ones_like(dataset.volume, dtype=np.int32)  # Count for averaging

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    with torch.no_grad():
        for batch, coords in tqdm(loader, desc='Segmenting', unit='batch'):
            batch = batch.to(device)  # Send batch to device
            segmented_batch = torch.sigmoid(model(batch)).cpu().numpy()  # Perform inference and send to cpu
            z_coords, y_coords, x_coords = coords

            # Iterate through each subvolume in the batch
            for idx, segmented_subvol in enumerate(segmented_batch):  # Unpack coordinates
                z = z_coords[idx].item()  # Convert tensor to integer
                y = y_coords[idx].item()
                x = x_coords[idx].item()
                
                segmented_volume[:, z:z+subvol_size, y:y+subvol_size, x:x+subvol_size] += segmented_subvol
                count_volume[:, z:z+subvol_size, y:y+subvol_size, x:x+subvol_size] += 1

    segmented_volume /= count_volume

    return segmented_volume

In [10]:
def build_model(device):
    model = ScNas(num_feature = 8, num_layers = 6, num_multiplier = 2, num_classes = 1, use_bridge = True, input_channel = 1)
    model.to(device)
    return model

In [11]:
def load_model(device, path):
    model = build_model(device)
    model.load_state_dict(torch.load(path))
    return model

In [12]:
model = load_model(CFG.device, "./checkpoint/last_epoch.bin")

# Inference

In [13]:
kidney_5_mask = infer_segmentation(model, 
                                   dataset = kidney_5, 
                                   subvol_size=CFG.volume_shape[0], 
                                   overlap=CFG.overlap_size, 
                                   batch_size=CFG.valid_bs, 
                                   device=CFG.device)


Segmenting: 100%|██████████| 1050/1050 [13:10<00:00,  1.33batch/s]


In [None]:
kidney_6_mask = infer_segmentation(model, 
                                   dataset = kidney_6, 
                                   subvol_size=CFG.volume_shape[0], 
                                   overlap=CFG.overlap_size, 
                                   batch_size=CFG.valid_bs, 
                                   device=CFG.device)

In [14]:
np.sum(np.isnan(kidney_5_mask))

0

In [17]:
preds = []
preds.append((kidney_5_mask>0.5).astype(np.uint8))
#preds.append((kidney_6_mask>0.5).astype(np.uint8))


In [16]:
def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    rle = ' '.join(str(x) for x in runs)
    if rle=='':
        rle = '1 0'
    return rle

def remove_small_objects(img, min_size):
    # Find all connected components (labels)
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img, connectivity=8)

    # Create a mask where small objects are removed
    new_img = np.zeros_like(img)
    for label in range(1, num_labels):
        if stats[label, cv2.CC_STAT_AREA] >= min_size:
            new_img[labels == label] = 1

    return new_img

In [18]:
rles = []

for pred in preds:
    pred = pred.squeeze()
    for pred_by_slice in pred:
        pred_by_slice = remove_small_objects(pred_by_slice, 10)
        rle = rle_encode(pred_by_slice)
        rles.append(rle)

In [21]:
ids = []
for p_img in tqdm(ls_images):
    path_ = p_img.split(os.path.sep)
    # parse the submission ID
    dataset = path_[-3]
    slice_id, _ = os.path.splitext(path_[-1])
    ids.append(f"{dataset}_{slice_id}")

0it [00:00, ?it/s]


In [None]:
submission = pd.DataFrame.from_dict({
    "id": ids,
    "rle": rles
})
submission.to_csv("submission.csv", index=False)

In [None]:
submission