In [None]:
import os
import numpy as np
import tensorflow as tf
import keras
from keras import layers
import nibabel as nib
from scipy import ndimage
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_curve
from sklearn.utils.class_weight import compute_class_weight

# Define paths to your dataset
data_dir = "/kaggle/input/all-zendo-dataset/DATA"
t2wi_dir = os.path.join(data_dir, "T2WI")
csv_path = "/kaggle/input/all-zendo-dataset/DATA/all_centers_combined.csv"

# Load the CSV file
df = pd.read_csv(csv_path)
df['image_name'] = df['image_name'].str.replace('.nii.gz', '.nii')

# Helper functions for preprocessing
def read_nifti_file(filepath):
    scan = nib.load(filepath).get_fdata()
    return scan

def normalize(volume):
    min_hu, max_hu = -1000, 400
    volume = np.clip(volume, min_hu, max_hu)
    volume = (volume - min_hu) / (max_hu - min_hu)
    return volume.astype("float32")

def resize_volume(img):
    desired_depth, desired_width, desired_height = 64, 128, 128
    current_depth, current_width, current_height = img.shape[-1], img.shape[0], img.shape[1]
    depth_factor = current_depth / desired_depth
    width_factor = current_width / desired_width
    height_factor = current_height / desired_height
    img = ndimage.rotate(img, 90, reshape=False)
    img = ndimage.zoom(img,
                       (1/width_factor, 1/height_factor, 1/depth_factor),
                       order=1)
    return img

def process_scan(path):
    volume = read_nifti_file(path)
    volume = normalize(volume)
    return resize_volume(volume)

# Load and pair T2WI images with labels
nmbic_scan_paths = []
mbic_scan_paths = []
for subfolder in os.listdir(t2wi_dir):
    folder = os.path.join(t2wi_dir, subfolder)
    if not os.path.isdir(folder):
        continue
    files = [f for f in os.listdir(folder) if f.endswith(".nii")]
    if not files:
        continue
    path = os.path.join(folder, files[0])
    img_id = subfolder.split(".")[0]
    row = df[df['image_name'].str.replace('.nii', '') == img_id]
    if row.empty:
        print(f"No label found for T2WI ID {img_id}")
        continue
    label = int(row['label'].iloc[0])
    if label == 0:
        nmbic_scan_paths.append(path)
    else:
        mbic_scan_paths.append(path)

print(f"NMBIC scans: {len(nmbic_scan_paths)}")
print(f"MBIC scans: {len(mbic_scan_paths)}")

# Process scans
nmbic_scans = np.array([process_scan(p) for p in nmbic_scan_paths])
mbic_scans = np.array([process_scan(p) for p in mbic_scan_paths])

# Assign labels
nmbic_labels = np.zeros(len(nmbic_scans), dtype=int)
mbic_labels = np.ones(len(mbic_scans), dtype=int)

# Split data into training, validation, and test (60-20-20)
def split_data(X, y, train_ratio=0.6, val_ratio=0.2):
    n_total = len(X)
    n_train = int(train_ratio * n_total)
    n_val = int(val_ratio * n_total)
    return X[:n_train], y[:n_train], X[n_train:n_train+n_val], y[n_train:n_train+n_val], X[n_train+n_val:], y[n_train+n_val:]

x_tr_n, y_tr_n, x_val_n, y_val_n, x_test_n, y_test_n = split_data(nmbic_scans, nmbic_labels)
x_tr_p, y_tr_p, x_val_p, y_val_p, x_test_p, y_test_p = split_data(mbic_scans, mbic_labels)

x_train = np.concatenate((x_tr_p, x_tr_n), axis=0)
y_train = np.concatenate((y_tr_p, y_tr_n), axis=0)
x_val = np.concatenate((x_val_p, x_val_n), axis=0)
y_val = np.concatenate((y_val_p, y_val_n), axis=0)
x_test = np.concatenate((x_test_p, x_test_n), axis=0)
y_test = np.concatenate((y_test_p, y_test_n), axis=0)

print(f"Number of samples in train, validation, and test are {x_train.shape[0]}, {x_val.shape[0]}, and {x_test.shape[0]}.")

# Address class imbalance via class weights
classes = np.unique(y_train)
cw = compute_class_weight(class_weight='balanced', classes=classes, y=y_train)
class_weight_dict = dict(zip(classes, cw))
print("Computed class weights:", class_weight_dict)

# Data augmentation
def rotate(volume):
    def scipy_rotate(vol):
        angles = [-20, -10, -5, 5, 10, 20]
        angle = np.random.choice(angles)
        vol = ndimage.rotate(vol, angle, reshape=False)
        vol = np.clip(vol, 0, 1)
        return vol
    return tf.numpy_function(scipy_rotate, [volume], tf.float32)

def train_preprocessing(volume, label):
    volume = rotate(volume)
    volume = tf.expand_dims(volume, axis=3)
    return volume, label

def validation_preprocessing(volume, label):
    volume = tf.expand_dims(volume, axis=3)
    return volume, label

# Define data loaders
batch_size = 4

# Evaluate on the validation set
metrics = model.evaluate(validation_dataset, verbose=0)
val_loss, val_accuracy, val_precision, val_recall, val_auc = metrics
print(f"Validation Loss: {val_loss:.4f}")
print(f"Validation Accuracy: {val_accuracy:.4f}")
print(f"Validation Precision: {val_precision:.4f}")
print(f"Validation Recall: {val_recall:.4f}")
print(f"Validation AUC: {val_auc:.4f}")

# Evaluate on the test set
metrics = model.evaluate(test_dataset, verbose=0)
test_loss, test_accuracy, test_precision, test_recall, test_auc = metrics
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f}")
print(f"Test Precision: {test_precision:.4f}")
print(f"Test Recall: {test_recall:.4f}")
print(f"Test AUC: {test_auc:.4f}")

# Plot training and validation metrics
plt.figure(figsize=(12, 8))

plt.subplot(2, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy'); plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.legend()

plt.subplot(2, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Loss'); plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend()

plt.subplot(2, 2, 3)
plt.plot(history.history['recall'], label='Training Recall')
plt.plot(history.history['val_recall'], label='Validation Recall')
plt.title('Recall'); plt.xlabel('Epoch'); plt.ylabel('Recall'); plt.legend()

plt.subplot(2, 2, 4)
plt.plot(history.history['auc'], label='Training AUC')
plt.plot(history.history['val_auc'], label='Validation AUC')
plt.title('AUC'); plt.xlabel('Epoch'); plt.ylabel('AUC'); plt.legend()

plt.tight_layout()
plt.savefig('metrics.png')
plt.show()

# Collect true labels and predictions for validation set
y_val_list = np.array([y.numpy() for _, y in validation_dataset.unbatch()]).astype(int)
y_val_pred = model.predict(validation_dataset)
y_val_pred_binary = (y_val_pred > 0.5).astype(int)

# DataFrame of validation predictions
val_df = pd.DataFrame({
    'Sample Index': np.arange(len(y_val_list)),
    'True Label': ['MBIC' if l==1 else 'NMBIC' for l in y_val_list],
    'Pred Label': ['MBIC' if p==1 else 'NMBIC' for p in y_val_pred_binary.flatten()],
    'Pred Prob': y_val_pred.flatten()
})
print(val_df)
val_df.to_csv('validation_predictions.csv', index=False)

# Classification report & confusion matrix for validation set
print("Validation Classification Report:")
print(classification_report(y_val_list, y_val_pred_binary, target_names=["NMBIC","MBIC"]))
cm = confusion_matrix(y_val_list, y_val_pred_binary)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['NMBIC','MBIC'], yticklabels=['NMBIC','MBIC'])
plt.title('Validation Confusion Matrix (threshold=0.5)')
plt.xlabel('Predicted'); plt.ylabel('Actual')
plt.show()

# Threshold optimization for validation set
precisions, recalls, thresholds = precision_recall_curve(y_val_list, y_val_pred)
opt_thr_f1 = thresholds[np.argmax(2*(precisions*recalls)/(precisions+recalls+1e-8))]
print(f"Validation Optimal threshold (max F1): {opt_thr_f1:.2f}")
y_val_pred_opt = (y_val_pred > opt_thr_f1).astype(int)
print("Validation Classification Report (Optimal Threshold):")
print(classification_report(y_val_list, y_val_pred_opt, target_names=["NMBIC","MBIC"]))
cm_opt = confusion_matrix(y_val_list, y_val_pred_opt)
sns.heatmap(cm_opt, annot=True, fmt='d', cmap='Blues',
            xticklabels=['NMBIC','MBIC'], yticklabels=['NMBIC','MBIC'])
plt.title(f'Validation Confusion Matrix (threshold={opt_thr_f1:.2f})')
plt.xlabel('Predicted'); plt.ylabel('Actual')
plt.show()

# Collect true labels and predictions for test set
y_test_list = np.array([y.numpy() for _, y in test_dataset.unbatch()]).astype(int)
y_test_pred = model.predict(test_dataset)
y_test_pred_binary = (y_test_pred > 0.5).astype(int)

# DataFrame of test predictions
test_df = pd.DataFrame({
    'Sample Index': np.arange(len(y_test_list)),
    'True Label': ['MBIC' if l==1 else 'NMBIC' for l in y_test_list],
    'Pred Label': ['MBIC' if p==1 else 'NMBIC' for p in y_test_pred_binary.flatten()],
    'Pred Prob': y_test_pred.flatten()
})
print(test_df)
test_df.to_csv('test_predictions.csv', index=False)

# Classification report & confusion matrix for test set
print("Test Classification Report:")
print(classification_report(y_test_list, y_test_pred_binary, target_names=["NMBIC","MBIC"]))
cm = confusion_matrix(y_test_list, y_test_pred_binary)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['NMBIC','MBIC'], yticklabels=['NMBIC','MBIC'])
plt.title('Test Confusion Matrix (threshold=0.5)')
plt.xlabel('Predicted'); plt.ylabel('Actual')
plt.show()

# Threshold optimization for test set
precisions, recalls, thresholds = precision_recall_curve(y_test_list, y_test_pred)
opt_thr_f1 = thresholds[np.argmax(2*(precisions*recalls)/(precisions+recalls+1e-8))]
print(f"Test Optimal threshold (max F1): {opt_thr_f1:.2f}")
y_test_pred_opt = (y_test_pred > opt_thr_f1).astype(int)
print("Test Classification Report (Optimal Threshold):")
print(classification_report(y_test_list, y_test_pred_opt, target_names=["NMBIC","MBIC"]))
cm_opt = confusion_matrix(y_test_list, y_test_pred_opt)
sns.heatmap(cm_opt, annot=True, fmt='d', cmap='Blues',
            xticklabels=['NMBIC','MBIC'], yticklabels=['NMBIC','MBIC'])
plt.title(f'Test Confusion Matrix (threshold={opt_thr_f1:.2f})')
plt.xlabel('Predicted'); plt.ylabel('Actual')
plt.show()

# Save the model
model.save('cnn_bladder_massive.keras')
print("Model saved as 'cnn_bladder_massive.keras'")
train_dataset = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(len(x_train))
    .map(train_preprocessing, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(batch_size)
    .prefetch(2)
)
validation_dataset = (
    tf.data.Dataset.from_tensor_slices((x_val, y_val))
    .map(validation_preprocessing, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(batch_size)
    .prefetch(2)
)
test_dataset = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .map(validation_preprocessing, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(batch_size)
    .prefetch(2)
)

# Define 3D CNN model
def get_model(width=128, height=128, depth=64):
    inputs = keras.Input((width, height, depth, 1))
    x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(x)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(x)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv3D(filters=256, kernel_size=3, activation="relu")(x)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)
    x = layers.GlobalAveragePooling3D()(x)
    x = layers.Dense(units=512, activation="relu")(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(units=1, activation="sigmoid")(x)
    return keras.Model(inputs, outputs, name="3dcnn")

# Build and compile model
model = get_model(width=128, height=128, depth=64)
initial_learning_rate = 0.0001
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
)
model.compile(
    loss="binary_crossentropy",
    optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
    metrics=[
        "accuracy",
        keras.metrics.Precision(name="precision"),
        keras.metrics.Recall(name="recall"),
        keras.metrics.AUC(name="auc")
    ],
    run_eagerly=True,
)
model.summary()

# Define callbacks
checkpoint_cb = keras.callbacks.ModelCheckpoint(
    "3d_bladder_classification.keras", save_best_only=True
)
early_stopping_cb = keras.callbacks.EarlyStopping(
    monitor="val_recall",
    mode="max",
    patience=15,
    restore_best_weights=True,
    verbose=1
)

# Train the model with class weights
epochs = 100
history = model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=epochs,
    shuffle=True,
    class_weight=class_weight_dict,
    verbose=2,
    callbacks=[checkpoint_cb, early_stopping_cb],
)


2025-05-04 22:06:11.668529: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746396371.853881      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746396371.910749      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


NMBIC scans: 140
MBIC scans: 80
Number of samples in train, validation, and test are 132, 44, and 44.
Computed class weights: {0: 0.7857142857142857, 1: 1.375}


I0000 00:00:1746396578.169612      31 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15513 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0


Epoch 1/100


I0000 00:00:1746396582.921476      31 cuda_dnn.cc:529] Loaded cuDNN version 90300


33/33 - 22s - 659ms/step - accuracy: 0.5076 - auc: 0.4781 - loss: 0.8380 - precision: 0.3559 - recall: 0.4375 - val_accuracy: 0.3636 - val_auc: 0.4330 - val_loss: 0.6954 - val_precision: 0.3636 - val_recall: 1.0000
Epoch 2/100
33/33 - 15s - 461ms/step - accuracy: 0.5530 - auc: 0.5370 - loss: 0.7684 - precision: 0.4154 - recall: 0.5625 - val_accuracy: 0.3636 - val_auc: 0.4375 - val_loss: 0.7190 - val_precision: 0.3636 - val_recall: 1.0000
Epoch 3/100
33/33 - 15s - 453ms/step - accuracy: 0.5303 - auc: 0.5707 - loss: 0.7037 - precision: 0.4205 - recall: 0.7708 - val_accuracy: 0.6364 - val_auc: 0.5000 - val_loss: 0.6872 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00
Epoch 4/100
33/33 - 15s - 458ms/step - accuracy: 0.5000 - auc: 0.5355 - loss: 0.7188 - precision: 0.3750 - recall: 0.5625 - val_accuracy: 0.3636 - val_auc: 0.4464 - val_loss: 0.7128 - val_precision: 0.3636 - val_recall: 1.0000
Epoch 5/100
33/33 - 15s - 453ms/step - accuracy: 0.5379 - auc: 0.5833 - loss: 0.6822 - precision

In [12]:
import os
import numpy as np
import pandas as pd
import nibabel as nib
import cv2
from scipy.ndimage import zoom
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import imgaug.augmenters as iaa
from tqdm import tqdm
import uuid

# Define paths to your dataset
data_dir = "/kaggle/input/all-zendo-dataset/DATA"
t2wi_dir = os.path.join(data_dir, "T2WI")
csv_path = "/kaggle/input/all-zendo-dataset/DATA/all_centers_combined.csv"

# Load the CSV file
df = pd.read_csv(csv_path)
df['image_name'] = df['image_name'].str.replace('.nii.gz', '.nii')
df['mask_name'] = df['mask_name'].str.replace('.nii.gz', '.nii')

# NEW: Print DataFrame info to verify contents
print(f"Loaded DataFrame with {len(df)} rows")
print("Sample image names:", df['image_name'].head().tolist())
print("Columns in DataFrame:", df.columns.tolist())

# Helper functions for preprocessing (inspired by preprocessing_for_classification.py)
def read_nifti_file(filepath):
    try:
        scan = nib.load(filepath).get_fdata()
        return scan
    except Exception as e:
        print(f"Error loading {filepath}: {e}")
        return None

def linear_normalizing(data):
    """Normalize image intensities to [0, 1]."""
    data_min = np.min(data)
    data_max = np.max(data)
    return (data - data_min) / (data_max - data_min + 1e-6)

def centre_window_cropping(data, reshapesize=(128, 128)):
    """Crop or pad to target size around the center."""
    or_size = data.shape
    target_size = (reshapesize[0], reshapesize[1], or_size[2])
    
    # Pad if original size is smaller than target
    if target_size[0] > or_size[0]:
        pad_size = (target_size[0] - or_size[0]) // 2
        data = np.pad(data, ((pad_size, pad_size), (0, 0), (0, 0)), mode='constant')
    if target_size[1] > or_size[1]:
        pad_size = (target_size[1] - or_size[1]) // 2
        data = np.pad(data, ((0, 0), (pad_size, pad_size), (0, 0)), mode='constant')
    
    # Center crop
    cur_size = data.shape
    centre_x = cur_size[0] // 2
    centre_y = cur_size[1] // 2
    dx = target_size[0] // 2
    dy = target_size[1] // 2
    data = data[centre_x - dx:centre_x + dx, centre_y - dy:centre_y + dy, :]
    
    # Resize each slice
    data_resize = np.zeros((reshapesize[0], reshapesize[1], cur_size[2]))
    for kk in range(cur_size[2]):
        data_resize[:, :, kk] = cv2.resize(data[:, :, kk], reshapesize, interpolation=cv2.INTER_NEAREST)
    
    return data_resize

def block_dividing(data, deep=1, step=1):
    """Divide 3D volume into 2D slices."""
    data_group = []
    o_data_deep = data.shape[2]
    
    # NEW: Log the depth of the volume
    print(f"Block dividing: Volume depth = {o_data_deep}")
    
    if o_data_deep <= deep:
        tmp_data = np.zeros((data.shape[0], data.shape[1], deep))
        tmp_data[:, :, :o_data_deep] = data
        data_group.append(tmp_data[:, :, 0])
    else:
        blocks = (o_data_deep - deep) // step + 2
        if (o_data_deep - deep) % step == 0:
            blocks -= 1
        for i in range(blocks - 1):
            tmp_data = data[:, :, i * step:i * step + deep]
            data_group.append(tmp_data[:, :, 0])
        tmp_data = data[:, :, o_data_deep - deep:o_data_deep]
        data_group.append(tmp_data[:, :, 0])
    
    # NEW: Log the number of slices generated
    print(f"Generated {len(data_group)} slices")
    return data_group

def preprocess_image(image_path, mask_path, reshapesize=(128, 128)):
    """Preprocess a single image and mask pair."""
    # NEW: Check if files exist
    if not os.path.exists(image_path):
        print(f"Image not found: {image_path}")
        return [], []
    if not os.path.exists(mask_path):
        print(f"Mask not found: {mask_path}")
        return [], []
    
    # Load NIfTI files
    image = read_nifti_file(image_path)
    mask = read_nifti_file(mask_path)
    
    if image is None or mask is None:
        print(f"Skipping {image_path} due to loading error")
        return [], []
    
    # NEW: Check if volumes are empty
    if image.size == 0 or mask.size == 0:
        print(f"Empty volume for {image_path}")
        return [], []
    
    # Binarize mask
    mask[mask < 0.5] = 0
    mask[mask >= 0.5] = 1
    
    # Crop and resize
    image = centre_window_cropping(image, reshapesize)
    mask = centre_window_cropping(mask, reshapesize)
    
    # Normalize image
    image = linear_normalizing(image)
    
    # Divide into 2D slices
    image_slices = block_dividing(image)
    mask_slices = block_dividing(mask)
    
    # NEW: Verify that slices are generated
    if len(image_slices) == 0 or len(mask_slices) == 0:
        print(f"No slices generated for {image_path}")
        return [], []
    
    return image_slices, mask_slices

# Custom Dataset (inspired by Data_loader.py)
class SegmentationDataset(Dataset):
    def __init__(self, df, t2wi_dir, augmentation=False, reshapesize=(128, 128)):
        self.df = df
        self.t2wi_dir = t2wi_dir
        self.augmentation = augmentation
        self.reshapesize = reshapesize
        self.image_slices = []
        self.mask_slices = []
        
        # NEW: Check if DataFrame is empty
        if len(df) == 0:
            print("Warning: Input DataFrame is empty")
            return
        
        # Preprocess all images and masks
        valid_pairs = 0
        for idx, row in df.iterrows():
            image_path = os.path.join(t2wi_dir, row['image_name'])
            print("right HERE",image_path)
            mask_path = os.path.join(t2wi_dir, row['image_name'].replace('.nii', '_mask.nii'))
            print(f"Processing {image_path}")  # NEW: Log processing
            img_slices, msk_slices = preprocess_image(image_path, mask_path, reshapesize)
            if img_slices and msk_slices:
                self.image_slices.extend(img_slices)
                self.mask_slices.extend(msk_slices)
                valid_pairs += 1
            else:
                print(f"Skipped {image_path} due to invalid slices")
        
        # NEW: Log the number of valid pairs and slices
        print(f"Found {valid_pairs} valid image-mask pairs")
        print(f"Total slices: {len(self.image_slices)}")
    
    def __len__(self):
        # NEW: Raise error if dataset is empty
        if len(self.image_slices) == 0:
            raise ValueError("Dataset is empty. No valid image-mask pairs found.")
        return len(self.image_slices)
    
    def __getitem__(self, idx):
        image = self.image_slices[idx]
        mask = self.mask_slices[idx]
        
        # Convert to 3-channel image (RGB-like)
        image = np.stack([image, image, image], axis=0)  # Shape: (3, H, W)
        mask = mask[np.newaxis, :, :]  # Shape: (1, H, W)
        
        if self.augmentation:
            # Prepare for augmentation
            image = np.transpose(image, (1, 2, 0))  # Shape: (H, W, 3)
            mask = mask[0, :, :]  # Shape: (H, W)
            # Apply augmentation
            aug_det = self.seq.to_deterministic()
            image = aug_det.augment_image(image)
            mask = aug_det.augment_image(mask)
            # Convert back
            image = np.transpose(image, (2, 0, 1))  # Shape: (3, H, W)
            mask = mask[np.newaxis, :, :]  # Shape: (1, H, W)
        
        return torch.FloatTensor(image), torch.FloatTensor(mask)

# U-Net Model
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()
        
        def conv_block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.ReLU(inplace=True)
            )
        
        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)
        
        self.pool = nn.MaxPool2d(2)
        
        self.bottleneck = conv_block(512, 1024)
        
        self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = conv_block(1024, 512)
        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = conv_block(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = conv_block(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = conv_block(128, 64)
        
        self.final = nn.Conv2d(64, out_channels, 1)
    
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        
        # Bottleneck
        b = self.bottleneck(self.pool(e4))
        
        # Decoder
        d4 = self.upconv4(b)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)
        
        d3 = self.upconv3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.upconv2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.upconv1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        
        return torch.sigmoid(self.final(d1))

# Metrics
def dice_score(pred, target, smooth=1e-6):
    pred = (pred > 0.5).float()
    intersection = (pred * target).sum()
    return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

def iou_score(pred, target, smooth=1e-6):
    pred = (pred > 0.5).float()
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    return (intersection + smooth) / (union + smooth)

# Training and Evaluation
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    train_losses = []
    val_dices = []
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, masks = images.to(device), masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * images.size(0)
        
        train_loss /= len(train_loader.dataset)
        train_losses.append(train_loss)
        
        # Validation
        model.eval()
        val_dice = 0.0
        val_iou = 0.0
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                val_dice += dice_score(outputs, masks).item() * images.size(0)
                val_iou += iou_score(outputs, masks).item() * images.size(0)
        
        val_dice /= len(val_loader.dataset)
        val_iou /= len(val_loader.dataset)
        val_dices.append(val_dice)
        
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {train_loss:.4f}, Val Dice: {val_dice:.4f}, Val IoU: {val_iou:.4f}")
    
    return train_losses, val_dices

# Visualize predictions
def visualize_predictions(model, loader, device, num_samples=3):
    model.eval()
    with torch.no_grad():
        for images, masks in loader:
            images, masks = images.to(device), masks.to(device)
            preds = model(images)
            preds = (preds > 0.5).float()
            
            for i in range(min(num_samples, images.size(0))):
                plt.figure(figsize=(12, 4))
                
                plt.subplot(1, 3, 1)
                plt.title("Input Image")
                plt.imshow(images[i, 0, :, :].cpu().numpy(), cmap='gray')
                plt.axis('off')
                
                plt.subplot(1, 3, 2)
                plt.title("Ground Truth")
                plt.imshow(masks[i, 0, :, :].cpu().numpy(), cmap='gray')
                plt.axis('off')
                
                plt.subplot(1, 3, 3)
                plt.title("Prediction")
                plt.imshow(preds[i, 0, :, :].cpu().numpy(), cmap='gray')
                plt.axis('off')
                
                plt.show()
            
            break

# Main execution
if __name__ == "__main__":
    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Split dataset
    train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
    
    # NEW: Print DataFrame splits
    print(f"Training set: {len(train_df)} samples")
    print(f"Validation set: {len(val_df)} samples")
    
    # Create datasets and loaders
    train_dataset = SegmentationDataset(train_df, t2wi_dir, augmentation=True)
    val_dataset = SegmentationDataset(val_df, t2wi_dir, augmentation=False)
    
    # NEW: Check dataset sizes before creating loaders
    print(f"Train dataset size: {len(train_dataset)} slices")
    print(f"Validation dataset size: {len(val_dataset)} slices")
    
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
    
    # Initialize model, loss, and optimizer
    model = UNet(in_channels=3, out_channels=1).to(device)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Train model
    num_epochs = 20
    train_losses, val_dices = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device)
    
    # Save model
    torch.save(model.state_dict(), "/kaggle/working/unet_segmentation.pth")
    
    # Visualize predictions
    visualize_predictions(model, val_loader, device)
    
    # Plot training metrics
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label="Training Loss")
    plt.title("Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(val_dices, label="Validation Dice")
    plt.title("Validation Dice Score")
    plt.xlabel("Epoch")
    plt.ylabel("Dice Score")
    plt.legend()
    
    plt.tight_layout()
    plt.savefig("/kaggle/working/training_metrics.png")
    plt.show()

Loaded DataFrame with 274 rows
Sample image names: ['c1_001.nii', 'c1_002.nii', 'c1_003.nii', 'c1_004.nii', 'c1_005.nii']
Columns in DataFrame: ['label', 'image_name', 'mask_name', 'Age (years)', 'Gender', 'Pathological T stage', 'Pathological grade', "Type of patient's tumor number"]
Training set: 219 samples
Validation set: 55 samples
right HERE /kaggle/input/all-zendo-dataset/DATA/T2WI/c1_011.nii
Processing /kaggle/input/all-zendo-dataset/DATA/T2WI/c1_011.nii
Mask not found: /kaggle/input/all-zendo-dataset/DATA/T2WI/c1_011_mask.nii
Skipped /kaggle/input/all-zendo-dataset/DATA/T2WI/c1_011.nii due to invalid slices
right HERE /kaggle/input/all-zendo-dataset/DATA/T2WI/c4_13.nii
Processing /kaggle/input/all-zendo-dataset/DATA/T2WI/c4_13.nii
Mask not found: /kaggle/input/all-zendo-dataset/DATA/T2WI/c4_13_mask.nii
Skipped /kaggle/input/all-zendo-dataset/DATA/T2WI/c4_13.nii due to invalid slices
right HERE /kaggle/input/all-zendo-dataset/DATA/T2WI/c1_093.nii
Processing /kaggle/input/all-ze

ValueError: Dataset is empty. No valid image-mask pairs found.

In [6]:
# Main execution
if __name__ == "__main__":
    # Device configuration
    device = torch.device("cuda")
    
    # Split dataset
    train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
    print(train_df)
    # Create datasets and loaders
    train_dataset = SegmentationDataset(train_df, t2wi_dir, augmentation=True)
    val_dataset = SegmentationDataset(val_df, t2wi_dir, augmentation=False)
    
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
    
    # Initialize model, loss, and optimizer
    model = UNet(in_channels=3, out_channels=1).to(device)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Train model
    num_epochs = 20
    train_losses, val_dices = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device)
    
    # Save model
    torch.save(model.state_dict(), "/kaggle/working/unet_segmentation.pth")
    
    # Visualize predictions
    visualize_predictions(model, val_loader, device)
    
    # Plot training metrics
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label="Training Loss")
    plt.title("Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(val_dices, label="Validation Dice")
    plt.title("Validation Dice Score")
    plt.xlabel("Epoch")
    plt.ylabel("Dice Score")
    plt.legend()
    
    plt.tight_layout()
    plt.savefig("/kaggle/working/training_metrics.png")
    plt.show()

     label  image_name        mask_name  Age (years)  Gender  \
10       0  c1_011.nii    c1_011.nii.gz           64    Male   
256      1   c4_13.nii     c4_13.nii.gz           80    Male   
120      0  c1_093.nii  c1_093_1.nii.gz           73  Female   
33       0  c1_032.nii    c1_032.nii.gz           80    Male   
173      0   c2_13.nii   c2_13_2.nii.gz           73    Male   
..     ...         ...              ...          ...     ...   
188      0   c2_28.nii     c2_28.nii.gz           63    Male   
71       1  c1_060.nii    c1_060.nii.gz           74    Male   
106      0  c1_081.nii    c1_081.nii.gz           71    Male   
270      1   c4_27.nii     c4_27.nii.gz           82    Male   
102      0  c1_079.nii  c1_079_1.nii.gz           70    Male   

    Pathological T stage Pathological grade Type of patient's tumor number  
10                    Ta                Low                         Single  
256                   T2                Low                         Single  


ValueError: num_samples should be a positive integer value, but got num_samples=0

<__main__.SegmentationDataset object at 0x7cb0d07cb6d0>


ValueError: num_samples should be a positive integer value, but got num_samples=0