In [1]:
import numpy as np
from twaidata.torchdatasets.whole_brain_dataset import MRISegmentationDataset
import torch
import os
from trustworthai.utils.augmentation.standard_transforms import NormalizeImg, PairedCompose, LabelSelect, PairedCentreCrop, CropZDim

### load existing domain datasets

In [2]:
root_dir = "/disk/scratch/s2208943/data"

In [3]:
CROP_SIZE = (224,160) # 192 perfect for singapore and utrecht but not for ge3t which is a different shape. 
                    # 192 is divisible by 2 6 times (3*2^6 = 192) but 224 and 160 still 5 times so won't cause any problems hopefully.
WMH_LABEL = 1


domain_dcrop_map = {
    "Singapore":(0,48),
    "Utrecht":(0,48),
    "GE3T":(30,78),
    "domainA":(0,40),
    "domainB":(5,48),
    "domainC":(5,48), # I do not think the 'domain C' images have been correctly identified
    "domainD":(7,52)
}

def get_transforms(dcrop_bounds):
    start_dcrop, end_dcrop = dcrop_bounds
    transforms = PairedCompose([
        PairedCentreCrop(CROP_SIZE),                  # cut out the centre square
        CropZDim(end_dcrop-start_dcrop, start_dcrop, end_dcrop),  # crop the z stack
        LabelSelect(WMH_LABEL),                       # extract the desired label
        # NormalizeImg(p=1) # do i want to normalize each image.. i might leave this out as I want to play about with this....
    ])
    return transforms

In [4]:
domains = ["Singapore", "Utrecht", "GE3T", "domainA", "domainB", "domainC", "domainD"]
domains = [os.path.join(root_dir, d) for d in domains]


In [5]:
datasets_domains = [
    (dom,
     MRISegmentationDataset(
         dom, 
         transforms=get_transforms(domain_dcrop_map[dom.split(os.path.sep)[-1]])
        )
    ) 
    for dom in domains
]

### processing

In [6]:
# collect each domain's imgs and labels into a list
data_domains = []
for dom, dataset in datasets_domains:
    print(f"reading in domain: {dom}")
    data_imgs = []
    data_labels = []
    for (img, label) in dataset:
        data_imgs.append(img)
        data_labels.append(label)
    data_domains.append((dom, data_imgs, data_labels))

reading in domain: /disk/scratch/s2208943/data/Singapore
reading in domain: /disk/scratch/s2208943/data/Utrecht
reading in domain: /disk/scratch/s2208943/data/GE3T
reading in domain: /disk/scratch/s2208943/data/domainA
reading in domain: /disk/scratch/s2208943/data/domainB
reading in domain: /disk/scratch/s2208943/data/domainC
reading in domain: /disk/scratch/s2208943/data/domainD


In [17]:
# convert the imgs and labels to numpy arrays
numpy_data_domains = []
for dom, img_data, label_data in data_domains:
    print(f"stacking {dom} into numpy arrays")
    numpy_data_domains.append((dom, np.stack(img_data, axis=0), np.stack(label_data, axis=0)))

stacking /disk/scratch/s2208943/data/Singapore into numpy arrays
stacking /disk/scratch/s2208943/data/Utrecht into numpy arrays
stacking /disk/scratch/s2208943/data/GE3T into numpy arrays
stacking /disk/scratch/s2208943/data/domainA into numpy arrays
stacking /disk/scratch/s2208943/data/domainB into numpy arrays
stacking /disk/scratch/s2208943/data/domainC into numpy arrays
stacking /disk/scratch/s2208943/data/domainD into numpy arrays


### saving

In [28]:
# save each numpy array to disk
out_dir = "/disk/scratch/s2208943/data/merged_data"

In [29]:
for dom, img_arr, label_arr in numpy_data_domains:
    print("saving domain: ", dom)
    out_file_imgs = os.path.join(out_dir, dom + "_imgs.npy")
    out_file_labels = os.path.join(out_dir, dom + "_labels.npy")
    np.save(out_file_imgs, img_arr)
    np.save(out_file_labels, label_arr)

saving domain:  /disk/scratch/s2208943/data/Singapore
saving domain:  /disk/scratch/s2208943/data/Utrecht
saving domain:  /disk/scratch/s2208943/data/GE3T
saving domain:  /disk/scratch/s2208943/data/domainA
saving domain:  /disk/scratch/s2208943/data/domainB
saving domain:  /disk/scratch/s2208943/data/domainC
saving domain:  /disk/scratch/s2208943/data/domainD


In [None]:
# now do rsync in a terminal to copy the data from scratch back to DFS
# rsync --archive --update --compress --progress ${src_path}/ ${dest_path}

#### Testing the 2D dataset


In [1]:
from twaidata.torchdatasets.slice_dataset_2D import MRISegmentation2DDataset

In [2]:
dataset = MRISegmentation2DDataset("/disk/scratch/s2208943/data/merged_data", "Singapore", None)

In [3]:
len(dataset)

960

In [4]:
img, label = dataset[100]

n: 2, d: 4
