### PnP ADMM

In [64]:
import numpy as np
from Unet import Unet
from pnp import pnp_admm
from utils_pnp import conv2d_from_kernel, compute_psnr, ImagenetDataset, myplot
import torch
from torch import nn
from torch.utils.data import DataLoader
from dataset import Galaxy_Dataset
import matplotlib.pyplot as plt

%matplotlib inline

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Train Unet

In [65]:
def conv2d_from_kernel(kernel, channels, device, stride=1):
    """
    Returns nn.Conv2d and nn.ConvTranspose2d modules from 2D kernel, such that 
    nn.ConvTranspose2d is the adjoint operator of nn.Conv2d
    Arg:
        kernel: 2D kernel
        channels: number of image channels
    """
    kernel_size = kernel.shape
    kernel = kernel/kernel.sum()
    kernel = kernel.repeat(channels, 1, 1, 1)
    filter = nn.Conv2d(
        in_channels=channels, out_channels=channels,
        kernel_size=kernel_size, groups=channels, bias=False, stride=stride,
        # padding=kernel_size//2
    )
    filter.weight.data = kernel
    filter.weight.requires_grad = False

    filter_adjoint = nn.ConvTranspose2d(
        in_channels=channels, out_channels=channels,
        kernel_size=kernel_size, groups=channels, bias=False, stride=stride,
        # padding=kernel_size//2,
    )
    filter_adjoint.weight.data = kernel
    filter_adjoint.weight.requires_grad = False

    return filter.to(device), filter_adjoint.to(device)

class PnP_ADMM(nn.Module):
    def __init__(self, n_iter):
        super(PnP_ADMM, self).__init__()
        self.n_iter = n_iter
    
    def forward(self, x, kernel):
        x = torch.zeros_like(x_h)
        u = torch.zeros_like(x)
        v = torch.zeros_like(x)
        for _ in range(self.n_iter):
            b = cg_rightside(v-u)
            x = conjugate_gradient(cg_leftside, b, x, max_cgiter, cg_tol)
            v = denoiser(x+u)
            u += (x - v)
        return v


### Deconvolution

In [66]:
Dataset = Galaxy_Dataset(data_path='/Users/luke/Desktop/Galaxy-Deconvolution/dataset/',
                         COSMOS_path='/Users/luke/Desktop/Galaxy-Deconvolution/data/')
Data_loader = DataLoader(Dataset, batch_size=1, shuffle=False)
(obs, psf, M), gt = Dataset[4]

test_image = obs.repeat(1,3,1,1).to(device)
# channels, h, w = test_image.shape
forward, forward_adjoint = conv2d_from_kernel(psf, 3, device)

model = Unet(3, 3, chans=64).to(device)
model.load_state_dict(torch.load('denoiser.pth', map_location=device))
print('#Parameters:', sum(p.numel() for p in model.parameters() if p.requires_grad))

# Run plug and play
# y = forward(test_image)
y = test_image
with torch.no_grad():
    model.eval()
    x = pnp_admm(y, forward, forward_adjoint, model, num_iter=50)
    x = x.clip(0,1)

# Plot
# print('PSNR [dB]: {:.2f}'.format(compute_psnr(x, test_image)))
# myplot(obs, x[:,0,:,:], gt)
plt.figure(figsize=(13,6))
plt.subplot(1,3,1)
plt.title('Observation')
plt.imshow(obs.squeeze(dim=0).squeeze(dim=0).cpu())
plt.subplot(1,3,2)
plt.title('Reconstruction')
plt.imshow(x.squeeze(dim=0)[0,:,:].cpu())
plt.subplot(1,3,3)
plt.title('Ground Truth')
plt.imshow(gt.squeeze(dim=0).squeeze(dim=0).cpu())
plt.show()

#Parameters: 31025027


KeyboardInterrupt: 