# NPPC applied on MR image reconstruction

In [1]:
%matplotlib notebook
import os, sys
import logging
import numpy as np
import torch
import torchvision
from utils import complex_utils as cplx
from utils.datasets import SliceData
from MoDL_single import UnrolledModel
from pathlib import Path

from datasets import Namespace, create_data_loaders
from modl_infrastructure import MoDLParams, MoDLWrapper
from MoDLsinglechannel.demo_modl_singlechannel.subsample_fastmri import SaveableMask
from alon.config import *

from nppc_infrastructure import NPPCParams, NPPCForMoDLWrapper


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
%load_ext autoreload
%autoreload 0



## Load modules

### 1. Undersampling mask

In [2]:
# Load a mask
#   Ask for a UUID, else take the most recently-created one
mask_name = input('Enter mask name (if skipping, taking the most recently created): ')
if not mask_name.replace(' ', ''):
    latest_file = max(Path(MASKS_DIR).glob('*.pkl'), key=os.path.getctime)
    mask_name = latest_file.stem
    print(f'User skipped, thus loaded the latest mask {mask_name}')

mask = SaveableMask(masks_dir=MASKS_DIR)
mask.load(mask_name)

User skipped, thus loaded the latest mask Initial Poisson acc 6 calib (28 28)


ModuleNotFoundError: No module named 'alon.MoDLsinglechannel'

### 2. MoDL

In [4]:
modl_params = MoDLParams()
mr_data_loader = create_data_loaders(modl_params, mask)
modl_wrapper = MoDLWrapper(modl_params, MODL_CKPT_DIR)
single_MoDL = modl_wrapper.load(
    MODL_CKPT_DIR, model_name='MoDL 4-step regul 0.01',
    epoch='last'
)

	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002452.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002467.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002469.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002512.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002629.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002641.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_201_6002876.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_201_6002888.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_201_6002914.h5.npy...
	/mnt/c/Users/along/brain_multicoil_t

**Optional test run of MoDL**

In [4]:
modl_wrapper.test_run(single_MoDL, mr_data_loader)

iter=15...
		Step 0/4...
		Step 1/4...
		Step 2/4...
		Step 3/4...


<IPython.core.display.Javascript object>

ValueError: invalid literal for int() with base 10: ''

### 3. NPPC

In [5]:
nppc_params = NPPCParams()
nppc_checkpoint_dir = Path('/home/alon_granek/PythonProjects/NPPC/alon/nppc_checkpoints')
nppc_wrapper = NPPCForMoDLWrapper(nppc_params, single_MoDL, save_dir=nppc_checkpoint_dir)

nppc_name: str = {
    # No DC loss, original NPPC edited for dual-channel data
    '1.0': 'modl_nppc_dc DUAL NO-DC fixed',
    
    # With DC loss                                              (July meeting)
    '1.1': 'modl_nppc_dc DUAL DC',
    
    # DC as an activation
    '2.0': 'modl_nppc_dc DUAL MaskAct 10dirs',
    
    # DC as an activation - but prior to Gram Schmidt
    '2.1': 'modl_nppc_dc DUAL MaskActPreGS',
    
    # Alternate between GS and DC Activation                    (Current as of August)
    '2.2': 'modl_nppc_dc DUAL AlterProj',
}[
    # --- Chosen ---
    '2.2'
]

nppc_net = nppc_wrapper.load(nppc_name)

## Run on a file

In [6]:
for iter, data in enumerate(mr_data_loader):
    # Unpack data: undersampled k-space, target image, k-space mask
    y, x_true, mask_tensor = (obj.to(device) for obj in data)
    break

# """ Or use a temporarily saved file """
# import pickle
# with open('/home/alon_granek/PythonProjects/NPPC/temp_data.pkl', 'rb') as file:
#     y, x_true, mask = pickle.load(file)

In [7]:
preproc = lambda x, pos: x[pos].permute(0, 3, 1, 2)

with torch.no_grad():
    x_recon, x_intermed = single_MoDL(y.float(), mask=mask_tensor, return_steps=True)
    w_mat_intermed = nppc_net(preproc(x_intermed, 0), preproc(x_intermed, -1), mask_tensor)

		Step 0/4...
		Step 1/4...
		Step 2/4...
		Step 3/4...


## Analysis

### _k-space error distribution_

In [30]:
import matplotlib.pyplot as plt
from alon.fastmri_preprocess import fftc
from alon.MoDLsinglechannel.demo_modl_singlechannel.utils.transforms import fft2, ifft2

# Chosen example scan, out of the ones the MoDL-NPPC chain was applied on (by default 0 and 1)
example = 1


fig, ax = plt.subplots(2, 4, sharex=True, sharey=True)
vmin = -30

ax[0, 0].set_title('Target')
ax[0, 0].imshow(10 * np.log10(cplx.abs(ifft2(x_true[example]))), cmap='inferno', origin='lower', vmin=vmin)
ax[1, 0].imshow(cplx.abs(x_true[example]), cmap='Greys_r', origin='lower', vmin=0, vmax=1.5)

ax[0, 1].set_title('Measured')
im = ax[0, 1].imshow(10 * np.log10(np.maximum(cplx.abs(fft2(x_intermed[0][example])), 10 ** (vmin / 10))), cmap='inferno', origin='lower', vmin=vmin)
ax[1, 1].imshow(cplx.abs(x_intermed[0][example]), cmap='Greys_r', origin='lower', vmin=0, vmax=1.5)

ax[0, 2].set_title('Reconstructed')
ax[0, 2].imshow(10 * np.log10(cplx.abs(ifft2(x_recon[example]))), cmap='inferno', origin='lower', vmin=vmin)
ax[1, 2].imshow(cplx.abs(x_intermed[-1][example]), cmap='Greys_r', origin='lower', vmin=0, vmax=1.5)

ax[0, 3].set_title('Error')
ax[0, 3].imshow(10 * np.log10(cplx.abs(ifft2(x_recon[example] - x_intermed[0][example]))), cmap='inferno', origin='lower', vmin=vmin)
ax[1, 3].imshow(cplx.abs(x_intermed[-1][example] - x_intermed[0][example]), cmap='Greys_r', origin='lower', vmin=0, vmax=1.5)
    
fig.set_size_inches(30, 12)
fig.subplots_adjust(right=0.8, wspace=0.05)
cbar_ax = fig.add_axes([0.83, 0.3, 0.015, 0.4])
fig.colorbar(mappable=im, cax=cbar_ax)

<IPython.core.display.Javascript object>

<matplotlib.colorbar.Colorbar at 0x7fd76732c3d0>

In [ ]:
plt.savefig(FIGURE_SAVE_DIR.joinpath('k-space error example.png'))

### _PCs_

In [11]:
import matplotlib.pyplot as plt

# Chosen example scan, out of the ones the MoDL-NPPC chain was applied on (by default 0 and 1)
sample = 0

# Showing all PCs
fig, ax = plt.subplots(3, 5, sharex=True, sharey=True)
for c in range(5):
    eigval = (torch.linalg.norm(w_mat_intermed.detach()[sample, c, 0]) + torch.linalg.norm(w_mat_intermed.detach()[sample, c, 1])) / torch.prod(torch.tensor(w_mat_intermed.detach()[sample, c, 0].size()))
    ax[0, c].set_title('Eig. val.: {:e}'.format(eigval))
    for re_im in [0, 1]:
        im = ax[re_im, c].imshow(w_mat_intermed.detach()[sample, c, re_im], vmin=-0.3, vmax=0.3, cmap='RdBu', origin='lower')

    # k-space
    pc_images = w_mat_intermed.detach()[sample, c].permute(1, 2, 0)[None]
    ax[2, c].imshow(10 * np.log10(np.maximum(cplx.abs(fft2(pc_images))[0], 1e-4)), cmap='inferno', origin='lower')

fig.set_size_inches(30, 15)
fig.subplots_adjust(right=0.8, wspace=0.05)
cbar_ax = fig.add_axes([0.83, 0.3, 0.015, 0.4])
fig.colorbar(mappable=im, cax=cbar_ax)

<IPython.core.display.Javascript object>

<matplotlib.colorbar.Colorbar at 0x7fd79ee8c1f0>

In [ ]:
plt.savefig(FIGURE_SAVE_DIR.joinpath('pc list.png'))

### _Relative k-space power distribution of PCs_

In [10]:
from scipy.ndimage import median_filter

# Chosen example scan, out of the ones the MoDL-NPPC chain was applied on (by default 0 and 1)
sample = 0

fts = torch.stack([cplx.abs(fft2(w_mat_intermed.detach()[sample, c].permute(1, 2, 0)[None]))[0] for c in range(nppc_params.n_dirs)])
normed = fts / fts.sum(dim=0, keepdims=True)

fig, ax = plt.subplots(1, 5, sharex=True, sharey=True)
for c in range(5):
    ax[c].set_title(f'PC {c + 1}')
    im = ax[c].imshow(median_filter(10 * np.log10(normed[c]), [8, 8]), cmap='plasma', origin='lower', vmin=-9)
    
fig.set_size_inches(20, 5)
fig.subplots_adjust(right=0.8, wspace=0.05)
cbar_ax = fig.add_axes([0.83, 0.3, 0.015, 0.4])
fig.colorbar(mappable=im, cax=cbar_ax)

<IPython.core.display.Javascript object>

  im = ax[c].imshow(median_filter(10 * np.log10(normed[c]), [8, 8]), cmap='plasma', origin='lower', vmin=-9)


<matplotlib.colorbar.Colorbar at 0x7fd88f4720e0>

In [ ]:
plt.savefig(FIGURE_SAVE_DIR.joinpath('pcs relative k-space power.png'))

-------------------------------------------------------------------------------------------

In [9]:
# import pickle
# # 
# save_path = '/home/alon_granek/PythonProjects/NPPC/temp_data.pkl'
# with open(save_path, 'wb') as file:
#     pickle.dump(data, file)

NameError: name 'data' is not defined

In [13]:
# """ Load models of different steps """
# # Show how they behave on a given file
# nppc_nets = [nppc_wrapper.load(f'modl_nppc_dc pos {pos}') for pos in range(6)]#single_MoDL.num_grad_steps - 1)]
# input_positions = np.arange(0, 6) #single_MoDL.num_grad_steps - 1)
# mmse_positions = input_positions + 1
# preproc = lambda x, pos: cplx.abs(x[pos])[:, None, ...]
# 
# for iter, data in enumerate(mr_data_loader):
#     # Unpack training data: subsampled k-space, target image, k-space mask
#     y, x_true, mask = (obj.to(device) for obj in data)
#     break
# 
# # Reconstruct, apply NPPC, calculate image-domain error
# with torch.no_grad():
#     x_recon, x_intermed = single_MoDL(y.float(), mask=mask, return_steps=True)
#     w_mat_intermed = [nppc_nets[i](preproc(x_intermed, i), preproc(x_intermed, j)) for i, j in zip(input_positions, mmse_positions)]


		Step 0/8...
		Step 1/8...
		Step 2/8...
		Step 3/8...
		Step 4/8...
		Step 5/8...
		Step 6/8...
		Step 7/8...


In [14]:
# """ Reconstruct covariance """
# 
# # W D W
# # D is the eigenvectors
# example = 0
# w_mat = w_mat_intermed[example]
# 
# w_mat_ = w_mat.flatten(3)
# w_norms = w_mat_.norm(dim=3)
# w_hat_mat = w_mat_ / w_norms[:, :, None]
# 
# # D = torch.diag(w_norms[:, 0])
# # W = w_hat_mat[:, 0]
# # DW = torch.einsum('ck, ki -> ci', D, W)
# # WDW = torch.einsum('ci, cj -> ij', W, DW)

In [15]:
# # WORK ON TOTAL COVARIANCE CALCULATION, IN CASE OF USING NPPC ON MULTIPLE STEPS
# 
# #todo Turn this into a differentiable PCWalk object. PC vectors' positions in the shared basis are pre-determined, the function merely places them together, then does Gram Schmidt to find a new basis, then finds the transfer to that basis, then sum covariances and de-project to the basis.
# 
# #todo We know: lstsq is backprop-able by autograd.
# 
# #Till 18:30 work on it
# 
# 
# 
# from nppc.nppc import gram_schmidt
# from torch.nn.functional import pad
# 
# example = 0
# w_mat_steps = tuple(w[example, :, 0].flatten(1) for w in w_mat_intermed)
# w_norms_steps = tuple(torch.linalg.norm(w, axis=1) for w in w_mat_steps)
# 
# # Get shared basis
# w_mat_stack = torch.cat(w_mat_steps)
# shared = gram_schmidt(w_mat_stack[None])[0]
# 
# # Define coordinates of bases in these
# n_steps = len(w_mat_intermed)           #todo This is going to be the number of steps in MoDL
# left_padding = lambda step: step * nppc_params.n_dirs
# right_padding = lambda step: (n_steps - step - 1) * nppc_params.n_dirs
# left_pads = list(map(left_padding, range(n_steps)))
# right_pads = list(map(right_padding, range(n_steps)))
# pad_mat = lambda W, left, right: pad(W, (0, 0, left, right))
# w_mat_steps = tuple(map(pad_mat, w_mat_steps, left_pads, right_pads))
# w_norms_steps = tuple(pad(w_norms_steps[step], (left_pads[step], right_pads[step])) for step in range(n_steps))
# 
# 
# # selection = lambda step: (step * nppc_params.n_dirs, (step + 1) * nppc_params.n_dirs)
# # sels = list(map(selection, range(n_steps)))
# # w_mat_steps = tuple(pad(w_mat_stack[s[0]:s[1]], (0, 0, l, r)) for s, l, r in zip(sels, left_pads, right_pads))
# 
# # Represent covariances with that basis
# covs_by_shared = list()
# for step, norms in enumerate(w_norms_steps):
#     trans_pcs_to_shared = torch.linalg.lstsq(shared.T, w_mat_steps[step].T).solution.T
#     cov_by_pcs = torch.diag(norms)
#     covs_by_shared.append(trans_pcs_to_shared.T @ cov_by_pcs @ trans_pcs_to_shared)
# 
# # Get a total covariance
# total_cov_by_shared = torch.stack(covs_by_shared).sum(0)
# 
# # Get total PCA
# pca_vals, pca_vecs_by_shared = torch.linalg.eigh(total_cov_by_shared)
# argsort_vecs = torch.argsort(-pca_vals)
# pca_vals = pca_vals[argsort_vecs]
# pca_vecs_by_shared = pca_vecs_by_shared[argsort_vecs]
# 
# # Transform to pixel domain
# dim1, dim2 = w_mat_intermed[0].shape[-2:]
# pca_vecs = torch.einsum('bi, bc -> ci', shared, pca_vecs_by_shared).reshape(shared.shape[0], dim1, dim2).flip(0)
# # trans_shared_to_pixel = torch.linalg.lstsq(shared.T, )
# # pca_vec = 
# 
# 
# 
# #     trans_0_to_shared = torch.linalg.lstsq(basis, w_mat[0].T).solution.T
# # cov_by_0 = torch.diag(w_norms[0])
# # cov_from__by_shared = trans_0_to_shared.T @ cov_by_0 @ trans_0_to_shared
# # 
# 
# # U, sigma, Vt = torch.linalg.svd(w_mat_stack)
# 
# # Enforce


In [18]:
# # Animation of adding or removing the PC
# import matplotlib.pyplot as plt
# import matplotlib.animation as animation
# 
# c = 2
# example = 1
# 
# # Create a sample 3D array (e.g., 10 slices of 2D images)
# alphas = np.linspace(-1, 1, 50)
# image_series = mmse.detach()[example, 0][None] + alphas[:, None, None] * (w_mat.detach().numpy()[example, c, 0] - w_mat.detach().numpy()[example, c, 0].mean())[None]
# # mask = BrainMasker()(mmse.detach()[0, 0].numpy())
# 
# # Set up the figure and axis
# fig, ax = plt.subplots()
# im = ax.imshow(image_series[0], cmap='Greys_r', vmin=0, vmax=1.5)
# fig.set_size_inches(10, 10)
# 
# def update(frame):
#     # Update the image data
#     ax.set_title(round(alphas[frame], 2))
#     im.set_array(image_series[frame]) # * mask)
#     return [im]
# 
# # Create animation
# ani = animation.FuncAnimation(fig, update, frames=alphas.size, repeat=True, interval=30)
# 
# # Display the animation
# # plt.show()
# FFwriter = animation.FFMpegWriter(fps=10)
# ani.save('/home/alon_granek/PythonProjects/NPPC/ani3.avi', writer=FFwriter)

<IPython.core.display.Javascript object>

INFO:matplotlib.animation:Animation.save using <class 'matplotlib.animation.FFMpegWriter'>
INFO:matplotlib.animation:MovieWriter._run: running command: ffmpeg -f rawvideo -vcodec rawvideo -s 1000x1000 -pix_fmt rgba -r 10 -loglevel error -i pipe: -vcodec h264 -pix_fmt yuv420p -y /home/alon_granek/PythonProjects/NPPC/ani3.avi
