In [None]:
import random
from pathlib import Path

import torch
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader

import utils

In [None]:
# directorys with data and to store training checkpoints and logs
DATA_DIR = Path.cwd().parent / "TrainingData" # pathlib library is used to handle paths

# data settings 
NO_VALIDATION_PATIENTS = 2
IMAGE_SIZE = [64, 64]
BATCH_SIZE = 32

In [None]:
# find patient folders in training directory
# excluding hidden folders (start with .)
patients = [
    path
    for path in DATA_DIR.glob("*")
    if not any(part.startswith(".") for part in path.parts)
]
random.shuffle(patients)

# split in training/validation after shuffling
partition = {
    "train": patients[:-NO_VALIDATION_PATIENTS],
    "validation": patients[-NO_VALIDATION_PATIENTS:],
}

In [None]:
# load training data and create DataLoader with batching and shuffling
dataset = utils.ProstateMRDataset(partition["train"], IMAGE_SIZE)
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
)

In [None]:
# load validation data
valid_dataset = utils.ProstateMRDataset(partition["validation"], IMAGE_SIZE)
valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
)

In [None]:
input, target = next(iter(dataloader))

In [None]:
print(input.shape)

In [None]:
plt.subplot(121)
plt.imshow(input[5,0,...], cmap="gray")
plt.axis("off")
plt.subplot(122)
plt.imshow(target[5,0,...], cmap="gray")
plt.axis("off")
plt.show()