In [None]:
import os
import numpy as np
import tensorflow as tf
import keras
from keras import layers
import nibabel as nib
from scipy import ndimage
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_curve
from sklearn.utils.class_weight import compute_class_weight

# Define paths to your dataset
data_dir = "/kaggle/input/all-zendo-dataset/DATA"
t2wi_dir = os.path.join(data_dir, "T2WI")
csv_path = "/kaggle/input/all-zendo-dataset/DATA/all_centers_combined.csv"

# Load the CSV file
df = pd.read_csv(csv_path)
df['image_name'] = df['image_name'].str.replace('.nii.gz', '.nii')

# Helper functions for preprocessing
def read_nifti_file(filepath):
    scan = nib.load(filepath).get_fdata()
    return scan

def normalize(volume):
    min_hu, max_hu = -1000, 400
    volume = np.clip(volume, min_hu, max_hu)
    volume = (volume - min_hu) / (max_hu - min_hu)
    return volume.astype("float32")

def resize_volume(img):
    desired_depth, desired_width, desired_height = 64, 128, 128
    current_depth, current_width, current_height = img.shape[-1], img.shape[0], img.shape[1]
    depth_factor = current_depth / desired_depth
    width_factor = current_width / desired_width
    height_factor = current_height / desired_height
    img = ndimage.rotate(img, 90, reshape=False)
    img = ndimage.zoom(img,
                       (1/width_factor, 1/height_factor, 1/depth_factor),
                       order=1)
    return img

def process_scan(path):
    volume = read_nifti_file(path)
    volume = normalize(volume)
    return resize_volume(volume)

# Load and pair T2WI images with labels
nmbic_scan_paths = []
mbic_scan_paths = []
for subfolder in os.listdir(t2wi_dir):
    folder = os.path.join(t2wi_dir, subfolder)
    if not os.path.isdir(folder):
        continue
    files = [f for f in os.listdir(folder) if f.endswith(".nii")]
    if not files:
        continue
    path = os.path.join(folder, files[0])
    img_id = subfolder.split(".")[0]
    row = df[df['image_name'].str.replace('.nii', '') == img_id]
    if row.empty:
        print(f"No label found for T2WI ID {img_id}")
        continue
    label = int(row['label'].iloc[0])
    if label == 0:
        nmbic_scan_paths.append(path)
    else:
        mbic_scan_paths.append(path)

print(f"NMBIC scans: {len(nmbic_scan_paths)}")
print(f"MBIC scans: {len(mbic_scan_paths)}")

# Process scans
nmbic_scans = np.array([process_scan(p) for p in nmbic_scan_paths])
mbic_scans = np.array([process_scan(p) for p in mbic_scan_paths])

# Assign labels
nmbic_labels = np.zeros(len(nmbic_scans), dtype=int)
mbic_labels = np.ones(len(mbic_scans), dtype=int)

# Split data into training, validation, and test (60-20-20)
def split_data(X, y, train_ratio=0.6, val_ratio=0.2):
    n_total = len(X)
    n_train = int(train_ratio * n_total)
    n_val = int(val_ratio * n_total)
    return X[:n_train], y[:n_train], X[n_train:n_train+n_val], y[n_train:n_train+n_val], X[n_train+n_val:], y[n_train+n_val:]

x_tr_n, y_tr_n, x_val_n, y_val_n, x_test_n, y_test_n = split_data(nmbic_scans, nmbic_labels)
x_tr_p, y_tr_p, x_val_p, y_val_p, x_test_p, y_test_p = split_data(mbic_scans, mbic_labels)

x_train = np.concatenate((x_tr_p, x_tr_n), axis=0)
y_train = np.concatenate((y_tr_p, y_tr_n), axis=0)
x_val = np.concatenate((x_val_p, x_val_n), axis=0)
y_val = np.concatenate((y_val_p, y_val_n), axis=0)
x_test = np.concatenate((x_test_p, x_test_n), axis=0)
y_test = np.concatenate((y_test_p, y_test_n), axis=0)

print(f"Number of samples in train, validation, and test are {x_train.shape[0]}, {x_val.shape[0]}, and {x_test.shape[0]}.")

# Address class imbalance via class weights
classes = np.unique(y_train)
cw = compute_class_weight(class_weight='balanced', classes=classes, y=y_train)
class_weight_dict = dict(zip(classes, cw))
print("Computed class weights:", class_weight_dict)

# Data augmentation
def rotate(volume):
    def scipy_rotate(vol):
        angles = [-20, -10, -5, 5, 10, 20]
        angle = np.random.choice(angles)
        vol = ndimage.rotate(vol, angle, reshape=False)
        vol = np.clip(vol, 0, 1)
        return vol
    return tf.numpy_function(scipy_rotate, [volume], tf.float32)

def train_preprocessing(volume, label):
    volume = rotate(volume)
    volume = tf.expand_dims(volume, axis=3)
    return volume, label

def validation_preprocessing(volume, label):
    volume = tf.expand_dims(volume, axis=3)
    return volume, label

# Define data loaders
batch_size = 2
train_dataset = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(len(x_train))
    .map(train_preprocessing, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(batch_size)
    .prefetch(2)
)
validation_dataset = (
    tf.data.Dataset.from_tensor_slices((x_val, y_val))
    .map(validation_preprocessing, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(batch_size)
    .prefetch(2)
)
test_dataset = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .map(validation_preprocessing, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(batch_size)
    .prefetch(2)
)
# place of the model 


# Evaluate on the validation set
metrics = model.evaluate(validation_dataset, verbose=0)
val_loss, val_accuracy, val_precision, val_recall, val_auc = metrics
print(f"Validation Loss: {val_loss:.4f}")
print(f"Validation Accuracy: {val_accuracy:.4f}")
print(f"Validation Precision: {val_precision:.4f}")
print(f"Validation Recall: {val_recall:.4f}")
print(f"Validation AUC: {val_auc:.4f}")

# Evaluate on the test set
metrics = model.evaluate(test_dataset, verbose=0)
test_loss, test_accuracy, test_precision, test_recall, test_auc = metrics
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f}")
print(f"Test Precision: {test_precision:.4f}")
print(f"Test Recall: {test_recall:.4f}")
print(f"Test AUC: {test_auc:.4f}")

# Plot training and validation metrics
plt.figure(figsize=(12, 8))

plt.subplot(2, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy'); plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.legend()

plt.subplot(2, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Loss'); plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend()

plt.subplot(2, 2, 3)
plt.plot(history.history['recall'], label='Training Recall')
plt.plot(history.history['val_recall'], label='Validation Recall')
plt.title('Recall'); plt.xlabel('Epoch'); plt.ylabel('Recall'); plt.legend()

plt.subplot(2, 2, 4)
plt.plot(history.history['auc'], label='Training AUC')
plt.plot(history.history['val_auc'], label='Validation AUC')
plt.title('AUC'); plt.xlabel('Epoch'); plt.ylabel('AUC'); plt.legend()

plt.tight_layout()
plt.savefig('metrics.png')
plt.show()

# Collect true labels and predictions for validation set
y_val_list = np.array([y.numpy() for _, y in validation_dataset.unbatch()]).astype(int)
y_val_pred = model.predict(validation_dataset)
y_val_pred_binary = (y_val_pred > 0.5).astype(int)

# DataFrame of validation predictions
val_df = pd.DataFrame({
    'Sample Index': np.arange(len(y_val_list)),
    'True Label': ['MBIC' if l==1 else 'NMBIC' for l in y_val_list],
    'Pred Label': ['MBIC' if p==1 else 'NMBIC' for p in y_val_pred_binary.flatten()],
    'Pred Prob': y_val_pred.flatten()
})
print(val_df)
val_df.to_csv('validation_predictions.csv', index=False)

# Classification report & confusion matrix for validation set
print("Validation Classification Report:")
print(classification_report(y_val_list, y_val_pred_binary, target_names=["NMBIC","MBIC"]))
cm = confusion_matrix(y_val_list, y_val_pred_binary)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['NMBIC','MBIC'], yticklabels=['NMBIC','MBIC'])
plt.title('Validation Confusion Matrix (threshold=0.5)')
plt.xlabel('Predicted'); plt.ylabel('Actual')
plt.show()

# Threshold optimization for validation set
precisions, recalls, thresholds = precision_recall_curve(y_val_list, y_val_pred)
opt_thr_f1 = thresholds[np.argmax(2*(precisions*recalls)/(precisions+recalls+1e-8))]
print(f"Validation Optimal threshold (max F1): {opt_thr_f1:.2f}")
y_val_pred_opt = (y_val_pred > opt_thr_f1).astype(int)
print("Validation Classification Report (Optimal Threshold):")
print(classification_report(y_val_list, y_val_pred_opt, target_names=["NMBIC","MBIC"]))
cm_opt = confusion_matrix(y_val_list, y_val_pred_opt)
sns.heatmap(cm_opt, annot=True, fmt='d', cmap='Blues',
            xticklabels=['NMBIC','MBIC'], yticklabels=['NMBIC','MBIC'])
plt.title(f'Validation Confusion Matrix (threshold={opt_thr_f1:.2f})')
plt.xlabel('Predicted'); plt.ylabel('Actual')
plt.show()

# Collect true labels and predictions for test set
y_test_list = np.array([y.numpy() for _, y in test_dataset.unbatch()]).astype(int)
y_test_pred = model.predict(test_dataset)
y_test_pred_binary = (y_test_pred > 0.5).astype(int)

# DataFrame of test predictions
test_df = pd.DataFrame({
    'Sample Index': np.arange(len(y_test_list)),
    'True Label': ['MBIC' if l==1 else 'NMBIC' for l in y_test_list],
    'Pred Label': ['MBIC' if p==1 else 'NMBIC' for p in y_test_pred_binary.flatten()],
    'Pred Prob': y_test_pred.flatten()
})
print(test_df)
test_df.to_csv('test_predictions.csv', index=False)

# Classification report & confusion matrix for test set
print("Test Classification Report:")
print(classification_report(y_test_list, y_test_pred_binary, target_names=["NMBIC","MBIC"]))
cm = confusion_matrix(y_test_list, y_test_pred_binary)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['NMBIC','MBIC'], yticklabels=['NMBIC','MBIC'])
plt.title('Test Confusion Matrix (threshold=0.5)')
plt.xlabel('Predicted'); plt.ylabel('Actual')
plt.show()

# Threshold optimization for test set
precisions, recalls, thresholds = precision_recall_curve(y_test_list, y_test_pred)
opt_thr_f1 = thresholds[np.argmax(2*(precisions*recalls)/(precisions+recalls+1e-8))]
print(f"Test Optimal threshold (max F1): {opt_thr_f1:.2f}")
y_test_pred_opt = (y_test_pred > opt_thr_f1).astype(int)
print("Test Classification Report (Optimal Threshold):")
print(classification_report(y_test_list, y_test_pred_opt, target_names=["NMBIC","MBIC"]))
cm_opt = confusion_matrix(y_test_list, y_test_pred_opt)
sns.heatmap(cm_opt, annot=True, fmt='d', cmap='Blues',
            xticklabels=['NMBIC','MBIC'], yticklabels=['NMBIC','MBIC'])
plt.title(f'Test Confusion Matrix (threshold={opt_thr_f1:.2f})')
plt.xlabel('Predicted'); plt.ylabel('Actual')
plt.show()

# Save the model
model.save('cnn_bladder_massive.keras')
print("Model saved as 'cnn_bladder_massive.keras'")