# Imports & config

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir('C:\\Users\\Usuario\\TFG\\digipanca\\')

In [3]:
import torch
import random
import numpy as np
from src.utils.config import load_config
from src.data.dataset2d import PancreasDataset2D
from src.data.dataset3d import PancreasDataset3D
from src.data.transforms import build_transforms_from_config
from src.data.augmentation import build_augmentations_from_config
from src.training.setup.dataset_factory import get_dataset

# __PancreasDataset2D__

In [4]:
transforms_config = load_config('configs/experiments/one_deep.yaml').get('transforms', None)
print(transforms_config)
aug_config = load_config('configs/experiments/one_deep.yaml').get('augmentations', None)
print(aug_config)
transforms = build_transforms_from_config(transforms_config)
augment = build_augmentations_from_config(aug_config)

[{'ApplyWindow': {'window_level': 50, 'window_width': 400}}, {'Normalize': {}}, {'CropBorders': {'crop_size': 120}}, {'Resize': {'size': [8, 8]}}, {'ToTensor': {}}]
[{'Affine': {'scale': [0.95, 1.05], 'translate_percent': [0.02, 0.02], 'rotate': [-10, 10], 'p': 0.2}}, {'RandomBrightnessContrast': {'brightness_limit': 0.2, 'contrast_limit': 0.2, 'p': 0.3}}, {'GaussianBlur': {'blur_limit': [3, 7], 'p': 0.3}}, {'ElasticTransform': {'alpha': 1.0, 'sigma': 50, 'p': 0.3}}, {'GridDistortion': {'num_steps': 5, 'distort_limit': 0.3, 'p': 0.3}}, {'ToTensorV2': {}}]


In [5]:
config = load_config('configs/experiments/deep_aug_5.yaml')

In [6]:
train_2d = get_dataset(
    config=config,
    split_type="train",
    transform=transforms,
    augment=augment
)
val_2d = get_dataset(
    config=config,
    split_type="val",
    transform=transforms
)

📊 Loading dataset... 7004 slices found.
📊 Loading dataset... 1830 slices found.


## __KFold test__

In [7]:
config = load_config('configs/experiments/test_kfcv.yaml')
n_splits = config['training']['n_splits']

for i in range(n_splits):
    print(f"Fold {i+1}/{n_splits}")
    train_2d = get_dataset(
        config=config,
        split_type="train",
        fold_idx=i,
        transform=transforms,
        augment=augment
    )
    val_2d = get_dataset(
        config=config,
        split_type="val",
        fold_idx=i,
        transform=transforms
    )

Fold 1/5
📊 Loading dataset... 7004 slices found.
📊 Loading dataset... 1830 slices found.
Fold 2/5
📊 Loading dataset... 7080 slices found.
📊 Loading dataset... 1754 slices found.
Fold 3/5
📊 Loading dataset... 7076 slices found.
📊 Loading dataset... 1758 slices found.
Fold 4/5
📊 Loading dataset... 6991 slices found.
📊 Loading dataset... 1843 slices found.
Fold 5/5
📊 Loading dataset... 7185 slices found.
📊 Loading dataset... 1649 slices found.


# __Old method__

In [15]:
def get_dataset_old(
    config,
    split_type='train',
    data_folder='train',
    fold_idx=None,
    transform=None,
    augment=None
):
    """Initialize dataset based on configuration.
    
    Parameters
    ----------
    config : dict
        Configuration dictionary.
    split_type : str, optional
        Split type (train/val/test), by default 'train'.
    data_folder : str, optional
        Data folder name, by default 'train'.
    fold_idx : int, optional
        Fold index for cross-validation, by default None.
    transform : callable, optional
        Transform function, by default None.
    augment : callable, optional
        Augmentation function, by default None.

    Returns
    -------
    PancreasDataset or PancreasDataset3D
        Pancreas dataset object.
    """
    # Check that there is not augmentation for validation/test sets
    if split_type != 'train' and augment is not None:
        raise ValueError("Augmentations are only allowed for the training set.")
    # Ensure split type is valid
    if split_type not in ['train', 'val', 'test']:
        raise ValueError(f"Invalid split type: {split_type}")
    
    # Check if cross-validation is enabled
    if fold_idx is not None:
        with open(config['data']['split_file'], 'r') as f:
            folds = json.load(f)
        patient_ids = folds[fold_idx][split_type]
    else:
        with open(config['data']['split_file'], 'r') as f:
            folds = json.load(f)
        patient_ids = folds[split_type]

    # data_folder is the folder name in the processed directory
    # e.g. 'train' or 'test'
    data_dir = os.path.join(config['data']['processed_dir'], data_folder)

    if config['data'].get('is_3d', False):
        return PancreasDataset3D(
            data_dir=data_dir,
            transform=transform,
            load_into_memory=config['data'].get('load_into_memory', False),
            patient_ids=patient_ids
        )
    else:
        return PancreasDataset2D(
            data_dir=data_dir,
            transform=transform,
            augment=augment,
            load_into_memory=config['data'].get('load_into_memory', False),
            patient_ids=patient_ids
        )

# __New method__

In [19]:
def get_dataset(
    config,
    data,
    split_type='train',
    data_folder='train',
    transform=None,
    augment=None
):
    """Initialize dataset based on configuration.
    
    Parameters
    ----------
    config : dict
        Configuration dictionary.
    split_type : str, optional
        Split type (train/val/test), by default 'train'.
    data_folder : str, optional
        Data folder name, by default 'train'.
    fold_idx : int, optional
        Fold index for cross-validation, by default None.
    transform : callable, optional
        Transform function, by default None.
    augment : callable, optional
        Augmentation function, by default None.

    Returns
    -------
    PancreasDataset or PancreasDataset3D
        Pancreas dataset object.
    """
    # Check that there is not augmentation for validation/test sets
    if split_type != 'train' and augment is not None:
        raise ValueError("Augmentations are only allowed for the training set.")
    # Ensure split type is valid
    if split_type not in ['train', 'val', 'test']:
        raise ValueError(f"Invalid split type: {split_type}")
    
    # Check if cross-validation is enabled
    patient_ids = data[split_type]

    print(patient_ids)

    # data_folder is the folder name in the processed directory
    # e.g. 'train' or 'test'
    data_dir = os.path.join(config['data']['processed_dir'], data_folder)

    if config['data'].get('is_3d', False):
        return PancreasDataset3D(
            data_dir=data_dir,
            transform=transform,
            load_into_memory=config['data'].get('load_into_memory', False),
            patient_ids=patient_ids
        )
    else:
        return PancreasDataset2D(
            data_dir=data_dir,
            transform=transform,
            augment=augment,
            load_into_memory=config['data'].get('load_into_memory', False),
            patient_ids=patient_ids
        )

## __Test new method__

In [9]:
import json

config = load_config('configs/experiments/deep_aug_5.yaml')
with open(config['data']['split_file'], 'r') as f:
    split_data = json.load(f)
    print(type(split_data))
split_data = [split_data]
for data in split_data:
    print(data)

print()
print('='*65)
print()

config = load_config('configs/experiments/test_kfcv.yaml')
with open(config['data']['split_file'], 'r') as f:
    split_data = json.load(f)
    print(type(split_data))

for data in split_data:
    print(data)

<class 'dict'>
{'train': ['rtum6', 'rtum76', 'rtum2', 'rtum37', 'rtum46', 'rtum35', 'rtum18', 'rtum7', 'rtum14', 'rtum45', 'rtum38', 'rtum41', 'rtum50', 'rtum5', 'rtum24', 'rtum48', 'rtum40', 'rtum16', 'rtum54', 'rtum65', 'rtum27', 'rtum75', 'rtum32', 'rtum52', 'rtum21', 'rtum31', 'rtum12', 'rtum25', 'rtum44', 'rtum17', 'rtum72', 'rtum15', 'rtum60', 'rtum42', 'rtum80', 'rtum9', 'rtum59', 'rtum49', 'rtum55', 'rtum51', 'rtum88', 'rtum23', 'rtum73', 'rtum34', 'rtum47', 'rtum66', 'rtum62', 'rtum53', 'rtum8', 'rtum61', 'rtum39', 'rtum85', 'rtum63', 'rtum67', 'rtum78', 'rtum43', 'rtum36', 'rtum10', 'rtum57', 'rtum29', 'rtum11', 'rtum30', 'rtum83', 'rtum77', 'rtum84', 'rtum28', 'rtum64', 'rtum74', 'rtum22', 'rtum56'], 'val': ['rtum79', 'rtum1', 'rtum33', 'rtum3', 'rtum20', 'rtum70', 'rtum19', 'rtum26', 'rtum13', 'rtum71', 'rtum87', 'rtum69', 'rtum58', 'rtum82', 'rtum86', 'rtum68', 'rtum4', 'rtum81'], 'test': []}


<class 'list'>
{'train': ['rtum10', 'rtum11', 'rtum12', 'rtum14', 'rtum15', 'rt

In [10]:
from src.data.split_manager import SplitManager

In [11]:
config = load_config('configs/experiments/deep_aug_5.yaml')
spman = SplitManager(split_data=config['data']['split_file'])
for data in spman:
    print(data)

print()
print('='*65)
print()

config = load_config('configs/experiments/test_kfcv.yaml')
spman = SplitManager(split_data=config['data']['split_file'])
for data in spman:
    print(data)

{'train': ['rtum6', 'rtum76', 'rtum2', 'rtum37', 'rtum46', 'rtum35', 'rtum18', 'rtum7', 'rtum14', 'rtum45', 'rtum38', 'rtum41', 'rtum50', 'rtum5', 'rtum24', 'rtum48', 'rtum40', 'rtum16', 'rtum54', 'rtum65', 'rtum27', 'rtum75', 'rtum32', 'rtum52', 'rtum21', 'rtum31', 'rtum12', 'rtum25', 'rtum44', 'rtum17', 'rtum72', 'rtum15', 'rtum60', 'rtum42', 'rtum80', 'rtum9', 'rtum59', 'rtum49', 'rtum55', 'rtum51', 'rtum88', 'rtum23', 'rtum73', 'rtum34', 'rtum47', 'rtum66', 'rtum62', 'rtum53', 'rtum8', 'rtum61', 'rtum39', 'rtum85', 'rtum63', 'rtum67', 'rtum78', 'rtum43', 'rtum36', 'rtum10', 'rtum57', 'rtum29', 'rtum11', 'rtum30', 'rtum83', 'rtum77', 'rtum84', 'rtum28', 'rtum64', 'rtum74', 'rtum22', 'rtum56'], 'val': ['rtum79', 'rtum1', 'rtum33', 'rtum3', 'rtum20', 'rtum70', 'rtum19', 'rtum26', 'rtum13', 'rtum71', 'rtum87', 'rtum69', 'rtum58', 'rtum82', 'rtum86', 'rtum68', 'rtum4', 'rtum81'], 'test': []}


{'train': ['rtum10', 'rtum11', 'rtum12', 'rtum14', 'rtum15', 'rtum16', 'rtum17', 'rtum18', 'rt

In [20]:
config = load_config('configs/experiments/deep_aug_5.yaml')
spman = SplitManager(split_data=config['data']['split_file'])
print("TRAIN-VAL SPLIT")
for i, split in enumerate(spman):
    print(f"Fold {i+1}/{len(spman)}")
    train_2d = get_dataset(
        config=config,
        data=split,
        split_type="train",
        transform=transforms,
        augment=augment
    )
    val_2d = get_dataset(
        config=config,
        data=split,
        split_type="val",
        transform=transforms
    )

TRAIN-VAL SPLIT
Fold 1/1
['rtum6', 'rtum76', 'rtum2', 'rtum37', 'rtum46', 'rtum35', 'rtum18', 'rtum7', 'rtum14', 'rtum45', 'rtum38', 'rtum41', 'rtum50', 'rtum5', 'rtum24', 'rtum48', 'rtum40', 'rtum16', 'rtum54', 'rtum65', 'rtum27', 'rtum75', 'rtum32', 'rtum52', 'rtum21', 'rtum31', 'rtum12', 'rtum25', 'rtum44', 'rtum17', 'rtum72', 'rtum15', 'rtum60', 'rtum42', 'rtum80', 'rtum9', 'rtum59', 'rtum49', 'rtum55', 'rtum51', 'rtum88', 'rtum23', 'rtum73', 'rtum34', 'rtum47', 'rtum66', 'rtum62', 'rtum53', 'rtum8', 'rtum61', 'rtum39', 'rtum85', 'rtum63', 'rtum67', 'rtum78', 'rtum43', 'rtum36', 'rtum10', 'rtum57', 'rtum29', 'rtum11', 'rtum30', 'rtum83', 'rtum77', 'rtum84', 'rtum28', 'rtum64', 'rtum74', 'rtum22', 'rtum56']
📊 Loading dataset... 7004 slices found.
['rtum79', 'rtum1', 'rtum33', 'rtum3', 'rtum20', 'rtum70', 'rtum19', 'rtum26', 'rtum13', 'rtum71', 'rtum87', 'rtum69', 'rtum58', 'rtum82', 'rtum86', 'rtum68', 'rtum4', 'rtum81']
📊 Loading dataset... 1830 slices found.


In [18]:
config = load_config('configs/experiments/test_kfcv.yaml')
spman = SplitManager(split_data=config['data']['split_file'])
print("5-FOLD CV SPLIT")
for i, split in enumerate(spman):
    print(f"Fold {i+1}/{len(spman)}")
    print(split)
    train_2d = get_dataset(
        config=config,
        data=split,
        split_type="train",
        transform=transforms,
        augment=augment
    )
    val_2d = get_dataset(
        config=config,
        data=split,
        split_type="val",
        transform=transforms
    )

5-FOLD CV SPLIT
Fold 1/5
{'train': ['rtum10', 'rtum11', 'rtum12', 'rtum14', 'rtum15', 'rtum16', 'rtum17', 'rtum18', 'rtum2', 'rtum21', 'rtum22', 'rtum23', 'rtum24', 'rtum25', 'rtum27', 'rtum28', 'rtum29', 'rtum30', 'rtum31', 'rtum32', 'rtum34', 'rtum35', 'rtum36', 'rtum37', 'rtum38', 'rtum39', 'rtum40', 'rtum41', 'rtum42', 'rtum43', 'rtum44', 'rtum45', 'rtum46', 'rtum47', 'rtum48', 'rtum49', 'rtum5', 'rtum50', 'rtum51', 'rtum52', 'rtum53', 'rtum54', 'rtum55', 'rtum56', 'rtum57', 'rtum59', 'rtum6', 'rtum60', 'rtum61', 'rtum62', 'rtum63', 'rtum64', 'rtum65', 'rtum66', 'rtum67', 'rtum7', 'rtum72', 'rtum73', 'rtum74', 'rtum75', 'rtum76', 'rtum77', 'rtum78', 'rtum8', 'rtum80', 'rtum83', 'rtum84', 'rtum85', 'rtum88', 'rtum9'], 'val': ['rtum1', 'rtum13', 'rtum19', 'rtum20', 'rtum26', 'rtum3', 'rtum33', 'rtum4', 'rtum58', 'rtum68', 'rtum69', 'rtum70', 'rtum71', 'rtum79', 'rtum81', 'rtum82', 'rtum86', 'rtum87']}
📊 Loading dataset... 7004 slices found.
📊 Loading dataset... 1830 slices found.
Fol

In [22]:
config = load_config('configs/experiments/one_deep_kf.yaml')
spman = SplitManager(split_data=config['data']['split_file'])
print("2-FOLD CV SPLIT")
for i, split in enumerate(spman):
    print(f"Fold {i+1}/{len(spman)}")
    print(split)
    train_2d = get_dataset(
        config=config,
        data=split,
        split_type="train",
        transform=transforms,
        augment=augment
    )
    val_2d = get_dataset(
        config=config,
        data=split,
        split_type="val",
        transform=transforms
    )

2-FOLD CV SPLIT
Fold 1/2
{'train': ['rtum10'], 'val': ['rtum1']}
['rtum10']
📊 Loading dataset... 56 slices found.
['rtum1']
📊 Loading dataset... 91 slices found.
Fold 2/2
{'train': ['rtum1'], 'val': ['rtum14']}
['rtum1']
📊 Loading dataset... 91 slices found.
['rtum14']
📊 Loading dataset... 125 slices found.
