# CNN-based Brain Tumour Segmentation Network
## Import packages
Please make sure you have all the required packages installed. 

In [None]:
import os
import tensorflow as tf
import matplotlib.pyplot as plt
from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Conv2DTranspose, concatenate, Dropout, BatchNormalization
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
from keras.regularizers import l2
from keras.preprocessing import image
from tensorflow.keras.utils import Sequence
from sklearn.model_selection import train_test_split
import natsort
from tensorflow.keras.utils import load_img, img_to_array
import imageio
import numpy as np
import nibabel as nib
import math
import shutil
import re


## Visualise MRI Volume Slices and Segmentation Maps
Each MRI image contains information about a three-dimensional (3D) volume of space. An MRI image is composed of a number of voxels, which is like pixels in 2D images. Here try to visualise the axial plane (usually has a higher resolution) of some of the volumes and the corresponding segmentation maps.

In [None]:
def mri_3d_to_2d_z(input_main_folder, output_main_folder):
    # 检查输出主文件夹是否存在，如果不存在则创建
    if not os.path.exists(output_main_folder):
        os.makedirs(output_main_folder)

    # 遍历主输入文件夹中的所有子文件夹
    for subfolder in os.listdir(input_main_folder):
        subfolder_path = os.path.join(input_main_folder, subfolder)
        if os.path.isdir(subfolder_path):
            # 为当前子文件夹创建对应的输出子文件夹
            output_subfolder = os.path.join(output_main_folder, subfolder)
            if not os.path.exists(output_subfolder):
                os.makedirs(output_subfolder)

            # 遍历当前子文件夹中的所有文件
            for filename in os.listdir(subfolder_path):
                if filename.endswith('.nii') or filename.endswith('.nii.gz'):
                    # 构建完整的文件路径
                    input_file_path = os.path.join(subfolder_path, filename)
                    # 读取3D MRI图像
                    img = nib.load(input_file_path)
                    # 获取图像数据
                    img_data = img.get_fdata()

                    # 提取文件名（去除扩展名）
                    base_filename = os.path.splitext(os.path.splitext(filename)[0])[0]

                    # 创建一个子文件夹来保存该3D图像的所有2D切片
                    patient_folder = os.path.join(output_subfolder, base_filename)
                    if not os.path.exists(patient_folder):
                        os.makedirs(patient_folder)

                    # 获取图像的尺寸
                    _, _, z = img_data.shape

                    # 提取轴向切片（Z方向）
                    for slice_index in range(z):
                        slice_data = img_data[:, :, slice_index]
                        # 检查切片数据的最大值是否为零
                        if slice_data.max() == 0:
                            slice_data = np.zeros_like(slice_data, dtype=np.uint8)
                        else:
                            # 将数据转换为8位无符号整数类型
                            slice_data = (slice_data / slice_data.max() * 255).astype(np.uint8)
                        # 构建切片图像的文件名
                        slice_filename = os.path.join(patient_folder, f'{base_filename}_axial_{slice_index}.png')
                        # 保存切片图像
                        imageio.imwrite(slice_filename, slice_data)

#修改成你的路径
input_main_folder = "C://Users//zhangjw//Documents//WeChat Files//wxid_kmwrbkl7akml22//FileStorage//File//2025-02//tech_winter_school_2025//dataset_segmentation//train"
output_main_folder = "C://Users//zhangjw//Documents//WeChat Files//wxid_kmwrbkl7akml22//FileStorage//File//2025-02//tech_winter_school_2025//dataset_segmentation//trainA"
mri_3d_to_2d_z(input_main_folder, output_main_folder)


# 定义原始数据路径
source_dir = r"C:\Users\zhangjw\Desktop\testA"
photo_dir = r"C:\Users\zhangjw\Desktop\photo1"
mask_dir = r"C:\Users\zhangjw\Desktop\mask1"

# 创建目标文件夹
os.makedirs(photo_dir, exist_ok=True)
os.makedirs(mask_dir, exist_ok=True)

# 计数器
photo_index = 1
mask_index = 1

# 提取文件名中的最后一个数字（用于排序）
def extract_number(filename):
    numbers = re.findall(r'\d+', filename)  # 找到所有数字
    return int(numbers[-1]) if numbers else float('inf')  # 取最后一个数字

# 遍历 trainA 目录中的所有编号文件夹（001-210）
for folder_num in range(210, 252):  # 001-210
    folder_name = f"{folder_num:03d}"  # 生成 3 位数格式的文件夹名
    fla_path = os.path.join(source_dir, folder_name, f"{folder_name}_fla")
    seg_path = os.path.join(source_dir, folder_name, f"{folder_name}_seg")

    # 处理 fla 目录
    if os.path.exists(fla_path):
        files = sorted(os.listdir(fla_path), key=extract_number)  # 按最后的数字排序
        for file in files:
            if file.endswith(".png"):
                src = os.path.join(fla_path, file)
                dst = os.path.join(photo_dir, f"{photo_index}.png")  # 按顺序重命名
                shutil.move(src, dst)
                photo_index += 1

    # 处理 seg 目录
    if os.path.exists(seg_path):
        files = sorted(os.listdir(seg_path), key=extract_number)  # 按最后的数字排序
        for file in files:
            if file.endswith(".png"):
                src = os.path.join(seg_path, file)
                dst = os.path.join(mask_dir, f"{mask_index}.png")  # 按顺序重命名
                shutil.move(src, dst)
                mask_index += 1

print(f"所有 fla 图片已移动到 {photo_dir}，共 {photo_index - 1} 张，文件名按 1.png, 2.png ... 命名。")
print(f"所有 seg 图片已移动到 {mask_dir}，共 {mask_index - 1} 张，文件名按 1.png, 2.png ... 命名。")



## Data preprocessing (Optional)

Images in the original dataset are usually in different sizes, so sometimes we need to resize and normalise (z-score is commonly used in preprocessing the MRI images) them to fit the CNN model. Depending on the images you choose to use for training your model, some other preprocessing methods. If preprocessing methods like cropping is applied, remember to convert the segmentation result back to its original size. 

## Train-time data augmentation
Generalizability is crucial to a deep learning model and it refers to the performance difference of a model when evaluated on the seen data (training data) versus the unseen data (testing data). Improving the generalizability of these models has always been a difficult challenge. 

**Data Augmentation** is an effective way of improving the generalizability, because the augmented data will represent a more comprehensive set of possible data samples and minimizing the distance between the training and validation/testing sets.

There are many data augmentation methods you can choose in this projects including rotation, shifting, flipping, etc.

You are encouraged to try different augmentation method to get the best segmentation result.


## Get the data generator ready

In [None]:
save_dir = r'C:\Users\21508\PycharmProjects\pythonProject8\111\predict'
os.makedirs(save_dir, exist_ok=True)

# 读取数据路径
path_imgs = r"C:\Users\21508\PycharmProjects\pythonProject8\111\photo"
path_masks = r"C:\Users\21508\PycharmProjects\pythonProject8\111\mask"

# 获取所有图片
imagesList = natsort.natsorted(os.listdir(path_imgs))
maskList = natsort.natsorted(os.listdir(path_masks))

# 确保 photo 和 mask 目录中的文件数量一致
assert len(imagesList) == len(maskList), "图片和掩码数量不匹配，请检查数据集！"

# 设定超参数
img_row, img_col, img_chan = 240, 240, 1
epochnum = 30 # 提高 epoch 以适应 EarlyStopping
input_size = (img_row, img_col, img_chan)
batch_size = 32

# 数据集划分
train_img_paths, test_img_paths, train_mask_paths, test_mask_paths = train_test_split(imagesList, maskList, test_size=0.2, random_state=42)
val_img_paths, test_img_paths, val_mask_paths, test_mask_paths = train_test_split(test_img_paths, test_mask_paths, test_size=0.5, random_state=42)

print(f"训练集大小: {len(train_img_paths)}, 验证集大小: {len(val_img_paths)}, 测试集大小: {len(test_img_paths)}")

# 数据增强
import tensorflow.keras.preprocessing.image as img_prep

def random_transform(img, mask):
    datagen = img_prep.ImageDataGenerator(rotation_range=20, horizontal_flip=True)
    params = datagen.get_random_transform(img.shape)
    img = datagen.apply_transform(img, params)
    mask = datagen.apply_transform(mask, params)
    return img, mask

# 数据生成器
class DataGenerator(Sequence):
    def __init__(self, img_paths, mask_paths, img_dir, mask_dir, batch_size, img_size):
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.batch_size = batch_size
        self.img_size = img_size
        self.indexes = np.arange(len(self.img_paths))

    def __len__(self):
        return int(np.floor(len(self.img_paths) / self.batch_size))

    def __getitem__(self, index):
        batch_indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        batch_imgs = np.zeros((self.batch_size, *self.img_size, 1), dtype=np.float32)
        batch_masks = np.zeros_like(batch_imgs)

        for i, idx in enumerate(batch_indexes):
            # 读取图像
            img = load_img(os.path.join(self.img_dir, self.img_paths[idx]), target_size=self.img_size,
                           color_mode="grayscale")
            img = img_to_array(img) / 255.0

            mask = load_img(os.path.join(self.mask_dir, self.mask_paths[idx]), target_size=self.img_size,
                            color_mode="grayscale")
            mask = img_to_array(mask) / 255.0

            batch_imgs[i], batch_masks[i] = random_transform(img, mask)  # 应用数据增强

        return batch_imgs, batch_masks

# 创建数据生成器
train_generator = DataGenerator(train_img_paths, train_mask_paths, path_imgs, path_masks, batch_size, (img_row, img_col))
val_generator = DataGenerator(val_img_paths, val_mask_paths, path_imgs, path_masks, batch_size, (img_row, img_col))

## Define a metric for the performance of the model
Dice score is used here to evaluate the performance of your model.
More details about the Dice score and other metrics can be found at 
https://towardsdatascience.com/metrics-to-evaluate-your-semantic-segmentation-model-6bcb99639aa2. Dice score can be also used as the loss function for training your model.

In [None]:
def dsc(y_true, y_pred):
    smooth = 1.0
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2.0 * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dsc(y_true, y_pred)

## Build your own model here
The U-Net (https://link.springer.com/chapter/10.1007/978-3-319-24574-4_28) structure is widely used for the medical image segmentation task. You can build your own model or modify the UNet by changing the hyperparameters for our task. If you choose to use Keras, more information about the Keras layers including Conv2D, MaxPooling and Dropout can be found at https://keras.io/api/layers/.

In [None]:
def ConvBlock(in_fmaps, num_fmaps):
    conv1 = Conv2D(num_fmaps, (3, 3), activation='relu', padding='same', kernel_regularizer=l2(1e-4))(in_fmaps)
    conv1 = BatchNormalization()(conv1)
    conv1 = Dropout(0.3)(conv1)
    return Conv2D(num_fmaps, (3, 3), activation='relu', padding='same', kernel_regularizer=l2(1e-4))(conv1)

def Network():
    input_layer = Input(shape=input_size)

    conv1 = ConvBlock(input_layer, 32)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = ConvBlock(pool1, 32)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = ConvBlock(pool2, 64)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = ConvBlock(pool3, 64)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = ConvBlock(pool4, 128)

    up6 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
    conv6 = ConvBlock(up6, 64)

    up7 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
    conv7 = ConvBlock(up7, 64)

    up8 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
    conv8 = ConvBlock(up8, 32)

    up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
    conv9 = ConvBlock(up9, 32)

    output_layer = Conv2D(1, (1, 1), activation='sigmoid')(conv9)

    return Model(inputs=input_layer, outputs=output_layer)

## Train your model here
Once you defined the model and data generator, you can start training your model.

In [None]:
model = Network()
model.compile(optimizer=Adam(learning_rate=0.0005), loss=dice_loss, metrics=[dsc])

callbacks = [EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True), ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6)]
history = model.fit(train_generator, validation_data=val_generator, epochs=epochnum, verbose=1, callbacks=callbacks)
print("训练完成 ✅")

## Save the model
Once your model is trained, remember to save it for testing.

In [None]:
model.save("version2.13.keras")

## Run the model on the test set
After your last Q&A session, you will be given the test set. Run your model on the test set to get the segmentation results and submit your results in a .zip file. If the MRI image is named '100_fla.nii.gz', save your segmentation result as '100_seg.nii.gz'. 

In [None]:
model = load_model("version2.13.keras", custom_objects={'dsc': dsc, 'dice_loss': dice_loss})

# 设置测试图片目录和保存预测结果的目录
test_img_dir = r"C:\Users\21508\PycharmProjects\pythonProject8\photo1"  # 测试图片路径
save_dir = r"C:\Users\21508\PycharmProjects\pythonProject8\test_predict"  # 预测结果保存路径
os.makedirs(save_dir, exist_ok=True)  # 创建保存目录

# 获取所有测试图片
test_images = natsort.natsorted(os.listdir(test_img_dir))

# 对每张测试图片进行预测
for img_name in test_images:
    # 读取测试图片并预处理
    img_path = os.path.join(test_img_dir, img_name)
    img = load_img(img_path, target_size=(240,240), color_mode="grayscale")
    img = img_to_array(img) / 255.0  # 归一化
    img = np.expand_dims(img, axis=0)  # 扩展维度

    # 使用训练好的模型进行预测
    pred = model.predict(img)
    pred = np.squeeze(pred)  # 移除不必要的维度

    # 保存预测结果
    pred_img_path = os.path.join(save_dir, img_name)
    plt.imsave(pred_img_path, pred, cmap='gray')

    print(f"预测完成：{img_name}")

print("所有图片预测完成 ✅")

# 设置图像目录和保存路径
img_dir = r"C:\Users\21508\Desktop\winter_school\dataset_segmentation\trainA\001\001_seg"  # 图像文件夹
save_dir = r"C:\Users\21508\PycharmProjects\pythonProject8\001seg_nii_files"  # 3D图像保存路径
sample_nii_path = r"C:\Users\21508\PycharmProjects\pythonProject8\211_fla.nii" # 样本NIfTI文件路径

# 获取样本NIfTI文件的尺寸
sample_nii = nib.load(sample_nii_path)
sample_data = sample_nii.get_fdata()

# 获取样本的尺寸 (height, width, depth)
sample_shape = sample_data.shape
img_row, img_col, num_images_per_file = sample_shape[0], sample_shape[1], 155

# 创建保存目录
os.makedirs(save_dir, exist_ok=True)

# 获取所有图片并排序
images_list = natsort.natsorted(os.listdir(img_dir))

# 处理图像并保存为 3D NIfTI 文件
for i in range(0, len(images_list), num_images_per_file):
    # 选择当前批次的155张图像
    batch_images = images_list[i:i + num_images_per_file]

    # 加载并调整图像尺寸
    img_stack = []
    for img_name in batch_images:
        img_path = os.path.join(img_dir, img_name)
        img = load_img(img_path, target_size=(img_row, img_col), color_mode="grayscale")
        img = img_to_array(img) / 255.0  # 归一化
        img_stack.append(img)

    # 将图像堆叠成3D数组，并确保尺寸与样本相同
    img_stack = np.stack(img_stack, axis=-1)
    img_stack_resized = np.resize(img_stack, (img_row, img_col, num_images_per_file))

    # 创建NIfTI图像
    nifti_img = nib.Nifti1Image(img_stack_resized, affine=np.eye(4))  # 使用单位矩阵作为仿射矩阵

    # 保存为nii.gz格式
    nifti_filename = os.path.join(save_dir, f"{str(i // num_images_per_file + 1).zfill(3)}.nii.gz")
    nib.save(nifti_img, nifti_filename)

    print(f"保存3D图像：{nifti_filename}")

print("所有文件转换完成 ✅")

# 设置目标目录
root_dir = r"C:\Users\21508\PycharmProjects\pythonProject8\111\nii_files"  # 修改为实际路径

# 获取所有文件夹
folders = [f for f in os.listdir(root_dir)]

# 目标名称范围
start_num = 211
end_num = 251
def extract_number(filename):
    numbers = re.findall(r'\d+', filename)  # 找到所有数字
    return int(numbers[-1]) if numbers else float('inf')  # 取最后一个数字
# 确保文件夹数量匹配
if len(folders) != (end_num - start_num + 1):
    print("文件夹数量和目标命名数量不匹配，请检查！")
else:
    for i, folder in enumerate(sorted(folders,key=extract_number)):  # 按字母顺序排序，避免随机顺序
        new_name = f"{(start_num + i):3d}_seg.nii.gz"
        old_path = os.path.join(root_dir, folder)
        new_path = os.path.join(root_dir, new_name)
        os.rename(old_path, new_path)
        print(f"重命名: {folder} -> {new_name}")

    print("所有文件夹重命名完成！")

