## Deconvolution iwith Unrolled ADMM

This notebook is a tutorial on deconvolving galaxy images using the suggested unrolled ADMM model.

In [None]:
import sys
sys.path.append('../')
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.fft import fft2, ifft2, fftshift, ifftshift
import galsim
from models.Unrolled_ADMM import Unrolled_ADMM
from utils.utils import PSNR, estimate_shear

%matplotlib inline

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Read in Images

In [None]:
psf = torch.load("psf_23.5_0.pth").unsqueeze(0)
obs = torch.load("obs_23.5_0.pth").unsqueeze(0)
gt = torch.load("gt_23.5_0.pth").unsqueeze(0)


obs = (obs - obs.min())/(obs.max() - obs.min())
gt = (gt - gt.min())/(gt.max() - gt.min())

## Load in Trained Model

In [None]:
n_iters = 8                 # number of iterations
model_file = f'../saved_models/Poisson_PnP_{n_iters}iters_LSST23.5_50epochs.pth'

model = Unrolled_ADMM(n_iters=n_iters, llh='Poisson', PnP=True)
model.to(device)
model.eval()
model.load_state_dict(torch.load(model_file, map_location=torch.device(device)))
print(f'Successfully loaded in {model_file}.')

## Deconvolution

In [None]:
# calculate average photon level
alpha = obs.ravel().mean()
alpha = torch.Tensor(alpha.float()).view(1,1,1,1)

output = model(obs.unsqueeze(dim=0).unsqueeze(dim=0).to(device), psf.unsqueeze(dim=0).unsqueeze(dim=0).to(device), alpha.to(device))
rec = (output.cpu() * alpha.cpu()).squeeze(dim=0).squeeze(dim=0).detach().numpy()

## Visualization

In [None]:
cmap = 'magma'
plt.figure(figsize=(10,10))
plt.subplot(2,2,1)
plt.imshow(obs, cmap=cmap)
plt.title('Observed Galaxy')
plt.subplot(2,2,2)
plt.imshow(psf, cmap=cmap)
plt.title('PSF')
plt.subplot(2,2,3)
plt.imshow(gt, cmap=cmap)
plt.title('Ground Truth')
plt.subplot(2,2,4)
plt.imshow(rec, cmap=cmap)
plt.title('Deconvolved Galaxy')