# Deconvolution and PSF Estimation from the Richardson-Lucy Algorithm


## Imports & Environment Setup

### Imports

In [None]:
# Automatically re-import project-specific modules.
%load_ext autoreload
%autoreload 1

# imports
import collections
import functools
import io
import datetime
import time
import math
import itertools
import os
import shutil
import pathlib
import copy
import pdb
import inspect
import random
import subprocess
import sys
import warnings
from pathlib import Path
import typing
import zipfile

import ants
import dipy
import dipy.core
import dipy.reconst
import dipy.reconst.dti
import dipy.segment.mask
import dipy.viz
import dipy.viz.regtools
import dotenv

# visualization libraries
%matplotlib inline
import matplotlib as mpl
import mpl_toolkits
import matplotlib.pyplot as plt
import seaborn as sns

import IPython

# Try importing GPUtil for printing GPU specs.
# May not be installed if using CPU only.
try:
    import GPUtil
except ImportError:
    warnings.warn("WARNING: Package GPUtil not found, cannot print GPU specs")
from tabulate import tabulate
from IPython.display import display, Markdown
import ipyplot

# Data management libraries.
import nibabel as nib
import natsort
from natsort import natsorted
import box
from box import Box
import pprint
from pprint import pprint as ppr

# Computation & ML libraries.
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torchvision

import skimage
import skimage.feature
import skimage.filters
import skimage.measure
import scipy

plt.rcParams.update({"figure.autolayout": True})
plt.rcParams.update({"figure.facecolor": [1.0, 1.0, 1.0, 1.0]})

# Set print options for ndarrays/tensors.
np.set_printoptions(suppress=True, edgeitems=2, threshold=100, linewidth=88)
torch.set_printoptions(
    sci_mode=False, edgeitems=2, threshold=100, linewidth=88, profile="short"
)

In [None]:
# Update notebook's environment variables with direnv.
# This requires the python-dotenv package, and direnv be installed on the system
# This will not work on Windows.
# NOTE: This is kind of hacky, and not necessarily safe. Be careful...
# Libraries needed on the python side:
# - os
# - subprocess
# - io
# - dotenv

# Form command to be run in direnv's context. This command will print out
# all environment variables defined in the subprocess/sub-shell.
command = f"direnv exec {os.getcwd()} /usr/bin/env"
# Run command in a new subprocess.
proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True, cwd=os.getcwd())
# Store and format the subprocess' output.
proc_out = proc.communicate()[0].strip().decode("utf-8")
# Use python-dotenv to load the environment variables by using the output of
# 'direnv exec ...' as a 'dummy' .env file.
dotenv.load_dotenv(stream=io.StringIO(proc_out), override=True);

In [None]:
# Project-specific scripts
# It's easier to import it this way rather than make an entirely new package, due to
# conflicts with local packages and anaconda installations.
# You made me do this, poor python package management!!
if "PROJECT_ROOT" in os.environ:
    lib_location = str(Path(os.environ["PROJECT_ROOT"]).resolve())
else:
    lib_location = str(Path("../../../").resolve())
if lib_location not in sys.path:
    sys.path.insert(0, lib_location)
import lib as pitn

# Include the top-level lib module along with its submodules.
%aimport lib
# Grab all submodules of lib, not including modules outside of the package.
includes = list(
    filter(
        lambda m: m.startswith("lib."),
        map(lambda x: x[1].__name__, inspect.getmembers(pitn, inspect.ismodule)),
    )
)
# Run aimport magic with constructed includes.
ipy = IPython.get_ipython()
ipy.run_line_magic("aimport", ", ".join(includes))

In [None]:
# torch setup
# allow for CUDA usage, if available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
# keep device as the cpu
# device = torch.device('cpu')
print(device)

### Specs Recording

In [None]:
%%capture --no-stderr cap
# Capture output and save to log. Needs to be at the *very first* line of the cell.
# Watermark
%load_ext watermark
%watermark --author "Tyler Spears" --updated --iso8601  --python --machine --iversions --githash
if torch.cuda.is_available():

    # GPU information
    # Taken from
    # <https://www.thepythoncode.com/article/get-hardware-system-information-python>.
    # If GPUtil is not installed, skip this step.
    try:
        gpus = GPUtil.getGPUs()
        print("=" * 50, "GPU Specs", "=" * 50)
        list_gpus = []
        for gpu in gpus:
            # get the GPU id
            gpu_id = gpu.id
            # name of GPU
            gpu_name = gpu.name
            driver_version = gpu.driver
            cuda_version = torch.version.cuda
            # get total memory
            gpu_total_memory = f"{gpu.memoryTotal}MB"
            gpu_uuid = gpu.uuid
            list_gpus.append(
                (
                    gpu_id,
                    gpu_name,
                    driver_version,
                    cuda_version,
                    gpu_total_memory,
                    gpu_uuid,
                )
            )

        print(
            tabulate(
                list_gpus,
                headers=(
                    "id",
                    "Name",
                    "Driver Version",
                    "CUDA Version",
                    "Total Memory",
                    "uuid",
                ),
            )
        )
    except NameError:
        print("CUDA Version: ", torch.version.cuda)

else:
    print("CUDA not in use, falling back to CPU")

In [None]:
# cap is defined in an ipython magic command
print(cap)

Author: Tyler Spears

Last updated: 2021-11-30T18:21:05.225553+00:00

Python implementation: CPython
Python version       : 3.8.8
IPython version      : 7.23.1

Compiler    : GCC 7.3.0
OS          : Linux
Release     : 5.4.0-90-generic
Machine     : x86_64
Processor   : x86_64
CPU cores   : 8
Architecture: 64bit

Git hash: 0c2fd3efebc6346e7fa2ae08902e7677240f6ee5

skimage    : 0.18.1
torch      : 1.10.0
ipywidgets : 7.6.3
box        : 5.4.1
numpy      : 1.20.2
dipy       : 1.4.1
nibabel    : 3.2.1
seaborn    : 0.11.1
torchvision: 0.11.1
GPUtil     : 1.4.0
ants       : 0.2.7
sys        : 3.8.8 (default, Feb 24 2021, 21:46:12) 
[GCC 7.3.0]
scipy      : 1.5.3
matplotlib : 3.4.1
natsort    : 7.1.1
json       : 2.0.9
pandas     : 1.2.3
IPython    : 7.23.1

  id  Name              Driver Version      CUDA Version  Total Memory    uuid
----  ----------------  ----------------  --------------  --------------  ----------------------------------------
   0  NVIDIA TITAN RTX  470.82.00           

## Fashion MNIST Dataset

### Experiment Parameters

In [None]:
# Parameters
params = Box(default_box=True)

# Data params.
params.num_channels = 1

# Training and testing params.
params.batch_size = 128
params.num_epochs = 20
params.rel_tol = 1e-6

### Data Variables & Definitions Setup

In [None]:
# Set up data directories.
data_dir = pathlib.Path(os.environ["DATA_DIR"])
assert data_dir.exists()
fmnist_root_dir = data_dir / "fashion_mnist"

results_dir = pathlib.Path(os.environ["RESULTS_DIR"])
assert results_dir.exists()
tmp_results_dir = pathlib.Path(os.environ["TMP_RESULTS_DIR"])
assert tmp_results_dir.exists()

In [None]:
train_data = torchvision.datasets.FashionMNIST(
    str(fmnist_root_dir),
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor(),
)
test_data = torchvision.datasets.FashionMNIST(
    str(fmnist_root_dir),
    train=False,
    download=True,
    transform=torchvision.transforms.ToTensor(),
)

In [None]:
plt.imshow(
    skimage.transform.downscale_local_mean(np.asarray(test_data[6][0][0]), (2, 2)),
    cmap="gray",
)
plt.show()
plt.imshow(np.asarray(test_data[6][0][0]), cmap="gray")

In [None]:
# Set up dataloaders.
train_dataloader = torch.utils.data.DataLoader(
    train_data, batch_size=params.batch_size, shuffle=True
)

In [None]:
def lr_psf_update_coeff(psf, img_source, img_distort, conv_stride=2):
    """Update coefficient for the PSF, batched.

    psf - shape out_channels x in_channels x H x W
    img_source - shape batch x in_channels x H x W
    img_distort - similar shape to img_source, but reduced as needed according to
        conv_stride and the PSF/kernel size.
    """

    # pytorch "convolution" is cross correlation in a signal processing context.
    cross_corr = F.conv2d

    psf_kernel_size = tuple(psf.shape[-2:])
    # Use convolution by transposing the kernel.
    actual_vs_pred = img_distort / (
        cross_corr(img_source, torch.transpose(psf, -2, -1), stride=conv_stride) + 1e-5
    )

    batch_size = img_source.shape[0]
    img_source = img_source.reshape(1, batch_size, *psf_kernel_size)
    # print(img_source.shape)
    # print(img_distort.shape)
    # print(actual_vs_pred.shape)
    grad = cross_corr(
        img_source, actual_vs_pred, groups=batch_size
    )  # , stride=conv_stride)
    grad = grad.reshape(batch_size, 1, *psf_kernel_size)
    # print(grad.shape)

    return grad

In [None]:
# Train & test loop.
psf_shape = (3, 3)
init_psf = torch.rand(1, 1, *psf_shape)

# Track error during training.
error_over_epochs = dict()

tol_reached = False
psf = init_psf
with torch.no_grad():
    for epoch in range(params.num_epochs):
        error_over_epochs[epoch] = list()
        for batch in train_dataloader:
            source_img = batch[0]

            source_patch = F.unfold(source_img, kernel_size=psf_shape, stride=psf_shape)
            source_patch = source_patch.view(-1, 1, *psf_shape)

            distort_patch = F.interpolate(
                source_patch,
                scale_factor=(1 / 2, 1 / 2),
                mode="area",
                recompute_scale_factor=False,
            )
            # distort_patch = F.pad(distort_patch, (1, 1, 1, 1))
            psf_update = lr_psf_update_coeff(
                psf, source_patch, distort_patch, conv_stride=2
            )
            psf_ip1 = list(functools.reduce(lambda x, y: x * y, psf_update, psf[0]))[-1]
            psf_ip1 = psf_ip1.view(*psf.shape)
            # avg_psf_update = psf_update.mean(dim=(0, 1))[None, ]
            # psf_ip1 = psf * avg_psf_update

            try:
                mse = F.mse_loss(psf, psf_ip1).item()
                if np.isnan(mse) or np.isinf(mse):
                    raise RuntimeError(f"ERROR: Invalid MSE value {mse}")
            except RuntimeError as e:
                print("PSF_i and PSF_i+1 MSE: ", mse)
                plt.matshow(psf[0, 0].cpu().numpy())
                plt.axis("off")
                plt.colorbar()
                raise e
                # break
            error_over_epochs[epoch].append(mse)

            psf = psf_ip1

plt.matshow(psf[0, 0].cpu().numpy())
plt.axis("off")
plt.colorbar()

### Autograd Deconvolution Kernel with SGD and Pytorch

In [None]:
# x = torch.randn(3, 4, 1, requires_grad=True)
# z = torch.randn(3, 4, 1, requires_grad=False)
# sgd = torch.optim.SGD([x], 0.01)
# y = x / z
# old_x = x.clone()
# loss = y.sum()
# sgd.zero_grad()
# loss.backward()
# sgd.step()

# y = x / z
# loss = y.sum()
# sgd.zero_grad()
# loss.backward()
# sgd.step()
# print((x - old_x)[0, 0].item())

In [None]:
# Train & test loop.
psf_shape = (5, 5)
psf = torch.rand(1, 1, *psf_shape, requires_grad=True)
init_psf = psf.clone()
optim = torch.optim.SGD(
    [
        psf,
    ],
    lr=0.01,
)
print(psf)
# Track error during training.
error_over_epochs = dict()

tol_reached = False
count = 0
for epoch in range(params.num_epochs):
    error_over_epochs[epoch] = list()
    for batch in train_dataloader:
        source_img = batch[0]

        distort_img = F.interpolate(
            source_img,
            scale_factor=(1 / 2, 1 / 2),
            mode="area",
            recompute_scale_factor=False,
        )
        distort_img.requires_grad_(False)

        padded_distorted = F.interpolate(
            distort_img,
            scale_factor=(2, 2),
            mode="nearest",
            recompute_scale_factor=False,
        )
        # padded_distorted = torch.zeros_like(source_img)
        # padded_distorted[..., ::2, ::2] = distort_img
        recovered = F.conv2d(F.pad(padded_distorted, (2, 2, 2, 2)), psf)
        loss = F.mse_loss(recovered, source_img)
        # Update the PSF with gradients.
        optim.zero_grad()
        loss.backward()
        # print(loss)
        optim.step()

        error_over_epochs[epoch].append(loss.item())
        # if count >= 20:
        #     break

        # source_patch = F.unfold(source_img, kernel_size=psf_shape, stride=psf_shape)
        # source_patch = source_patch.view(-1, 1, *psf_shape)
print(psf)
print(torch.abs(psf - init_psf))

In [None]:
plt.plot(list(itertools.chain.from_iterable(error_over_epochs.values())))
print(torch.median(torch.as_tensor(list(itertools.chain.from_iterable(error_over_epochs.values())))).item())

In [None]:
test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=params.batch_size, shuffle=False
)
viz_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True)

losses = list()
with torch.no_grad():

    viz_test_samples = Box(source=list(), observed=list(), pred=list())
    num_viz_samples = 7
    i_viz_sample = 0
    for batch in viz_loader:
        source_img = batch[0]

        distort_img = F.interpolate(
            source_img,
            scale_factor=(1 / 2, 1 / 2),
            mode="area",
            recompute_scale_factor=False,
        )

        padded_distorted = F.interpolate(
            distort_img,
            scale_factor=(2, 2),
            mode="nearest",
            recompute_scale_factor=False,
        )
        recovered = F.conv2d(F.pad(padded_distorted, (2, 2, 2, 2)), psf)

        viz_test_samples.source.append(source_img[0])
        viz_test_samples.observed.append(
            F.interpolate(
                distort_img,
                scale_factor=(2, 2),
                mode="nearest",
                recompute_scale_factor=False,
            )[0]
        )
        viz_test_samples.pred.append(recovered[0])

        i_viz_sample += 1
        if i_viz_sample >= num_viz_samples:
            break

    for batch in test_loader:
        source_img = batch[0]

        distort_img = F.interpolate(
            source_img,
            scale_factor=(1 / 2, 1 / 2),
            mode="area",
            recompute_scale_factor=False,
        )

        padded_distorted = F.interpolate(
            distort_img,
            scale_factor=(2, 2),
            mode="nearest",
            recompute_scale_factor=False,
        )
        recovered = F.conv2d(F.pad(padded_distorted, (2, 2, 2, 2)), psf)
        loss = F.mse_loss(recovered, source_img)

        losses.append(loss)

total_loss = torch.as_tensor(losses).mean()
print(total_loss.item())

In [None]:
viz_grid = torchvision.utils.make_grid(
    list(
        itertools.chain.from_iterable(
            [viz_test_samples.source, viz_test_samples.observed, viz_test_samples.pred]
        )
    ),
    nrow=7,
)
plt.figure(dpi=120)

# plt.imshow(viz_grid.T, cmap='gray')
plt.imshow(torch.movedim(viz_grid, 0, -1).cpu().numpy(), cmap="gray")
plt.axis("off")

---

In [None]:
# import numpy as np
from scipy import misc
from scipy import fftpack

# import matplotlib.pyplot as plt

img = misc.face()[:, :, 0]
print(img.shape)
kernel = np.ones((21, 21)) / 21 ** 2
print(kernel.shape)
sz = (
    img.shape[0] - kernel.shape[0],
    img.shape[1] - kernel.shape[1],
)  # total amount of padding
print(sz)
kernel = np.pad(
    kernel, (((sz[0] + 1) // 2, sz[0] // 2), ((sz[1] + 1) // 2, sz[1] // 2)), "constant"
)
print(kernel.shape)
kernel = fftpack.ifftshift(kernel)
print(kernel.shape)
filtered = np.real(fftpack.ifft2(fftpack.fft2(img) * fftpack.fft2(kernel)))
plt.imshow(filtered, vmin=0, vmax=255)
plt.show()
plt.imshow(img)
plt.show()

In [None]:
size = 256
print(size ** 2)
noise = np.random.randn(size, size)
plt.matshow(noise)
plt.colorbar()
plt.show()

N = fftpack.fft2(noise)
plt.matshow(np.abs(N) ** 2)
plt.colorbar()
plt.show()

np.mean(np.abs(N) ** 2)

In [None]:
print(np.abs(N))
print(np.sqrt(N.real ** 2 + N.imag ** 2))