In [None]:
mport os
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Conv3D, BatchNormalization, Activation,
                                     Add, MaxPooling3D, GlobalAveragePooling3D,
                                     Dropout, Dense)
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

# ✅ Training paths
train_dir = "/home/shrishailterniofficial/mri-data/train"
model_output_path = "/home/shrishailterniofficial/final_model_best_3.h5"

# ✅ File lists
train_images = sorted([f for f in os.listdir(train_dir) if "images_batch" in f])
train_labels = sorted([f for f in os.listdir(train_dir) if "labels_batch" in f])

# ✅ Residual Block Definition
def residual_block(x, filters, kernel_size=3, strides=1):
    shortcut = x

    # Main path
    x = Conv3D(filters, kernel_size, padding='same', strides=strides)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv3D(filters, kernel_size, padding='same')(x)
    x = BatchNormalization()(x)

    # Shortcut path
    if shortcut.shape[-1] != x.shape[-1]:
        shortcut = Conv3D(filters, kernel_size=1, padding='same')(shortcut)
        shortcut = BatchNormalization()(shortcut)

    # Add and activate
    x = Add()([shortcut, x])
    x = Activation('relu')(x)

    return x

# ✅ Build model
def build_residual_3d_cnn(input_shape=(80, 128, 128, 1)):
    inputs = Input(shape=input_shape)

    x = Conv3D(16, kernel_size=3, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = residual_block(x, 16)
    x = MaxPooling3D(pool_size=2)(x)

    x = residual_block(x, 32)
    x = MaxPooling3D(pool_size=2)(x)

    x = residual_block(x, 64)
    x = MaxPooling3D(pool_size=2)(x)

    x = GlobalAveragePooling3D()(x)
    x = Dense(64, activation='relu')(x)
    x = Dropout(0.4)(x)
    outputs = Dense(1, activation='sigmoid')(x)

    model = Model(inputs, outputs)
    return model

# ✅ Compile model
model = build_residual_3d_cnn()
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# ✅ Callbacks
checkpoint = ModelCheckpoint(
    filepath=model_output_path,
    monitor='val_loss',
    save_best_only=True,
    verbose=1
)

early_stop = EarlyStopping(
    monitor='val_loss',
    patience=2,
    restore_best_weights=True,
    verbose=1
)

# ✅ Training config
EPOCHS = 4
BATCH_SIZE = 1

# ✅ Batch-wise training loop
for epoch in range(EPOCHS):
    print(f"\n🚀 Training Epoch {epoch+1}/{EPOCHS}...\n")
    for img_file, lbl_file in zip(train_images, train_labels):
        X = np.load(os.path.join(train_dir, img_file))
        y = np.load(os.path.join(train_dir, lbl_file)).flatten()
        X = X.reshape(-1, 80, 128, 128, 1)

        model.fit(
            X, y,
            batch_size=BATCH_SIZE,
            epochs=1,
            validation_split=0.1,
            callbacks=[checkpoint, early_stop],
            verbose=1
        )
        del X, y  # Free memory
print("✅ Training complete.")