In [1]:
import sys
import os
import tensorflow as tf

from sklearn.model_selection import train_test_split

In [2]:
sys.path.append(os.path.abspath('..'))
from src.generator import BraTSGenerator
from src.architectures.unet_3d import build_unet_3d
from src.training.losses import dice_loss, dice_coef

In [3]:
DATA_DIR = '../data/02_processed'
BATCH_SIZE = 2
EPOCHS = 100
LR = 1e-4
IMG_SIZE = (128, 128, 128)

In [4]:
patient_ids = os.listdir(DATA_DIR)
train_ids, val_ids = train_test_split(patient_ids, test_size=0.2, random_state=42)

print(f"Trening: {len(train_ids)} pacjentów")
print(f"Walidacja: {len(val_ids)} pacjentów")

Trening: 295 pacjentów
Walidacja: 74 pacjentów


In [5]:
train_gen = BraTSGenerator(train_ids, DATA_DIR, batch_size=BATCH_SIZE, img_size=IMG_SIZE)
val_gen = BraTSGenerator(val_ids, DATA_DIR, batch_size=BATCH_SIZE, img_size=IMG_SIZE)

In [7]:
model = build_unet_3d(input_shape=(*IMG_SIZE, 4), start_filters=32)
optimizer = tf.keras.optimizers.Adam(learning_rate=LR)
model.compile(optimizer=optimizer, loss=dice_loss, metrics=[dice_coef, 'accuracy'])

checkpoint_path = "../models/best_model.keras"
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        checkpoint_path,
        verbose=1,
        save_best_only=True,
        monitor='val_dice_coef',
        mode='max'
    ),

    tf.keras.callbacks.ReduceLROnPlateau(
		monitor='val_loss',
		factor=0.5,
		patience=5,
		verbose=1
	),

	tf.keras.callbacks.EarlyStopping(
        monitor='val_dice_coef',
        patience=15,
        mode='max',
        verbose=1,
        restore_best_weights=True
    )
]

In [8]:
history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=EPOCHS,
    callbacks=callbacks,
	verbose=1
)

  self._warn_if_super_not_called()


Epoch 1/100
[1m 17/147[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m18:44[0m 9s/step - accuracy: 0.3736 - dice_coef: 0.3136 - loss: 0.6864

KeyboardInterrupt: 