In [1]:
import nibabel as nib
import os
from augmentation import *
import torch
import numpy as np
import random
import torch.nn.functional as F
from tqdm import tqdm

mod1 = 'DWI_800.nii.gz'
mod2 = 'GED1.nii.gz'
mod3 = 'GED2.nii.gz'
mod4 = 'GED3.nii.gz'
mod5 = 'GED4.nii.gz'
mod6 = 'T1.nii.gz'
mod7 = 'T2.nii.gz'
mod8 = 'mask_GED4.nii.gz'

mods = [mod1, mod2, mod3, mod4, mod5, mod6, mod7, mod8]

In [2]:
def aug_pair(
             sample_dir,
             src, 
             dst):
    """
    Augment a pair of source and destination images.
    :param sample_dir: Directory containing the sample images
    :param src: Source image name
    :param dst: Destination image name
    :return: None
    """

    ### read src and dst imgs and masks
    src_path = os.path.join(sample_dir, src)
    dst_path = os.path.join(sample_dir, dst)
    if not os.path.exists(src_path) or not os.path.exists(dst_path):
        print(f"Source or destination path does not exist: {src_path}, {dst_path}")
        return

    src_img = nib.load(os.path.join(src_path, mod5)).get_fdata()
    src_mask = nib.load(os.path.join(src_path, mod8)).get_fdata()

    dst_img = nib.load(os.path.join(dst_path, mod5)).get_fdata()
    dst_mask = nib.load(os.path.join(dst_path, mod8)).get_fdata()

    if src_mask.sum() == 0 or dst_mask.sum() == 0:
        print(f"Source or destination mask is empty: {src_path}, {dst_path}")
        import pdb; pdb.set_trace()

    # augment src image and mask
    src_img = torch.from_numpy(src_img).unsqueeze(0).unsqueeze(0) #bs, c, z, y, x
    src_mask = torch.from_numpy(src_mask).unsqueeze(0) #bs, z, y, x

    tgt_img = torch.from_numpy(dst_img).unsqueeze(0).unsqueeze(0) #bs, c, z, y, x
    tgt_mask = torch.from_numpy(dst_mask).unsqueeze(0) #bs, z, y, x

    ### if the img size is not the same, resize them
    if src_img.shape[2:] != tgt_img.shape[2:]:
        src_img = F.interpolate(src_img, size=tgt_img.shape[2:], mode='trilinear', align_corners=False)
        src_mask = F.interpolate(src_mask.unsqueeze(1).float(), size=tgt_img.shape[2:], mode='nearest').squeeze(1).long()

    if src_mask.sum() == 0 or tgt_mask.sum() == 0:
        print(f"Source or destination mask is empty after resizing: {src_path}, {dst_path}")
        import pdb; pdb.set_trace()
    ### concat src and tgt images and masks as a batch
    imgs = torch.cat([src_img, tgt_img], dim=0)  # bs, c, z, y, x
    masks = torch.cat([src_mask, tgt_mask], dim=0)  # bs, z, y, x

    # augment src image and mask
    aug_imgs, aug_masks = ins_aug(imgs, masks)

    # save augmented images and masks
    aug_root = os.path.join('C:\\Users\\SCoulY\\Downloads\\care2025_liver_biodreamer\\data', 'augmented')
    aug_path = os.path.join(aug_root, f'{src}_{dst}')
    os.makedirs(aug_path, exist_ok=True)
    for i, (aug_img, aug_mask) in enumerate(zip(aug_imgs, aug_masks)):
        nib.save(nib.Nifti1Image(aug_img.squeeze(0).numpy(), None), os.path.join(aug_path, f'{src}_{dst}_img.nii.gz'))
        nib.save(nib.Nifti1Image(aug_mask.squeeze(0).numpy(), None), os.path.join(aug_path, f'{src}_{dst}_mask.nii.gz'))


In [3]:

sample_dir = 'C:\\Users\\SCoulY\\Downloads\\care2025_liver_biodreamer\\data\\Vendor_A\\Vendor_A'

def aug_vendor(sample_dir, folds=5):
    """
    Augment all samples in the vendor directory.
    :param sample_dir: Directory containing the vendor samples
    :param folds: Number of folds for augmentation
    :return: None
    """
    if not os.path.exists(sample_dir):
        print(f"Sample directory does not exist: {sample_dir}")
        return
    vendor = os.path.basename(sample_dir)
    sample_names = os.listdir(sample_dir)

    # check all vailable labels
    label_names = []
    for sample in sample_names:
        sample_path = os.path.join(sample_dir, sample)
        if 'mask_GED4.nii.gz' in os.listdir(sample_path):
            label_names.append(sample)
    print(f'Available {len(label_names)} labels in vendor {vendor}:', label_names)

    for i, src in tqdm(enumerate(label_names), desc=f'Processing {vendor} samples'):
        ### randomly select a label from the list
        rest_labels = label_names[:i] + label_names[i+1:]  # exclude the current src
        dsts = random.sample(rest_labels, folds)  # ensure dst is different from src
        for dst in dsts:
            aug_pair(sample_dir, src, dst)  # augment the same label for demonstration


In [5]:
# sample_dirs = ['C:\\Users\\SCoulY\\Downloads\\care2025_liver_biodreamer\\data\\Vendor_A\\Vendor_A',
#                'C:\\Users\\SCoulY\\Downloads\\care2025_liver_biodreamer\\data\\Vendor_B2\\Vendor_B2',
#                'C:\\Users\\SCoulY\\Downloads\\care2025_liver_biodreamer\\data\\Vendor_B1']


sample_dirs = ['C:\\Users\\SCoulY\\Downloads\\care2025_liver_biodreamer\\data\\Vendor_B2\\Vendor_B2']
for sample_dir in tqdm(sample_dirs, desc='Processing all vendor directories'):
    aug_vendor(sample_dir, folds=5)  # augment all samples in the vendor directory
    print(f'Augmentation completed for {sample_dir}')

Processing all vendor directories:   0%|          | 0/1 [00:00<?, ?it/s]

Available 10 labels in vendor Vendor_B2: ['1031-B2-S1', '1041-B2-S3', '1053-B2-S3', '1070-B2-S4', '1075-B2-S4', '1076-B2-S4', '1086-B2-S4', '1097-B2-S4', '1098-B2-S4', '1115-B2-S4']


Processing Vendor_B2 samples: 10it [04:11, 25.12s/it]
Processing all vendor directories: 100%|██████████| 1/1 [04:11<00:00, 251.23s/it]

Augmentation completed for C:\Users\SCoulY\Downloads\care2025_liver_biodreamer\data\Vendor_B2\Vendor_B2



