In [3]:
"""
TODO:

1. Downsample 4 images into 1 - DONE
2. Enhance florescent images (intensity & color) - DONE
3. Build and train Pix2Pix model - DONE
4. Figure out what the best distance is for virtual staining

"""

'\nTODO:\n\n1. Downsample 4 images into 1 - DONE\n2. Enhance florescent images (intensity & color) - DONE\n3. Build and train Pix2Pix model - DONE\n4. Figure out what the best distance is for virtual staining\n\n'

In [1]:
import numpy as np
import os, shutil, random
import matplotlib.pyplot as plt
from PIL import Image, ImageEnhance

In [7]:
dataset_dir = "datasets"

def enhance_fluorescence(image_path, threshold, color=(0, 255, 0)):
    img = Image.open(image_path).convert("RGB")
    grayscale = img.convert("L")
    img_array = np.array(img)
    gray_array = np.array(grayscale)

    mask = gray_array > threshold 
    img_array[mask] = color 

    enhanced_img = Image.fromarray(img_array)
    return enhanced_img

for file in os.listdir(dataset_dir):
    if file.endswith("f.tif"):
        img_path = os.path.join(dataset_dir, file)
        enhanced_img = enhance_fluorescence(img_path, threshold=7)
        enhanced_img.save(os.path.join(dataset_dir, file))

print("Fluorescence enhanced and images saved.")

Fluorescence enhanced and images saved.


In [None]:
dataset_dir = "datasets"
output_dir = "paired_patches"
os.makedirs(output_dir, exist_ok=True)

def split_image(image_path, patch_size=(512, 512), downsample_size=(256, 256)):
    img = Image.open(image_path).convert("RGB")
    w, h = img.size
    patches = []
    for i in range(0, w, patch_size[0]):
        for j in range(0, h, patch_size[1]):
            if i + patch_size[0] <= w and j + patch_size[1] <= h:
                patch = img.crop((i, j, i + patch_size[0], j + patch_size[1]))
                patch = patch.resize(downsample_size, Image.LANCZOS)
                patches.append((i, j, patch))
    return patches


image_files = [f for f in os.listdir(dataset_dir) if f.endswith(".tif")]

for i in range(0, len(image_files), 12):
    batch = image_files[i:i+12]
    # input_images = batch[:11] # :11
    input_images = [img for img in batch if img.endswith("_0.5.tif")]
    output_image = batch[11]
    
    output_patches = split_image(os.path.join(dataset_dir, output_image))

    image_counter = 0
    for input_image in input_images:
        input_patches = split_image(os.path.join(dataset_dir, input_image))

        for idx, (output_patch, input_patch) in enumerate(zip(output_patches, input_patches)):
            out_x, out_y, out_patch = output_patch
            in_x, in_y, in_patch = input_patch

            in_patch = np.array(in_patch) / 127.5 - 1.0
            out_patch = np.array(out_patch) / 127.5 - 1.0
            
            paired_img = Image.new("RGB", (512, 256))
            in_patch = Image.fromarray((in_patch * 127.5 + 127.5).astype(np.uint8))
            out_patch = Image.fromarray((out_patch * 127.5 + 127.5).astype(np.uint8))
            
            paired_img.paste(in_patch, (0, 0))
            paired_img.paste(out_patch, (256, 0))

            patch_filename = f"{input_image[8:]}_{image_counter}.png"
            image_counter += 1

            paired_img.save(os.path.join(output_dir, patch_filename))

print("Images paired and saved")

Images paired and saved


In [16]:
dataset_dir, train_dir, val_dir = "paired_patches", "paired_patches/train", "paired_patches/test"
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)

image_files = [f for f in os.listdir(dataset_dir) if f.endswith(".png")]
random.shuffle(image_files)

train_size = int(0.85 * len(image_files))
for f in image_files[:train_size]: shutil.move(os.path.join(dataset_dir, f), os.path.join(train_dir, f))
for f in image_files[train_size:]: shutil.move(os.path.join(dataset_dir, f), os.path.join(val_dir, f))

print(f"Training set: {len(image_files[:train_size])} images, Validation set: {len(image_files[train_size:])} images.")


Training set: 34 images, Validation set: 6 images.
