# **Setup the environment**

In [None]:
import os

from math import *

import numpy as np

import torch
from torch import nn

from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split

from torchvision.utils import save_image

from pytorch_lightning import LightningDataModule, LightningModule, Callback, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from kornia.losses import ssim_loss, psnr_loss

from matplotlib import pyplot as plt

In [None]:
USE_COLAB = False

In [None]:
AVAIL_GPUS = max(0, torch.cuda.device_count())
BATCH_SIZE = 1
NUM_WORKERS = 2#int(os.cpu_count())

# Image size that we are going to use
N = 128

N_detectors = 4
N_THETA = 179
N_RHO = int(ceil(pi*2*N)/N_detectors);
NB_PROJECTIONS = N_THETA * N_RHO

# Our images are graysacle (1 channels)
N_CHANNELS = 1

N_ITER = 10

SAVE_TEST = "results/test/L_MLEM%i" % N + "_N_ITER_%i" % N_ITER +"/"
    
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "models/L_MLEM/")
SAVE_NAME = "L_MLEM_%i" % N + "_N_ITER_%i" % N_ITER +"/"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_PATH, SAVE_NAME)
if not os.path.exists(CHECKPOINT_PATH):
    os.makedirs(CHECKPOINT_PATH)
    
print("Num Available GPUs: ", AVAIL_GPUS)
print("Num Available workerss : ", NUM_WORKERS)

In [None]:
PATH_DATASETS = 'Lung-CCST-128-4detectors-data.npy'
PATH_SYSTEM_MATRIX = 'CCST-System-matrix.txt'

if USE_COLAB: 
    from google.colab import drive
    drive.mount("/content/gdrive")
    PATH_DATASETS = "gdrive/MyDrive/" + PATH_DATASETS
    PATH_SYSTEM_MATRIX = "gdrive/MyDrive/" + PATH_SYSTEM_MATRIX
else:
    PATH_DATASETS = "data/" + PATH_DATASETS
    PATH_SYSTEM_MATRIX = "data/" + PATH_SYSTEM_MATRIX

print("Data set file : ", PATH_DATASETS)
print("System matrix file : ", PATH_SYSTEM_MATRIX)
print("Check point file : ", CHECKPOINT_PATH)
print("Results file : ", SAVE_TEST)

# **Utility function**

In [None]:
def display_func(display_list, save=False, epoch=0):
    plt.figure(figsize=(25, 25))

    predicted="Predicted Object : "

    if len(display_list) > 2 :
        psnr_p = -psnr_loss(display_list[1].cpu(), display_list[2].cpu(), 1)
        ssim_p = -2 * ssim_loss(display_list[1].cpu(), display_list[2].cpu(), 5) + 1
        predicted = predicted + ", PSNR=" + str('%.2f' % psnr_p) + ", SSIM=" + str('%.2f' % ssim_p)

    title = ["Input Data", 'True Mask', predicted]

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(torch.squeeze(display_list[i]).cpu().numpy(), cmap='gray')
        plt.axis('off')
    plt.show()

In [None]:
def fp_system_torch(y, sys_mat, n, nr, nt):
    y = torch.reshape(y, (n*n, 1))
    d = torch.mm(sys_mat, y)
    d = torch.reshape(d, (nr, nt))
    return d

In [None]:
def bp_system_torch(d, sys_mat, n, nr, nt):
    d = torch.reshape(d, (nr*nt, 1))
    y = torch.mm(sys_mat.T, d)
    y = torch.reshape(y, (n, n))
    return y

# **Dataset Loader**

In [None]:
# load the numpy data array
with open(PATH_DATASETS, 'rb') as f:
    X_dataset = np.load(f)
    Y_dataset = np.load(f)

In [None]:
DATASET_SIZE = len(X_dataset)

train_size = int(0.95 * DATASET_SIZE)
val_size = DATASET_SIZE - train_size

tensor_x = torch.Tensor(X_dataset)
tensor_y = torch.Tensor(Y_dataset)

dataset = TensorDataset(tensor_x, tensor_y) # create your datset
train_set, val_set = random_split(dataset, [train_size, val_size])

train_set = DataLoader(train_set)
val_set = DataLoader(val_set)

In [None]:
print("Size all dataset: ", len(dataset))
print("Size training dataset: ", len(train_set))
print("Size testing dataset: ", len(val_set))

In [None]:
x_inp, y_re = next(iter(train_set))
display_func([x_inp, y_re])

In [None]:
class CCSTDataModule(LightningDataModule):
    def __init__(self, data_dir: str = PATH_DATASETS, batch_size: int = BATCH_SIZE, num_workers: int = NUM_WORKERS):
        super(CCSTDataModule, self).__init__()
        
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

    def prepare_data(self):
      # download
      # load the numpy data array
        print("BEGIN__Loading the dataset__")
        with open(self.data_dir, 'rb') as f:
            X_dataset = torch.Tensor(np.load(f))
            Y_dataset = torch.Tensor(np.load(f))
        print("__DONE__Loading the dataset__")

        self.dataset = TensorDataset(X_dataset, Y_dataset) # create the datset
      
        self.dataset_size = len(self.dataset)
        self.train_size = int(0.90 * self.dataset_size)
        self.val_size = int(0.5 * (self.dataset_size - self.train_size))
        self.test_size = self.dataset_size - (self.train_size + self.val_size)


    def setup(self, stage=None):
      # Assign train/val datasets
      if stage == "fit" or stage is None:
            self.train_set, self.val_set, _ = random_split(self.dataset, [self.train_size, self.val_size, self.test_size])

      # Assign test dataset
      if stage == "test" or stage is None:
            _, _, self.test_set = random_split(self.dataset, [self.train_size, self.val_size, self.test_size])


    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, num_workers=self.num_workers)


    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers)


    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers)

# **Build the MLEM model**

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.CNN_cell = nn.Sequential(
            nn.Conv2d(1, 8, 7, padding=(3,3)), nn.PReLU(),
            nn.Conv2d(8, 8, 7, padding=(3,3)), nn.PReLU(),
            nn.Conv2d(8, 8, 7, padding=(3,3)), nn.PReLU(),
            nn.Conv2d(8, 8, 7, padding=(3,3)), nn.PReLU(),
            nn.Conv2d(8, 1, 7, padding=(3,3)), nn.PReLU(),
        )
        
    def forward(self, x_t):
        x = torch.squeeze(self.CNN_cell(x_t.unsqueeze(0).unsqueeze(0)))
        return x

In [None]:
class Learned_MLEM(nn.Module):
    def __init__(self, n, nr, nt, num_its, sys_mat_path=PATH_SYSTEM_MATRIX):
        super(Learned_MLEM, self).__init__()
        self.n = n
        self.nr = nr
        self.nt = nt
        
        self.sys_mat = torch.Tensor(np.loadtxt(open(PATH_SYSTEM_MATRIX, "rb"), delimiter=",")).T
        self.num_its = num_its
        self.cnn = CNN()
        
    def forward(self, d):
        d = torch.squeeze(d).T
        self.sys_mat = self.sys_mat.type_as(d)
        
        y = torch.ones(self.n, self.n).type_as(d)
        data_ones = torch.ones_like(d)
        sens_y = bp_system_torch(data_ones, self.sys_mat, self.n, self.nr, self.nt)
        
        for it in range(self.num_its):
            fpdata = fp_system_torch(y, self.sys_mat, self.n, self.nr, self.nt)
            ratio = d / (fpdata + 1.0e-9)
            correction = bp_system_torch(ratio, self.sys_mat, self.n, self.nr, self.nt)
            y = y * correction
            y = torch.abs(y + self.cnn(y))
            
        return y

In [None]:
lm = Learned_MLEM(N, N_RHO, N_THETA, N_ITER)

In [None]:
y = lm(x_inp)
display_func([y.detach(), y_re])

In [None]:
class LP_Learned_MLEM(LightningModule):
    def __init__(self, n, nr, nt, num_its, 
                lr: float = 0.0002,
                b1: float = 0.5,
                b2: float = 0.999,
                 **kwargs):
        super(LP_Learned_MLEM, self).__init__()
        self.save_hyperparameters()
        
        self.automatic_optimization = False
        
        self.l_mlem = Learned_MLEM(n, nr, nt, num_its)
    
    def forward(self, d):
        return self.l_mlem(d)
    
    def L_MLEM_loss(self, y, y_true):
        mse_loss = nn.MSELoss()
        return mse_loss(y, y_true)
    
    def training_step(self, batch, batch_idx):
        opt_lmlem = self.optimizers()
        
        d, y_true = batch
        y_true = torch.squeeze(y_true)
        
        y = self(d)

        ######################
        # Optimize L_MLEM    #
        ######################
        # compute losses
        lmlem_loss = self.L_MLEM_loss(y, y_true)
            
        opt_lmlem.zero_grad()
        self.manual_backward(lmlem_loss)
        opt_lmlem.step()
        
        self.log_dict({"Learned_MLEM_loss": lmlem_loss}, prog_bar=True)         
    
    def validation_step(self, batch, batch_idx):
        d, y_true = batch
        y_true = torch.squeeze(y_true)

        y = self(d)
        val_lmlem_loss = self.L_MLEM_loss(y, y_true)

        self.log_dict({"val_Learned_MLEM_loss": val_lmlem_loss}, prog_bar=True)        
        
    def test_step(self, batch, batch_idx):
        # check if the saving dir exist else create it
        if not os.path.exists(SAVE_TEST):
            os.makedirs(SAVE_TEST)
        
        d, y_true = batch
        y_true = torch.squeeze(y_true)
        
        with torch.no_grad():
            self.eval()
            y = self(d)
            self.train()
            test_lmlem_loss = self.L_MLEM_loss(y, y_true)
            
            self.log_dict({"test_Learned_MLEM_loss": test_lmlem_loss}, prog_bar=True)
            
            psnr_p = -psnr_loss(y, y_true, 1).cpu()
            ssim_p = (-2 * ssim_loss(y, y_true, 5) + 1).cpu()
            file = SAVE_TEST+"idx_"+str(batch_idx)+"_psnr="+str(psnr_p.numpy())+"_ssim="+str(ssim_p.numpy())+".png"
            save_image(y, file)
            
            display_list = [d, y_true, y]
            display_func(display_list)             
        
    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2
        
        opt_lmlem = torch.optim.Adam(self.l_mlem.parameters(), lr=lr, betas=(b1, b2))

        return opt_lmlem        

In [None]:
class DisplayCallback(Callback):
    def __init__(self, every_n_epochs=5):
        super(DisplayCallback, self).__init__()
        self.every_n_epochs = every_n_epochs

    def on_train_epoch_end(self, trainer, pl_module, *args):
        val_dataloader = trainer.val_dataloaders[0]
        val_dataset = val_dataloader.dataset

        d, y_true = next(iter(val_dataset))
        
        # Reconstruct images
        d = d.to(pl_module.device)

        with torch.no_grad():
            pl_module.eval()
            y = pl_module(d)
            pl_module.train()
                
            display_list = [d, y_true, y]
            display_func(display_list)                

# **Traing the model**

In [None]:
ccst_data = CCSTDataModule()

checkpoint_callback = ModelCheckpoint(
    save_weights_only=True,
    dirpath = CHECKPOINT_PATH,
    monitor="val_mlem_loss",
    mode="min"
)

trainer = Trainer(
    accelerator='gpu', 
    devices=AVAIL_GPUS,
    strategy="dp",
    max_epochs=10, 
    callbacks=[
               checkpoint_callback,
               DisplayCallback(every_n_epochs=1),
               LearningRateMonitor("epoch"),
        ],
    )

pretrained_filename = CHECKPOINT_PATH
ckpt_found = False
for x in os.listdir(pretrained_filename):
    if x.endswith(".ckpt"):
        ckpt_found = True
        pretrained_filename += x
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        break

if ckpt_found:
    model = LP_Learned_MLEM.load_from_checkpoint(pretrained_filename)
else :
    model = LP_Learned_MLEM(N, N_RHO, N_THETA, N_ITER)
    trainer.fit(model, ccst_data)
    
trainer.test(model, ccst_data)