In [112]:
from tqdm import tqdm
from array_lib import *
from skimage.transform import resize
from ply_creation_lib import create_ply
from skimage.measure import block_reduce
from numpy.lib.stride_tricks import sliding_window_view
import matplotlib.pyplot as plt
import pydicom as dicom
import nibabel as nib
import torchio as tio
import numpy as np
import cv2 as cv
import itertools
import pickle
import shutil
import time
import copy
import json
import os

In [127]:
np.random.seed(35138)

folder = f'{os.getcwd()}\\annotations'
result_folder = f'd:\\dicom\\calcinate_dataset\\base'
augmented_folder = f'd:\\dicom\\calcinate_dataset\\augmented'
vessels = [i for i in os.listdir(folder) if 'annotation' in i]

In [111]:
def recreate_folder(output_folder):
    if os.path.exists(output_folder):
        shutil.rmtree(output_folder)
    os.mkdir(f'{output_folder}')
    os.mkdir(f'{output_folder}\\vessels')
    os.mkdir(f'{output_folder}\\calcinates')

def extract_joint_patches(scan, vessel_mask, calcinate_mask, patch_size=64, stride=32, vessel_thresh=0):
    # Sliding window views
    scan_patches = sliding_window_view(scan, (patch_size,)*3)[::stride, ::stride, ::stride]
    vessel_patches = sliding_window_view(vessel_mask, (patch_size,)*3)[::stride, ::stride, ::stride]
    calc_patches = sliding_window_view(calcinate_mask, (patch_size,)*3)[::stride, ::stride, ::stride]

    # Reshape into (N, 64, 64, 64)
    scan_patches = scan_patches.reshape(-1, patch_size, patch_size, patch_size)
    vessel_patches = vessel_patches.reshape(-1, patch_size, patch_size, patch_size)
    calc_patches = calc_patches.reshape(-1, patch_size, patch_size, patch_size)

    # Filter: patches that contain some vessels
    vessel_sums = np.count_nonzero(vessel_patches, axis=(1, 2, 3))
    keep_idx = np.where(vessel_sums > vessel_thresh)[0]

    scans = scan_patches[keep_idx]
    vessels = vessel_patches[keep_idx]
    calcinates = calc_patches[keep_idx]
    return (scans, vessels, calcinates)

def save_data(filename: str, array: np.ndarray):
    np.save(filename, array)

def save_datums(split_data: tuple[list]):
    with_calcinates = [d for d in zip(*split_data) if np.sum(d[2]) > 50]
    current_files = [i for i in os.listdir(f'{result_folder}\\vessels') if i.startswith('1.')]
    counter = 0
    if len(current_files) > 0:
        counter = max([int(i.split('.')[1]) for i in current_files]) + 1
        print(counter)
    for scan, vessel, calc in tqdm(with_calcinates, desc='with'):
        colored_vessel = np.where(vessel, scan, 0)
        save_data(f'{result_folder}\\vessels\\1.{counter}.npy', colored_vessel)
        save_data(f'{result_folder}\\calcinates\\1.{counter}.npy', calc)
        counter += 1

    without_calcinates = [d for d in zip(*split_data) if np.sum(d[2]) == 0]
    current_files = [i for i in os.listdir(f'{result_folder}\\vessels') if i.startswith('0.')]
    counter = 0
    if len(current_files) > 0:
        counter = max([int(i.split('.')[1]) for i in current_files]) + 1
        print(counter)
    for scan, vessel, calc in tqdm(without_calcinates, desc='without'):
        colored_vessel = np.where(vessel, scan, 0)
        save_data(f'{result_folder}\\vessels\\0.{counter}.npy', colored_vessel)
        save_data(f'{result_folder}\\calcinates\\0.{counter}.npy', calc)
        counter += 1

possible = ['20241209_17', '20250222_37', '20250224_45', '20250224_46']
recreate_folder(result_folder)
for i, selected in enumerate(possible):
    vessel_mask: np.ndarray = np.load(f'{folder}\\{selected}_annotation.npy').astype(bool)
    calcinate_mask: np.ndarray = np.load(f'{folder}\\{selected}_calcinates.npy').astype(bool)
    image: np.ndarray = np.load(f'{folder}\\{selected}_main.npy')
    split_data = extract_joint_patches(image, vessel_mask, calcinate_mask, patch_size=64, stride=32, vessel_thresh=50)
    save_datums(split_data)


with: 100%|██████████| 36/36 [00:00<00:00, 347.60it/s]
without: 100%|██████████| 364/364 [00:00<00:00, 523.43it/s]


36


with: 100%|██████████| 40/40 [00:00<00:00, 519.48it/s]

364



without: 100%|██████████| 317/317 [00:00<00:00, 512.12it/s]


76


with: 100%|██████████| 20/20 [00:00<00:00, 487.87it/s]


681


without: 100%|██████████| 261/261 [00:00<00:00, 505.82it/s]


96


with: 0it [00:00, ?it/s]


942


without: 100%|██████████| 578/578 [00:01<00:00, 474.16it/s]


In [131]:
def read_datum(folder: str, filename: str) -> None:
    scan = np.load(f'{folder}\\vessels\\{filename}')
    mask = np.load(f'{folder}\\calcinates\\{filename}')
    return (scan, mask)

def save_datum(folder: str, id: int, datum: tuple[np.ndarray, np.ndarray]) -> None:
    scan = np.ascontiguousarray(datum[0].astype(np.float32))
    np.save(f'{folder}\\vessels\\{id}.npy', scan)
    mask = np.ascontiguousarray(datum[1].astype(np.float32))
    np.save(f'{folder}\\calcinates\\{id}.npy', mask)

def define_augmentations():
    return tio.Compose([
        tio.RandomGamma(p=0.5),
        tio.RandomNoise(mean=0, std=0.1, p=0.5),
        tio.RandomElasticDeformation(p=0.2),
        tio.RandomAffine(
            scales=(0.9, 1.1),
            p=0.5)])

def get_subject(scan: np.ndarray, mask: np.ndarray) -> tio.Subject:
    return tio.Subject(
        scan=tio.ScalarImage(tensor=np.expand_dims(scan, axis=0)),
        mask=tio.LabelMap(tensor=np.expand_dims(mask, axis=0)))

def augment(scan: np.ndarray, mask: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    subject = get_subject(scan, mask)
    transform = define_augmentations()
    augmented_subject = transform(subject)
    aug_scan: np.ndarray = augmented_subject['scan'].numpy()[0]
    aug_mask: np.ndarray = augmented_subject['mask'].numpy()[0]
    return (aug_scan, aug_mask)

def generate_48_transform_params():
    """Returns list of unique 3D rotation+flip transformations as sequences."""
    base = np.arange(8).reshape(2, 2, 2)
    seen = set()
    transforms = []

    axes_pairs = [(0, 1), (0, 2), (1, 2)]
    ks = [1, 2, 3]  # 0 is identity, so skip to avoid duplicates from in-place
    flips = [False, True]

    for (ax1, k1), (ax2, k2), flip in itertools.product(
        itertools.product(axes_pairs, ks),
        itertools.product(axes_pairs, ks),
        flips):
        vol = base.copy()
        vol = np.rot90(vol, k=k1, axes=ax1)
        vol = np.rot90(vol, k=k2, axes=ax2)
        if flip:
            vol = np.flip(vol, axis=0)

        signature = vol.tobytes()
        if signature not in seen:
            seen.add(signature)
            transforms.append(((k1, ax1), (k2, ax2), flip))

    return transforms

def apply_transform(volume, transform):
    (k1, ax1), (k2, ax2), flip = transform
    v = np.rot90(volume, k=k1, axes=ax1)
    v = np.rot90(v, k=k2, axes=ax2)
    if flip:
        v = np.flip(v, axis=0)
    return v

def augment_vessels(folder: str, output_folder: str):
    data_files = os.listdir(f'{folder}\\vessels')
    no_calc = [i for i in data_files if i.startswith('0.')]
    calc = [i for i in data_files if i.startswith('1.')]
    no_calc_reduced = list(np.random.choice(no_calc, len(calc), replace=False))
    both_classes = calc + no_calc_reduced

    all_transforms = generate_48_transform_params()

    counter = 0
    for filename in tqdm(both_classes, desc=f'augmenting data'):
        scan, mask = read_datum(result_folder, filename)
        
        for transform in all_transforms:
            rotated_scan = np.ascontiguousarray(apply_transform(scan, transform))
            rotated_mask = np.ascontiguousarray(apply_transform(mask, transform))
            aug_scan, aug_mask = augment(rotated_scan, rotated_mask)
            aug_scan = aug_scan.astype(np.int16)
            aug_mask = aug_mask.astype(bool)
            save_datum(output_folder, counter, (aug_scan, aug_mask))
            counter += 1

recreate_folder(augmented_folder)
augment_vessels(result_folder, augmented_folder)
    

augmenting data: 100%|██████████| 192/192 [21:17<00:00,  6.65s/it]
