# **Setup the environment**

In [1]:
import os

from math import *

import numpy as np

import torch
from torch import nn

from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split

from torchvision.utils import save_image

from pytorch_lightning import LightningDataModule, LightningModule, Callback, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from kornia.losses import ssim_loss, psnr_loss

from matplotlib import pyplot as plt

In [2]:
USE_COLAB = False

In [6]:
AVAIL_GPUS = max(0, torch.cuda.device_count())
BATCH_SIZE = 1
NUM_WORKERS = min(int(os.cpu_count()), 4)

# Image size that we are going to use
N = 128

N_detectors = 4
N_THETA = 179
N_RHO = int(ceil(pi*2*N)/N_detectors);
NB_PROJECTIONS = N_THETA * N_RHO

# Our images are graysacle (1 channels)
N_CHANNELS = 1

N_ITER = 2

SAVE_TEST = "results/test/LEARN_%i" % N + "_N_ITER_%i" % N_ITER +"/"
    
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "models/LEARN/")
SAVE_NAME = "LEARN_%i" % N + "_N_ITER_%i" % N_ITER +"/"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_PATH, SAVE_NAME)
if not os.path.exists(CHECKPOINT_PATH):
    os.makedirs(CHECKPOINT_PATH)
    
print("Num Available GPUs: ", AVAIL_GPUS)
print("Num Available workerss : ", NUM_WORKERS)

Num Available GPUs:  1
Num Available workerss :  4


In [7]:
PATH_DATASETS = 'Lung-CCST-128-4detectors-data.npy'
PATH_SYSTEM_MATRIX = 'CCST-System-matrix.txt'

if USE_COLAB: 
    from google.colab import drive
    drive.mount("/content/gdrive")
    PATH_DATASETS = "gdrive/MyDrive/" + PATH_DATASETS
    PATH_SYSTEM_MATRIX = "gdrive/MyDrive/" + PATH_SYSTEM_MATRIX
else:
    PATH_DATASETS = "data/" + PATH_DATASETS
    PATH_SYSTEM_MATRIX = "data/" + PATH_SYSTEM_MATRIX

print("Data set file : ", PATH_DATASETS)
print("System matrix file : ", PATH_SYSTEM_MATRIX)
print("Check point file : ", CHECKPOINT_PATH)
print("Results file : ", SAVE_TEST)

Data set file :  data/Lung-CCST-128-4detectors-data.npy
System matrix file :  data/CCST-System-matrix.txt
Check point file :  models/LEARN/LEARN_128_N_ITER_2/
Results file :  results/test/LEARN_128_N_ITER_2/


# **Utility function**

In [None]:
def display_func(display_list, save=False, epoch=0):
    plt.figure(figsize=(25, 25))

    predicted="Predicted Object : "

    if len(display_list) > 2 :
        psnr_p = -psnr_loss(display_list[1].cpu(), display_list[2].cpu(), 1)
        ssim_p = -2 * ssim_loss(display_list[1].cpu(), display_list[2].cpu(), 5) + 1
        predicted = predicted + ", PSNR=" + str('%.2f' % psnr_p) + ", SSIM=" + str('%.2f' % ssim_p)

    title = ["Input Data", 'True Mask', predicted]

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(torch.squeeze(display_list[i]).cpu().numpy(), cmap='gray')
        plt.axis('off')
    plt.show()

In [None]:
def fp_system_torch(x, sys_mat, n, nr, nt):
    x = torch.reshape(x, (n*n, 1))
    y = torch.mm(sys_mat, x)
    y = torch.reshape(y, (nr, nt))
    return y

In [None]:
def bp_system_torch(y, sys_mat, n, nr, nt):
    y = torch.reshape(y, (nr*nt, 1))
    x = torch.mm(sys_mat.T, y)
    x = torch.reshape(x, (n, n))
    return x

# **Load the dataset**

In [None]:
# load the numpy data array
with open(PATH_DATASETS, 'rb') as f:
    X_dataset = np.load(f)
    Y_dataset = np.load(f)

In [None]:
DATASET_SIZE = len(X_dataset)

train_size = int(0.95 * DATASET_SIZE)
val_size = DATASET_SIZE - train_size

tensor_x = torch.Tensor(X_dataset)
tensor_y = torch.Tensor(Y_dataset)

dataset = TensorDataset(tensor_x, tensor_y) # create your datset
train_set, val_set = random_split(dataset, [train_size, val_size])

train_set = DataLoader(train_set)
val_set = DataLoader(val_set)

In [None]:
print("Size all dataset: ", len(dataset))
print("Size training dataset: ", len(train_set))
print("Size testing dataset: ", len(val_set))

In [None]:
x_inp, y_re = next(iter(train_set))
display_func([x_inp, y_re])

In [None]:
class CCSTDataModule(LightningDataModule):
    def __init__(self, data_dir: str = PATH_DATASETS, batch_size: int = BATCH_SIZE, num_workers: int = NUM_WORKERS):
        super(CCSTDataModule, self).__init__()
        
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

    def prepare_data(self):
      # download
      # load the numpy data array
        print("BEGIN__Loading the dataset__")
        with open(self.data_dir, 'rb') as f:
            X_dataset = torch.Tensor(np.load(f))
            Y_dataset = torch.Tensor(np.load(f))
        print("__DONE__Loading the dataset__")

        self.dataset = TensorDataset(X_dataset, Y_dataset) # create the datset
      
        self.dataset_size = len(self.dataset)
        self.train_size = int(0.90 * self.dataset_size)
        self.val_size = int(0.5 * (self.dataset_size - self.train_size))
        self.test_size = self.dataset_size - (self.train_size + self.val_size)


    def setup(self, stage=None):
      # Assign train/val datasets
      if stage == "fit" or stage is None:
            self.train_set, self.val_set, _ = random_split(self.dataset, [self.train_size, self.val_size, self.test_size])

      # Assign test dataset
      if stage == "test" or stage is None:
            _, _, self.test_set = random_split(self.dataset, [self.train_size, self.val_size, self.test_size])


    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, num_workers=self.num_workers)


    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers)


    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers)

# **Build the LEARN model**

In [None]:
class RegularizationBlock(nn.Module):
    def __init__(self, in_channels=1, filters=48):
        super(RegularizationBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=filters, kernel_size=5, padding='same')   
        self.conv2 = nn.Conv2d(in_channels=filters, out_channels=filters, kernel_size=5, padding='same')
        self.outputconv = nn.Conv2d(in_channels=filters, out_channels=in_channels, kernel_size=5, padding='same')

        self.prelu = nn.PReLU()

    
    def forward(self, x_t):
        x = self.prelu(self.conv1(x_t))
        x = self.prelu(self.conv2(x))
        x_out = self.outputconv(x)

        return torch.squeeze(x_out)

In [None]:
class LEARN_RNNCell(nn.Module):
    def __init__(self, n, nr, nt):
        super(QNM_RNNCell, self).__init__()
        
        self.n = n
        self.nr = nr
        self.nt = nt

        self.reg_block = RegularizationBlock()
        self.lamb = nn.Parameter(torch.tensor(0.))

    
    def forward(self, d, x_t, sys_mat):        
        '''compute the regularization terms'''
        x_reg = self.reg_block(x_t)
        
        d_t = fp_system_torch(x_t, sys_mat, self.n, self.nr, self.nt)
        d_dif = d_t - d
        
        x_update_t = self.lamb * bp_system_torch(d_dif, sys_mat, self.n, self.nr, self.nt)

        next_x_t = x_t - x_update_t - x_reg

        next_x_t = next_x_t[None, None, :, :]
        return next_x_t

In [None]:
class LEARN_RNN(nn.Module):
    def __init__(self, n, nr, nt, n_its, sys_mat_path=PATH_SYSTEM_MATRIX):
        super(QNM_RNN, self).__init__()

        self.n = n
        self.nr = nr
        self.nt = nt
        self.n_its = n_its

        self.sys_mat = torch.Tensor(np.loadtxt(open(PATH_SYSTEM_MATRIX, "rb"), delimiter=",")).T

        self.learn_rnn_cells = nn.ModuleList([LEARN_RNNCell(n=self.n, nr=self.nr, nt=self.nt) for _ in range(self.n_its)])

    def forward(self, d):
        
        x_t = torch.zeros((self.n, self.n)).type_as(d)
        x_t = x_t[None, None, :, :]
        
        self.sys_mat = self.sys_mat.type_as(d)
        
        for learn_model in self.learn_rnn_cells:
            x_t = learn_model(d, x_t, sys_mat)
        
        return x_t

# **Lightning GAN architecture**

In [None]:
class L_LEARN(LightningModule):
    def __init__(self, n, nr, nt, n_its, 
                lr: float = 0.0002,
                b1: float = 0.5,
                b2: float = 0.999,
                 **kwargs):
        super(L_LEARN, self).__init__()
        self.save_hyperparameters()
        
        self.automatic_optimization = False

        self.learn_model = LEARN_RNN(n=n, nr=nr, nt=nt, n_its=n_its)
        
    def forward(self, y):
        return self.learn_model(y)
    
    def LEARN_loss(self, y, y_true):
        mse_loss = nn.MSELoss()
        return mse_loss(y, y_true)
    
    def training_step(self, batch, batch_idx):
        opt_learn = self.optimizers()
        
        d, y_true = batch
        y_true = torch.squeeze(y_true)

        y = self(d)

        ######################
        # Optimize QNM-RNN   #
        ######################
        # compute losses
        learn_loss = self.LEARN_loss(y, y_true)
        
        opt_learn.zero_grad()
        self.manual_backward(learn_loss)
        opt_learn.step()
        
        self.log_dict({"learn_rnn_loss": learn_loss}, prog_bar=True) 
    
    def validation_step(self, batch, batch_idx):
        d, y_true = batch
        y_true = torch.squeeze(y_true)

        y = self.learn_model(d)
        val_learn_loss = self.LEARN_loss(y, y_true)

        self.log_dict({"val_learn_rnn_loss": val_learn_loss}, prog_bar=True)
  

    def test_step(self, batch, batch_idx):
        # check if the saving dir exist else create it
        if not os.path.exists(SAVE_TEST):
            os.makedirs(SAVE_TEST)
        
        d, y_true = batch
        y_true = torch.squeeze(y_true)
        
        with torch.no_grad():
            self.eval()
            y = self(d)
            self.train()
            test_learn_loss = self.LEARN_loss(y, y_true)
            
            self.log_dict({"test_Learn_rnn_loss": test_learn_loss}, prog_bar=True)
            
            y_true, y = y_true[None, None, :, :], y[None, None, :, :]
            
            psnr_p = -psnr_loss(y, y_true, 1).cpu()
            ssim_p = (-2 * ssim_loss(y, y_true, 5) + 1).cpu()
            file = SAVE_TEST+"idx_"+str(batch_idx)+"_psnr="+str(psnr_p.numpy())+"_ssim="+str(ssim_p.numpy())+".png"
            save_image(y, file)
            
            display_list = [d, y_true, y]
            display_func(display_list)     


    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2
        
        opt_learn = torch.optim.Adam(self.l_mlem.parameters(), lr=lr, betas=(b1, b2))

        return opt_learn   

In [None]:
class DisplayCallback(Callback):
    def __init__(self, every_n_epochs=5):
        super(DisplayCallback, self).__init__()
        self.every_n_epochs = every_n_epochs

    def on_train_epoch_end(self, trainer, pl_module, *args):
        val_dataloader = trainer.val_dataloaders[0]
        val_dataset = val_dataloader.dataset

        d, y_true = next(iter(val_dataset))
        y_true = torch.squeeze(y_true)
        
        # Reconstruct images
        d = d.to(pl_module.device)

        with torch.no_grad():
            pl_module.eval()
            y = pl_module(d)
            pl_module.train()
                
            display_list = [d, y_true[None, None, :, :], y[None, None, :, :]]
            display_func(display_list)                

# **Traing the model**

In [None]:
ccst_data = CCSTDataModule()

checkpoint_callback = ModelCheckpoint(
    save_weights_only=True,
    dirpath = CHECKPOINT_PATH,
    monitor="val_learn_rnn_loss",
    mode="min"
)

trainer = Trainer(
    accelerator='gpu', 
    devices=AVAIL_GPUS,
    strategy="dp",
    max_epochs=10, 
    callbacks=[
               checkpoint_callback,
               DisplayCallback(every_n_epochs=1),
               LearningRateMonitor("epoch"),
        ],
    )

pretrained_filename = CHECKPOINT_PATH
ckpt_found = False
for x in os.listdir(pretrained_filename):
    if x.endswith(".ckpt"):
        ckpt_found = True
        pretrained_filename += x
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        break

if ckpt_found:
    model = L_LEARN.load_from_checkpoint(pretrained_filename)
else :
    model = L_LEARN(N, N_RHO, N_THETA, N_ITER)
    trainer.fit(model, ccst_data)
    
trainer.test(model, ccst_data)