In [1]:
import numpy as np
import cupy as cp
import os
from bioio import BioImage
from skimage.io import
import tifffile
import timeit
from matplotlib import pyplot as plt

seed = 363
rng = np.random.default_rng(seed)

def show_mips(image, recon):
  plt.figure(figsize=(20, 10))
  plt.subplot(1, 2, 1)
  plt.imshow(np.max(image.get(), axis=0), cmap='gray')
  plt.axis('off')
  plt.title('Original image MIP')
  plt.subplot(1, 2, 2)
  plt.imshow(np.max(recon, axis=0), cmap='gray')
  plt.axis('off')
  plt.title('Deconvolved image MIP')
  plt.show()

In [5]:
tg_image_path = r'D:\images\tnia-python-images\tg\2024_10_22_ts_decon'
tg_image_path = r'/home/bnorthan/images/tnia-python-images/tg/'
image_path = r'_small_data_\[AS-00304]\1x1_FOVs'
image_path = r'_small_data_/[AS-00304]/1x1_FOVs/'
image_name = r'Slide 1 - A02 - ROI 01_1x1_FOVs - DAPI.tif'

#image_path = r'_small_data_\[AS-00344]\3x3_FOVs'
#image_name = r'Slide 1 - Region 001 - ROI 02 - DAPI.tif'
#image_name = r'Slide 1 - Region 001 - ROI 02 - Cy 5.tif'
#image_name = r'Slide 1 - Region 001 - ROI 02 - FITC.tif'

full_name = os.path.join(tg_image_path, image_path, image_name)

try:
  bioimg = BioImage(full_name)

  image = np.squeeze(bioimg.data)
except:
  image = imread(full_name)
  
# Add new z-axis if we have 2D data
if image.ndim == 2:
  image = np.expand_dims(image, axis=0)

print(image.shape)

(1, 995, 1334)


In [9]:
from tnia.deconvolution.gaussian_psf import gaussian_2d
sigma = 2.0
psf_temp = gaussian_2d(255,sigma)

# Add new z-axis if we have 2D data
if psf_temp.ndim == 2:
  psf_temp = np.expand_dims(psf_temp, axis=0)

# Pad to size of image
psf = np.zeros(image.shape)
psf[:psf_temp.shape[0], :psf_temp.shape[1], :psf_temp.shape[2]] = psf_temp
for axis, axis_size in enumerate(psf.shape):
  psf = np.roll(psf, int(axis_size / 2), axis=axis)
for axis, axis_size in enumerate(psf_temp.shape):
  psf = np.roll(psf, -int(axis_size / 2), axis=axis)
psf = np.fft.ifftshift(psf)
psf = psf / np.sum(psf)

In [10]:
def fftconv(x, H):
	return cp.real(cp.fft.ifftn(cp.fft.fftn(x) * H))

def kldiv(p, q, mask=None):
  mask = p < mask_val
  p = p + 1E-4
  q = q + 1E-4
  p = p / cp.sum(p)
  q = q / cp.sum(q)
  kldiv = p * (cp.log(p) - cp.log(q))
  kldiv[cp.isnan(kldiv)] = 0
  if mask is not None:
    kldiv[mask] = 0
  kldiv = cp.mean(kldiv)
  return kldiv

def kldiv_org(p, q):
	p = p + 1E-4
	q = q + 1E-4
	p = p / cp.sum(p)
	q = q / cp.sum(q)
	kldiv = p * (cp.log(p) - cp.log(q))
	kldiv[cp.isnan(kldiv)] = 0
	kldiv = cp.sum(kldiv)
	return kldiv


# Load data and PSF onto GPU
image = cp.array(image, dtype=cp.float32)
psf = cp.array(psf, dtype=cp.float32)

# Calculate OTF and transpose
otf = cp.fft.fftn(psf)
otfT = cp.conjugate(otf)
del psf

# Get dimensions of data
num_z = image.shape[0]
num_y = image.shape[1]
num_x = image.shape[2]
num_pixels = num_z * num_y * num_x

# Calculate Richardson-Lucy iterations
HTones = fftconv(cp.ones_like(image), otfT)
recon = cp.mean(image) * cp.ones((num_z, num_y, num_x), dtype=cp.float32)
previous_recon = recon

num_iters = 0
prev_kldim = np.inf
prev_kld1 = np.inf
prev_kld2 = np.inf
start_time = timeit.default_timer()

#while True:
for i in range(100):
  iter_start_time = timeit.default_timer()

  # Split recorded image into 50:50 images
  # TODO: make this work on the GPU (for some reason, we get repeating blocks with a naive conversion to cupy)
  split1 = rng.binomial(image.get().astype('int64'), p=0.5)
  split1 = cp.array(split1)
  split2 = image - split1

  # Calculate prediction
  Hu = fftconv(recon, otf)

  # Calculate KL divergences and stop iterations if both have increased
  mask_val = 100
  print(f"Hu min: {cp.min(Hu)}, max: {cp.max(Hu)}") 
  kldim = kldiv(Hu, image, mask_val)
  kld1 = kldiv(Hu, split1, mask_val)
  kld2 = kldiv(Hu, split2, mask_val)
  if ((kld1 > prev_kld1) & (kld2 > prev_kld2)):
    recon = previous_recon
    print("Optimum result obtained after %d iterations with a total time of %1.1f seconds." % (num_iters - 1, timeit.default_timer() - start_time))
    #break
  del previous_recon
  prev_kldim = kldim
  prev_kld1 = kld1
  prev_kld2 = kld2

  # Calculate updates for split images and full images (H^T (d / Hu))
  HTratio1 = fftconv(split1 / (0.5 * (Hu + 1E-12)), otfT) / HTones
  del split1
  HTratio2 = fftconv(split2 / (0.5 * (Hu + 1E-12)), otfT) / HTones
  del split2
  HTratio = fftconv(image / (Hu + 1E-12), otfT) / HTones
  del Hu

  # Normalise update steps by H^T(1) and only update pixels in full estimate where split updates agree in 'sign'
  shouldNotUpdate = (HTratio1 - 1) * (HTratio2 - 1) < 0
  del HTratio1
  del HTratio2
  HTratio[shouldNotUpdate] = 1
  num_updated = num_pixels - cp.sum(shouldNotUpdate)
  del shouldNotUpdate

  # Save previous estimate in case KLDs increase after this iteration
  previous_recon = recon

  # Update estimate
  recon = recon * HTratio
  min_HTratio = cp.min(HTratio)
  max_HTratio = cp.max(HTratio)
  max_relative_delta = cp.max((recon - previous_recon) / cp.max(recon))
  del HTratio

  calc_time = timeit.default_timer() - iter_start_time
  print("Iteration %03d completed in %1.3f s. KLDs = %1.4f (image), %1.4f (split 1), %1.4f (split 2). %1.2f %% of image updated. Update range: %1.2f to %1.2f. Largest relative delta = %1.5f." % (num_iters + 1, calc_time, kldim, kld1, kld2, 100 * num_updated / num_pixels, min_HTratio, max_HTratio, max_relative_delta))


  num_iters = num_iters + 1


Hu min: 15.325589179992676, max: 15.325614929199219
Iteration 001 completed in 0.117 s. KLDs = 0.0000 (image), 0.0000 (split 1), 0.0000 (split 2). 97.55 % of image updated. Update range: 0.21 to 16.64. Largest relative delta = 0.93990.
Hu min: 3.3016974925994873, max: 254.86245727539062
Iteration 002 completed in 0.113 s. KLDs = -0.0000 (image), -0.0000 (split 1), -0.0000 (split 2). 19.42 % of image updated. Update range: 0.49 to 1.24. Largest relative delta = 0.13789.
Hu min: 3.138530969619751, max: 268.2447204589844
Optimum result obtained after 1 iterations with a total time of 0.3 seconds.
Iteration 003 completed in 0.114 s. KLDs = -0.0000 (image), -0.0000 (split 1), -0.0000 (split 2). 14.27 % of image updated. Update range: 0.63 to 1.15. Largest relative delta = 0.07481.
Hu min: 3.261080026626587, max: 251.4672088623047
Iteration 004 completed in 0.114 s. KLDs = -0.0000 (image), -0.0000 (split 1), -0.0000 (split 2). 18.07 % of image updated. Update range: 0.56 to 1.21. Largest rel

In [11]:
import napari
viewer=napari.Viewer()

viewer.add_image(image.get(), name='image')
viewer.add_image(recon.get(), name='recon')

<Image layer 'recon' at 0x7e03629575e0>