In [2]:
from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt


In [3]:
# Load the cropped image
image = Image.open('/Users/mikaildemir/Desktop/Image-Augmentation-Salmonella/Array.tif')

# Define the new patch size based on the updated image dimensions
patch_size = (55, 53)  # Width x Height

# Convert the image to a numpy array for easy manipulation
image_array = np.array(image)

# Define the number of rows and columns in the grid
rows = 7  # 7 rows
cols = 10  # 10 columns

# Initialize a list to store all patches
patches = []

# Loop through each row and column to extract all patches
for row in range(rows):
    for col in range(cols):
        # Calculate the coordinates for the top-left corner of each patch
        x = col * patch_size[0]
        y = row * patch_size[1]
        
        # Extract the patch
        patch = image_array[y:y + patch_size[1], x:x + patch_size[0]]
        patches.append(patch)
patches = patches[:61] + [patches[-1]]
# The list `patches` now contains all extracted patches from the image
print(f"Total patches extracted: {len(patches)}")


Total patches extracted: 62


In [4]:
# Your list of concentration levels (62 values)
concentration_levels = [
    0, 0, 20, 20, 0, 0, 20, 10, 0, 10, 10, 0, 50, 0, 0, 50, 0, 50, 20, 0, 
    0, 0, 0, 20, 0, 10, 10, 50, 0, 0, 20, 50, 10, 0, 10, 0, 20, 0, 0, 20, 
    0, 0, 10, 0, 0, 50, 50, 0, 10, 50, 50, 0, 50, 20, 0, 10, 0, 0, 20, 0, 
    0, 0
]

# Make sure the number of patches matches the number of concentration levels
assert len(patches) == len(concentration_levels), "Mismatch between patches and concentration levels."

# Create a list of tuples where each tuple contains a patch and its corresponding concentration level
labeled_patches = [(patch, concentration_levels[i]) for i, patch in enumerate(patches)]



In [11]:
import albumentations as A
import cv2
import os

# Define augmentations
rgb_shift_augmentation = A.Compose([
    A.RGBShift(r_shift_limit=100, g_shift_limit=100, b_shift_limit=100, p=1.0)
])

channel_shuffle_augmentation = A.Compose([
    A.ChannelShuffle(p=1.0)
])

# Define flip and rotation augmentations
flip_rotation_augmentations = [
    A.HorizontalFlip(p=1.0),               # Horizontal flip
    A.VerticalFlip(p=1.0),                 # Vertical flip
    A.Rotate(limit=45, p=1.0),             # Rotate by +45 degrees
    A.Rotate(limit=-45, p=1.0)             # Rotate by -45 degrees
]

# Create directories if they don’t exist
os.makedirs("original_patches", exist_ok=True)
os.makedirs("augmented_images", exist_ok=True)

# List to store all augmented images along with their labels
augmented_labeled_patches = []

# Loop through all patches and generate exactly 8 augmented versions
for idx, (patch, label) in enumerate(labeled_patches):
    # Save original patch
    original_path = f"original_patches/patch_{idx}.png"
    cv2.imwrite(original_path, cv2.cvtColor(patch, cv2.COLOR_RGB2BGR))
    
    # Initialize a list to store this patch’s augmented images
    augmented_images = []
    
    # Generate 2 RGB Shift Augmentations
    for _ in range(2):
        augmented_image = rgb_shift_augmentation(image=patch)['image']
        augmented_images.append(augmented_image)
        augmented_labeled_patches.append((augmented_image, label))  # Track with label

    # Generate 2 Channel Shuffle Augmentations
    for _ in range(2):
        augmented_image = channel_shuffle_augmentation(image=patch)['image']
        augmented_images.append(augmented_image)
        augmented_labeled_patches.append((augmented_image, label))  # Track with label

    # Generate 4 Flip/Rotation Augmentations
    for transform in flip_rotation_augmentations:
        augmented_image = transform(image=patch)['image']
        augmented_images.append(augmented_image)
        augmented_labeled_patches.append((augmented_image, label))  # Track with label
    
    # Save all 8 augmented images at once
    for aug_idx, augmented_image in enumerate(augmented_images, start=1):
        augmented_path = f"augmented_images/augmented_{idx}_{aug_idx}.png"
        cv2.imwrite(augmented_path, cv2.cvtColor(augmented_image, cv2.COLOR_RGB2BGR))

print("All images saved successfully.")


All images saved successfully.


In [19]:
len(augmented_labeled_patches)

496

In [21]:
# Check the total number of augmented images
total_augmented_images = len(augmented_labeled_patches)
print(f"Total number of augmented images: {total_augmented_images}")

# Check if each patch has exactly 8 augmented images
# Assuming `labeled_patches` has the original patches
num_patches = len(labeled_patches)
expected_total_augmented_images = num_patches * 8

if total_augmented_images == expected_total_augmented_images:
    print(f"Each patch has exactly 8 augmented images. Total is correct: {total_augmented_images}")
else:
    print(f"Error: Expected {expected_total_augmented_images} augmented images, but got {total_augmented_images}")


Total number of augmented images: 496
Each patch has exactly 8 augmented images. Total is correct: 496


In [23]:
# Check if concentration levels are correctly labeled for augmented images
correct_labels = True

# Loop over each original patch and its label in `labeled_patches`
for idx, (_, original_label) in enumerate(labeled_patches):
    # Get the start and end indices for the 8 augmented images of this patch
    start_idx = idx * 8
    end_idx = start_idx + 8
    
    # Loop over the 8 augmented images for this patch
    for augmented_image, augmented_label in augmented_labeled_patches[start_idx:end_idx]:
        # Check if the augmented label matches the original label
        if augmented_label != original_label:
            print(f"Error: Augmented image label does not match for patch {idx}")
            correct_labels = False
            break

if correct_labels:
    print("All augmented images have correct labels.")
else:
    print("There were errors in labeling augmented images.")


All augmented images have correct labels.


In [25]:
# Print the label of each augmented image
for i, (augmented_image, label) in enumerate(augmented_labeled_patches):
    print(f"Augmented Image {i + 1}: Label = {label}")


Augmented Image 1: Label = 0
Augmented Image 2: Label = 0
Augmented Image 3: Label = 0
Augmented Image 4: Label = 0
Augmented Image 5: Label = 0
Augmented Image 6: Label = 0
Augmented Image 7: Label = 0
Augmented Image 8: Label = 0
Augmented Image 9: Label = 0
Augmented Image 10: Label = 0
Augmented Image 11: Label = 0
Augmented Image 12: Label = 0
Augmented Image 13: Label = 0
Augmented Image 14: Label = 0
Augmented Image 15: Label = 0
Augmented Image 16: Label = 0
Augmented Image 17: Label = 20
Augmented Image 18: Label = 20
Augmented Image 19: Label = 20
Augmented Image 20: Label = 20
Augmented Image 21: Label = 20
Augmented Image 22: Label = 20
Augmented Image 23: Label = 20
Augmented Image 24: Label = 20
Augmented Image 25: Label = 20
Augmented Image 26: Label = 20
Augmented Image 27: Label = 20
Augmented Image 28: Label = 20
Augmented Image 29: Label = 20
Augmented Image 30: Label = 20
Augmented Image 31: Label = 20
Augmented Image 32: Label = 20
Augmented Image 33: Label = 0
Aug

In [30]:
import csv
import os

# Path to save CSV
csv_path = "image_labels.csv"

# Initialize list to hold image data
image_data = []

# Process original patches
for filename in sorted(os.listdir("original_patches")):
    if filename.endswith(".png"):
        # Extract index from filename (e.g., patch_0.png -> 0)
        idx = int(filename.split("_")[1].split(".")[0])
        label = concentration_levels[idx]  # Get concentration level based on index
        image_path = os.path.join("original_patches", filename)
        image_data.append([image_path, label])

# Process augmented images
for filename in sorted(os.listdir("augmented_images")):
    if filename.endswith(".png"):
        # Extract index from filename (e.g., augmented_0_1.png -> 0)
        idx = int(filename.split("_")[1])
        label = concentration_levels[idx]  # Get concentration level based on index
        image_path = os.path.join("augmented_images", filename)
        image_data.append([image_path, label])

# Write image paths and labels to CSV file
with open(csv_path, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["Image_Path", "Concentration_Level"])  # Header
    writer.writerows(image_data)

print(f"CSV file saved at {csv_path} with image paths and concentration levels.")


CSV file saved at image_labels.csv with image paths and concentration levels.
