In [None]:
from tqdm import tqdm
from skimage.transform import resize
from ply_creation_lib import create_ply
from skimage.measure import block_reduce
import matplotlib.pyplot as plt
import pydicom as dicom
import nibabel as nib
import torchio as tio
import numpy as np
import pickle
import time
import copy
import json
import os

In [None]:
def read_nii(file_path: str) -> np.ndarray:
    nii = nib.load(file_path)
    data = nii.get_fdata()

    original_spacing = nii.header.get_zooms()
    target_spacing = 0.25

    z_zoom = original_spacing[2] / target_spacing
    new_z = int(round(data.shape[2] * z_zoom))

    resampled_data = resize(
        data,
        output_shape=(128, 128, new_z//4),
        order=1,
        preserve_range=True,
        anti_aliasing=False
    ).astype(np.float32)

    rotated = np.rot90(resampled_data, k=1, axes=(0, 2))
    return rotated

def downscale(input: np.ndarray, block_size: tuple) -> np.ndarray:
    return block_reduce(input, block_size=block_size, func=np.mean)

def downscale_data(data: np.ndarray, new_size: int, mask: bool) -> np.ndarray:
    b_size = data.shape[0] // new_size
    block_size = (b_size, b_size, b_size)
    if mask:
        return downscale(data, block_size) > 0.5
    return downscale(data, block_size).astype(np.int16)

def save_data(filename: str, array: np.ndarray):
    np.save(filename, array)

def get_scans_masks(folder: str, output_folder: str):
    data_files = os.listdir(folder)

    for i, filename in enumerate(tqdm(data_files, desc=f'rewriting data')):
        is_mask = 'label' in filename

        downscaled_image = read_nii(f'{folder}\\{filename}')
        downscaled_image = downscaled_image[:128, :128, :128]
        id = int(filename.split('.')[0])

        if is_mask:
            save_data(f'{output_folder}\\masks\\{id}.npy', downscaled_image.astype(bool))
        else:
            min_scale = 100
            max_scale = 700
            scaled_image = downscaled_image - min_scale
            scaled_image = np.clip(scaled_image, 0, max_scale - min_scale)
            downscaled_image = scaled_image / (max_scale - min_scale)
            save_data(f'{output_folder}\\scans\\{id}.npy', downscaled_image)

folder = f'd:\\dicom\\nii_dataset_original_data\\all'
output_folder = f'd:\\dicom\\my_dataset\\nii_dataset'
augmented_folder = f'd:\\dicom\\my_dataset\\nii_augmented'

In [None]:
get_scans_masks(folder, output_folder)

In [None]:
def read_datum(folder: str, filename: str) -> None:
    scan = np.load(f'{folder}\\scans\\{filename}')
    mask = np.load(f'{folder}\\masks\\{filename}')
    return (scan, mask)

def crop_to_128(image: np.ndarray) -> np.ndarray:
    new_image = np.zeros((128, 128, 128))
    cropped = image[:128, :128, :128]
    x, y, z = cropped.shape
    new_image[:x, :y, :z] = cropped
    
    return new_image

def save_datum(folder: str, id: int, datum: tuple[np.ndarray, np.ndarray]) -> None:
    scan = np.ascontiguousarray(datum[0].astype(np.float32))
    np.save(f'{folder}\\scans\\{id}.npy', scan)
    mask = np.ascontiguousarray(datum[1].astype(np.float32))
    np.save(f'{folder}\\masks\\{id}.npy', mask)

def rotation_transformation():
    return tio.Compose([
        tio.RandomFlip(axes=(1, 2), p=0.5),
        tio.RandomGamma(p=0.5),
        tio.RandomNoise(mean=0, std=0.1, p=0.5),
        tio.RandomAffine(
            scales=(0.9, 1.1),
            translation=(5, 5, 5),
            p=0.5)])

def get_subject(scan: np.ndarray, mask: np.ndarray) -> tio.Subject:
    return tio.Subject(
        scan=tio.ScalarImage(tensor=np.expand_dims(scan, axis=0)),
        mask=tio.LabelMap(tensor=np.expand_dims(mask, axis=0)))

def augment(scan: np.ndarray, mask: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    subject = get_subject(scan, mask)
    transform = rotation_transformation()
    augmented_subject = transform(subject)
    aug_scan: np.ndarray = augmented_subject['scan'].numpy()[0]
    aug_mask: np.ndarray = augmented_subject['mask'].numpy()[0]
    return (aug_scan, aug_mask)

def augment_nii(folder: str, output_folder: str):
    data_files = os.listdir(f'{folder}\\scans')
    data_files.sort(key = lambda x: int(x.split('.')[0]))

    for i, filename in enumerate(tqdm(data_files, desc=f'augmenting data')):
        scan, mask = read_datum(folder, filename)

        scan = crop_to_128(scan)
        mask = crop_to_128(mask)
        
        for j in range(4):
            rot_mask = np.ascontiguousarray(np.rot90(mask, k=j, axes=(1, 2)))
            rot_scan = np.ascontiguousarray(np.rot90(scan, k=j, axes=(1, 2)))
            aug_scan, aug_mask = augment(rot_scan, rot_mask)

            aug_scan = aug_scan.astype(np.int16)
            aug_mask = aug_mask.astype(bool)

            new_id = i*4+j
            save_datum(output_folder, new_id, (aug_scan, aug_mask))

augment_nii(output_folder, augmented_folder)