# **Setup the environment**

In [None]:
!pip install pytorch_lightning
!pip install kornia

In [1]:
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import numpy as np

from math import *

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split

from pytorch_lightning import LightningDataModule, LightningModule, Callback, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from torchvision import models
from torchvision.transforms import Resize
from torchvision.utils import save_image

from kornia.losses import ssim_loss, psnr_loss

from matplotlib import pyplot as plt
from PIL import Image

from tqdm import tqdm

  from pandas import MultiIndex, Int64Index


In [2]:
USE_COLAB = False

In [3]:
AVAIL_GPUS = max(0, torch.cuda.device_count())
BATCH_SIZE = 1
NUM_WORKERS = int(os.cpu_count())

# Image size that we are going to use
IMG_SIZE = 256

N_THETA = 180
N_RHO = ceil(IMG_SIZE*sqrt(2))
NB_PROJECTIONS = N_THETA * N_RHO

# Our images are graysacle (1 channels)
N_CHANNELS = 1

N_ITER = 20

CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "models/LQNM_RNN/")
SAVE_NAME = "LQNM_RNN_%i" % IMG_SIZE + "_N_ITER_%i" % N_ITER +"/"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_PATH, SAVE_NAME)
if not os.path.exists(CHECKPOINT_PATH):
    os.makedirs(CHECKPOINT_PATH)

SAVE_VALID = "Results/validation/%i" % N_ITER + "_ITER" + "/"
SAVE_TEST = "Results/test/LQNM_RNN_%i" % IMG_SIZE + "_N_ITER_%i" % N_ITER +"/"

RESNET18_MODEL_PATH = "models/resnet18.pth"

PATH_DATASETS = "Lung-CT-sinogram-data-lack-%i" % IMG_SIZE + ".npy"
PATH_SYSTEM_MATRIX = "System_matrix_%i" % IMG_SIZE + ".npy"

print("Num Available GPUs: ", AVAIL_GPUS)
print("Num Available workerss : ", NUM_WORKERS)

Num Available GPUs:  4
Num Available workerss :  48


In [4]:
if USE_COLAB: 
    from google.colab import drive
    drive.mount("/content/gdrive")
    PATH_DATASETS = "gdrive/MyDrive/" + PATH_DATASETS
    PATH_SYSTEM_MATRIX = "gdrive/MyDrive/" + PATH_SYSTEM_MATRIX
    GIF_TEST = "gdrive/MyDrive/" + GIF_TEST
else:
    PATH_DATASETS = "data/" + PATH_DATASETS
    PATH_SYSTEM_MATRIX = "data/" + PATH_SYSTEM_MATRIX

print("Data set file : ", PATH_DATASETS)
print("System matrix file : ", PATH_SYSTEM_MATRIX) 

Data set file :  data/Lung-CT-sinogram-data-lack-256.npy
System matrix file :  data/System_matrix_256.npy


# **Utility function**

In [5]:
def display_func(display_list, save=False, epoch=0):
    plt.figure(figsize=(25, 25))
    
    theta = torch.squeeze(display_list[0]).cpu().numpy()
    step = torch.squeeze(display_list[1]).cpu().numpy()

    input_title="Input Sinogram : " +str(theta)+" Angle range projection, "+str(step)+" Step projection"
    predicted="Predicted Object"

    if len(display_list) > 4 :
        psnr_p = -psnr_loss(display_list[3].cpu(), display_list[4].cpu(), 1)
        ssim_p = -2 * ssim_loss(display_list[3].cpu(), display_list[4].cpu(), 5) + 1
        predicted = predicted + ", PSNR=" + str('%.2f' % psnr_p) + ", SSIM=" + str('%.2f' % ssim_p)

    title = [input_title, 'True Mask', predicted]

    for i in range(2,len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i-2])
        plt.imshow(torch.squeeze(display_list[i]).cpu().numpy(), cmap='gray')
        plt.axis('off')
    plt.show()
    
    if save:
        r_im = torch.squeeze(display_list[4].cpu())
        save_image(r_im, SAVE_VALID + "epoch_" + str(epoch) + "_angle="+str(theta)+"_step="+str(step)+"_psnr="+str(psnr_p.numpy())+"_ssim="+str(ssim_p.numpy())+".png")

# **Load the dataset**

In [None]:
# load the numpy data array
with open(PATH_DATASETS, 'rb') as f:
    X_dataset = np.load(f)
    Y_dataset = np.load(f)
    angles_dataset = np.load(f)
    steps_dataset = np.load(f)

In [None]:
DATASET_SIZE = len(X_dataset)

train_size = int(0.95 * DATASET_SIZE)
val_size = DATASET_SIZE - train_size

tensor_x = torch.Tensor(X_dataset)
tensor_y = torch.Tensor(Y_dataset)
tensor_angle = torch.Tensor(angles_dataset)
tensor_step = torch.Tensor(steps_dataset)

dataset = TensorDataset(tensor_x, tensor_y, tensor_angle, tensor_step) # create your datset
train_set, val_set = random_split(dataset, [train_size, val_size])

train_set = DataLoader(train_set)
val_set = DataLoader(val_set)

In [None]:
print("Size all dataset: ", len(dataset))
print("Size training dataset: ", len(train_set))
print("Size testing dataset: ", len(val_set))

In [None]:
inp, re, angle, step = next(iter(train_set))
display_func([angle, step, inp, re, re], save=True)

In [6]:
class CTDataModule(LightningDataModule):
    def __init__(self, data_dir: str = PATH_DATASETS, batch_size: int = BATCH_SIZE, num_workers: int = NUM_WORKERS):
        super(CTDataModule, self).__init__()
        
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers


    def prepare_data(self):
      # download
      # load the numpy data array
        print("BEGIN__Loading the dataset__")
        with open(self.data_dir, 'rb') as f:
            X_dataset = torch.Tensor(np.load(f))
            Y_dataset = torch.Tensor(np.load(f))
            angles_dataset = torch.Tensor(np.load(f))
            steps_dataset = torch.Tensor(np.load(f))
        print("__DONE__Loading the dataset__")

        self.dataset = TensorDataset(X_dataset, Y_dataset, angles_dataset, steps_dataset) # create the datset
      
        self.dataset_size = len(self.dataset)
        self.train_size = int(0.95 * self.dataset_size)
        self.val_size = int(0.75 * (self.dataset_size - self.train_size))
        self.test_size = self.dataset_size - (self.train_size + self.val_size)


    def setup(self, stage=None):
      # Assign train/val datasets
      if stage == "fit" or stage is None:
            self.train_set, self.val_set, _ = random_split(self.dataset, [self.train_size, self.val_size, self.test_size])

      # Assign test dataset
      if stage == "test" or stage is None:
            _, _, self.test_set = random_split(self.dataset, [self.train_size, self.val_size, self.test_size])


    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, num_workers=self.num_workers)


    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers)


    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers)

# **Build the QNMs-RNN model**

## QNMs-RNN cell

In [7]:
class RegularizationBlock(nn.Module):
    def __init__(self, in_channels=1, filters=48):
        super(RegularizationBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=filters, kernel_size=5, padding='same')   
        self.conv2 = nn.Conv2d(in_channels=filters, out_channels=filters, kernel_size=5, padding='same')
        self.outputconv = nn.Conv2d(in_channels=filters, out_channels=in_channels, kernel_size=5, padding='same')

        self.relu = nn.ReLU()

    
    def forward(self, x_t):
        x = self.relu(self.conv1(x_t))
        x = self.relu(self.conv2(x))
        x_out = self.outputconv(x)

        return torch.squeeze(x_out)

In [None]:
reg_block = RegularizationBlock()
x_reg = reg_block(re)
plt.imshow(np.squeeze(x_reg.detach()), cmap='gray')

In [8]:
class QNM_RNNCell(nn.Module):
    def __init__(self, N, N_THETA, N_RHO):
        super(QNM_RNNCell, self).__init__()
        
        self.N = N
        self.N_THETA = N_THETA
        self.N_RHO = N_RHO

        self.reg_block = RegularizationBlock().cuda()
        self.h_linear = nn.Linear(in_features=self.N, out_features=self.N).cuda()
        self.lamb = nn.Parameter(torch.tensor(1.))

        self.splus = nn.Softplus()
        
        self.mse_loss = nn.MSELoss()
        
        self.apply(self.__init_weights__)


    def __init_weights__(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)

    
    def scale_tensor(self, X, a=0, b=1):
        x_min = torch.min(X)
        x_max = torch.max(X)
        q = (X - x_min)*(b-a)
        d = x_max - x_min
        
        if d != 0:
            s_X = q / d
        else:
            s_X = X
        return s_X


    def compute_next_H_t(self, H_t, s_t, z_t):
        s_t = self.scale_tensor(s_t)
        z_t = self.scale_tensor(z_t)

        diag_next_H_t = self.splus(self.h_linear(torch.diagonal(H_t, 0)))
        next_H_t = torch.diag(diag_next_H_t)
        
        h_t_loss = self.mse_loss(next_H_t, H_t) + self.mse_loss(next_H_t @ z_t, s_t)
        return next_H_t, h_t_loss

    
    def forward(self, As, x_t, y, H_t, sigmaE_t):        
        AsT = torch.transpose(As, 0, 1)

        '''compute the regularization terms'''
        x_reg = self.reg_block(x_t)

        '''compute next sigmaE_t'''
        x_t = torch.reshape(x_t, (self.N*self.N, ))
        y_t = torch.sparse.mm(As, x_t[:, None])
        y_t = self.scale_tensor(torch.squeeze(y_t))
        
        y_dif = y_t - y
        y_dif = self.scale_tensor(y_dif)
        
        x_update_t = torch.sparse.mm(AsT, y_dif[:, None])
        x_update_t = torch.squeeze(x_update_t)
        x_update_t = torch.reshape(x_update_t, (self.N, self.N))
        
        next_sigmaE_t = -self.lamb * (x_update_t + x_reg)
        
        '''compute delta_x'''
        delta_x_t = H_t @ next_sigmaE_t
        delta_x_t = self.scale_tensor(delta_x_t)

        '''compute next x'''
        x_t = torch.reshape(torch.squeeze(x_t), (self.N, self.N))
        next_x_t = x_t + delta_x_t
        next_x_t = self.scale_tensor(next_x_t)

        '''compute next H'''
        s_t = next_x_t - x_t
        z_t = next_sigmaE_t - sigmaE_t 

        next_H_t, h_t_loss = self.compute_next_H_t(H_t, s_t, z_t)
        
        next_x_t = next_x_t[None, None, :, :]
        return next_x_t, next_H_t, next_sigmaE_t, h_t_loss

In [9]:
class QNM_RNN(nn.Module):
    def __init__(self, N, N_THETA, N_RHO, N_ITER, BETA=1., SYS_MAT_FILE=PATH_SYSTEM_MATRIX):
        super(QNM_RNN, self).__init__()

        self.N = N
        self.N_THETA = N_THETA
        self.N_RHO = N_RHO
        self.N_ITER = N_ITER
        self.BETA = BETA
        self.sys_mat_dir = PATH_SYSTEM_MATRIX

        self.qnm_rnn_cells = nn.ModuleList([QNM_RNNCell(N=self.N, N_THETA=self.N_THETA, N_RHO=self.N_RHO) for _ in range(self.N_ITER)])

        self.A = self.__init_system_matrix__() 


    def __init_system_matrix__(self):
        print("__BEGIN__Loading the system matrix__")
        with open(self.sys_mat_dir, 'rb') as f:
            A = torch.Tensor(np.load(f))
            A.requires_grad = False
        print("__DONE__Loading the system matrix__")

        return A


    def forward(self, y):
        tot_h_loss = 0

        H_t = self.BETA * torch.eye(self.N).type_as(y)
        sigmaE_t = torch.zeros((self.N, self.N)).type_as(y)

        y = torch.reshape(y, (self.N_RHO*self.N_THETA, ))
        
        x_t = torch.zeros((self.N, self.N)).type_as(y)
        x_t = x_t[None, None, :, :]
        
        self.A = self.A.type_as(y)
        As = self.A.to_sparse()
        
        for qnm_model in self.qnm_rnn_cells:
            x_t, H_t, sigmaE_t, h_t_loss = qnm_model(As, x_t, y, H_t, sigmaE_t)
            tot_h_loss += h_t_loss
        
        return x_t, (tot_h_loss/self.N_ITER)

In [None]:
qnm_model = QNM_RNN(N=IMG_SIZE, N_THETA=N_THETA, N_RHO=N_RHO, N_ITER=N_ITER)

In [None]:
for name, parameter in qnm_model.named_parameters():
    print(name)

In [None]:
x_out, h_loss = qnm_model(inp.cuda())
print(h_loss)
plt.imshow(np.squeeze(x_out.cpu().detach()), cmap='gray')

# **Lightning GAN architecture**

In [10]:
class LQNM(LightningModule):
    def __init__(
        self,
        N_ITER, 
        N, 
        N_THETA, 
        N_RHO,
        BETA=1., 
        lr: float = 0.0002,
        b1: float = 0.5,
        b2: float = 0.999,
        **kwargs
    ):
        super(LQNM, self).__init__()
        self.save_hyperparameters()
        
        self.automatic_optimization = False
        
        self.train_step_idx = 0

        self.qnm_model = QNM_RNN(N=N, N_THETA=N_THETA, N_RHO=N_RHO, N_ITER=N_ITER, BETA=BETA)

        
    def __train_step_idx__(self):
        return self.train_step_idx
        

    def forward(self, y):
        return self.qnm_model(y)

    
    def QNM_loss(self, rec_img, tar_img):
        l1_loss = nn.L1Loss()
        
        mae_loss = l1_loss(rec_img, tar_img)
        s_loss = ssim_loss(rec_img, tar_img, 5)
        
        total_loss = mae_loss + s_loss

        return total_loss    

    
    def training_step(self, batch, batch_idx):
        opt_qnm = self.optimizers()
        
        y, tar_img, angle, step = batch

        rec_img, h_loss = self(y)

        ######################
        # Optimize QNM-RNN   #
        ######################
        # compute losses
        qnm_loss = self.QNM_loss(rec_img, tar_img) + h_loss
            
        opt_qnm.zero_grad()
        self.manual_backward(qnm_loss)
        opt_qnm.step()
        
        self.train_step_idx += 1
        self.log_dict({"QNM-RNN_loss": qnm_loss}, prog_bar=True) 

    
    def validation_step(self, batch, batch_idx):
        y, tar_img, _, _ = batch

        rec_img, h_loss = self.qnm_model(y)
        val_qnm_loss = self.QNM_loss(rec_img, tar_img) + h_loss

        self.log_dict({"val_qnm_loss": val_qnm_loss}, prog_bar=True)
  

    def test_step(self, batch, batch_idx):
        # check if the saving dir exist else create it
        if not os.path.exists(SAVE_TEST):
            os.makedirs(SAVE_TEST)
        
        y, tar_img, angle, step = batch
        
        with torch.no_grad():
            self.eval()
            rec_img, h_loss = self.forward(y)
            self.train()
            val_qnm_loss = self.QNM_loss(rec_img, tar_img) + h_loss
            
            self.log_dict({"val_qnm_loss": val_qnm_loss}, prog_bar=True)
            
            psnr_p = -psnr_loss(tar_img, rec_img, 1).cpu()
            ssim_p = (-2 * ssim_loss(tar_img, rec_img, 5) + 1).cpu()
            file = (SAVE_TEST+"idx_"+str(batch_idx)+"_angle="+str(angle.cpu().numpy())+"_step="+str(step.cpu().numpy())
                             +"_psnr="+str(psnr_p.numpy())+"_ssim="+str(ssim_p.numpy())+".png")
            save_image(rec_img, file)
            display_list = [angle, step, y, tar_img, rec_img]
            display_func(display_list)     


    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2
        
        opt_qnm = torch.optim.Adam(self.qnm_model.parameters(), lr=lr, betas=(b1, b2))

        return opt_qnm

In [11]:
class DisplayCallback(Callback):
    def __init__(self, every_n_steps=500):
        super().__init__()
        self.every_n_steps = every_n_steps

    def on_train_batch_end(self, trainer, pl_module, *args):
        val_dataloader = trainer.val_dataloaders[0]
        val_dataset = val_dataloader.dataset

        y, tar_img, angle, step = next(iter(val_dataset))
        
        if pl_module.__train_step_idx__() % self.every_n_steps == 0:
            # Reconstruct images
            y = y.to(pl_module.device)
            y = y[None, :, :, :]

            with torch.no_grad():
                pl_module.eval()
                rec_img, _= pl_module(y)
                pl_module.train()
                
                display_list = [angle, step, y, tar_img[None, :, :, :], rec_img]
                display_func(display_list)
                
    def on_train_epoch_end(self, trainer, pl_module, *args):
        if not os.path.exists(SAVE_VALID):
            os.makedirs(SAVE_VALID)
            
        val_dataloader = trainer.val_dataloaders[0]
        val_dataset = val_dataloader.dataset

        y, tar_img, angle, step = next(iter(val_dataset))
        
        # Reconstruct images
        y = y.to(pl_module.device)
        y = y[None, :, :, :]

        with torch.no_grad():
            pl_module.eval()
            rec_img, _= pl_module(y)
            pl_module.train()
                
            display_list = [angle, step, y, tar_img[None, :, :, :], rec_img]
            display_func(display_list, save=True, epoch=trainer.current_epoch)                

# **Traing the model**

In [None]:
ct_data = CTDataModule()

checkpoint_callback = ModelCheckpoint(
    save_weights_only=True,
    dirpath = CHECKPOINT_PATH,
    monitor="val_qnm_loss",
    mode="min"
)

trainer = Trainer(
    strategy="dp",
    gpus=AVAIL_GPUS,
    max_epochs=10, 
    callbacks=[
               checkpoint_callback,
               DisplayCallback(every_n_steps=2000),
               LearningRateMonitor("epoch"),
        ],
    )

pretrained_filename = CHECKPOINT_PATH
ckpt_found = False
for x in os.listdir(pretrained_filename):
    if x.endswith(".ckpt"):
        ckpt_found = True
        pretrained_filename += x
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        break

if ckpt_found:
    model = LQNM.load_from_checkpoint(pretrained_filename)
else :
    model = LQNM(N_ITER=N_ITER, N=IMG_SIZE, N_THETA=N_THETA, N_RHO=N_RHO)
    trainer.fit(model, ct_data)
    
trainer.test(model, ct_data)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


__BEGIN__Loading the system matrix__
