In [1]:
##### kernel: lan_1118

import torch
import numpy as np

def render_image_at_depth_gpu(A, U, V, lambda_, f_obj, f_eye, lambda_oil, lambda_coverslip, lambda_sample, z, NA, padding_factor, N_obj_sam):
    """
    Simulates the amplitude of an image at a specific depth using GPU acceleration with PyTorch.
    """
    device = A.device  # Ensure computations are done on the same device as input tensor
    
    # Thin lens propagation transfer function (objective lens)
    H_obj = torch.exp(-1j * np.pi * (U**2 + V**2) * lambda_ * z / f_obj)
    # Thin lens propagation transfer function (eyepiece)
    H_eye = torch.exp(-1j * np.pi * (U**2 + V**2) * lambda_ * z / f_eye)
    # Fresnel diffraction transfer function through oil immersion
    H_oil = torch.exp(1j * np.pi * lambda_oil * z * (U**2 + V**2))
    # Fresnel diffraction transfer function through coverslip
    H_coverslip = torch.exp(1j * np.pi * lambda_coverslip * z * (U**2 + V**2))
    # Fresnel diffraction transfer function through sample medium
    H_sample = torch.exp(1j * np.pi * lambda_sample * z * (U**2 + V**2))
    
    # Combined transfer function with NA cutoff
    H = H_obj * H_eye * H_oil * H_coverslip * H_sample
    H[torch.sqrt(U**2 + V**2) > NA / lambda_oil] = 0  # Apply cutoff based on numerical aperture
    
    # Zero-padding before FFT
    padded_size = round(padding_factor * N_obj_sam)
    pad_x = (padded_size - N_obj_sam) // 2
    pad_y = (padded_size - N_obj_sam) // 2
    A_padded = torch.nn.functional.pad(A, (pad_y, pad_y, pad_x, pad_x), mode='constant', value=0)
    H_padded = torch.nn.functional.pad(H, (pad_y, pad_y, pad_x, pad_x), mode='constant', value=0)
    
    # Compute the propagated field
    A_fft = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(A_padded)))
    B_fft = A_fft * H_padded
    B = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(B_fft)))
    
    # # Verify energy conservation (Parseval's theorem)
    # energy_before = torch.sum(torch.abs(A_padded)**2).item()
    # energy_after = torch.sum(torch.abs(B)**2).item()
    # print(f'Energy before: {energy_before:.6f}, Energy after: {energy_after:.6f}')
    
    # Crop and accumulate the squared amplitude of each depth slice
    crop_start_x = pad_x
    crop_end_x = crop_start_x + N_obj_sam
    crop_start_y = pad_y
    crop_end_y = crop_start_y + N_obj_sam
    B_cropped = B[crop_start_x:crop_end_x, crop_start_y:crop_end_y]
    amplitude = torch.abs(B_cropped)**2
    
    return amplitude


In [2]:
import os
import time
import numpy as np
import torch
import matplotlib.pyplot as plt
from skimage.io import imread, imsave
from skimage.color import rgb2gray
from torch import tensor
from skimage.util import img_as_ubyte
from torchvision import transforms

# Set new size for the cropped images
crop_size = 500  # Variable for cropping size; smaller sizes increase rendering speed.
# Define ranges
start = torch.linspace(0, 1e-5, 20)
seg1 = torch.linspace(1e-5, -1e-5, 40)
seg2 = torch.linspace(-1e-5, 1e-5, 40)
seg3 = torch.linspace(1e-5, -1e-5, 40)
seg4 = torch.linspace(-1e-5, 1e-5, 40)
seg5 = torch.linspace(1e-5, -1e-5, 40)
seg6 = torch.linspace(-1e-5, 1e-5, 40)
seg7 = torch.linspace(1e-5, -1e-5, 40)
seg8 = torch.linspace(-1e-5, 1e-5, 40)
seg9 = torch.linspace(1e-5, -1e-5, 40)
seg10 = torch.linspace(-1e-5, 1e-5, 40)
seg11 = torch.linspace(1e-5, -1e-5, 40)
seg12 = torch.linspace(-1e-5, 1e-5, 40)
seg13 = torch.linspace(1e-5, -1e-5, 40)
seg14 = torch.linspace(-1e-5, 1e-5, 40)
end = torch.linspace(1e-5, 0, 20)
# Concatenate the segments
dz = torch.cat((start, seg1, seg2, seg3, seg4, seg5, seg6, seg7, seg8, seg9, seg10, seg11, seg12, seg13, seg14, end))

# Precompute some common parameters
lambda_ = 632.8e-6  # mm
NA = 1.45  # Numerical Aperture for oil immersion
f_obj = 50e-3  # mm, objective focal length
f_eye = 20e-3  # mm, eyepiece focal length
M = 100  # Total magnification
pixel_size = 6.5e-3  # mm
N_obj_sam = crop_size
padding_factor = 2
plotW = 4  # mm
n_oil = 1.515  # Refractive index for immersion oil
n_coverslip = 1.515  # Refractive index for coverslip
n_sample = 1.33  # Refractive index for sample medium
lambda_oil = lambda_ / n_oil
lambda_coverslip = lambda_ / n_coverslip
lambda_sample = lambda_ / n_sample
dx = pixel_size / M
du = 1 / (N_obj_sam * dx)

# Frequency grid
u = torch.linspace(-N_obj_sam / 2, N_obj_sam / 2 - 1, N_obj_sam) * du
U, V = torch.meshgrid(u, u, indexing='ij')  # indexing='ij' for MATLAB-like behavior
U = U.to('cuda')  # Use GPU
V = V.to('cuda')  # Use GPU

# Get folder list
folder_list = [f for f in os.listdir() if os.path.isdir(f) and f.startswith("rgb")]

center_crop = transforms.CenterCrop(350)


for folder_name in folder_list:
    print(folder_name)
    
    # Get image files in folder
    image_files = [f for f in os.listdir(folder_name) if f.startswith("rgb_image") and f.endswith(".png")]
    num_images = len(image_files)

    for image_file in image_files:
        # Load and preprocess the image
        image_path = os.path.join(folder_name, image_file)
        A = imread(image_path)
        A = rgb2gray(A)
        A = torch.tensor(A, dtype=torch.float32, device='cuda')

        # Crop the image to the center
        height, width = A.shape
        start_row = (height - crop_size) // 2
        start_col = (width - crop_size) // 2
        A = A[start_row:start_row + crop_size, start_col:start_col + crop_size]

        # Initialize final_image
        final_image = torch.zeros((N_obj_sam, N_obj_sam), device='cuda')

        idx = 1
        for z in dz:
            amplitude = render_image_at_depth_gpu(A, U, V, lambda_, f_obj, f_eye, lambda_oil, lambda_coverslip, lambda_sample, z, NA, padding_factor, N_obj_sam)

            # Accumulate energy
            final_image += amplitude

            # Standardize and save each depth slice amplitude
            amplitude_standardized = (amplitude - amplitude.min()) / (amplitude.max() - amplitude.min())
            amplitude_standardized = center_crop(amplitude_standardized)
            amplitude_standardized_uint8 = img_as_ubyte(amplitude_standardized.cpu().numpy())  # Convert to uint8
            # imsave(os.path.join(output_folder, f"amplitude_depth_{z.item():.6f}.png"), amplitude_standardized_uint8)

            part1 = folder_name[-4:-2]  # Extract "P10"
            part2 = folder_name[-2:]    # Extract "R20"
            type_name = f"P{part1}_R{part2}"
            output_folder = os.path.join("Output", type_name)
            os.makedirs(output_folder, exist_ok=True)
            imsave(os.path.join(output_folder, f"Image_{idx}.png"), amplitude_standardized_uint8)

            idx += 1


rgbFourBall0000
rgbFourBall0010
rgbFourBall0020
rgbFourBall0030
rgbFourBall0040
rgbFourBall0050
rgbFourBall0060
rgbFourBall0070
rgbFourBall1010
rgbFourBall1020
rgbFourBall1030
rgbFourBall1040
rgbFourBall1050
rgbFourBall1060
rgbFourBall1070
rgbFourBall2020
rgbFourBall2030
rgbFourBall2040
rgbFourBall2050
rgbFourBall2060
rgbFourBall2070
rgbFourBall3030
rgbFourBall3040
rgbFourBall3050
rgbFourBall3060
rgbFourBall3070
rgbFourBall4040
rgbFourBall4050
rgbFourBall4060
rgbFourBall4070
rgbFourBall5050
rgbFourBall5060
rgbFourBall5070
rgbFourBall6060
rgbFourBall6070
rgbFourBall7070


In [6]:
output_folder

'rgbFourBall6070/output'