# Addressing Class Imbalance and Splitting Dataset for U-Net

In [2]:
import os
import cv2
import numpy as np
import random
from tqdm import tqdm
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import shutil


In [3]:
# Dataset paths
base_path = r"D:\IIT\Subjects\(4606)Machine Vision\CW\Develo\DataSet\U-Net\processed"
image_dir = os.path.join(base_path, "images")
mask_dir = os.path.join(base_path, "masks")
target_count = 1426  # Match the largest class (glioma)


In [4]:
datagen = ImageDataGenerator(
    rotation_range=20,
    zoom_range=0.15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    fill_mode="nearest"
)

def augment_class(class_name, current_count):
    img_class_path = os.path.join(image_dir, class_name)
    mask_class_path = os.path.join(mask_dir, class_name)
    files = os.listdir(img_class_path)

    needed = target_count - current_count
    generated = 0
    index = 0

    while generated < needed:
        img_name = files[index % len(files)]
        img_path = os.path.join(img_class_path, img_name)
        mask_path = os.path.join(mask_class_path, img_name)

        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        img = np.expand_dims(img, axis=(0, -1))
        mask = np.expand_dims(mask, axis=(0, -1))

        seed = random.randint(0, 10000)
        img_aug = datagen.flow(img, batch_size=1, seed=seed)
        mask_aug = datagen.flow(mask, batch_size=1, seed=seed)

        aug_img = next(img_aug)[0].astype(np.uint8)
        aug_mask = next(mask_aug)[0].astype(np.uint8)

        save_name = f"aug_{generated}_{img_name}"
        cv2.imwrite(os.path.join(img_class_path, save_name), aug_img)
        cv2.imwrite(os.path.join(mask_class_path, save_name), aug_mask)

        generated += 1
        index += 1

    print(f"✅ Augmented {needed} samples for {class_name}")


In [5]:
for class_name in os.listdir(image_dir):
    img_count = len(os.listdir(os.path.join(image_dir, class_name)))
    if img_count < target_count:
        augment_class(class_name, img_count)


✅ Augmented 718 samples for meningioma
✅ Augmented 496 samples for pituitary


In [6]:
splits = ['train', 'val', 'test']
split_ratio = {'train': 0.7, 'val': 0.2, 'test': 0.1}
split_base_dir = os.path.join(base_path, 'split')

for class_name in os.listdir(image_dir):
    img_files = os.listdir(os.path.join(image_dir, class_name))
    random.shuffle(img_files)

    total = len(img_files)
    train_end = int(split_ratio['train'] * total)
    val_end = train_end + int(split_ratio['val'] * total)

    split_files = {
        'train': img_files[:train_end],
        'val': img_files[train_end:val_end],
        'test': img_files[val_end:]
    }

    for split in splits:
        os.makedirs(os.path.join(split_base_dir, split, 'images', class_name), exist_ok=True)
        os.makedirs(os.path.join(split_base_dir, split, 'masks', class_name), exist_ok=True)

        for file in split_files[split]:
            src_img = os.path.join(image_dir, class_name, file)
            src_mask = os.path.join(mask_dir, class_name, file)

            dst_img = os.path.join(split_base_dir, split, 'images', class_name, file)
            dst_mask = os.path.join(split_base_dir, split, 'masks', class_name, file)

            shutil.copy(src_img, dst_img)
            shutil.copy(src_mask, dst_mask)

    print(f"✅ Split completed for class: {class_name}")


✅ Split completed for class: glioma
✅ Split completed for class: meningioma
✅ Split completed for class: pituitary
