# preprocessing for the augmentation experiment

## 1.Check Class Balance in Original Image Dataset

In [1]:
import os

def check_class_balance(directory):
    """
    Check the number of images in each class subdirectory.

    Args:
        directory (str): Path to the directory containing class subfolders.

    Returns:
        dict: A dictionary with class names as keys and the number of images as values.
    """
    class_counts = {}

    # Iterate through all class subdirectories
    for class_name in os.listdir(directory):
        class_path = os.path.join(directory, class_name)
        if os.path.isdir(class_path):  # Ensure it's a folder
            # Count the number of files (excluding hidden files)
            num_files = len([f for f in os.listdir(class_path) if not f.startswith('.')])
            class_counts[class_name] = num_files

    return class_counts


# Specify the directory to check
target_dir = "Data/Dataset_BUSI_with_GT"

# Check class balance
class_balance = check_class_balance(target_dir)

# Display the results
print("Class Distribution:")
for class_name, count in class_balance.items():
    print(f"{class_name}: {count} images")



Class Distribution:
normal: 266 images
benign: 891 images
malignant: 421 images


## 2.Generating Image and Mask Augmentation for Class Balancing

In [None]:
import os
import uuid
import random
from PIL import Image
from torchvision.transforms import functional as F

# Augmentation function with synchronized transformations
def synchronized_transform(image, masks, resize=256, crop_size=224, degrees=70, p_flip=0.5):
    seed = random.randint(0, 10000)

    # Resize image and masks
    image = F.resize(image, resize)
    masks = [F.resize(mask, resize) for mask in masks]

    # Random rotation (shared angle)
    random.seed(seed)
    angle = random.uniform(-degrees, degrees)
    image = F.rotate(image, angle, fill=0)
    masks = [F.rotate(mask, angle, fill=0) for mask in masks]

    # Center crop
    image = F.center_crop(image, crop_size)
    masks = [F.center_crop(mask, crop_size) for mask in masks]

    # Random horizontal flip
    random.seed(seed)
    if random.random() < p_flip:
        image = F.hflip(image)
        masks = [F.hflip(mask) for mask in masks]

    # Convert to tensor
    image = F.to_tensor(image)
    masks = [F.to_tensor(mask) for mask in masks]

    return image, masks

# Prepare the dataset object
def prepare_dataset(data_dir, categories):
    dataset = []
    for category in categories:
        category_dir = os.path.join(data_dir, category)
        for file_name in os.listdir(category_dir):
            if '_mask' not in file_name:  # Check for base image
                image_path = os.path.join(category_dir, file_name)
                base_name = os.path.splitext(file_name)[0]
                mask_paths = [os.path.join(category_dir, f) for f in os.listdir(category_dir)
                              if base_name in f and '_mask' in f]
                dataset.append((image_path, mask_paths, category))
    return dataset

# Augmentation target counts
augmentation_target = {
    "malignant": 25,  # Additional images needed
    "normal": 179      # Additional images needed
}

def augment_class(dataset, class_name, additional_samples, output_dir):
    class_samples = [data for data in dataset if data[2] == class_name]

    augmented_count = 0
    while augmented_count < additional_samples:
        for img_path, mask_paths, _ in class_samples:
            if augmented_count >= additional_samples:
                break

            # Load the image and masks
            image = Image.open(img_path).convert("RGB")
            masks = [Image.open(mask_path).convert("L") for mask_path in mask_paths]

            # Apply synchronized transformations to image and masks
            augmented_image, augmented_masks = synchronized_transform(image, masks)

            # Save the augmented image and masks
            unique_id = str(uuid.uuid4())
            base_name = os.path.splitext(os.path.basename(img_path))[0]
            output_class_dir = os.path.join(output_dir, class_name)
            os.makedirs(output_class_dir, exist_ok=True)

            # Save augmented image
            img_save_path = os.path.join(output_class_dir, f"{base_name}_aug_{unique_id}.png")
            F.to_pil_image(augmented_image).save(img_save_path)

            # Save augmented masks
            for i, mask in enumerate(augmented_masks):
                mask_save_path = os.path.join(output_class_dir, f"{base_name}_mask_aug_{unique_id}_{i}.png")
                F.to_pil_image(mask).save(mask_save_path)

            augmented_count += 1

# Directories
original_dir = "Data/Dataset_BUSI_with_GT"
augmented_dir = "Augmented_data"
categories = ['normal', 'benign', 'malignant']

# Prepare the dataset
dataset = prepare_dataset(original_dir, categories)

# Augment the minority classes
for class_name, additional_samples in augmentation_target.items():
    print(f"Augmenting class '{class_name}' to add {additional_samples} samples...")
    augment_class(dataset, class_name, additional_samples, augmented_dir)


## 3.Check Class Balance in Augmented Image Dataset

In [None]:
import os

def check_class_balance(directory):
    """
    Check the number of images in each class subdirectory.

    Args:
        directory (str): Path to the directory containing class subfolders.

    Returns:
        dict: A dictionary with class names as keys and the number of images as values.
    """
    class_counts = {}

    # Iterate through all class subdirectories
    for class_name in os.listdir(directory):
        class_path = os.path.join(directory, class_name)
        if os.path.isdir(class_path):  # Ensure it's a folder
            # Count the number of files (excluding hidden files)
            num_files = len([f for f in os.listdir(class_path) if not f.startswith('.')])
            class_counts[class_name] = num_files

    return class_counts


# Specify the directory to check
target_dir = "Augmented_data"

# Check class balance
class_balance = check_class_balance(target_dir)

# Display the results
print("Class Distribution:")
for class_name, count in class_balance.items():
    print(f"{class_name}: {count} images")



## 4.Renaming Augmented Images and Masks

In [None]:
import os

# Define the path to the dataset directory (containing 'normal', 'benign', 'malignant')
dataset_dir = "Augmented_data"

# Categories to process
categories = ['normal', 'benign', 'malignant']

# Process each category
for category in categories:
    category_path = os.path.join(dataset_dir, category)
    if not os.path.exists(category_path):
        print(f"Warning: {category_path} does not exist. Skipping.")
        continue

    # Get all files in the category directory
    all_files = sorted(os.listdir(category_path))

    # Initialize counters for renaming
    aug_counter = {}

    print(f"Processing category: {category}")

    # Process each file in the directory
    for file_name in all_files:
        if "_aug_" in file_name and "mask" not in file_name:  # Augmented images without masks
            # Extract the base name before the unique ID
            base_name = file_name.split("_aug_")[0]
            
            # Increment the counter for this base name
            if base_name not in aug_counter:
                aug_counter[base_name] = 0
            else:
                aug_counter[base_name] += 1
            
            # Generate the new file name for the image
            new_image_name = f"{base_name}_aug_{aug_counter[base_name]}.png"
            
            # Generate the expected mask file name
            unique_id = file_name.split("_aug_")[1].split(".png")[0]
            mask_name = f"{base_name}_mask_aug_{unique_id}_0.png"
            
            # Generate the new file name for the mask
            new_mask_name = f"{base_name}_aug_{aug_counter[base_name]}_mask.png"
            
            # Rename the image
            image_path = os.path.join(category_path, file_name)
            new_image_path = os.path.join(category_path, new_image_name)
            os.rename(image_path, new_image_path)
            print(f"Renamed {file_name} to {new_image_name}")
            
            # Rename the mask if it exists
            mask_path = os.path.join(category_path, mask_name)
            if os.path.exists(mask_path):
                new_mask_path = os.path.join(category_path, new_mask_name)
                os.rename(mask_path, new_mask_path)
                print(f"Renamed {mask_name} to {new_mask_name}")
            else:
                print(f"Warning: No matching mask found for {file_name}")


# preprocessing for the Intersection experiment

## 1.Processing Images with Intersection Masks

In [None]:
import os
import cv2
import torch
import numpy as np
from torchvision.transforms import ToTensor, Resize
from PIL import Image

def process_image_and_masks(image_path, mask_paths, output_dir, target_size=(224, 224)):
    """
    Process an image and its corresponding masks to create an intersection mask and apply it to the image.

    Args:
        image_path (str): Path to the input image.
        mask_paths (list of str): List of paths to the masks corresponding to the image.
        output_dir (str): Directory to save the processed images.
        target_size (tuple): Target size for resizing images and masks (width, height).

    Returns:
        None: Saves the processed image with the intersection mask applied.
    """
    try:
        # Load the image
        image = Image.open(image_path).convert("RGB")  # Ensure image is RGB
        image = image.resize(target_size)  # Resize to target size
        image_tensor = ToTensor()(image)  # Convert to tensor with shape (3, H, W)

        # Load and concatenate the masks
        masks = []
        for mask_path in mask_paths:
            mask = Image.open(mask_path).convert("L")  # Ensure mask is grayscale
            mask = mask.resize(target_size)  # Resize to target size
            mask_tensor = ToTensor()(mask)  # Convert to tensor with shape (1, H, W)
            masks.append(mask_tensor)

        # Concatenate masks along the channel dimension
        masks_tensor = torch.cat(masks, dim=0)  # Shape: (N, H, W), where N is the number of masks

        # Compute the intersection of the masks
        intersection_mask = torch.all(masks_tensor.bool(), dim=0).float()  # Shape: (H, W)

        # Apply the intersection mask to the image
        processed_image = image_tensor * intersection_mask.unsqueeze(0)  # Broadcasting to match image shape

        # Save the processed image
        base_name = os.path.splitext(os.path.basename(image_path))[0]
        save_path = os.path.join(output_dir, f"{base_name}_processed.png")
        processed_image_np = processed_image.permute(1, 2, 0).numpy() * 255  # Convert to HWC and scale to 0-255
        cv2.imwrite(save_path, processed_image_np.astype(np.uint8))
        print(f"Processed and saved: {save_path}")

    except Exception as e:
        print(f"Error processing {image_path}: {e}")

# Dataset Directories
dataset_dir = "Augmented_data"
output_dir = "Intersection_data"
categories = ['normal', 'benign', 'malignant']

# Create output directories for each category
for category in categories:
    os.makedirs(os.path.join(output_dir, category), exist_ok=True)

# Process images and masks
for category in categories:
    category_dir = os.path.join(dataset_dir, category)
    output_category_dir = os.path.join(output_dir, category)

    for file_name in os.listdir(category_dir):
        if '_mask' not in file_name:  # Find the base image
            image_path = os.path.join(category_dir, file_name)

            # Collect corresponding masks
            base_name = os.path.splitext(file_name)[0]
            mask_paths = [os.path.join(category_dir, f) for f in os.listdir(category_dir)
                          if base_name in f and '_mask' in f]

            if mask_paths:  # Only process if there are masks
                process_image_and_masks(image_path, mask_paths, output_category_dir, target_size=(224, 224))


## 2.Visualizing Weighted Intersection and Corrected Regions with Fallback

In [None]:
def visualize_weighted_intersection(image_path, mask_paths, target_size=(224, 224), weight_threshold=0.3, output_dir="./corrected_intersections"):
    """
    Visualize the weighted intersection process of an image and its corresponding mask(s) with fallback.

    Args:
        image_path (str): Path to the input image.
        mask_paths (list of str): List of paths to the masks corresponding to the image.
        target_size (tuple): Target size for resizing images and masks (width, height).
        weight_threshold (float): Minimum weight sum to retain a pixel in the weighted mask.

    Returns:
        None: Displays the results at each step.
    """
    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Load the image
    image = Image.open(image_path).convert("RGB")  # Ensure image is RGB
    image_resized = image.resize(target_size)  # Resize to target size
    image_tensor = ToTensor()(image_resized)  # Convert to tensor with shape (3, H, W)

    # Load the masks
    masks = []
    for mask_path in mask_paths:
        mask = Image.open(mask_path).convert("L")  # Ensure mask is grayscale
        mask_resized = mask.resize(target_size)  # Resize to target size
        mask_tensor = ToTensor()(mask_resized)  # Convert to tensor with shape (1, H, W)
        masks.append(mask_tensor)

    # Concatenate masks along the channel dimension
    masks_tensor = torch.cat(masks, dim=0)  # Shape: (N, H, W), where N is the number of masks

    # Compute the weighted intersection of the masks
    weighted_mask = masks_tensor.mean(dim=0)  # Compute the average pixel intensity across masks
    thresholded_mask = (weighted_mask > weight_threshold).float()  # Retain pixels above the threshold

    # If the thresholded mask is empty, use the union of masks as fallback
    if thresholded_mask.sum() == 0:
        thresholded_mask = (masks_tensor.sum(dim=0) > 0).float()  # Union of masks (logical OR)

    # Apply the intersection mask to the image
    intersection_applied = image_tensor * thresholded_mask.unsqueeze(0)  # Broadcasting to match image shape

    # Plot and visualize
    plt.figure(figsize=(15, 8))
    plt.subplot(1, len(masks) + 3, 1)
    plt.title("Original Image")
    plt.imshow(image_resized)
    plt.axis("off")

    for i, mask_tensor in enumerate(masks):
        plt.subplot(1, len(masks) + 3, i + 2)
        plt.title(f"Mask {i+1}")
        plt.imshow(mask_tensor.squeeze(0), cmap="gray")
        plt.axis("off")

    plt.subplot(1, len(masks) + 3, len(masks) + 2)
    plt.title("Weighted Intersection Mask")
    plt.imshow(thresholded_mask, cmap="gray")
    plt.axis("off")

    plt.subplot(1, len(masks) + 3, len(masks) + 3)
    plt.title("Intersection Applied")
    plt.imshow(intersection_applied.permute(1, 2, 0).numpy())
    plt.axis("off")

    plt.show()
        # Save the processed image
    base_name = os.path.splitext(os.path.basename(image_path))[0]
    save_path = os.path.join(output_dir, f"{base_name}_corrected.png")
    processed_image_np = intersection_applied.permute(1, 2, 0).numpy() * 255  # Convert to HWC and scale to 0-255
    cv2.imwrite(save_path, processed_image_np.astype(np.uint8))
    print(f"Processed and saved corrected image: {save_path}")

# Example usage
image_path = "Augmented_data/benign/benign (195).png"  # Replace with actual image path
mask_paths = [
    "Augmented_data/benign/benign (195)_mask.png",
    "Augmented_data/benign/benign (195)_mask_1.png",
    "Augmented_data/benign/benign (195)_mask_2.png"
]  # Replace with actual mask paths
visualize_weighted_intersection(image_path, mask_paths, target_size=(224, 224), weight_threshold=0.3, output_dir="./corrected_intersections")


## 3.Organizing and Renaming Corrected Files in Intersection Dataset

In [None]:
import os
import shutil

# Paths
corrected_dir = "corrected_intersections"
intersection_dir = "Intersection_data"

# Process corrected files
for root, dirs, files in os.walk(corrected_dir):
    for file in files:
        # Full path to the corrected file
        corrected_file_path = os.path.join(root, file)
        
        # Determine the category (e.g., normal, benign, malignant)
        relative_path = os.path.relpath(root, corrected_dir)
        category = relative_path.split(os.sep)[0]
        
        # Define the target directory in the intersection data folder
        target_dir = os.path.join(intersection_dir, category)
        os.makedirs(target_dir, exist_ok=True)  # Ensure the directory exists
        
        # Append "_processed" to the file name
        base_name = os.path.splitext(file)[0]  # Get the file name without extension
        new_file_name = f"{base_name}_processed.png"  # Append "_processed"
        
        # Full path for the new file in the target directory
        target_file_path = os.path.join(target_dir, new_file_name)
        
        # Remove existing file in the target directory if it has the same name
        if os.path.exists(target_file_path):
            os.remove(target_file_path)
            print(f"Removed existing file: {target_file_path}")
        
        # Move and rename the corrected file
        shutil.move(corrected_file_path, target_file_path)
        print(f"Moved and renamed: {corrected_file_path} -> {target_file_path}")

print("All corrected files have been processed, renamed, and moved to the correct directories.")


## 4.Center Cropping and Resizing Tumor Images

In [None]:
import os
import numpy as np
from PIL import Image
from torchvision.transforms import functional as F

def center_crop_image(image_path, output_dir, target_size=(224, 224)):
    """
    Center crop an intersection image to make the tumor region central.

    Args:
        image_path (str): Path to the input intersection image.
        output_dir (str): Directory to save the cropped and resized images.
        target_size (tuple): Target size for resizing images (width, height).

    Returns:
        None: Saves the cropped and resized image.
    """
    # Load the image
    image = Image.open(image_path).convert("RGB")  # Ensure image is RGB

    # Convert image to a NumPy array to find the bounding box
    image_array = np.array(image.convert("L"))  # Convert to grayscale for processing
    y_indices, x_indices = np.where(image_array > 0)  # Non-zero pixels in the image (tumor region)

    if len(x_indices) == 0 or len(y_indices) == 0:
        print(f"Warning: No non-zero region found in {image_path}. Skipping.")
        return

    # Calculate bounding box of the tumor region
    x_min, x_max = x_indices.min(), x_indices.max()
    y_min, y_max = y_indices.min(), y_indices.max()

    # Calculate the center and size of the bounding box
    x_center = (x_min + x_max) // 2
    y_center = (y_min + y_max) // 2

    # Define cropping box to center the tumor region
    crop_size = min(image.size)  # Use the smaller dimension of the image as the crop size
    left = max(0, x_center - crop_size // 2)
    top = max(0, y_center - crop_size // 2)
    right = min(image.size[0], left + crop_size)
    bottom = min(image.size[1], top + crop_size)

    # Perform cropping
    cropped_image = image.crop((left, top, right, bottom))

    # Resize the cropped image to the target size
    resized_image = cropped_image.resize(target_size, Image.BILINEAR)

    # Save the processed image
    base_name = os.path.splitext(os.path.basename(image_path))[0]
    save_path = os.path.join(output_dir, f"{base_name}_centered.png")
    resized_image.save(save_path)
    print(f"Processed and saved: {save_path}")

# Dataset Directories
intersection_dir = "Intersection_data"
output_dir = "Cropped_intersection_data"
categories = ['normal', 'benign', 'malignant']

# Create output directories for each category
for category in categories:
    os.makedirs(os.path.join(output_dir, category), exist_ok=True)

# Process images
for category in categories:
    category_dir = os.path.join(intersection_dir, category)
    output_category_dir = os.path.join(output_dir, category)

    for file_name in os.listdir(category_dir):
        image_path = os.path.join(category_dir, file_name)
        center_crop_image(image_path, output_category_dir, target_size=(224, 224))
