In [21]:
import wandb
from src.data import DataModule
from src.config import radiomics_folder, lesion_level_labels_csv


wandb.init()
wandb.config.roi_selection_method = "crop"
wandb.config.aggregation_function = "min"
wandb.config.roi_size = 150
wandb.config.optimizer = "adamw"
wandb.config.weight_decay = 0.00001
wandb.config.model = "SEResNet50"
wandb.config.dropout = 0.07292136035956572
wandb.config.momentum = 0
wandb.config.pretrained = False
wandb.config.learning_rate_max = 0.000023059510738335888
wandb.config.sampler = "stratified"
wandb.config.dim = 2
wandb.config.size = 128 if wandb.config.dim == 3 else 256
wandb.config.test_center = None  # "amphia"
wandb.config.lesion_target = "lesion_response"
wandb.config.patient_target = "response"
wandb.config.max_batch_size = 6 if wandb.config.dim == 3 else 32
wandb.config.seed = 0
wandb.config.max_epochs = 100
wandb.config.patience = 10
wandb.config.lr_min = 1e-7
wandb.config.T_0 = 10

dm = DataModule(
    radiomics_folder,
    lesion_level_labels_csv,
    wandb.config,
)

2023-01-16 10:17:49,238 - Created a temporary directory at /tmp/tmpqrniwuaa
2023-01-16 10:17:49,240 - Writing /tmp/tmpqrniwuaa/_remote_module_non_scriptable.py
2023-01-16 10:17:49,597 - Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mrenstermaat[0m ([33mpremium[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [22]:
dm.setup()

In [23]:
dl = dm.train_dataloader()

batches = []
for batch in dl:
    batches.append(
        list(zip(batch['patient'], batch['label'].numpy().tolist()))
    )

Loading dataset: 100%|██████████| 1332/1332 [00:00<00:00, 1000503.75it/s]


In [24]:
val_dl = dm.val_dataloader()

val_batches = []
for batch in val_dl:
    val_batches.append(
        list(zip(batch['patient'], batch['label'].numpy().tolist()))
    )

Loading dataset: 100%|██████████| 568/568 [00:00<00:00, 596485.90it/s]


In [43]:
from collections import defaultdict

def count_per_patient(batches):
    count = defaultdict(int)
    for batch in batches:
        for case in batch:
            count[case[0]] += 1

    assert max(list(count.values())) <= 5

def no_overlap(train, val):
    train_patients = set()
    for batch in train:
        for case in batch:
            train_patients.add(case[0])

    val_patients = set()
    for batch in val:
        for case in batch:
            val_patients.add(case[0])

    assert train_patients.isdisjoint(val_patients)

count_per_patient(batches)
no_overlap(batches, val_batches)