In [1]:
# Import dependencies
import os
os.environ['TF_CPP_MIN_VLOG_LEVEL'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


import pandas as pd
import numpy as np
import tensorflow as tf
import logging
from tqdm import tqdm

# Import Tensorflow Keras
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy, BinaryFocalCrossentropy
from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint

# Import local modules
from src.utils.consts import TF_RECORD_DATASET, MODELS_PATH, TF_BUFFER_SIZE, NUM_CLASSES, TF_SHUFFLE_SIZE, TF_BATCH_SIZE
from src.model.tensorflow_utils import load_dataset, apply_augmentation_to_dataset, oversample_minority_classes, optimize_dataset, count_dataset_size
from src.model.tensorflow_utils import setup_logger, setup_training_logger, setup_metrics_monitor, setup_loss_monitor, setup_garbage_collector, get_metrics
from src.model.tensorflow_utils import calculate_class_weights, show_class_weights, start_or_resume_training, analyze_class_distribution
from src.model.densnet.tensorflow_dense_net_121 import build_densenet121
from src.model.loss.tensorflow_no_finding_binary_crossentropy import NoFindingBinaryCrossentropy

# Input Data
initial_epoch   = 30
resume_training = True
checkpoint_path = '/Users/piotr.r/Projects/codebook/studies/bachelor-thesis/models/DenseNet121_v3_2/checkpoints/cp-0029.keras'
model_name      = "DenseNet121_v3_2"

In [2]:
train_ds = load_dataset(f"{TF_RECORD_DATASET}/train.tfrecord", TF_BUFFER_SIZE)
val_ds   = load_dataset(f"{TF_RECORD_DATASET}/val.tfrecord", TF_BUFFER_SIZE)
test_ds  = load_dataset(f"{TF_RECORD_DATASET}/test.tfrecord", TF_BUFFER_SIZE)

I0000 00:00:1743320162.775732 25102790 pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
I0000 00:00:1743320162.776359 25102790 pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [3]:
# Optimize Dataset for rare clasess
class_weights = calculate_class_weights(train_ds, NUM_CLASSES)
train_ds      = oversample_minority_classes(train_ds, class_weights)
class_weights = calculate_class_weights(train_ds, NUM_CLASSES)

In [4]:
steps_per_epoch  = int(count_dataset_size(train_ds, None) / TF_BATCH_SIZE)
validation_steps = int(count_dataset_size(val_ds, None) / TF_BATCH_SIZE)

# Testing
train_ds = train_ds.shuffle(TF_SHUFFLE_SIZE, reshuffle_each_iteration=True)
train_ds = apply_augmentation_to_dataset(train_ds)
train_ds = optimize_dataset(train_ds, TF_BATCH_SIZE)

val_ds  = optimize_dataset(val_ds, TF_BATCH_SIZE)

Counting samples: 108109 samples [01:18, 1375.23 samples/s] 
Counting samples: 15391 samples [00:17, 869.87 samples/s]


In [5]:
# Setup Model Deps
# Setup Loggers
logger            = setup_logger()
training_logger   = setup_training_logger(logger, TF_BATCH_SIZE, 100)
metrics_monitor   = setup_metrics_monitor(MODELS_PATH, model_name, logger, resume_training=resume_training, initial_epoch=initial_epoch)
loss_monitor      = setup_loss_monitor(MODELS_PATH, model_name, logger, val_ds, resume_training=resume_training, initial_epoch=initial_epoch)
garbage_collector = setup_garbage_collector(logger)
metrics           = get_metrics()

# Setup compile arguments
loss       = NoFindingBinaryCrossentropy(10, with_sigmoid=True, lambda_value=0.1, from_logits=False, label_smoothing=0.01)
reduce_lr  = ReduceLROnPlateau(monitor="val_f1_score", factor=0.5,  patience=3, min_lr=1e-6, mode="max", verbose=1)

epoch_mode           = 'cp-{epoch:04d}'
save_checkpoint_path = f"{MODELS_PATH}/{model_name}/checkpoints/{epoch_mode}.keras"
checkpoint           = ModelCheckpoint(save_checkpoint_path, monitor="val_f1_score", save_best_only=False, mode="max")

model_path      = f"{MODELS_PATH}/{model_name}.keras"
best_checkpoint = ModelCheckpoint(model_path, monitor="val_f1_score", save_best_only=True, mode="max")

2025-03-30 09:40:22 - INFO - Resuming from existing metrics file: /Users/piotr.r/Projects/codebook/studies/bachelor-thesis/models/DenseNet121_v3_2/train_metrics.csv
2025-03-30 09:40:22 - INFO - Resuming from existing validation metrics file: /Users/piotr.r/Projects/codebook/studies/bachelor-thesis/models/DenseNet121_v3_2/val_metrics.csv
2025-03-30 09:40:22 - INFO - Cleaned training metrics file, kept 97962 records before epoch 30
2025-03-30 09:40:22 - INFO - Cleaned validation metrics file, kept 29 records before epoch 30
2025-03-30 09:40:22 - INFO - Found 97962 existing training records
2025-03-30 09:40:22 - INFO - Found 29 existing validation records
2025-03-30 09:40:22 - INFO - Cleaned loss analysis metrics file, kept 435 records before epoch 30
2025-03-30 09:40:22 - INFO - Resuming from existing loss analysis file: /Users/piotr.r/Projects/codebook/studies/bachelor-thesis/models/DenseNet121_v3_2/loss_analysis_metrics.csv


In [6]:
# Model Training
model          = build_densenet121(NUM_CLASSES, use_se=True)
compile_kwargs = {'optimizer': Adam(learning_rate=1e-4, clipnorm=1.0), 'loss': loss, 'metrics': metrics}

history, model = start_or_resume_training(
    model, 
    compile_kwargs, 
    train_ds, 
    val_ds, 
    30,
    steps_per_epoch, 
    validation_steps, 
    class_weights=class_weights,
    callbacks=[checkpoint, best_checkpoint, reduce_lr, training_logger, metrics_monitor, loss_monitor, garbage_collector], 
    checkpoint_path=checkpoint_path,
    initial_epoch=initial_epoch,
    output_dir=MODELS_PATH,
    model_name=model_name,
    logger=logger
)

2025-03-30 09:40:23 - INFO - Resuming from existing metrics file: /Users/piotr.r/Projects/codebook/studies/bachelor-thesis/models/DenseNet121_v3_2/train_metrics.csv
2025-03-30 09:40:23 - INFO - Resuming from existing validation metrics file: /Users/piotr.r/Projects/codebook/studies/bachelor-thesis/models/DenseNet121_v3_2/val_metrics.csv
2025-03-30 09:40:24 - INFO - Cleaned training metrics file, kept 97962 records before epoch 30
2025-03-30 09:40:24 - INFO - Cleaned validation metrics file, kept 29 records before epoch 30
2025-03-30 09:40:24 - INFO - Found 97962 existing training records
2025-03-30 09:40:24 - INFO - Found 29 existing validation records
2025-03-30 09:40:24 - INFO - Cleaned loss analysis metrics file, kept 435 records before epoch 30
2025-03-30 09:40:24 - INFO - Resuming from existing loss analysis file: /Users/piotr.r/Projects/codebook/studies/bachelor-thesis/models/DenseNet121_v3_2/loss_analysis_metrics.csv


Loading full model from checkpoint: /Users/piotr.r/Projects/codebook/studies/bachelor-thesis/models/DenseNet121_v3_2/checkpoints/cp-0029.keras


2025-03-30 09:40:26 - INFO - 
=== Training Started ===

2025-03-30 09:40:26 - INFO - Batch Size: 32
2025-03-30 09:40:26 - INFO - Optimizer: Adam
2025-03-30 09:40:26 - INFO - 

2025-03-30 09:40:26 - INFO - 
=== Starting Epoch 30 ===



Epoch 30/30
[1m3378/3378[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 606ms/step - accuracy: 0.9215 - auc: 0.8290 - f1_score: 0.3686 - loss: 0.2203 - precision: 0.6698 - recall: 0.2876   

2025-03-30 10:17:19 - INFO - 
=== Epoch 30 Summary ===
2025-03-30 10:17:19 - INFO - Time: 2213.02s
2025-03-30 10:17:19 - INFO - Training   - accuracy: 0.9207 - auc: 0.8272 - f1_score: 0.3557 - loss: 0.2207 - precision: 0.6653 - recall: 0.2758 - learning_rate: 0.0001
2025-03-30 10:17:19 - INFO - Validation - accuracy: 0.9232 - auc: 0.7623 - f1_score: 0.3699 - loss: 0.3264 - precision: 0.5989 - recall: 0.3101



[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 155ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 134ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 135ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 136ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 133ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 130ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 136ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 141ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 131ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 130ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 134ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 129ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1

2025-03-30 10:18:55 - INFO - 
Loss Analysis - Atelectasis
2025-03-30 10:18:55 - INFO - Confidence Distribution:
2025-03-30 10:18:55 - INFO - -- High (>0.9): 0.06%
2025-03-30 10:18:55 - INFO - -- Medium (0.6-0.9): 1.77%
2025-03-30 10:18:55 - INFO - -- Uncertain (0.4-0.6): 3.45%
2025-03-30 10:18:55 - INFO - -- Low (<0.4): 94.72%
2025-03-30 10:18:55 - INFO - Performance:
2025-03-30 10:18:55 - INFO - -- True Positives: 202
2025-03-30 10:18:55 - INFO - -- False Positives: 305
2025-03-30 10:18:55 - INFO - -- Loss Contribution: 0.3286
2025-03-30 10:18:55 - INFO - Average Confidence:
2025-03-30 10:18:55 - INFO - -- Correct Predictions: 7.89%
2025-03-30 10:18:55 - INFO - -- Incorrect Predictions: 23.64%
2025-03-30 10:18:55 - INFO - 
Loss Analysis - Cardiomegaly
2025-03-30 10:18:55 - INFO - Confidence Distribution:
2025-03-30 10:18:55 - INFO - -- High (>0.9): 0.10%
2025-03-30 10:18:55 - INFO - -- Medium (0.6-0.9): 0.94%
2025-03-30 10:18:55 - INFO - -- Uncertain (0.4-0.6): 0.89%
2025-03-30 10:18:

[1m3378/3378[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2316s[0m 662ms/step - accuracy: 0.9215 - auc: 0.8290 - f1_score: 0.3686 - loss: 0.2203 - precision: 0.6698 - recall: 0.2876 - val_accuracy: 0.9232 - val_auc: 0.7623 - val_f1_score: 0.3699 - val_loss: 0.3264 - val_precision: 0.5989 - val_recall: 0.3101 - learning_rate: 1.0000e-04


2025-03-30 10:19:02 - INFO - 
=== Training Completed! ===

2025-03-30 10:19:02 - INFO - Final Metrics: accuracy: 0.9207 - auc: 0.8272 - f1_score: 0.3557 - loss: 0.2207 - precision: 0.6653 - recall: 0.2758 - val_accuracy: 0.9232 - val_auc: 0.7623 - val_f1_score: 0.3699 - val_loss: 0.3264 - val_precision: 0.5989 - val_recall: 0.3101

