In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, LayerNormalization, Flatten, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.applications import ConvNeXtBase
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
from datetime import datetime

# Parameters
num_classes = 6
image_size = (224, 224)
dropout_rate = 0.1
batch_size = 32
epochs = 20

# ConvNeXt Model
def create_convnext_model(input_shape, num_classes):
    inputs = Input(shape=input_shape)

    # ConvNeXt Backbone
    convnext = ConvNeXtBase(include_top=False, weights="imagenet", input_tensor=inputs)
    convnext.trainable = False
    x = convnext(inputs)
    
    x = LayerNormalization(epsilon=1e-6)(x)
    x = GlobalAveragePooling2D()(x)
    x = Dropout(dropout_rate)(x)
    outputs = Dense(num_classes, activation="softmax")(x)

    model = Model(inputs=inputs, outputs=outputs)
    return model

# Simplified Metrics Callback
class MetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, total_batches):
        super().__init__()
        self.total_batches = total_batches

    def on_epoch_begin(self, epoch, logs=None):
        print(f"\nEpoch {epoch + 1}/{self.params['epochs']}")

    def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        accuracy = logs.get('accuracy', 0)
        loss = logs.get('loss', 0)
        current_time = datetime.now().strftime("%H:%M:%S")
        print(f"Batch {batch+1}/{self.total_batches} ━━━━━━━━━━━━━━━━━━━━ {current_time}")
        print(f"Accuracy: {accuracy:.4f} - Loss: {loss:.4f}\n")

# Compile Model
input_shape = (224, 224, 3)
model = create_convnext_model(input_shape, num_classes)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])

# Data Loading and Preprocessing
csv_path = r"C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Ultrasound_Fetal\Data\FETAL_PLANES_DB_data.csv"
df = pd.read_csv(csv_path, delimiter=";")

# Shuffle the DataFrame for a random split
df = df.sample(frac=1, random_state=42).reset_index(drop=True)

# Add .png extension to each image name
df["Image_name"] = df["Image_name"].apply(lambda x: f"{x}.png")

# Image data generator with train-validation split
datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

train_gen = datagen.flow_from_dataframe(
    dataframe=df,
    directory=r"C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Ultrasound_Fetal\Data\Images",
    x_col="Image_name",
    y_col="Plane",
    target_size=image_size,
    class_mode="sparse",
    batch_size=batch_size,
    subset="training"
)

val_gen = datagen.flow_from_dataframe(
    dataframe=df,
    directory=r"C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Ultrasound_Fetal\Data\Images",
    x_col="Image_name",
    y_col="Plane",
    target_size=image_size,
    class_mode="sparse",
    batch_size=batch_size,
    subset="validation",
    shuffle=False
)

# Early Stopping
early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

# Training with Metrics Callback
history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=epochs,
    callbacks=[early_stopping, MetricsCallback(total_batches=len(train_gen))],
    verbose=0  # Set verbose to 0 to avoid duplicate output
)

# Evaluate on Validation Data
val_loss, val_accuracy = model.evaluate(val_gen)
print(f"Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}")

# Calculate Metrics on Validation Set
val_preds = model.predict(val_gen)
val_labels = val_gen.classes

# Convert predictions to label format
val_preds = np.argmax(val_preds, axis=1)

precision = precision_score(val_labels, val_preds, average='weighted')
recall = recall_score(val_labels, val_preds, average='weighted')
f1 = f1_score(val_labels, val_preds, average='weighted')
print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")

# Visualization of Training History
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.show()



Found 9920 validated image filenames belonging to 6 classes.
Found 2480 validated image filenames belonging to 6 classes.

Epoch 1/20


  self._warn_if_super_not_called()


Batch 1/310 ━━━━━━━━━━━━━━━━━━━━ 23:59:56
Accuracy: 0.4062 - Loss: 1.7511

Batch 2/310 ━━━━━━━━━━━━━━━━━━━━ 00:00:47
Accuracy: 0.3750 - Loss: 1.7549

Batch 3/310 ━━━━━━━━━━━━━━━━━━━━ 00:01:37
Accuracy: 0.3542 - Loss: 1.7538

Batch 4/310 ━━━━━━━━━━━━━━━━━━━━ 00:02:27
Accuracy: 0.3281 - Loss: 1.7598

Batch 5/310 ━━━━━━━━━━━━━━━━━━━━ 00:03:16
Accuracy: 0.3375 - Loss: 1.7876

Batch 6/310 ━━━━━━━━━━━━━━━━━━━━ 00:04:05
Accuracy: 0.3281 - Loss: 1.7785

Batch 7/310 ━━━━━━━━━━━━━━━━━━━━ 00:04:54
Accuracy: 0.3214 - Loss: 1.7853

Batch 8/310 ━━━━━━━━━━━━━━━━━━━━ 00:05:42
Accuracy: 0.3281 - Loss: 1.7841

Batch 9/310 ━━━━━━━━━━━━━━━━━━━━ 00:06:31
Accuracy: 0.3299 - Loss: 1.7733

Batch 10/310 ━━━━━━━━━━━━━━━━━━━━ 00:07:19
Accuracy: 0.3438 - Loss: 1.7547

Batch 11/310 ━━━━━━━━━━━━━━━━━━━━ 00:08:08
Accuracy: 0.3466 - Loss: 1.7449

Batch 12/310 ━━━━━━━━━━━━━━━━━━━━ 00:08:57
Accuracy: 0.3516 - Loss: 1.7370

Batch 13/310 ━━━━━━━━━━━━━━━━━━━━ 00:09:45
Accuracy: 0.3606 - Loss: 1.7253

Batch 14/310 ━━━━━━━━