In [None]:
"""
TODO:

1. Downsample 4 images into 1 - DONE
2. Enhance florescent images (intensity & color) - DONE
3. Build and train Pix2Pix model - DONE
4. Figure out what the best distance is for virtual staining
5. Improve marking centre of cells

"""

'\nTODO:\n\n1. Downsample 4 images into 1 - DONE\n2. Enhance florescent images (intensity & color) - DONE\n3. Build and train Pix2Pix model - DONE\n4. Figure out what the best distance is for virtual staining\n\n'

In [1]:
import random
import numpy as np
import os, shutil, random
import matplotlib.pyplot as plt
from PIL import Image, ImageEnhance
import multiprocessing
from concurrent.futures import ThreadPoolExecutor

In [10]:
dataset_dir = "dataset-OVCAR/dataset"
subset_dir = "datasets"
os.makedirs(subset_dir, exist_ok=True)

subset = 4 # Image to filter from, 0 <= subset <= 300
filter_level = 0

def clip_and_convert_to_png(file):
    img_path = os.path.join(dataset_dir, file)
    new_file_name = os.path.splitext(file)[0] + ".png"
    new_path = os.path.join(subset_dir, new_file_name)
    
    try:
        with Image.open(img_path) as img:
            width, height = img.size  
            l, t = (width - 1024) // 2, (height - 1024) // 2
            r, b = l + 1024, t + 1024
            cropped_img = img.crop((l, t, r, b))
            cropped_img.save(new_path, format="PNG")
    except Exception as e:
        print(f"Error processing {file}: {e}")

files = [
    f for f in os.listdir(dataset_dir) if f.endswith((f"{filter_level}.tif", "f.tif"))
]#[:subset]

with ThreadPoolExecutor(max_workers=multiprocessing.cpu_count()) as executor:
    executor.map(clip_and_convert_to_png, files)

print(f"Subset of data created with {len(files)} PNG images")


Subset of data created with 164 PNG images


In [11]:
dataset_dir = subset_dir

def enhance_fluorescence(image_path, threshold, color=(0, 255, 0)):
    img = Image.open(image_path).convert("RGB")
    grayscale = img.convert("L")
    img_array = np.array(img)
    gray_array = np.array(grayscale)

    mask = gray_array > threshold 
    img_array[mask] = color 

    enhanced_img = Image.fromarray(img_array)
    return enhanced_img

for file in os.listdir(dataset_dir):
    if file.endswith("f.png"):
        img_path = os.path.join(dataset_dir, file)
        enhanced_img = enhance_fluorescence(img_path, threshold=8)
        enhanced_img.save(os.path.join(dataset_dir, file))

print("Fluorescence enhanced and images saved.")

Fluorescence enhanced and images saved.


In [12]:
def rotate_and_zoom_pair(image_path, output_path, angle=90):
    image = Image.open(image_path).convert("RGB")
    image = image.rotate(angle)
    image.save(output_path, format="PNG")

def increase_brightness_linear(image_path, output_path, min_factor, max_factor):
    if not image_path.endswith("f.png"):
        image = Image.open(image_path).convert("RGB")
        width, height = image.size
        pixels = np.array(image, dtype=np.float32)
        
        gradient = np.linspace(min_factor, max_factor, width)
        brightness_matrix = np.tile(gradient, (height, 1))[:, :, None]
        
        brightened_pixels = np.clip(pixels * brightness_matrix, 0, 255).astype(np.uint8)
        Image.fromarray(brightened_pixels).save(output_path)
    else:
        Image.open(image_path).convert("RGB").save(output_path)

def increase_brightness_radial(image_path, output_path, min_factor, max_factor):
    if not image_path.endswith("f.png"):
        image = Image.open(image_path).convert("RGB")
        width, height = image.size
        pixels = np.array(image, dtype=np.float32)

        center_x, center_y = width // 2, height // 2
        y_indices, x_indices = np.meshgrid(np.arange(height), np.arange(width), indexing="ij")
        distances = np.sqrt((x_indices - center_x) ** 2 + (y_indices - center_y) ** 2)

        max_distance = np.max(distances)
        radial_factor = min_factor + (max_factor - min_factor) * (1 - distances / max_distance)
        radial_matrix = np.repeat(radial_factor[:, :, None], 3, axis=2)

        brightened_pixels = np.clip(pixels * radial_matrix, 0, 255).astype(np.uint8)
        Image.fromarray(brightened_pixels).save(output_path)
    else:
        Image.open(image_path).convert("RGB").save(output_path)


def process_image(file, rotation_angle, brightness_factor=(0.75, 1.25)):
    image_path = os.path.join(subset_dir, file)
    rotate_and_zoom_pair(image_path, os.path.join(subset_dir, f"rotate_90_{file}"), rotation_angle)
    rotate_and_zoom_pair(image_path, os.path.join(subset_dir, f"rotate_180_{file}"), rotation_angle*2)
    rotate_and_zoom_pair(image_path, os.path.join(subset_dir, f"rotate_270_{file}"), rotation_angle*3)
    # increase_brightness_linear(image_path, os.path.join(subset_dir, f"linear_{file}"), 
                            #    brightness_factor[0], brightness_factor[1])
    # increase_brightness_radial(image_path, os.path.join(subset_dir, f"radial_{file}"), 
                            #    brightness_factor[0], brightness_factor[1])

images = [f for f in os.listdir(subset_dir) if f.endswith(".png")]
original_size = len(os.listdir(subset_dir))


with ThreadPoolExecutor(max_workers=multiprocessing.cpu_count()) as executor:
    for i, img in enumerate(images):
        brightness_factor = (np.random.uniform(0.75, 1), np.random.uniform(1, 2))
        executor.submit(process_image, img, 90, brightness_factor)

print("Augmentation Done!")
print("Original Images:", original_size)
print("Augmented Images:", len(os.listdir(subset_dir)))

Augmentation Done!
Original Images: 164
Augmented Images: 656


In [31]:
# Augmentation test
# source = [f for f in os.listdir(".") if f.endswith(".png")]

# print(type(Image.open("img1.png")))


# for img in source:
#     for i in range(5):
#         min_factor, max_factor = np.random.uniform(0.5, 1), np.random.uniform(1, 2)
#         print(min_factor, max_factor)
#         increase_brightness_linear(img, f"{img[:-4]}_linear_{i}.png", min_factor, max_factor)
#     for i in range(5):
#         min_factor, max_factor = np.random.uniform(0.5, 1), np.random.uniform(1, 2)
#         print(min_factor, max_factor)
#         increase_brightness_radial(img, f"{img[:-4]}_radial_{i}.png", min_factor, max_factor)

In [15]:
output_dir = "virtual_staining_OVCAR"
os.makedirs(output_dir, exist_ok=True)

filter_level = "0"

def split_image(image_path, patch_size=(256, 256), downsample_size=(256, 256)):
    img = Image.open(image_path).convert("RGB")
    w, h = img.size
    patches = []
    for i in range(0, w, patch_size[0]):
        for j in range(0, h, patch_size[1]):
            if i + patch_size[0] <= w and j + patch_size[1] <= h:
                patch = img.crop((i, j, i + patch_size[0], j + patch_size[1]))
                # patch = patch.resize(downsample_size, Image.LANCZOS)
                patches.append((i, j, patch))
    return patches

counter = 0
image_files = [f for f in os.listdir(dataset_dir) if f.endswith(".png")]
print(len(image_files))
for i in range(0, len(image_files), 2):
    batch = image_files[i:i+2]
    # input_images = batch[:11] # :11
    input_image = batch[0]#[img for img in batch if not img.endswith(f"f.png")]
    output_image = batch[1]
    output_patches = split_image(os.path.join(dataset_dir, output_image))
    input_patches = split_image(os.path.join(dataset_dir, input_image))
    image_counter = 0
    for idx, (output_patch, input_patch) in enumerate(zip(output_patches, input_patches)):
        out_x, out_y, out_patch = output_patch
        in_x, in_y, in_patch = input_patch

        in_patch = np.array(in_patch) / 127.5 - 1.0
        out_patch = np.array(out_patch) / 127.5 - 1.0
        
        paired_img = Image.new("RGB", (512, 256))
        in_patch = Image.fromarray((in_patch * 127.5 + 127.5).astype(np.uint8))
        out_patch = Image.fromarray((out_patch * 127.5 + 127.5).astype(np.uint8))
        
        paired_img.paste(in_patch, (0, 0))
        paired_img.paste(out_patch, (256, 0))

        patch_filename = f"{input_image[:-4]}_{image_counter}.png"
        image_counter += 1

        paired_img.save(os.path.join(output_dir, patch_filename))

print("Images paired and saved")

656
Images paired and saved


In [None]:
found = 0
def valid_image(image_path, target_color=(0, 255, 0)):
    image = Image.open(image_path).convert("RGB")
    if target_color not in image.getdata():
        if np.random.uniform(0, 1) < 0.8:
            os.remove(image_path)
            return False
    return True

image_files = [f"{os.path.join(output_dir, f)}" for f in os.listdir(output_dir) if f.endswith(".png")]
for f in image_files:
    if not valid_image(f):
        found += 1

print(len(image_files), found)

3555 218


In [18]:
dataset_dir, train_dir, val_dir = output_dir, f"{output_dir}/train", f"{output_dir}/test"
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)

image_files = [f for f in os.listdir(dataset_dir) if f.endswith(".png")]
random.shuffle(image_files)

train_size = int(0.8 * len(image_files))
for f in image_files[:train_size]: shutil.move(os.path.join(dataset_dir, f), os.path.join(train_dir, f))
for f in image_files[train_size:]: shutil.move(os.path.join(dataset_dir, f), os.path.join(val_dir, f))

print(f"Training set: {len(image_files[:train_size])} images, Validation set: {len(image_files[train_size:])} images.")


Training set: 2669 images, Validation set: 668 images.


In [19]:
# Move train/test data to correct location for pix2pix process
# Filter subset of dataset
train_test_dir = os.path.join("pytorch-CycleGAN-and-pix2pix/datasets", output_dir)
if os.path.exists(train_test_dir):
    shutil.rmtree(train_test_dir)

shutil.copytree(output_dir, train_test_dir)

print(f"Train/Test image moved to pix2pix")

Train/Test image moved to pix2pix


In [None]:
# marking centres of florescent cells
from PIL import Image
import numpy as np
import cv2
from scipy.ndimage import center_of_mass, label

img = Image.open("datasets/rotate_90_capture_19800_750_6742.8_f.png").convert("RGB")
img_np = np.array(img)

green_mask = (img_np[:, :, 1] == 255) & (img_np[:, :, 0] == 0) & (img_np[:, :, 2] == 0)
binary = green_mask.astype(np.uint8)

kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (50, 50))
binary_cleaned = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)

labeled_array, num_features = label(binary_cleaned)

for i in range(1, num_features + 1):
    blob = (labeled_array == i)
    target_pixels = np.sum(blob)
    if target_pixels > 1000: # filtering bigger bobs only
        y, x = center_of_mass(blob)
        cv2.circle(img_np, (int(x), int(y)), radius=10, color=(255, 0, 0), thickness=-1)

result = Image.fromarray(img_np)
result.save("separated_cleaned_centers.png")