# Pain in the Net - Laplacian Pyramid Translation Network (LPTN)
Application of Laplacian Pyramid Translation Network (LPTN) to domain adaptation of diffusion MRI.


Code by:

Tyler Spears - tas6hh@virginia.edu

Dr. Tom Fletcher

---

Based on the following work(s):

* `J. Liang, H. Zeng, and L. Zhang, “High-Resolution Photorealistic Image Translation in Real-Time: A Laplacian Pyramid Translation Network,” 2021, pp. 9392–9400. Accessed: Aug. 26, 2021. [Online]. Available: https://openaccess.thecvf.com/content/CVPR2021/html/Liang_High-Resolution_Photorealistic_Image_Translation_in_Real-Time_A_Laplacian_Pyramid_Translation_CVPR_2021_paper.html
`


## Imports & Environment Setup

### Imports

In [None]:
# Automatically re-import project-specific modules.
%load_ext autoreload
%autoreload 1

# imports
import collections
import functools
import io
import datetime
import time
import math
import itertools
import os
import shutil
import pathlib
import copy
import pdb
import inspect
import random
import subprocess
import sys
import warnings
from pathlib import Path
import typing
import zipfile

import ants
import dipy
import dipy.core
import dipy.reconst
import dipy.reconst.dti
import dipy.segment.mask
import dipy.viz
import dipy.viz.regtools
import dotenv

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

import IPython

# Try importing GPUtil for printing GPU specs.
# May not be installed if using CPU only.
try:
    import GPUtil
except ImportError:
    warnings.warn("WARNING: Package GPUtil not found, cannot print GPU specs")
from tabulate import tabulate
from IPython.display import display, Markdown
import ipyplot

# Data management libraries.
import nibabel as nib
import natsort
from natsort import natsorted
import addict
from addict import Addict
import box
from box import Box
import pprint
from pprint import pprint as ppr

# Computation & ML libraries.
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torchio
import pytorch_lightning as pl
import monai

import skimage
import skimage.feature
import skimage.filters
import skimage.measure
import scipy

plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True, threshold=100, linewidth=88)
torch.set_printoptions(sci_mode=False, threshold=100, linewidth=88)

In [None]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = "direnv exec {} /usr/bin/env".format(os.getcwd())
# Run command in a new subprocess.
proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode("utf-8")
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True);

In [None]:
# Project-specific scripts
# It's easier to import it this way rather than make an entirely new package, due to
# conflicts with local packages and anaconda installations.
# You made me do this, poor python package management!!
if "PROJECT_ROOT" in os.environ:
    lib_location = str(Path(os.environ["PROJECT_ROOT"]).resolve())
else:
    lib_location = str(Path("../../").resolve())
if lib_location not in sys.path:
    sys.path.insert(0, lib_location)
import lib as pitn

# Include the top-level lib module along with its submodules.
%aimport lib
# Grab all submodules of lib, not including modules outside of the package.
includes = list(
    filter(
        lambda m: m.startswith("lib."),
        map(lambda x: x[1].__name__, inspect.getmembers(pitn, inspect.ismodule)),
    )
)
# Run aimport magic with constructed includes.
ipy = IPython.get_ipython()
ipy.run_line_magic("aimport", ", ".join(includes))

In [None]:
# torch setup
# allow for CUDA usage, if available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
# keep device as the cpu
# device = torch.device('cpu')
print(device)

### Specs Recording

In [None]:
%%capture --no-stderr cap
# Capture output and save to log. Needs to be at the *very first* line of the cell.
# Watermark
%load_ext watermark
%watermark --author "Tyler Spears" --updated --iso8601  --python --machine --iversions --githash
if torch.cuda.is_available():

    # GPU information
    # Taken from
    # <https://www.thepythoncode.com/article/get-hardware-system-information-python>.
    # If GPUtil is not installed, skip this step.
    try:
        gpus = GPUtil.getGPUs()
        print("=" * 50, "GPU Specs", "=" * 50)
        list_gpus = []
        for gpu in gpus:
            # get the GPU id
            gpu_id = gpu.id
            # name of GPU
            gpu_name = gpu.name
            driver_version = gpu.driver
            cuda_version = torch.version.cuda
            # get total memory
            gpu_total_memory = f"{gpu.memoryTotal}MB"
            gpu_uuid = gpu.uuid
            list_gpus.append(
                (
                    gpu_id,
                    gpu_name,
                    driver_version,
                    cuda_version,
                    gpu_total_memory,
                    gpu_uuid,
                )
            )

        print(
            tabulate(
                list_gpus,
                headers=(
                    "id",
                    "Name",
                    "Driver Version",
                    "CUDA Version",
                    "Total Memory",
                    "uuid",
                ),
            )
        )
    except NameError:
        print("CUDA Version: ", torch.version.cuda)

else:
    print("CUDA not in use, falling back to CPU")

In [None]:
# cap is defined in an ipython magic command
print(cap)

Author: Tyler Spears

Last updated: 2021-09-24T18:31:24.463846+00:00

Python implementation: CPython
Python version       : 3.8.8
IPython version      : 7.23.1

Compiler    : GCC 7.3.0
OS          : Linux
Release     : 5.4.0-84-generic
Machine     : x86_64
Processor   : x86_64
CPU cores   : 8
Architecture: 64bit

Git hash: ee869f65131c73fb8cc6d2ecf04f2d31014b8426

seaborn          : 0.11.1
IPython          : 7.23.1
pandas           : 1.2.3
addict           : 2.4.0
torch            : 1.9.0
dipy             : 1.4.1
nibabel          : 3.2.1
torchio          : 0.18.37
numpy            : 1.20.2
ipywidgets       : 7.6.3
natsort          : 7.1.1
matplotlib       : 3.4.1
box              : 5.4.1
skimage          : 0.18.1
json             : 2.0.9
pytorch_lightning: 1.4.5
sys              : 3.8.8 (default, Feb 24 2021, 21:46:12) 
[GCC 7.3.0]
GPUtil           : 1.4.0
ants             : 0.2.7
monai            : 0.7.dev2138
scipy            : 1.5.3

  id  Name       Driver Version      CUDA Version

### Data Variables & Definitions Setup

In [None]:
# Set up directories
data_dir = pathlib.Path(os.environ["DATA_DIR"])
hcp_source_data_dir = data_dir / "hcp"
clinic_source_data_dir = data_dir / "uva/chronic_pain_head_and_neck"
assert hcp_source_data_dir.exists() and clinic_source_data_dir.exists()
processed_data_dir = pathlib.Path(os.environ["WRITE_DATA_DIR"])
hcp_processed_data_dir = processed_data_dir / "hcp"
clinic_processed_data_dir = processed_data_dir / "uva/chronic_pain_head_and_neck"
assert hcp_processed_data_dir.exists() and clinic_processed_data_dir.exists()
results_dir = pathlib.Path(os.environ["RESULTS_DIR"])
assert results_dir.exists()
tmp_results_dir = pathlib.Path(os.environ["TMP_RESULTS_DIR"])
assert tmp_results_dir.exists()

### Experiment Logging Setup

In [None]:
# tensorboard experiment logging setup.
EXPERIMENT_NAME = "debug_da_gan"

ts = datetime.datetime.now().replace(microsecond=0).isoformat()
# Break ISO format because many programs don't like having colons ':' in a filename.
ts = ts.replace(":", "_")
experiment_name = ts + "__" + EXPERIMENT_NAME
run_name = experiment_name
print(experiment_name)
# experiment_results_dir = results_dir / experiment_name

# Create temporary directory for results directory, in case experiment does not finish.
tmp_dirs = list(tmp_results_dir.glob("*"))

# Only keep up to N tmp results.
n_tmp_to_keep = 3
if len(tmp_dirs) > (n_tmp_to_keep - 1):
    print(f"More than {n_tmp_to_keep} temporary results, culling to the most recent")
    tmps_to_delete = natsorted([str(tmp_dir) for tmp_dir in tmp_dirs])[
        : -(n_tmp_to_keep - 1)
    ]
    for tmp_dir in tmps_to_delete:
        shutil.rmtree(tmp_dir)
        print("Deleted temporary results directory ", tmp_dir)

experiment_results_dir = tmp_results_dir / experiment_name
# Final target directory, to be made when experiment is complete.
final_experiment_results_dir = results_dir / experiment_name

In [None]:
# Pass this object into the pytorchlightning Trainer object, for easier logging within
# the training/testing loops.
pl_logger = pl.loggers.TensorBoardLogger(
    tmp_results_dir,
    name=experiment_name,
    version="",
    log_graph=False,
    default_hp_metric=False,
)
# Use the lower-level logger for logging histograms, images, etc.
logger = pl_logger.experiment

# Create a separate txt file to log streams of events & info besides parameters & results.
log_txt_file = Path(logger.log_dir) / "log.txt"
with open(log_txt_file, "a+") as f:
    f.write(f"Experiment Name: {experiment_name}\n")
    f.write(f"Timestamp: {ts}\n")
    # cap is defined in an ipython magic command
    f.write(f"Environment and Hardware Info:\n {cap}\n\n")

### Experiment Parameters

In [None]:
# Parameters
params = Box(default_box=True)

# Data params.
params.num_channels = 6
params.hcp.num_subjects = 5
params.clinic.num_subjects = 1
params.clamp_percentiles = (0.01, 99.99)
# Must be a factor of 2**num_laplace_high_freq
params.patch_size = (32, 32, 32)

# Network params.
params.num_laplace_high_freq = 3
params.discriminator_downscale_factors = [1, 2, 4]
params.lambda_adversary_loss = 0.1
params.optim.lr = 0.001
params.optim.betas = (0.9, 0.999)

# Training, validation, & testing params
params.batch_size = 12
params.samples_per_subj_per_epoch = 100
params.max_epochs = 50

## Data Loading

In [None]:
# Transformation pipeline.
# The input to the laplacian pyramid must be divisible by 2 for the number of high-
# frequency levels in the pyramid.
laplace_pyramid_divisible_by_shape = 2 ** params.num_laplace_high_freq

pre_process_pipeline = monai.transforms.Compose(
    [
        monai.transforms.CropForegroundd(["dti", "mask"], source_key="mask", margin=3),
        monai.transforms.DivisiblePadd(
            ["dti", "mask"], laplace_pyramid_divisible_by_shape
        ),
        pitn.transforms.ClipPercentileTransformd(
            "dti",
            params.clamp_percentiles[0],
            params.clamp_percentiles[1],
            nonzero=True,
            channel_wise=True,
        ),
        monai.transforms.ToTensord("dti", dtype=torch.float),
        monai.transforms.ToTensord("mask", dtype=torch.bool),
    ]
)

### Load and Pre-Process HCP Data

In [None]:
# Find data directories for each subject.
hcp_subj_dirs: dict = dict()

possible_ids = [
    "397154",
    "224022",
    "140117",
    "751348",
    "894774",
    "156637",
    "227432",
    "303624",
    "185947",
    "810439",
    "753251",
    "644246",
    "141422",
    "135528",
    "103010",
    "700634",
]

## Sub-set the chosen participants for dev and debugging!
selected_ids = random.sample(possible_ids, params.hcp.num_subjects)
if params.hcp.num_subjects < len(possible_ids):
    warnings.warn(
        "WARNING: Sub-selecting participants for dev and debugging. "
        + f"Subj IDs selected: {selected_ids}"
    )
# ### A nested warning! For debugging only.
# warnings.warn("WARNING: Mixing training and testing subjects")
# selected_ids.append(selected_ids[0])
# ###
##

selected_ids = natsorted(list(map(lambda s: int(s), selected_ids)))

for subj_id in selected_ids:
    hcp_subj_dirs[subj_id] = (
        hcp_processed_data_dir
        / f"derivatives/diffusion/mean_downsample/scale-2.00mm/sub-{subj_id}"
    )
    assert hcp_subj_dirs[subj_id].exists()
ppr(hcp_subj_dirs)

In [None]:
# Log to file and experiment.
with open(log_txt_file, "a+") as f:
    f.write(f"Selected HCP Subjects: {selected_ids}\n")

logger.add_text("hcp_subjs", pprint.pformat(selected_ids))

In [None]:
# Data loading and processing loop.
hcp_subj_data = list()
# Data reader object for NIFTI files.
nib_reader = monai.data.NibabelReader(as_closest_canonical=False)

# Directory prefixes for each image to be read.
dti_file_prefix = "dti"
mask_file_prefix = "mask"

for subj_id, subj_dir in hcp_subj_dirs.items():
    subj_id = int(subj_id)
    subj_data = dict()
    subj_data["subj_id"] = subj_id

    # Load the DTIs
    img_dir = subj_dir / dti_file_prefix
    img_filename = list(img_dir.glob(f"sub-{subj_id}*dti.nii.gz"))
    # Make sure the glob pattern only matches one file.
    assert len(img_filename) == 1
    img_filename = img_filename[0]
    nib_img = nib_reader.read(img_filename)
    img, metadata = nib_reader.get_data(nib_img)
    subj_data["dti"] = img
    # The default metadata key name for monai.
    subj_data["dti_meta_dict"] = metadata

    # Load masks
    img_dir = subj_dir / mask_file_prefix
    img_filename = list(img_dir.glob(f"sub-{subj_id}*mask.nii.gz"))
    # Make sure the glob pattern only matches one file.
    assert len(img_filename) == 1
    img_filename = img_filename[0]
    nib_img = nib_reader.read(img_filename)
    img, metadata = nib_reader.get_data(nib_img)
    subj_data["mask"] = img
    # The default metadata key name for monai.
    subj_data["mask_meta_dict"] = metadata

    # Pre-process subject DTIs.
    subj_data = pre_process_pipeline(subj_data)

    hcp_subj_data.append(subj_data)

# Create dataset with all HCP subjects included.
hcp_subj_dataset = monai.data.Dataset(hcp_subj_data)

### Load & Pre-Process Clinical Data

In [None]:
# Find data directories for each subject.
clinic_subj_dirs: dict = dict()

possible_ids = ["001"]

## Sub-set the chosen participants for dev and debugging!
selected_ids = random.sample(possible_ids, params.clinic.num_subjects)
if params.clinic.num_subjects < len(possible_ids):
    warnings.warn(
        "WARNING: Sub-selecting participants for dev and debugging. "
        + f"Subj IDs selected: {selected_ids}"
    )
# ### A nested warning! For debugging only.
# warnings.warn("WARNING: Mixing training and testing subjects")
# selected_ids.append(selected_ids[0])
# ###
##

selected_ids = natsorted(selected_ids)

for subj_id in selected_ids:
    clinic_subj_dirs[subj_id] = (
        clinic_processed_data_dir / f"derivatives/diffusion/sub-{subj_id}/ses-01"
    )
    assert clinic_subj_dirs[subj_id].exists()
ppr(clinic_subj_dirs)

In [None]:
# Log to file and experiment.
with open(log_txt_file, "a+") as f:
    f.write(f"Selected Clinically-Scanned Subjects: {selected_ids}\n")

logger.add_text("clinic_data_subjs", pprint.pformat(selected_ids))

In [None]:
# Data loading and processing loop.
clinic_subj_data = list()
# Data reader object for NIFTI files.
nib_reader = monai.data.NibabelReader(as_closest_canonical=False)

# Directory prefixes for each image to be read.
dti_file_prefix = "dti"
mask_file_prefix = "mask"

for subj_id, subj_dir in clinic_subj_dirs.items():
    subj_data = dict()
    subj_data["subj_id"] = subj_id

    # Load the DTIs
    img_dir = subj_dir / dti_file_prefix
    img_filename = list(img_dir.glob(f"sub-{subj_id}*dti.nii.gz"))
    # Make sure the glob pattern only matches one file.
    assert len(img_filename) == 1
    img_filename = img_filename[0]
    nib_img = nib_reader.read(img_filename)
    img, metadata = nib_reader.get_data(nib_img)
    subj_data["dti"] = img
    # The default metadata key name for monai.
    subj_data["dti_meta_dict"] = metadata

    # Load masks
    img_dir = subj_dir / mask_file_prefix
    img_filename = list(img_dir.glob(f"sub-{subj_id}*mask.nii.gz"))
    # Make sure the glob pattern only matches one file.
    assert len(img_filename) == 1
    img_filename = img_filename[0]
    nib_img = nib_reader.read(img_filename)
    img, metadata = nib_reader.get_data(nib_img)
    subj_data["mask"] = img
    # The default metadata key name for monai.
    subj_data["mask_meta_dict"] = metadata

    # Pre-process subject DTIs.
    subj_data = pre_process_pipeline(subj_data)

    clinic_subj_data.append(subj_data)

# Create dataset with all "clinical quality" subjects included.
clinic_subj_dataset = monai.data.Dataset(clinic_subj_data)

In [None]:
# Set up HCP scan data.
source_part_patch_sampler = functools.partial(
    pitn.samplers.random_patches_from_mask,
    patch_size=params.patch_size,
    num_patches=params.samples_per_subj_per_epoch,
)
source_random_patch_sample_f = lambda p: source_part_patch_sampler(p["dti"], p["mask"])
source_train_dataset = monai.data.PatchDataset(
    hcp_subj_dataset,
    patch_func=source_random_patch_sample_f,
    samples_per_image=params.samples_per_subj_per_epoch,
)

# Set up clinic scan data.
# Calculate the number of clinic samples per subject to match the total length of the
# source domain dataset.
num_clinic_samples_per_img = int(
    np.floor(len(source_train_dataset) / params.clinic.num_subjects)
)
target_part_patch_sampler = functools.partial(
    pitn.samplers.random_patches_from_mask,
    patch_size=params.patch_size,
    num_patches=num_clinic_samples_per_img,
)
target_random_patch_sample_f = lambda p: target_part_patch_sampler(p["dti"], p["mask"])

target_train_dataset = monai.data.PatchDataset(
    clinic_subj_dataset,
    patch_func=target_random_patch_sample_f,
    samples_per_image=num_clinic_samples_per_img,
)

# ! This will cause the source domain patches and the target domain patches to *not*
# be aligned; this is intentional for unpaired I2I.

# Zip together the source domain and target domain to feed into each training step.
train_dataset = monai.data.ZipDataset([source_train_dataset, target_train_dataset])

train_loader = monai.data.DataLoader(
    train_dataset,
    batch_size=params.batch_size,
    #     shuffle=True,
    pin_memory=True,
    num_workers=0,
    #     persistent_workers=True,
)

In [None]:
# tio_subject_ds = torchio.SubjectsDataset(
#     [
#         torchio.Subject(
#             {
#                 **dict(
#                     tio_img=torchio.ScalarImage(
#                         tensor=s["dti"], affine=s["dti_meta_dict"]["affine"]
#                     )
#                 ),
#                 **s,
#             }
#         )
#         for s in hcp_subj_dataset
#     ]
# )

## Model Definition

In [None]:
class ClinicMatchGAN(pl.LightningModule):
    def __init__(
        self,
        num_channels: int,
        gen_num_high_freq: int = 3,
        discriminator_downsample_factors=[1, 2, 4],
        lambda_adversary_loss: float = 0.1,
        lr=0.001,
        betas=(0.9, 0.999),
    ):
        super().__init__()
        self.save_hyperparameters()

        self.generator = pitn.nn.gan.generative.LPTN(
            num_channels, num_high_freq_levels=self.hparams.gen_num_high_freq
        )

        self.discriminator = pitn.nn.gan.adversarial.MultiDiscriminator(
            num_channels, self.hparams.discriminator_downsample_factors
        )

        self.plain_log = Box(loss_gen=dict(), loss_discrim=dict())
        print(self.device)

    def forward(self, x):
        return self.generator(x)

    def reconstruction_loss(self, y_source, y_pred):
        return F.mse_loss(y_source, y_pred, reduction="sum")

    def ls_adversarial_loss(self, sample, label: int):

        sample_pred = self.discriminator(sample)
        sample_loss = (
            (sample_pred - (torch.ones_like(sample_pred) * label)) ** 2
        ).mean()

        return sample_loss

    def training_step(self, batch, batch_idx, optimizer_idx):
        source_samples, target_samples = batch

        # Optimizer index decides whether this step updates the generator or discriminator.
        if optimizer_idx == self._GENERATOR_OPTIMIZER_IDX:
            translated_samples = self.generator(source_samples)
            l_g_reconstruct = self.reconstruction_loss(
                source_samples, translated_samples
            )

            l_g_adversarial = (
                self.ls_adversarial_loss(
                    translated_samples,
                    label=0,
                )
                / 2
            )

            loss_gen = l_g_reconstruct + (
                l_g_adversarial * self.hparams.lambda_adversary_loss
            )

            # Log loss and set up return dictionary.
            self.plain_log.loss_gen[self.global_step] = float(
                loss_gen.detach().cpu().item()
            )
            tqdm_dict = {"loss_gen": loss_gen}
            output = collections.OrderedDict(
                {"loss": loss_gen, "progress_bar": tqdm_dict, "log": tqdm_dict}
            )

        elif optimizer_idx == self._DISCRIMINATOR_OPTIMIZER_IDX:
            # Real images.
            loss_real = self.ls_adversarial_loss(target_samples, label=1)
            # Translated (i.e., fake) images
            # We aren't updating the generator weights here, so there's no need to
            # keep track of the generator's gradients.
            with torch.no_grad():
                translated_samples = self.generator(source_samples)
            loss_fake = self.ls_adversarial_loss(translated_samples, label=-1)

            loss_discrim = (loss_fake / 2) + (loss_real / 2)

            # Record loss and set up return dictionary.
            self.plain_log.loss_discrim[self.global_step] = float(
                loss_discrim.detach().cpu().item()
            )
            tqdm_dict = {"loss_discrim": loss_discrim}
            output = collections.OrderedDict(
                {"loss": loss_discrim, "progress_bar": tqdm_dict, "log": tqdm_dict}
            )
        else:
            raise RuntimeError(f"ERROR: Invalid optimizer index {optimizer_idx}")
        print(
            f"Step {self.global_step} optimizer {optimizer_idx} | loss {list(tqdm_dict.keys())[0]}: {output['loss']} | batch_size {len(source_samples)}"
        )
        return output

    def configure_optimizers(self):
        lr = self.hparams.lr
        betas = self.hparams.betas

        opt_gen = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=betas)
        opt_discriminator = torch.optim.Adam(
            self.discriminator.parameters(), lr=lr, betas=betas
        )

        self._GENERATOR_OPTIMIZER_IDX = 0
        self._DISCRIMINATOR_OPTIMIZER_IDX = 1
        return [opt_gen, opt_discriminator], []

## Model Training

In [None]:
# Training loop

# Instantiate model.
model = ClinicMatchGAN(
    params.num_channels,
    gen_num_high_freq=params.num_laplace_high_freq,
    discriminator_downsample_factors=params.discriminator_downscale_factors,
    lambda_adversary_loss=params.lambda_adversary_loss,
    lr=params.optim.lr,
    betas=params.optim.betas,
)

# Create trainer object.
trainer = pl.Trainer(
    gpus=1,
    max_epochs=params.max_epochs,
    logger=pl_logger,
    #     log_every_n_steps=50,
    #     progress_bar_refresh_rate=10,
    terminate_on_nan=True,
)

# Many warnings are produced here, so it's better for my sanity (i.e., worse in every other
# way) to just filter and ignore them...
# with warnings.catch_warnings():
#     warnings.simplefilter("ignore")
trainer.fit(model, train_dataloaders=train_loader)

---

In [None]:
ds = torch.utils.data.TensorDataset(torch.arange(0, 5, step=1 / 9).reshape(5, 3, 3))

sampler = torch.utils.data.WeightedRandomSampler(
    torch.ones(len(ds)), num_samples=3, replacement=False
)
loader = torch.utils.data.DataLoader(ds, sampler=sampler)

print(len(ds), len(loader))

for epoch in range(3):
    i = 0
    print(epoch)
    for el in loader:
        print(i, el)
        i += 1

In [None]:
class MaskFilteredPatchDataset3d(torch.utils.data.Dataset):
    def __init__(self, img: torch.Tensor, mask: torch.Tensor, patch_size: tuple):
        """Dataset of static patches where the center voxels lies in the foreground/mask.

        Parameters
        ----------
        img : torch.Tensor
            Source image, of shape [C x H x W x D]
        mask : torch.Tensor
            Mask that corresponds to the target locations in the img, of shape [H x W x D]
        patch_size : tuple
        """

        super().__init__()
        self.patch_size = monai.utils.misc.ensure_tuple_rep(patch_size, 3)
        self.img = img

        if mask.ndim == 4:
            mask = mask[0]
        if self.img.ndim == 3:
            self.img = self.img[None, ...]

        patch_centers = torch.stack(torch.where(mask))
        for i_dim, patch_dim_size in enumerate(self.patch_size):
            offset_lower = int(np.floor(patch_dim_size / 2))
            offset_upper = int(np.ceil(patch_dim_size / 2))
            patch_centers = patch_centers[
                :,
                (patch_centers[i_dim] >= offset_lower)
                & (patch_centers[i_dim] <= img.shape[i_dim + 1] - offset_upper),
            ]

        self.patch_starts = (
            patch_centers
            - torch.floor(torch.as_tensor(self.patch_size)[:, None] / 2).int()
        ).T

        def __len__(self):
            return len(self.patch_starts)

        def __getitem__(self, index):
            patch_start = self.patch_starts[index]
            patch = self.img[
                :,
                patch_start[0] : patch_start[0] + self.patch_size[0],
                patch_start[1] : patch_start[1] + self.patch_size[1],
                patch_start[2] : patch_start[2] + self.patch_size[2],
            ]
            return patch


class ConcatDatasetBalancedRandomSampler(torch.utils.data.Sampler):
    def __init__(self, datasets, max_samples_per_dataset, generator=None):
        """Sampler that draws a given number of samples from each dataset.

        datasets: List[torch.utils.data.Dataset]

        max_samples_per_dataset: int or List[int]
            Give a single integer to make sample amounts the same for all datasets, or
            a list of integers with length equal to the number of datasets to specify
            a sample number particular to each dataset. If a dataset is smaller than
            the requested number of samples, the entire length of the dataset will be
            used instead. Each Dataset *must* be a Map-style Dataset.
        """

        self.ds_lens = [len(ds) for ds in datasets]
        if (
            np.isscalar(max_samples_per_dataset)
            and int(max_samples_per_dataset) == max_samples_per_dataset
        ):
            self.sample_sizes = [
                max_samples_per_dataset,
            ] * len(self.ds_lens)
        else:
            self.sample_sizes = max_samples_per_dataset
            if len(self.sample_sizes) != len(self.ds_lens):
                raise ValueError(
                    "Must request sample sizes with length equal to the number of datasets"
                )
        # Make sure we don't assign more samples to a dataset than there are elements
        # in that dataset.
        self.sample_sizes = list(map(min, zip(self.ds_lens, self.sample_sizes)))
        cum_lens = list(itertools.accumulate(self.ds_lens))
        self.start_idx = [
            0,
        ] + cum_lens[:-1]
        self._total_samples = sum(self.sample_sizes)
        self.generator = generator

    def __iter__(self):
        samples = list()
        for sample_size, len_i, i_start in zip(
            self.sample_sizes, self.ds_lens, self.start_idx
        ):
            idx_i = (
                i_start
                + torch.randperm(len_i, generator=self.generator)[:sample_size]
            )
            samples.extend(idx_i.tolist())
        
        return (samples[i] for i in torch.randperm(len(samples), generator=self.generator))
        
    def __len__(self):
        return self._total_samples

In [None]:
cat_ds = [
    torch.utils.data.TensorDataset(torch.randint(0, 10, (4, 4, 4))),
    torch.utils.data.TensorDataset(torch.randint(0, 10, (4, 4, 4))*10),
#     torch.utils.data.TensorDataset(torch.randint(0, 10, (2, 4, 4))*100),
]
ds = torch.utils.data.ConcatDataset(cat_ds)
sampler = ConcatDatasetBalancedRandomSampler(cat_ds, 2)

In [None]:
for idx in sampler:
    print(idx)
    print(ds[idx])