# Super resolution with implicit neural representations 

In [None]:
%matplotlib inline
from nn_mri import ImageFitting_set, Siren, get_mgrid, cases, calculate_contrast, save_dicom
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm
from skimage.color import rgb2gray, gray2rgb

from skimage import data, img_as_float
from skimage.restoration import denoise_nl_means, estimate_sigma
from skimage.metrics import peak_signal_noise_ratio
from skimage.util import random_noise
from skimage.transform import rescale, resize, downscale_local_mean
import os
import SimpleITK as sitk

seg = 50
scale = 1
total_steps = 3000
radius =5
color = (255, 0, 0)
thickness = 1

# 1. Directional approach
## Training a separate NN for each of the x, y and z directions

##### Repeat the process 5 times and observe the difference w.r.t. the contrast and CNR

In [None]:
total_steps = 3000
gland_start = 0
focus_size = 128
show_cancer = False
weighted = False
display = False
do_detection = False
sigma_est = 2
scale = 1
patch_kw = dict(patch_size=3,      # 5x5 patches
                patch_distance=3)
metrics = ['C', 'CNR']
filename = '../experiments/f.csv'
with open(filename, 'w') as f:
    f.write('seed,patient,direction,epoch,image,metric,performance\n')

for seed in range(5):
    torch.manual_seed(seed)
    for case in cases:
        _slice = case.cancer_slice
        if show_cancer : #TODO add titles and pt_id descriptions
            orig = case.dwi[:, :, _slice, :].mean(-1)
            center_coordinates = case.cancer_loc[::-1]
            height, width = orig.shape
            img = orig
            img = gray2rgb(img*255/img.max())
            img = np.ascontiguousarray(img, dtype=np.uint8)

            cv2.circle(img, center_coordinates, radius, color, thickness)
            plt.figure()
            plt.imshow(img, cmap='gray')
            
        predicted_XYZ = []
        original_XYZ = []
        directions = ['x', 'y', 'z']
        for direction in range(3):  # gradient directions x, y, z
            ends = np.cumsum(case.acquisitions)
            starts = ends - case.acquisitions
            
            # Create a dataset for training SIREN
            img_dataset = []
            for acq in range(starts[direction], ends[direction]):
                img = case.dwi[gland_start : gland_start + focus_size,
                               gland_start : gland_start + focus_size,
                               _slice,
                               acq]
                img_dataset.append(Image.fromarray(img))

            dataset = ImageFitting_set(img_dataset)
            orig = dataset.mean
            pt_no = case.pt_id.split('-')[-1]

                
            original_XYZ.append(orig)
            dataloader = DataLoader(dataset, batch_size=1, pin_memory=True, num_workers=0)
            img_siren = Siren(in_features=2, out_features=1, hidden_features=128, 
                         hidden_layers=2, outermost_linear=True)
            img_siren.cuda()
            torch.cuda.empty_cache()
            optim = torch.optim.Adam(lr=0.0003, params=img_siren.parameters())
            ctr = 0
            for step in tqdm(range(total_steps)):
                size = dataset.shape
                for sample in range(len(dataset)):                    
                    ground_truth, model_input  = dataset.pixels[sample], dataset.coords[sample]
                    ground_truth, model_input = ground_truth.cuda(), model_input.cuda()
                    model_output, coords = img_siren(model_input)
                    if weighted:
                        weights = ground_truth/ground_truth.sum()
                        weights -= weights.min()
                        weights += 0.000001
                    else:
                        weights = 1
                    loss = weights*(model_output - ground_truth)**2
                    loss = loss.mean()
                    optim.zero_grad()
                    loss.backward()
                    optim.step()
                if not step % seg:
                    coords2 = get_mgrid(size[0]*scale, 2).cuda()
                    superres, _ = img_siren(coords2)
                    pr = superres.cpu().view(scale*size[0], scale* size[1]).detach().numpy()
                    if ctr < 50:
                        predicted = superres.cpu().view(scale*size[0], scale* size[1]).detach().numpy()
                        out_img = predicted
                    else:
                        predicted += superres.cpu().view(scale*size[0], scale* size[1]).detach().numpy()
                        out_img = predicted/(ctr-49)
                    ctr += 1
                    nlm = denoise_nl_means(out_img, h=1.15 * sigma_est, fast_mode=True, **patch_kw)
                    orig2 = rescale(orig, scale, anti_aliasing=False)
                    images = {'mean':orig2, 'reconst':pr, 'superres':out_img, 'NLM':nlm}

                    with open(filename, 'a') as f:
                        for image in images.keys():
                            for inx, metric in enumerate(metrics):
                                f.write('{},{},{},{},{},{},{}\n'.format(seed, pt_no, directions[direction], step,
                                                                        image, metric,
                                                                        calculate_contrast(case, 
                                                                                           scale,
                                                                                           images[image],
                                                                                           gland_start)[inx]))

            predicted_XYZ.append(out_img)
            
        predicted = sum(predicted_XYZ)/len(predicted_XYZ)
        orig = sum(original_XYZ)/len(original_XYZ)    
        noisy = predicted
        denoise = denoise_nl_means(noisy, h=1.15 * sigma_est, fast_mode=False,
                                       **patch_kw)
        nlm = denoise
        out_img = noisy
        with open(filename, 'a') as f:
            for image in images.keys():
                for inx, metric in enumerate(metrics):
                    f.write('{},{},{},{},{},{},{}\n'.format(seed, pt_no, 'x+y+z', ((total_steps-1)//seg)*seg,
                                                            image, metric,
                                                            calculate_contrast(case,
                                                                               scale,
                                                                               images[image],
                                                                               gland_start)[inx]))

        if display:

            fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 15),
                                   sharex=True, sharey=True)


            ax[0].imshow(rescale(orig, scale, anti_aliasing=False), cmap='gray')
            ax[0].axis('off')
            ax[0].set_title('original')
            ax[1].imshow(noisy, cmap='gray')
            ax[1].axis('off')
            ax[1].set_title('superres')
            ax[2].imshow(denoise, cmap='gray')
            ax[2].axis('off')
            ax[2].set_title('superres + NLM')

            fig.tight_layout()
            plt.show()
        
        if do_detection:
            a = rescale(orig, 1, anti_aliasing=True)
            denoise_fast = denoise_nl_means(a, h=0.8 * sigma_est, fast_mode=True,
                                            **patch_kw)
            (thresh, blackAndWhiteImage) = cv2.threshold(a, a.max()*0.95, 255, cv2.THRESH_BINARY)
            (thresh, blackAndWhiteImage2) = cv2.threshold(noisy, noisy.max()*0.95, 255, cv2.THRESH_BINARY)
            (thresh, blackAndWhiteImage3) = cv2.threshold(denoise, denoise.max()*0.95, 255, cv2.THRESH_BINARY)

            fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 15),
                                   sharex=True, sharey=True)


            ax[0].imshow(rescale(a, scale, anti_aliasing=False), cmap='gray')
            ax[0].imshow(rescale(blackAndWhiteImage, scale, anti_aliasing=True), 'hot',alpha=0.5)
            ax[0].axis('off')
            ax[0].axis('off')
            ax[0].set_title('original')

            ax[1].imshow(noisy, cmap='gray')
            ax[1].imshow(blackAndWhiteImage2,'hot', alpha=0.5)
            ax[1].axis('off')
            ax[1].axis('off')
            ax[1].set_title('superres')
            ax[2].imshow( denoise, cmap='gray')
            ax[2].imshow(blackAndWhiteImage3, 'hot', alpha=0.5)
            ax[2].axis('off')
            ax[2].set_title('superres + NLM')



### Save 3X reconstruction images

In [None]:

total_steps = 3000
gland_start = 0
focus_size = 128
scale = 3
show_cancer = False
weighted = False
display = True
sigma_est = 2
patch_kw = dict(patch_size=3,      # 5x5 patches
                patch_distance=3)

out_folder = '../output_images/'

for case in cases:
    _slice = case.cancer_slice
    if show_cancer : #TODO add titles and pt_id descriptions
        orig = case.dwi[:, :, _slice, :].mean(-1)
        center_coordinates = case.cancer_loc[::-1]
        height, width = orig.shape
        img = orig
        img = gray2rgb(img*255/img.max())
        img = np.ascontiguousarray(img, dtype=np.uint8)

        cv2.circle(img, center_coordinates, radius, color, thickness)
        plt.figure()
        plt.imshow(img, cmap='gray')

    predicted_XYZ = []
    original_XYZ = []
    directions = ['x', 'y', 'z']
    for direction in range(3):  # gradient directions x, y, z
        ends = np.cumsum(case.acquisitions)
        starts = ends - case.acquisitions

        # Create a dataset for training SIREN
        img_dataset = []
        for acq in range(starts[direction], ends[direction]):
            img = case.dwi[gland_start : gland_start + focus_size,
                           gland_start : gland_start + focus_size,
                           _slice,
                           acq]
            img_dataset.append(Image.fromarray(img))

        dataset = ImageFitting_set(img_dataset)
        orig = dataset.mean
        pt_no = case.pt_id.split('-')[-1]

        original_XYZ.append(orig)
        dataloader = DataLoader(dataset, batch_size=1, pin_memory=True, num_workers=0)
        img_siren = Siren(in_features=2, out_features=1, hidden_features=128, 
                     hidden_layers=2, outermost_linear=True)
        img_siren.cuda()
        torch.cuda.empty_cache()
        optim = torch.optim.Adam(lr=0.0003, params=img_siren.parameters())
        ctr = 0
        for step in range(total_steps):
            size = dataset.shape
            for sample in range(len(dataset)):                    
                ground_truth, model_input  = dataset.pixels[sample], dataset.coords[sample]
                ground_truth, model_input = ground_truth.cuda(), model_input.cuda()
                model_output, coords = img_siren(model_input)
                if weighted:
                    weights = ground_truth/ground_truth.sum()
                    weights -= weights.min()
                    weights += 0.000001
                else:
                    weights = 1
                loss = weights*(model_output - ground_truth)**2
                loss = loss.mean()
                optim.zero_grad()
                loss.backward()
                optim.step()
            if not step % seg:
                coords2 = get_mgrid(size[0]*scale, 2).cuda()
                superres, _ = img_siren(coords2)
                pr = superres.cpu().view(scale*size[0], scale* size[1]).detach().numpy()
                if ctr < 50:
                    predicted = superres.cpu().view(scale*size[0], scale* size[1]).detach().numpy()
                    out_img = predicted
                else:
                    predicted += superres.cpu().view(scale*size[0], scale* size[1]).detach().numpy()
                    out_img = predicted/float(ctr-49)
                ctr += 1
                nlm = denoise_nl_means(out_img, h=1.15 * sigma_est, fast_mode=True, **patch_kw)


        predicted_XYZ.append(out_img)
        filename = os.path.join(out_folder, 'sr1_exp_1_' + pt_no + '_mean_' + directions[direction] + '_wide.dcm')
        save_dicom(orig, filename)
        filename = os.path.join(out_folder, 'sr1_exp_1_' + pt_no + '_super_' + directions[direction] + '_wide.dcm')
        save_dicom(out_img, filename)

    
    predicted = sum(predicted_XYZ)/len(predicted_XYZ)
    orig = sum(original_XYZ)/len(original_XYZ)
    
    nlm = denoise_nl_means(predicted, h=1.15 * sigma_est, fast_mode=False,
                                   **patch_kw)

    filename = os.path.join(out_folder, 'sr1_exp_1_' + pt_no + '_mean_wide.dcm')
    save_dicom(orig, filename)
    filename = os.path.join(out_folder, 'sr1_exp_1_' + pt_no + '_super_wide.dcm')
    save_dicom(predicted, filename)
    filename = os.path.join(out_folder, 'sr1_exp_1_' + pt_no + '_NLM_wide.dcm')
    save_dicom(nlm, filename)

## 2. Use all acquisitions

In [None]:
out_folder = '../experiments/'

total_steps = 3000
seg = 50
gland_start = 40
focus_size = 50
weighted = True
sigma_est = 2
hidden_layers = 2
hidden_features = 128
scale = 1
patch_kw = dict(patch_size=3, patch_distance=3)

metrics = ['C', 'CNR']

method_name = 'sr1'
exp_no = 5
filename = os.path.join(out_folder, method_name + '_exp_' + str(exp_no) + '.csv')
with open(filename, 'w') as f:
    f.write('seed,patient,direction,epoch,image,metric,performance\n')


for seed in range(5):
    torch.manual_seed(seed)
    for case in cases:
        _slice = case.cancer_slice
        ends =sum(case.acquisitions)
            
        # Create a dataset for training SIREN
        img_dataset = []
        for acq in range(ends):
            img = case.dwi[gland_start : gland_start + focus_size,
                           gland_start : gland_start + focus_size,
                           _slice,
                           acq]
            img_dataset.append(Image.fromarray(img))

        dataset = ImageFitting_set(img_dataset)
        orig = dataset.mean
        pt_no = case.pt_id.split('-')[-1]

        dataloader = DataLoader(dataset, batch_size=1, pin_memory=True, num_workers=0)
        img_siren = Siren(in_features=2, out_features=1, hidden_features=128, hidden_layers=2)
        img_siren.cuda()
        torch.cuda.empty_cache()
        optim = torch.optim.Adam(lr=0.0003, params=img_siren.parameters())
        ctr = 0
        for step in tqdm(range(total_steps)):
            size = dataset.shape
            for sample in range(len(dataset)):                    
                ground_truth, model_input  = dataset.pixels[sample], dataset.coords[sample]
                ground_truth, model_input = ground_truth.cuda(), model_input.cuda()
                model_output, coords = img_siren(model_input)
                weights = 1
                loss = weights*(model_output - ground_truth)**2
                loss = loss.mean()
                optim.zero_grad()
                loss.backward()
                optim.step()
            if not step % seg:
                coords2 = get_mgrid(size[0]*scale, 2).cuda()
                superres, _ = img_siren(coords2)
                pr = superres.cpu().view(scale*size[0], scale* size[1]).detach().numpy()
                if ctr < 5:
                    predicted = superres.cpu().view(scale*size[0], scale* size[1]).detach().numpy()
                    out_img = predicted
                else:
                    predicted += superres.cpu().view(scale*size[0], scale* size[1]).detach().numpy()
                    out_img = predicted/(ctr-4)
                ctr += 1
                nlm = denoise_nl_means(out_img, h=1.15 * sigma_est, fast_mode=True, **patch_kw)
                orig2 = rescale(orig, scale, anti_aliasing=False)
                images = {'mean':orig2, 'reconst':pr, 'superres':out_img, 'NLM':nlm}

                with open(filename, 'a') as f:
                    for image in images.keys():
                        for inx, metric in enumerate(metrics):
                            f.write('{},{},{},{},{},{},{}\n'.format(seed, pt_no, 'all', step,
                                                                    image, metric,
                                                                    calculate_contrast(case, 
                                                                                       scale,
                                                                                       images[image],
                                                                                       gland_start)[inx]))


In [None]:
total_steps = 3000
gland_start = 0
focus_size = 128
show_cancer = True
weighted = False
display = False
do_detection = False
sigma_est = 2
scale = 1
patch_kw = dict(patch_size=3,      # 5x5 patches
                patch_distance=3)
metrics = ['C', 'CNR']
filename = '../experiments/f.csv'
with open(filename, 'w') as f:
    f.write('seed,patient,direction,epoch,image,metric,performance\n')


for case in cases:
    _slice = case.cancer_slice
    if show_cancer : #TODO add titles and pt_id descriptions
        orig = case.dwi[:, :, _slice, :].mean(-1)
        center_coordinates = case.cancer_loc[::-1]
        height, width = orig.shape
        img = orig
        img = gray2rgb(img*255/img.max())
        img = np.ascontiguousarray(img, dtype=np.uint8)

        cv2.circle(img, center_coordinates, radius, color, thickness)
        plt.figure()
        plt.title(case.pt_id)
        plt.imshow(img, cmap='gray')

In [None]:
import scipy.io as sio
total_steps = 3000
gland_start = 0
focus_size = 128
scale = 3
show_cancer = False
weighted = False
display = True
sigma_est = 2
patch_kw = dict(patch_size=3,      # 5x5 patches
                patch_distance=3)

out_folder = '../output_images/'

pt_id = '18-1681-08'
cancer_loc = (79, 71)
collateral_loc = (79, 59)
cancer_slice = 10
acquisitions = (8, 7, 8)
pt_no = pt_id.split('-')[-1]
filename = '../anon_data/pat' + pt_no + '_alldata.mat'
dwi = sio.loadmat(filename)['data']
filename = '../anon_data/pat' + pt_no + '_mean_b0.mat'
b0 = sio.loadmat(filename)['data_mean_b0']
filename = '../anon_data/pat' + pt_no + '_ADC_alldata_mm.mat'
adc = sio.loadmat(filename)['ADC_alldata_mm']

In [None]:
plt.imshow(adc[:, : ,cancer_slice, :].mean(-1), cmap='gray')

In [None]:
plt.imshow(b0[:, :, cancer_slice], cmap='gray')

In [None]:
mean_1500 = dwi[:, : ,cancer_slice, :].mean(-1)

In [None]:
adc = -np.log(mean_1500/(b0[:, :, cancer_slice] + 1e-7))
adc /=1500

In [None]:
predicted_XYZ = []
directions = ['x', 'y', 'z']
for direction in range(3):  # gradient directions x, y, z
    ends = np.cumsum(acquisitions)
    starts = ends - acquisitions

    # Create a dataset for training SIREN
    img_dataset = []
    for acq in range(starts[direction], ends[direction]):
        img = dwi[gland_start : gland_start + focus_size,
                       gland_start : gland_start + focus_size,
                       cancer_slice,
                       acq]
        img_dataset.append(Image.fromarray(img))

    dataset = ImageFitting_set(img_dataset)
    orig = dataset.mean

    dataloader = DataLoader(dataset, batch_size=1, pin_memory=True, num_workers=0)
    img_siren = Siren(in_features=2, out_features=1, hidden_features=128, 
                 hidden_layers=2)
    img_siren.cuda()
    torch.cuda.empty_cache()
    optim = torch.optim.Adam(lr=0.0003, params=img_siren.parameters())
    ctr = 0
    for step in range(total_steps):
        size = dataset.shape
        for sample in range(len(dataset)):                    
            ground_truth, model_input  = dataset.pixels[sample], dataset.coords[sample]
            ground_truth, model_input = ground_truth.cuda(), model_input.cuda()
            model_output, coords = img_siren(model_input)
            if weighted:
                weights = ground_truth/ground_truth.sum()
                weights -= weights.min()
                weights += 0.000001
            else:
                weights = 1
            loss = weights*(model_output - ground_truth)**2
            loss = loss.mean()
            optim.zero_grad()
            loss.backward()
            optim.step()
        if not step % seg:
            coords2 = get_mgrid(size[0]*scale, 2).cuda()
            superres, _ = img_siren(coords2)
            pr = superres.cpu().view(scale*size[0], scale* size[1]).detach().numpy()
            if ctr < 50:
                predicted = superres.cpu().view(scale*size[0], scale* size[1]).detach().numpy()
                out_img = predicted
            else:
                predicted += superres.cpu().view(scale*size[0], scale* size[1]).detach().numpy()
                out_img = predicted/float(ctr-49)
            ctr += 1
            nlm = denoise_nl_means(out_img, h=1.15 * sigma_est, fast_mode=True, **patch_kw)


    predicted_XYZ.append(out_img)

predicted = sum(predicted_XYZ)/len(predicted_XYZ)


nlm = denoise_nl_means(predicted, h=1.15 * sigma_est, fast_mode=False,
                               **patch_kw)


filename = os.path.join(out_folder, 'sr1_exp_5_' + pt_no + '_super_wide.dcm')
save_dicom(predicted, filename)
filename = os.path.join(out_folder, 'sr1_exp_5_' + pt_no + '_NLM_wide.dcm')
save_dicom(nlm, filename)

In [None]:
predicted = sum(predicted_XYZ)/len(predicted_XYZ)


nlm = denoise_nl_means(predicted, h=1.15 * sigma_est, fast_mode=False,
                               **patch_kw)

scale = 3
filename = os.path.join(out_folder, 'sr1_exp_5_' + pt_no + '_super.dcm')
save_dicom(predicted, filename)

b0_dataset = []
img = b0[gland_start : gland_start + focus_size,
                   gland_start : gland_start + focus_size,
                   cancer_slice]
b0_dataset.append(Image.fromarray(img))

dataset = ImageFitting_set(b0_dataset)

dataloader = DataLoader(dataset, batch_size=1, pin_memory=True, num_workers=0)
img_siren = Siren(in_features=2, out_features=1, hidden_features=128, 
             hidden_layers=2)
img_siren.cuda()
torch.cuda.empty_cache()
optim = torch.optim.Adam(lr=0.0003, params=img_siren.parameters())
ctr = 0
for step in range(24000):
    size = dataset.shape
    for sample in range(len(dataset)):                    
        ground_truth, model_input  = dataset.pixels[sample], dataset.coords[sample]
        ground_truth, model_input = ground_truth.cuda(), model_input.cuda()
        model_output, coords = img_siren(model_input)
        if weighted:
            weights = ground_truth/ground_truth.sum()
            weights -= weights.min()
            weights += 0.000001
        else:
            weights = 1
        loss = weights*(model_output - ground_truth)**2
        loss = loss.mean()
        optim.zero_grad()
        loss.backward()
        optim.step()
    if not step % seg:
        coords2 = get_mgrid(size[0]*scale, 2).cuda()
        superres, _ = img_siren(coords2)
        pr = superres.cpu().view(scale*size[0], scale* size[1]).detach().numpy()
        if ctr < 5:
            _predicted = superres.cpu().view(scale*size[0], scale* size[1]).detach().numpy()
            out_img = _predicted
        else:
            _predicted += superres.cpu().view(scale*size[0], scale* size[1]).detach().numpy()
            out_img = _predicted/float(ctr-4)
        ctr += 1
        nlm = denoise_nl_means(out_img, h=1.15 * sigma_est, fast_mode=True, **patch_kw)
    

filename = os.path.join(out_folder, 'sr1_exp_5_' + pt_no + '_super_b0.dcm')
save_dicom(out_img, filename)

In [None]:
filename = os.path.join(out_folder, 'sr1_exp_5_' + pt_no + '_b0.dcm')
save_dicom(img, filename)

In [None]:
predicted -= predicted.min()
out_img -= out_img.min()

In [None]:
adc = -np.log(predicted/(out_img + 1e-7))
adc /=9000

In [None]:
plt.imshow(adc, cmap='gray')

In [None]:
filename = os.path.join(out_folder, 'sr1_exp_5_' + pt_no + '_adc.dcm')
save_dicom(adc[:,:, cancer_slice,:].mean(axis=-1), filename)

In [None]:
adc.max()

In [None]:
filename = '../anon_data/pat' + pt_no + '_ADC_alldata_mm.mat'
adc = sio.loadmat(filename)['ADC_alldata_mm']

In [None]:
plt.imshow(adc[:,:, cancer_slice,:].mean(axis=-1), cmap='gray')