In [None]:
%run Include.ipynb
%run FileIO.ipynb
%run Medical_IO.ipynb
%run Argparser.ipynb
import nibabel as nib
import glob
import cv2
from scipy import signal
from torch.utils.data.sampler import SubsetRandomSampler
from monai.data import ImageDataset
from torch.utils.data import Dataset,DataLoader

from monai.transforms import AddChannel, Compose, ScaleIntensity, RandAffine, EnsureType,Rand3DElastic

In [None]:
def return_settings_extend():
    
    settings = dict()
    settings["Basic"] = {
        "branch_name":       "extend",         #"map_patch64", cremi_patch64", "isbi_patch64", "retina_patch128"
        "continue_model":    False,
        "model_step":        0,
        "data_extension":    "nii",
        "dataset":           "nii"
    }
    settings["Path"] = {
        "save_path0":         "/home/zhilin/TopoTxR_backup/data/extend_ori2",
        "save_path1":         "/home/zhilin/TopoTxR_backup/data/extend_dim1_",
        "save_path2":         "/home/zhilin/TopoTxR_backup/data/extend_dim2_",
        'ori_path':          "/home/zhilin/TopoTxR_backup/data/data_ori",
        "data_path":         "/home/zhilin/TopoTxR_backup/data/whole_256_th2_dim1_dil0_fv",
        "data_path2":        "/home/zhilin/TopoTxR_backup/data/whole_256_th2_dim2_dil0_fv"
    }
    # settings["Monitor"] = {
    #     "print_step":        20,
    #     "save_step":         3000
    # }
    # settings["GPU"] = {
    #     "gpu_num":           1,
    #     "gpu_enable":        True,
    #     "cudnn_benchmark":   True
    # }
    return settings


def return_data_settings_extend():
    data_settings = dict()
    data_settings = {
        "epochs":           500,
        "batch_size":       1,
        "batch_workers":    0,
        "shuffle":          True,
        "drop_last":        False,
        'transform':        True,
        "datasplit_scheme": "Test",        # options: All|Test|Valid
        "test_split":       0.2,           # obsolete parameter
        "xfold":            10,
        "fold_idx":         0,
        "random_seed":      64
    }
    return data_settings

In [None]:
settings    = return_settings_extend()
data_params = return_data_settings_extend()
print(settings)
print(data_params)

epochs           = data_params["epochs"]
batch_size       = data_params["batch_size"]
batch_workers    = data_params["batch_workers"]
shuffle          = data_params["shuffle"]
drop_last        = data_params["drop_last"]
transform        = data_params['transform']
datasplit_scheme = data_params["datasplit_scheme"]
test_split       = data_params["test_split"]
xfold            = data_params["xfold"]
fold_idx         = data_params["fold_idx"]
random_seed      = data_params["random_seed"]

In [None]:
class FileIO_MEDICAL(object):
    
    @staticmethod
    def load_nii(pathIn):
        struct = nib.load(pathIn)
        #print("Data type is: ", struct.get_data_dtype())
        return struct.get_fdata()
    
    @staticmethod
    def save_nii(data, pathOut):
        struct = nib.Nifti1Image(data, np.eye(4))
        nib.save(struct, pathOut)



In [None]:
def fetch_dataset_wValidation_extend(name, data_path, batch_size, batch_workers, shuffle, drop_last, scalor, datasplit_scheme, test_split, xfold, fold_idx, random_seed=-1):

    set_global_random_seed(random_seed)

    if datasplit_scheme=="Test":
        print("Test mode in data_aug fetcher.")
        os.chdir(data_path)
        data_address = []
        data_labels = []
        for file in glob.glob("*."+ name ):
            data_address.append(os.path.join(data_path, file))
            data_labels.append(int(file.split('_')[2]))
        
        if shuffle:
            state = np.random.get_state()
            np.random.shuffle(data_address)
            np.random.set_state(state)
            np.random.shuffle(data_labels)
            np.random.set_state(state)

       
        # sample_num = len(data_address)
        # fold_size = sample_num // xfold

        ##define transfomation:
        train_transform = Compose([ScaleIntensity(), AddChannel(), EnsureType(), Rand3DElastic(sigma_range=(5, 8),magnitude_range=(100, 200),prob= 1, padding_mode='zeros',rotate_range=(np.pi , np.pi , np.pi),shear_range=(np.pi , np.pi , np.pi),translate_range=(32, 32, 32),scale_range=(0.15, 0.15, 0.15))])
        # train_transform = Compose([ScaleIntensity(), AddChannel(), EnsureType(), RandAffine(prob= 1, padding_mode='border',rotate_range=(np.pi , np.pi , np.pi),shear_range=(np.pi , np.pi , np.pi),translate_range=(32, 32, 32),scale_range=(0.15, 0.15, 0.15))])
        # train_transform = Compose([ScaleIntensity(), AddChannel(), EnsureType()])
        # val_transform = Compose([ScaleIntensity(), AddChannel(), EnsureType()])


        # if fold_idx < xfold -1:
        #     val_start = fold_idx * fold_size
        #     val_end = (fold_idx+1) * fold_size
        #     address_val, label_val = data_address[val_start:val_end], data_labels[val_start:val_end]
        #     address_train = data_address[0:val_start]
        #     address_train.extend(data_address[val_end:])
        #     # label_train = np.concatenate(data_labels[:val_start], data_labels[val_end:], axis =0)
        #     label_train = data_labels[:val_start]
        #     label_train.extend(data_labels[val_end:])
        
        # else:
        #     val_start = fold_idx * fold_size
        #     address_val, label_val = data_address[val_start:], data_labels[val_start:]
        #     address_train, label_train = data_address[:val_start], data_labels[:val_start]
        
        label_train = np.array(data_labels, dtype=np.int64)
        # label_val   = np.array(label_val, dtype=np.int64)

        address_train = data_address

        train_ds = ImageDataset(image_files=address_train,labels = label_train,transform=train_transform)
        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=False, num_workers=2,
                                pin_memory=torch.cuda.is_available())
        # val_ds = ImageDataset(image_files=address_val,labels = label_val,transform=val_transform)
        # val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2,
        #                     pin_memory=torch.cuda.is_available())
        
        return train_loader,address_train

In [None]:
train_loader0, address_train0 = fetch_dataset_wValidation_extend(settings['Basic']['dataset'], settings['Path']['ori_path'], batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
train_loader1, address_train1 = fetch_dataset_wValidation_extend(settings['Basic']['dataset'], settings['Path']['data_path'], batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)
train_loader2, address_train2 = fetch_dataset_wValidation_extend(settings['Basic']['dataset'], settings['Path']['data_path2'], batch_size, batch_workers, shuffle, drop_last, 0.5, datasplit_scheme, test_split, xfold, fold_idx, random_seed)

In [None]:
# a = address_train0[0]
# print(a.split('/')[-1].split('.')[0])
# data_address.append(os.path.join(data_path, file))


def write_address_Out(address_train,out_folder):
    address_out = []
    if not os.path.exists(out_folder):
        os.makedirs(out_folder)
    for i,address in enumerate(address_train):
        file_name = (address.split('/')[-1]).split('.')[0] +'_aug.nii'
        out_path = os.path.join(out_folder,file_name)
        address_out.append(out_path)
    return address_out

address_out0 = write_address_Out(address_train0,settings['Path']['save_path0'])
address_out1 = write_address_Out(address_train1,settings['Path']['save_path1'])
address_out2 = write_address_Out(address_train2,settings['Path']['save_path2'])

print(address_out0)


In [None]:
def extend_dataset(train_loader,address_out):
    for i, data in enumerate(train_loader):
        
        vol = ((data[0]).squeeze(0)).squeeze(0).numpy()
        label = data[1].numpy()
        if label == 0:
            continue
        else:
            FileIO_MEDICAL.save_nii(vol,address_out[i])
    return 0
        
#_ = extend_dataset(train_loader0,address_out0)
_ = extend_dataset(train_loader1,address_out1)   
_ = extend_dataset(train_loader2,address_out2)                    
