# This is an implementation of the iLPD method

In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import numpy as np
np.random.seed(42)
import torch
from torch import nn
import torch.optim
import tensorboardX
import memcnn

import odl
from odl.contrib import fom
from odl.contrib import torch as odl_torch

from mayo_generator import generate_mayo

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
dtype = torch.cuda.FloatTensor

name = 'invertible'

### Create ODL data structures to define the Ray Transform operator

In [None]:
size = 512
xlim = 128
space = odl.uniform_discr([-xlim]*2, [xlim]*2, [size]*2, dtype='float32')
angle_partition = odl.uniform_partition(0, 2 * np.pi, 1000)
detector_partition = odl.uniform_partition(-360, 360, 1000)
geometry = odl.tomo.FanBeamGeometry(angle_partition, detector_partition,src_radius=500, det_radius=500)
operator = odl.tomo.RayTransform(space, geometry)

# transform the operator and it's adjoint into pytorch modules
pt_op = odl_torch.OperatorModule(operator)
pt_op_adj = odl_torch.OperatorModule(operator.adjoint)

# after each application of the operator, divide the result by it's norm
# for numerical stability
opnorm = operator.norm(estimate=True)
print("Norm of the operator:", opnorm) 

### Load the data

In [None]:
ppp = 10000 # number of photons per pixel determines the amount of noise in the projection data

# create test and validation images once, after that load from disk
if not os.path.exists('test_image_full_noisy.npy'):
    test_image, test_data = generate_mayo(operator, 'test', 210, photons_per_pixel=ppp).__next__()
    np.save("test_image_full_noisy.npy", test_image)
    np.save("test_data_full_noisy.npy", test_data)

val_ration = 0.01
if not os.path.exists('val_image_full_noisy.npy'):
    val_image, val_data = generate_mayo(operator, 'validate', 21, val_ration, photons_per_pixel=ppp).__next__()
    np.save("val_image_full_noisy.npy", val_image)
    np.save("val_data_full_noisy.npy", val_data)

val_image = np.load("val_image_full_noisy.npy")
val_data = np.load("val_data_full_noisy.npy")
val_image_pt = torch.from_numpy(val_image).type(dtype)
val_data_pt = torch.from_numpy(val_data).type(dtype)

## Definition of the iLPD architecture

In [None]:
n_data = 1 # batch size
n_iter = 20 # number of unrolled iterations
# number of primal and dual memory channels
n_primal = 1 
n_dual = 1


class Iteration(nn.Module):
    def __init__(self):
        super().__init__()
        self.op = pt_op
        self.op_adj = pt_op_adj
        self.filters = 32

        self.primalblock = nn.Sequential(
            nn.Conv2d(n_primal, self.filters, 3, padding=1),
            nn.PReLU(num_parameters=self.filters, init=0.0),
            nn.Conv2d(self.filters, self.filters, 3, padding=1),
            nn.PReLU(num_parameters=self.filters, init=0.0),
            nn.Conv2d(self.filters, n_primal, 3, padding=1))
        
        self.dualblock = nn.Sequential(
            nn.Conv2d(n_dual + 1, self.filters, 3, padding=1),
            nn.PReLU(num_parameters=self.filters, init=0.0),
            nn.Conv2d(self.filters, self.filters, 3, padding=1),
            nn.PReLU(num_parameters=self.filters, init=0.0),
            nn.Conv2d(self.filters, n_dual, 3, padding=1))
            

    def forward(self, primal, dual, y):
        
        # dual block
        evalop = self.op(primal) / opnorm
        inp = torch.cat([evalop, y / opnorm ], dim=1)
        dual = dual + self.dualblock(inp)
        
        # primal block
        evalop = self.op_adj(dual) / opnorm
        inp = evalop
        primal = primal + self.primalblock(inp)

        return primal, dual, y
    
    
    def inverse(self, primal, dual, y):
        
        # primal block
        evalop = self.op_adj(dual) / opnorm
        inp = evalop
        primal = primal - self.primalblock(inp)
        
        # dual block
        evalop = self.op(primal) / opnorm
        inp = torch.cat([evalop, y / opnorm ], dim=1)
        dual = dual - self.dualblock(inp)
        return primal, dual, y
    
class IterativeNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss = torch.nn.MSELoss()
        
        for i in range(n_iter):
            iteration = Iteration()
            inv_iteration = memcnn.InvertibleModuleWrapper(fn=iteration,
                                                           keep_input=True, 
                                                           keep_input_inverse=True)
            setattr(self, 'iteration_{}'.format(i), inv_iteration)

    def forward(self, y, true):
        im_shape = true.shape[2:]
        data_shape = y.shape[2:]
        primal = torch.zeros((n_data, n_primal) + im_shape).type(dtype)
        dual = torch.zeros((n_data, n_dual) + data_shape).type(dtype)
        
        for i in range(n_iter):
            iteration = getattr(self, 'iteration_{}'.format(i))
            primal, dual, y = iteration(primal, dual, y)
        
        res = primal[:, 0:1, ...]
        loss = self.loss(res, true)
            
        return res, loss
    
    def inv_module_eval(self):
        for i in range(n_iter):
            iteration = getattr(self, 'iteration_{}'.format(i))
            iteration.num_bwd_passes = 0 # number of backward passes
            iteration.eval()
            
    def inv_module_train(self):
        for i in range(n_iter):
            iteration = getattr(self, 'iteration_{}'.format(i))
            iteration.num_bwd_passes = 1 # number of backward passes
            iteration.train()

# This is "Xavier" initialization of weights.
# It is very important for the training!!!
def weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv2d') != -1:
        shape = m.weight.shape
        lim = np.sqrt(6/(shape[0] + shape[1])/shape[2]/shape[3])
        m.weight.data.uniform_(-lim, lim)
        m.bias.data.fill_(0)
        
iter_net = IterativeNetwork().type(dtype)
iter_net.apply(weights_init)

### Define summary functions to view logs in the tensorboard

In [None]:
def summary_image(writer, name, image, it):
    image = image[0,0]
    image = (image - torch.min(image)) / (torch.max(image) - torch.min(image) + 1e-5)
    writer.add_image(name, image, it, dataformats='HW')
        
def summaries(writer, result, true, loss, it, do_print=False):
    residual = result - true 
    squared_error = residual ** 2
    mse = torch.mean(squared_error)
    maxval = torch.max(true) - torch.min(true)
    psnr = 20 * torch.log10(maxval) -10 * torch.log10(mse)
    
    if do_print:
        print(it, mse.item(), psnr.item())

    writer.add_scalar('loss', loss, it)
    writer.add_scalar('psnr', psnr, it)

    summary_image(writer, 'result', result, it)
    #summary_image(writer, 'true', true, it)

train_writer = tensorboardX.SummaryWriter(comment="/train_"+name)
test_writer = tensorboardX.SummaryWriter(comment="/test_"+name)

## Training

In [None]:
maximum_steps = 100001
starter_learning_rate = 0.5 * 1e-3
optimizer = torch.optim.Adam(iter_net.parameters(), lr=starter_learning_rate, betas=(0.9, 0.99))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, maximum_steps)

train_generator = generate_mayo(operator, 'train', n_data, val_ration, photons_per_pixel=ppp)

for i in range(maximum_steps):
    iter_net.train()
    iter_net.inv_module_train()
    if i%10 == 0:
        x_true, y = train_generator.__next__() 
        x_true_pt = torch.from_numpy(x_true).type(dtype)
        y_pt = torch.from_numpy(y).type(dtype)

    optimizer.zero_grad()
    output, loss = iter_net(y_pt, x_true_pt)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(iter_net.parameters(), max_norm=1.0, norm_type=2)
    optimizer.step()
    scheduler.step()
    
    if i % 10 == 0:
        summaries(train_writer, output, x_true_pt, loss, i, do_print=False)
    if i % 100 == 0:
        iter_net.eval()
        iter_net.inv_module_eval()
        output_test, loss_test = iter_net(val_data_pt[0:1], val_image_pt[0:1])
        summaries(test_writer, output_test, val_image_pt[0:1], loss_test, i, do_print=True)
        
    if i > 0 and i % 1000 == 0:
        torch.save(iter_net.state_dict(), name)

## Testing

In [None]:
iter_net.load_state_dict(torch.load(name))
iter_net.eval()
iter_net.inv_module_eval()

test_image = np.load("test_image_full_noisy.npy")
test_data = np.load("test_data_full_noisy.npy")
test_image_pt = torch.from_numpy(test_image).type(dtype)
test_data_pt = torch.from_numpy(test_data).type(dtype)

def evaluate(test_image, test_data, plot=False):
    
    result, loss = iter_net(test_data, test_image)
    result = result.detach().cpu().numpy()
    true = test_image.detach().cpu().numpy()
    
    names = ["True", "iLPD"]
    
    if plot:
        figsize = 10
        fig, row = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(2*figsize, figsize))
        for name, res, ax in zip(names, [true, result], row):
            ax.set_title(name)
            ax.imshow(res.squeeze(), clim = [0.8, 1.2], cmap="bone")
            ax.set_axis_off()
        plt.show()
        
    return result

# evaluate the method slice by slice
res = []
for i in range(test_image_pt.shape[0]):
    r = evaluate(test_image_pt[i:(i+1)], test_data_pt[i:(i+1)])
    res.append(r)
res = np.stack(res, axis=1)[0]
l1 = fom.mean_squared_error(res, test_image)
l2 = fom.psnr(res, test_image)
l3 = fom.ssim(res, test_image)
print('Final result. Mean squared error: {}, PSNR: {}, SSIM: {}'.format(l1, l2, l3))