# NPPC applied on MR image reconstruction

In [1]:
%matplotlib notebook
import os, sys
import logging
import numpy as np
import torch
import torchvision
from utils import complex_utils as cplx
from utils.datasets import SliceData
from MoDL_single import UnrolledModel
from pathlib import Path

from datasets import Namespace, create_data_loaders
from modl_infrastructure import MoDLParams, MoDLWrapper
from alon.MoDLsinglechannel.demo_modl_singlechannel.subsample_fastmri import SaveableMask
from alon.config import *

from nppc_infrastructure import NPPCParams, NPPCForMoDLWrapper


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
%load_ext autoreload
%autoreload 0



## Load modules

### 1. Undersampling mask

In [2]:
# Load a mask
#   Ask for a UUID, else take the most recently-created one
mask_name = input('Enter mask name (if skipping, taking the most recently created): ')
if not mask_name.replace(' ', ''):
    latest_file = max(Path(MASKS_DIR).glob('*.pkl'), key=os.path.getctime)
    mask_name = latest_file.stem
    print(f'User skipped, thus loaded the latest mask {mask_name}')

mask = SaveableMask(masks_dir=MASKS_DIR)
mask.load(mask_name)

User skipped, thus loaded the latest mask Initial Poisson acc 6 calib (28 28)
Successfully loaded mask generator from /home/alon_granek/PythonProjects/NPPC/alon/masks/Initial Poisson acc 6 calib (28 28).pkl


### 2. MoDL

In [3]:
modl_params = MoDLParams()
modl_train_loader = create_data_loaders(modl_params, mask)
modl_checkpoint_dir = Path('/home/alon_granek/PythonProjects/NPPC/alon/checkpoints2')
modl_wrapper = MoDLWrapper(modl_params, modl_checkpoint_dir)
single_MoDL = modl_wrapper.load(modl_checkpoint_dir,
                                # model_name='Initial 4-step MoDL (2)', #'test',
                                model_name='MoDL 4-step regul 0.01',
                                epoch='last')

	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002452.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002467.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002469.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002512.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002629.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002641.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_201_6002876.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_201_6002888.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_201_6002914.h5.npy...
	/mnt/c/Users/along/brain_multicoil_t

**Optional test run of MoDL**

In [4]:
modl_wrapper.test_run(single_MoDL, modl_train_loader)

iter=15...
		Step 0/4...
		Step 1/4...
		Step 2/4...
		Step 3/4...


<IPython.core.display.Javascript object>

ValueError: invalid literal for int() with base 10: ''

### 3. NPPC

In [5]:
nppc_params = NPPCParams()
nppc_checkpoint_dir = Path('/home/alon_granek/PythonProjects/NPPC/alon/nppc_checkpoints')
nppc_wrapper = NPPCForMoDLWrapper(nppc_params, single_MoDL, save_dir=nppc_checkpoint_dir)

nppc_name: str = {
    # No DC loss, original NPPC edited for dual-channel data
    '1.0': 'modl_nppc_dc DUAL NO-DC fixed',
    
    # With DC loss                                              (July meeting)
    '1.1': 'modl_nppc_dc DUAL DC',
    
    # DC as an activation
    '2.0': 'modl_nppc_dc DUAL MaskAct 10dirs',
    
    # DC as an activation - but prior to Gram Schmidt
    '2.1': 'modl_nppc_dc DUAL MaskActPreGS',
    
    # Alternate between GS and DC Activation                    (Current as of August)
    '2.2': 'modl_nppc_dc DUAL AlterProj',
}[
    # --- Chosen ---
    '2.2'
]

nppc_net = nppc_wrapper.load(nppc_name)

## Run on a file

In [8]:
for iter, data in enumerate(modl_train_loader):
    # Unpack data: undersampled k-space, target image, k-space mask
    y, x_true, mask_tensor = (obj.to(device) for obj in data)
    break

# """ Or use a temporarily saved file """
# import pickle
# with open('/home/alon_granek/PythonProjects/NPPC/temp_data.pkl', 'rb') as file:
#     y, x_true, mask = pickle.load(file)

In [9]:
preproc = lambda x, pos: x[pos].permute(0, 3, 1, 2)

with torch.no_grad():
    x_recon, x_intermed = single_MoDL(y.float(), mask=mask_tensor, return_steps=True)
    w_mat_intermed = nppc_net(preproc(x_intermed, 0), preproc(x_intermed, -1), mask_tensor)

		Step 0/4...
		Step 1/4...
		Step 2/4...
		Step 3/4...


## Analysis

### _k-space error distribution_

In [10]:
import matplotlib.pyplot as plt
from alon.fastmri_preprocess import fftc
from alon.MoDLsinglechannel.demo_modl_singlechannel.utils.transforms import fft2, ifft2

example = 0

fig, ax = plt.subplots(1, 3, sharex=True, sharey=True)
# ax[0].imshow(10 * np.log10(cplx.abs(fft2(x_intermed[0][example]))), cmap='inferno', origin='lower')
im = ax[0].imshow(10 * np.log10(cplx.abs(fft2(x_intermed[0][example]))), cmap='inferno', origin='lower', vmin=-30)
ax[1].imshow(10 * np.log10(cplx.abs(ifft2(x_recon[example]))), cmap='inferno', origin='lower', vmin=-30)
ax[2].imshow(10 * np.log10(cplx.abs(ifft2(x_recon[example] - x_intermed[0][example]))), cmap='inferno', origin='lower', vmin=-30)
             
             # Adjusting the figure size and layout
fig.set_size_inches(20, 5)
fig.subplots_adjust(right=0.8, wspace=0.05)

# Adding a separate axis for the colorbar
cbar_ax = fig.add_axes([0.83, 0.3, 0.015, 0.4])
fig.colorbar(mappable=im, cax=cbar_ax)

<IPython.core.display.Javascript object>

  im = ax[0].imshow(10 * np.log10(cplx.abs(fft2(x_intermed[0][example]))), cmap='inferno', origin='lower', vmin=-30)


<matplotlib.colorbar.Colorbar at 0x7fa66320cfa0>

### _PCs_

In [11]:
import matplotlib.pyplot as plt
from alon.fastmri_preprocess import fftc
# plt.style.use('dark_background')
sample = 1

# Showing all PCs
fig, ax = plt.subplots(3, 5, sharex=True, sharey=True)
for c in range(5):
    eigval = (torch.linalg.norm(w_mat_intermed.detach()[sample, c, 0]) + torch.linalg.norm(w_mat_intermed.detach()[sample, c, 1])) / torch.prod(torch.tensor(w_mat_intermed.detach()[sample, c, 0].size()))
    ax[0, c].set_title('Eig. val.: {:e}'.format(eigval))
    for re_im in [0, 1]:
        # im = ax[re_im, c].imshow(w_mat_intermed.detach()[sample, c, re_im] / abs(w_mat_intermed.detach()[sample, c, re_im]).max(), vmin=-0.3, vmax=0.3, cmap='RdBu', origin='lower')
        im = ax[re_im, c].imshow(w_mat_intermed.detach()[sample, c, re_im], vmin=-0.3, vmax=0.3, cmap='RdBu', origin='lower')

    # # Showing spectra instead
    image = w_mat_intermed.detach()[sample, c].permute(1, 2, 0)[None]
    ax[2, c].imshow(10 * np.log10(np.maximum(cplx.abs(fft2(image))[0], 1e-4)), cmap='inferno', origin='lower')
    
    # im = ax[c].imshow(10 * np.log10(abs(fftc(image))), cmap='inferno', origin='lower')
    # im = ax[c].imshow(np.sign(fftc(image).real) * 10 * np.log10(abs(fftc(image))), origin='lower', cmap='RdBu')
    #im = ax[c].imshow(fftc(image).real, origin='lower', cmap='RdBu', vmin=-0.05, vmax=0.05)

# Adjusting the figure size and layout
fig.set_size_inches(30, 15)
# fig.subplots_adjust(right=0.8, wspace=0.05)
fig.subplots_adjust(right=0.8, wspace=0.05)

# Adding a separate axis for the colorbar
cbar_ax = fig.add_axes([0.83, 0.3, 0.015, 0.4])
fig.colorbar(mappable=im, cax=cbar_ax)

# Showing the figure
plt.show()

# fig.savefig('/home/alon_granek/PythonProjects/NPPC/temp4 (5).png')
# fig.show()

<IPython.core.display.Javascript object>

### _Relative k-space power distribution of PCs_

In [ ]:
sample = 0
fts = torch.tensor([abs(fftc(w_mat_intermed.detach()[sample, c, 0])) for c in range(5)])
# normed = w_mat_intermed.detach()[sample, :, 0] / w_mat_intermed.detach()[sample, :, 0].sum(axis=0, keepdims=True)
normed = fts / fts.sum(dim=0, keepdims=True)

from scipy.ndimage import median_filter

fig, ax = plt.subplots(1, 5, sharex=True, sharey=True)
for c in range(5):
    ax[c].set_title(f'PC {c + 1}')
    im = ax[c].imshow(median_filter(10 * np.log10(normed[c]), [8, 8]), cmap='plasma', origin='lower', vmin=-9)
    
# Adjusting the figure size and layout
fig.set_size_inches(20, 5)
fig.subplots_adjust(right=0.8, wspace=0.05)

# Adding a separate axis for the colorbar
cbar_ax = fig.add_axes([0.83, 0.3, 0.015, 0.4])
fig.colorbar(mappable=im, cax=cbar_ax)

# Showing the figure
plt.show()
# fig.savefig('/home/alon_granek/PythonProjects/NPPC/kspace pcs DC.png')


...

In [ ]:
""" IN THAT ATTEMPT - TEST MODEL """
# Show how they behave on a given file
#nppc_net = nppc_wrapper.load(f'modl_nppc_dc END-TO-END NO-DC (faster)')
# nppc_net = nppc_wrapper.load(f'modl_nppc_dc END-TO-END DC test')#(l 8)')
# preproc = lambda x, pos: cplx.abs(x[pos])[:, None, ...]
# nppc_net = nppc_wrapper.load('modl_nppc_dc DUAL NO-DC fixed')
# nppc_net = nppc_wrapper.load('modl_nppc_dc DUAL DC')
# nppc_net = nppc_wrapper.load('modl_nppc_dc DUAL DC20 NullErr')

# DC as an activation
# nppc_net = nppc_wrapper.load('modl_nppc_dc DUAL MaskAct 10dirs')

# DC as an activation - but prior to Gram Schmidt
# nppc_net = nppc_wrapper.load('modl_nppc_dc DUAL MaskActPreGS')

# Alternate between GS and Fourier Activation
nppc_net = nppc_wrapper.load('modl_nppc_dc DUAL AlterProj')






""" Load new samples """
for iter, data in enumerate(modl_train_loader):
    # Unpack training data: subsampled k-space, target image, k-space mask
    y, x_true, mask = (obj.to(device) for obj in data)
    break
# """ Or use temporarily saved ones """
# import pickle
# with open('/home/alon_granek/PythonProjects/NPPC/temp_data.pkl', 'rb') as file:
#     y, x_true, mask = pickle.load(file)
# 

# Reconstruct, apply NPPC, calculate image-domain error
# from nppc_infrastructure import fourier_activation

preproc = lambda x, pos: x[pos].permute(0, 3, 1, 2) #[:, None, ...]

with torch.no_grad():
    x_recon, x_intermed = single_MoDL(y.float(), mask=mask, return_steps=True)
    w_mat_intermed = nppc_net(preproc(x_intermed, 0), preproc(x_intermed, -1), mask)
    # w_mat_intermed = fourier_activation(w_mat_intermed, mask)


# Data preparation

### 1. Mask

In [5]:
from datasets import Namespace, create_data_loaders

### 2. MoDL setup

In [8]:
from modl_infrastructure import MoDLParams, MoDLWrapper
from alon.MoDLsinglechannel.demo_modl_singlechannel.subsample_fastmri import SaveableMask
from alon.config import *

# Load a mask
#   Ask for a UUID, else take the most recently-created one
mask_name = input('Enter mask name (if skipping, taking the most recently created): ')
if not mask_name.replace(' ', ''):
    latest_file = max(Path(MASKS_DIR).glob('*.pkl'), key=os.path.getctime)
    mask_name = latest_file.stem
    print(f'User skipped, thus loaded the latest mask {mask_name}')

mask = SaveableMask(masks_dir=MASKS_DIR)
mask.load(mask_name)


modl_params = MoDLParams()
modl_train_loader = create_data_loaders(modl_params, mask)
modl_checkpoint_dir = Path('/home/alon_granek/PythonProjects/NPPC/alon/checkpoints2')
modl_wrapper = MoDLWrapper(modl_params, modl_checkpoint_dir)

User skipped, thus loaded the latest mask Initial Poisson acc 6 calib (28 28)
Successfully loaded mask generator from /home/alon_granek/PythonProjects/NPPC/alon/masks/Initial Poisson acc 6 calib (28 28).pkl
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002452.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002467.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002469.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002512.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002629.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002641.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_201_6002876.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_ba

### 3. Train MoDL, or load a MoDL model

In [9]:
""" Loading """
single_MoDL = modl_wrapper.load(modl_checkpoint_dir,
                                # model_name='Initial 4-step MoDL (2)', #'test',
                                model_name='MoDL 4-step regul 0.01',
                                epoch='last')

search_str:  Model MoDL 4-step regul 0.01  epoch *


In [5]:
""" MoDL test run """
modl_wrapper.test_run(single_MoDL, modl_train_loader)

iter=355...
		Step 0/4...
		Step 1/4...
		Step 2/4...
		Step 3/4...


<IPython.core.display.Javascript object>

Saved figure in: /home/alon_granek/PythonProjects/NPPC/alon/MoDLsinglechannel/temp1.png


<IPython.core.display.Javascript object>

Saved figure in: /home/alon_granek/PythonProjects/NPPC/alon/MoDLsinglechannel/temp2.png


ValueError: invalid literal for int() with base 10: ''

In [10]:
from nppc_infrastructure import NPPCParams, NPPCForMoDLWrapper

nppc_params = NPPCParams()
nppc_checkpoint_dir = Path('/home/alon_granek/PythonProjects/NPPC/alon/nppc_checkpoints')
nppc_wrapper = NPPCForMoDLWrapper(nppc_params, single_MoDL, save_dir=nppc_checkpoint_dir)

In [12]:
""" IN THAT ATTEMPT - TEST MODEL """
# Show how they behave on a given file
#nppc_net = nppc_wrapper.load(f'modl_nppc_dc END-TO-END NO-DC (faster)')
# nppc_net = nppc_wrapper.load(f'modl_nppc_dc END-TO-END DC test')#(l 8)')
# preproc = lambda x, pos: cplx.abs(x[pos])[:, None, ...]
# nppc_net = nppc_wrapper.load('modl_nppc_dc DUAL NO-DC fixed')
# nppc_net = nppc_wrapper.load('modl_nppc_dc DUAL DC')
# nppc_net = nppc_wrapper.load('modl_nppc_dc DUAL DC20 NullErr')

# DC as an activation
# nppc_net = nppc_wrapper.load('modl_nppc_dc DUAL MaskAct 10dirs')

# DC as an activation - but prior to Gram Schmidt
# nppc_net = nppc_wrapper.load('modl_nppc_dc DUAL MaskActPreGS')

# Alternate between GS and Fourier Activation
nppc_net = nppc_wrapper.load('modl_nppc_dc DUAL AlterProj')
preproc = lambda x, pos: x[pos].permute(0, 3, 1, 2) #[:, None, ...]


""" Load new samples """
for iter, data in enumerate(modl_train_loader):
    # Unpack training data: subsampled k-space, target image, k-space mask
    y, x_true, mask = (obj.to(device) for obj in data)
    break
# """ Or use temporarily saved ones """
# import pickle
# with open('/home/alon_granek/PythonProjects/NPPC/temp_data.pkl', 'rb') as file:
#     y, x_true, mask = pickle.load(file)
# 

# Reconstruct, apply NPPC, calculate image-domain error
# from nppc_infrastructure import fourier_activation

with torch.no_grad():
    x_recon, x_intermed = single_MoDL(y.float(), mask=mask, return_steps=True)
    w_mat_intermed = nppc_net(preproc(x_intermed, 0), preproc(x_intermed, -1), mask)
    # w_mat_intermed = fourier_activation(w_mat_intermed, mask)


		Step 0/4...
		Step 1/4...
		Step 2/4...
		Step 3/4...


In [9]:
import pickle

save_path = '/home/alon_granek/PythonProjects/NPPC/temp_data.pkl'
with open(save_path, 'wb') as file:
    pickle.dump(data, file)

NameError: name 'data' is not defined

In [18]:
import matplotlib.pyplot as plt
from alon.fastmri_preprocess import fftc
from alon.MoDLsinglechannel.demo_modl_singlechannel.utils.transforms import fft2, ifft2

example = 0

fig, ax = plt.subplots(1, 3, sharex=True, sharey=True)
# ax[0].imshow(10 * np.log10(cplx.abs(fft2(x_intermed[0][example]))), cmap='inferno', origin='lower')
im = ax[0].imshow(10 * np.log10(cplx.abs(fft2(x_intermed[0][example]))), cmap='inferno', origin='lower', vmin=-30)
ax[1].imshow(10 * np.log10(cplx.abs(ifft2(x_recon[example]))), cmap='inferno', origin='lower', vmin=-30)
ax[2].imshow(10 * np.log10(cplx.abs(ifft2(x_recon[example] - x_intermed[0][example]))), cmap='inferno', origin='lower', vmin=-30)
             
             # Adjusting the figure size and layout
fig.set_size_inches(20, 5)
fig.subplots_adjust(right=0.8, wspace=0.05)

# Adding a separate axis for the colorbar
cbar_ax = fig.add_axes([0.83, 0.3, 0.015, 0.4])
fig.colorbar(mappable=im, cax=cbar_ax)

<IPython.core.display.Javascript object>

  im = ax[0].imshow(10 * np.log10(cplx.abs(fft2(x_intermed[0][example]))), cmap='inferno', origin='lower', vmin=-30)


<matplotlib.colorbar.Colorbar at 0x7ff1bd0bbe80>

In [39]:
cplx.abs(fft2(x_intermed[example])).shape

torch.Size([2, 372, 372])

In [17]:
sample = 0
fts = torch.tensor([abs(fftc(w_mat_intermed.detach()[sample, c, 0])) for c in range(5)])
# normed = w_mat_intermed.detach()[sample, :, 0] / w_mat_intermed.detach()[sample, :, 0].sum(axis=0, keepdims=True)
normed = fts / fts.sum(dim=0, keepdims=True)

from scipy.ndimage import median_filter

fig, ax = plt.subplots(1, 5, sharex=True, sharey=True)
for c in range(5):
    ax[c].set_title(f'PC {c + 1}')
    im = ax[c].imshow(median_filter(10 * np.log10(normed[c]), [8, 8]), cmap='plasma', origin='lower', vmin=-9)
    
# Adjusting the figure size and layout
fig.set_size_inches(20, 5)
fig.subplots_adjust(right=0.8, wspace=0.05)

# Adding a separate axis for the colorbar
cbar_ax = fig.add_axes([0.83, 0.3, 0.015, 0.4])
fig.colorbar(mappable=im, cax=cbar_ax)

# Showing the figure
plt.show()
# fig.savefig('/home/alon_granek/PythonProjects/NPPC/kspace pcs DC.png')


  fts = torch.tensor([abs(fftc(w_mat_intermed.detach()[sample, c, 0])) for c in range(5)])


<IPython.core.display.Javascript object>

In [23]:
import matplotlib.pyplot as plt
from alon.fastmri_preprocess import fftc
# plt.style.use('dark_background')
sample = 1

# Showing all PCs
fig, ax = plt.subplots(3, 5, sharex=True, sharey=True)
for c in range(5):
    eigval = (torch.linalg.norm(w_mat_intermed.detach()[sample, c, 0]) + torch.linalg.norm(w_mat_intermed.detach()[sample, c, 1])) / torch.prod(torch.tensor(w_mat_intermed.detach()[sample, c, 0].size()))
    ax[0, c].set_title('Eig. val.: {:e}'.format(eigval))
    for re_im in [0, 1]:
        # im = ax[re_im, c].imshow(w_mat_intermed.detach()[sample, c, re_im] / abs(w_mat_intermed.detach()[sample, c, re_im]).max(), vmin=-0.3, vmax=0.3, cmap='RdBu', origin='lower')
        im = ax[re_im, c].imshow(w_mat_intermed.detach()[sample, c, re_im], vmin=-0.3, vmax=0.3, cmap='RdBu', origin='lower')

    # # Showing spectra instead
    image = w_mat_intermed.detach()[sample, c].permute(1, 2, 0)[None]
    ax[2, c].imshow(10 * np.log10(np.maximum(cplx.abs(fft2(image))[0], 1e-4)), cmap='inferno', origin='lower')
    
    # im = ax[c].imshow(10 * np.log10(abs(fftc(image))), cmap='inferno', origin='lower')
    # im = ax[c].imshow(np.sign(fftc(image).real) * 10 * np.log10(abs(fftc(image))), origin='lower', cmap='RdBu')
    #im = ax[c].imshow(fftc(image).real, origin='lower', cmap='RdBu', vmin=-0.05, vmax=0.05)

# Adjusting the figure size and layout
fig.set_size_inches(30, 15)
# fig.subplots_adjust(right=0.8, wspace=0.05)
fig.subplots_adjust(right=0.8, wspace=0.05)

# Adding a separate axis for the colorbar
cbar_ax = fig.add_axes([0.83, 0.3, 0.015, 0.4])
fig.colorbar(mappable=im, cax=cbar_ax)

# Showing the figure
plt.show()

# fig.savefig('/home/alon_granek/PythonProjects/NPPC/temp4 (5).png')
# fig.show()

<IPython.core.display.Javascript object>

In [32]:
def fourier_activation(w_mat: torch.Tensor, mask: torch.Tensor):
    """
    Activation of sampling mask

    :param w_mat:
    :param mask:
    :return:
    """
    w_mat_ksp = fft2(w_mat.permute(0, 3, 4, 1, 2))
    w_mat_ksp_filt = (1 - mask[0, :, :, 0][None, :, :, None, None]) * w_mat_ksp
    w_mat_filt = ifft2(w_mat_ksp_filt).permute(0, 3, 4, 1, 2)
    # 'cije, eijkc'
    return w_mat_filt

In [34]:
from nppc.nppc import gram_schmidt

w_mat = torch.clone(w_mat_intermed)
for i in range(5):
    w_mat = gram_schmidt(fourier_activation(w_mat, mask))

In [25]:
u = w_mat_intermed.detach()[sample, 0, 0] / abs(w_mat_intermed.detach()[sample, 0, 0]).max()
v = w_mat_intermed.detach()[sample, 1, 0] / abs(w_mat_intermed.detach()[sample, 1, 0]).max()

In [27]:
u.flatten().dot(v.flatten()) / u.flatten().dot(u.flatten())

tensor(0.9844)

In [25]:
plt.figure()
plt.imshow(10 * np.log10(abs(fftc(torch.linalg.norm(w_mat_intermed.detach()[sample, :, 0], axis=0)))))

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7fee38221000>

In [43]:
fig, ax = plt.subplots(1, 2 + len(x_intermed), sharex=True, sharey=True)
ax[0].set_title('Target')
ax[0].imshow(torch.flipud(cplx.abs(x_true)[sample]), cmap='Greys_r', vmin=0, vmax=1.2)
ax[1].set_title('6x accelerated zero-filled')
ax[1].imshow(torch.flipud(cplx.abs(x_intermed[0])[sample].detach()), cmap='Greys_r', vmin=0, vmax=1.2)
for i in range(2, 2 + len(x_intermed)):
    ax[i].set_title(f'Recon step {i - 2}')
    ax[i].imshow(torch.flipud(cplx.abs(x_intermed[i - 2][sample].detach())), cmap='Greys_r', vmin=0,
                 vmax=1.2)
fig.set_size_inches(30, 7)
plt.tight_layout()
fig.savefig('/home/alon_granek/PythonProjects/NPPC/modl.png')
fig.show()

<IPython.core.display.Javascript object>

In [7]:
for pos in range(single_MoDL.num_grad_steps - 1):
    print(f'TRAINING ON MODL POSITION {pos}')
    # if pos in [0, 1]:
    #     continue
    #todo Train for [6, 7]
    nppc_wrapper.train(modl_train_loader, nppc_position=pos, model_name=f'modl_nppc_dc pos {pos}')
    nppc_wrapper.reset()

TRAINING ON MODL POSITION 0
Iter 0...
		Step 0/8...
		Step 1/8...
		Step 2/8...
		Step 3/8...
		Step 4/8...
		Step 5/8...
		Step 6/8...
		Step 7/8...
	Reconstructed
	Estimated principal components
Iter 1...
		Step 0/8...
		Step 1/8...
		Step 2/8...
		Step 3/8...
		Step 4/8...
		Step 5/8...
		Step 6/8...
		Step 7/8...
	Reconstructed
	Estimated principal components
Iter 2...
		Step 0/8...
		Step 1/8...
		Step 2/8...
		Step 3/8...
		Step 4/8...
		Step 5/8...
		Step 6/8...
		Step 7/8...
	Reconstructed
	Estimated principal components
Iter 3...
		Step 0/8...
		Step 1/8...
		Step 2/8...
		Step 3/8...
		Step 4/8...
		Step 5/8...
		Step 6/8...
		Step 7/8...
	Reconstructed
	Estimated principal components
Loss: 7.190761089324951
Iter 2...
		Step 0/8...
		Step 1/8...
		Step 2/8...
		Step 3/8...
		Step 4/8...
		Step 5/8...
		Step 6/8...
		Step 7/8...
	Reconstructed
	Estimated principal components
Iter 3...
		Step 0/8...
		Step 1/8...
		Step 2/8...
		Step 3/8...
		Step 4/8...
		Step 5/8...
		Step 6/

KeyboardInterrupt: 

In [13]:
""" Load models of different steps """
# Show how they behave on a given file
nppc_nets = [nppc_wrapper.load(f'modl_nppc_dc pos {pos}') for pos in range(6)]#single_MoDL.num_grad_steps - 1)]
input_positions = np.arange(0, 6) #single_MoDL.num_grad_steps - 1)
mmse_positions = input_positions + 1
preproc = lambda x, pos: cplx.abs(x[pos])[:, None, ...]

for iter, data in enumerate(modl_train_loader):
    # Unpack training data: subsampled k-space, target image, k-space mask
    y, x_true, mask = (obj.to(device) for obj in data)
    break

# Reconstruct, apply NPPC, calculate image-domain error
with torch.no_grad():
    x_recon, x_intermed = single_MoDL(y.float(), mask=mask, return_steps=True)
    w_mat_intermed = [nppc_nets[i](preproc(x_intermed, i), preproc(x_intermed, j)) for i, j in zip(input_positions, mmse_positions)]


		Step 0/8...
		Step 1/8...
		Step 2/8...
		Step 3/8...
		Step 4/8...
		Step 5/8...
		Step 6/8...
		Step 7/8...


In [48]:
""" Display the PCs """
import matplotlib.pyplot as plt

sample = 0

# Showing all PCs
fig, ax = plt.subplots(2, 5, sharex=True, sharey=True)
for pos in [0, 1]:
    for c in range(5):
        ax[pos, c].set_title(f'PC {c + 1}')
        ax[pos, c].imshow(w_mat_intermed[pos].detach()[sample, c, 0], vmin=-0.5, vmax=0.5, cmap='RdBu')
# fig.set_size_inches(20, 6)
fig.set_size_inches(20, 10)
fig.tight_layout()
# fig.savefig('/home/alon_granek/PythonProjects/NPPC/temp2.png')
fig.show()

<IPython.core.display.Javascript object>

In [ ]:
""" Branch new examples from the existing PC subspace """


In [ ]:
""" Insane ideas to combine them - PCA on the smallest required point cloud """

points = w_mat_intermed[0]

In [14]:
""" Reconstruct covariance """

# W D W
# D is the eigenvectors
example = 0
w_mat = w_mat_intermed[example]

w_mat_ = w_mat.flatten(3)
w_norms = w_mat_.norm(dim=3)
w_hat_mat = w_mat_ / w_norms[:, :, None]

# D = torch.diag(w_norms[:, 0])
# W = w_hat_mat[:, 0]
# DW = torch.einsum('ck, ki -> ci', D, W)
# WDW = torch.einsum('ci, cj -> ij', W, DW)

In [29]:
w_mat.shape

torch.Size([2, 5, 1, 372, 372])

In [15]:
#todo Turn this into a differentiable PCWalk object. PC vectors' positions in the shared basis are pre-determined, the function merely places them together, then does Gram Schmidt to find a new basis, then finds the transfer to that basis, then sum covariances and de-project to the basis.

#todo We know: lstsq is backprop-able by autograd.

#Till 18:30 work on it



from nppc.nppc import gram_schmidt
from torch.nn.functional import pad

example = 0
w_mat_steps = tuple(w[example, :, 0].flatten(1) for w in w_mat_intermed)
w_norms_steps = tuple(torch.linalg.norm(w, axis=1) for w in w_mat_steps)

# Get shared basis
w_mat_stack = torch.cat(w_mat_steps)
shared = gram_schmidt(w_mat_stack[None])[0]

# Define coordinates of bases in these
n_steps = len(w_mat_intermed)           #todo This is going to be the number of steps in MoDL
left_padding = lambda step: step * nppc_params.n_dirs
right_padding = lambda step: (n_steps - step - 1) * nppc_params.n_dirs
left_pads = list(map(left_padding, range(n_steps)))
right_pads = list(map(right_padding, range(n_steps)))
pad_mat = lambda W, left, right: pad(W, (0, 0, left, right))
w_mat_steps = tuple(map(pad_mat, w_mat_steps, left_pads, right_pads))
w_norms_steps = tuple(pad(w_norms_steps[step], (left_pads[step], right_pads[step])) for step in range(n_steps))


# selection = lambda step: (step * nppc_params.n_dirs, (step + 1) * nppc_params.n_dirs)
# sels = list(map(selection, range(n_steps)))
# w_mat_steps = tuple(pad(w_mat_stack[s[0]:s[1]], (0, 0, l, r)) for s, l, r in zip(sels, left_pads, right_pads))

# Represent covariances with that basis
covs_by_shared = list()
for step, norms in enumerate(w_norms_steps):
    trans_pcs_to_shared = torch.linalg.lstsq(shared.T, w_mat_steps[step].T).solution.T
    cov_by_pcs = torch.diag(norms)
    covs_by_shared.append(trans_pcs_to_shared.T @ cov_by_pcs @ trans_pcs_to_shared)

# Get a total covariance
total_cov_by_shared = torch.stack(covs_by_shared).sum(0)

# Get total PCA
pca_vals, pca_vecs_by_shared = torch.linalg.eigh(total_cov_by_shared)
argsort_vecs = torch.argsort(-pca_vals)
pca_vals = pca_vals[argsort_vecs]
pca_vecs_by_shared = pca_vecs_by_shared[argsort_vecs]

# Transform to pixel domain
dim1, dim2 = w_mat_intermed[0].shape[-2:]
pca_vecs = torch.einsum('bi, bc -> ci', shared, pca_vecs_by_shared).reshape(shared.shape[0], dim1, dim2).flip(0)
# trans_shared_to_pixel = torch.linalg.lstsq(shared.T, )
# pca_vec = 



#     trans_0_to_shared = torch.linalg.lstsq(basis, w_mat[0].T).solution.T
# cov_by_0 = torch.diag(w_norms[0])
# cov_from__by_shared = trans_0_to_shared.T @ cov_by_0 @ trans_0_to_shared
# 

# U, sigma, Vt = torch.linalg.svd(w_mat_stack)

# Enforce


In [30]:
w_mat_steps[0].T.shape

torch.Size([372, 377, 1, 5])

In [21]:
""" Display the PCs """
import matplotlib.pyplot as plt

sample = 0
mmse = preproc(x_intermed, -1)[sample, 0]

# Showing all PCs
n = 20 #10
fig, ax = plt.subplots(1, n + 1, sharex=True, sharey=True)
ax[0].imshow(mmse, vmin=0, vmax=1.5, cmap='Greys_r', origin='lower')
for c in range(n):
    ax[c + 1].set_title(f'PC {c + 1}\nEig. val. {pca_vals[c]}')
    ax[c + 1].imshow(pca_vecs[c], vmin=-0.2, vmax=0.2, cmap='RdBu', origin='lower')
# fig.set_size_inches(20, 6)
fig.set_size_inches(70, 10)
fig.tight_layout()
fig.savefig('/home/alon_granek/PythonProjects/NPPC/NPPC random walk PCs (first 6 steps out of 8).png')
# fig.show()

<IPython.core.display.Javascript object>

In [38]:
mmse.shape

torch.Size([2, 1, 372, 372])

In [20]:
""" Attempt in rough skull stripping """

from alon.brain_masking import BrainMasker

mask = BrainMasker()(recon_image[sample])[0]


# Connecting MoDL to NPPC

**Naive approach:** Use the whole unrolled network for PC estimation. The PC subspace is expected to be data consistent, as there is no variance on the projection of k-space one-hot vectors corresponding to the measured k's.

In [9]:
from nppc_architecture import UNet
from nppc import PCWrapper

n_dirs = 5

# Initialize frozen sampling mask
for iter, data in enumerate(train_loader):
    mask = data[-1].to(device)
    mask_ifft = torch.fft.ifft2(torch.fft.ifftshift(mask, axis=(1, 2)), axis=(1, 2))
    break

nppc_net = PCWrapper(UNet(in_channels=1 + 1, out_channels=1 * n_dirs), n_dirs=n_dirs) #, mask_ifft=mask_ifft) #, mask=1) #mask)
nppc_net.__setattr__('ddp', Namespace(size=1))

nppc_net.to(device)
nppc_net.train()
nppc_optimizer = torch.optim.Adam(
    nppc_net.parameters(), lr=1e-4, #lr=1e-5,
    betas=(0.9, 0.999)
)
nppc_step = 0

restoration_n_steps = 4000 #1000 #3000
nppc_n_steps = 3000 #3000
batch_size = 8 #2 #64 #128 #64 #128       #256

second_moment_loss_lambda = 1e0 #0.2 #1e0
second_moment_loss_grace = 500

Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/usr/lib/python3.10/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/usr/lib/python3.10/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/usr/lib/python3.10/shutil.py", line 731, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/usr/lib/python3.10/shutil.py", line 729, in rmtree
    os.rmdir(path)
OSError: [Errno 39] Directory not empty: '/tmp/pymp-o5q0m1nh'


In [ ]:
from dataclasses import dataclass
from typing import Tuple


@dataclass
class NPPCParams:
    # Inference params
    n_dirs: int = 5
    nppc_step: int = 0
    restoration_n_steps: int = 4000
    nppc_n_steps: int = 3000
    
    # Train params
    #   How many samples to apply NPPC in parallel to
    parallel_est_size: int = 8
    #   How many parallel runs, to define as a batch? On each batch, we do a backprop.
    batch_size: int = 4
    #   Loss components
    second_moment_loss_lambda: float = 1e0
    second_moment_loss_grace: float = 500
    #       Data consistency loss component. Set to 0 if we don't want to calculate it
    dc_loss_lambda: float = 5e0 #1e0


train_loader_given_mask = create_data_loaders(params, frozen_mask=mask)


class NPPCForMoDLWrapper:
    def __init__(self, nppc_params: NPPCParams, modl: UnrolledModel = single_MoDL):
        self.nppc_params = nppc_params
        self.nppc_net = PCWrapper(UNet(in_channels=1 + 1, out_channels=1 * self.nppc_params.n_dirs), n_dirs=self.nppc_params.n_dirs) #, mask_ifft=mask_ifft)
        self.nppc_net.__setattr__('ddp', Namespace(size=1))
        self.nppc_net.to(device)
        self.nppc_net.train()
        self.nppc_optimizer = torch.optim.Adam(
            self.nppc_net.parameters(), lr=1e-4, #lr=1e-5,
            betas=(0.9, 0.999)
        )
        
        self.modl = modl
    
    def __call__(self):
        return
    
    def train(self, ):
        for rep in range(20): #12):
            start = repeats * rep
            losses = list()
            for iter, data in enumerate(train_loader_given_mask, start=start):
                print(f'Iter {iter}...')
                losses.append(get_batch_loss(data))
                if iter - start == repeats - 1:
                    break
            
            objective = torch.stack(losses).mean()
            
            # enum_lim = islice(enumerate(train_loader_given_mask, start=repeats * rep), repeats)
            # objective = torch.mean(torch.tensor([get_batch_loss(data) for iter, data in enum_lim]))
            nppc_optimizer.zero_grad()
            objective.backward() #retain_graph=True)
            nppc_optimizer.step()
            nppc_step += 1
        
            print(f'Loss: {objective.detach().item()}')
        
            torch.save(nppc_net, f'/home/alon_granek/PythonProjects/NPPC/modl_nppc_net_DCed2test.pth')

        return
    
    def get_batch_loss(self, data: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], nppc_position: int = 0):
        """
        
        :param data: 
        :param nppc_position:       Position of input to the NPPC (not position of the MMSE)
        :return: 
        """
        assert nppc_position < params.num_grad_steps, f'Chosen NPPC placement ({nppc_position}) is out of bounds, not in [0, {params.num_grad_steps} - 1]'
        
        # Unpack training data: subsampled k-space, target image, k-space mask
        y, x_true, mask = (obj.to(device) for obj in data)
                
        # Reconstruct, apply NPPC, calculate image-domain error
        with torch.no_grad():
            x_recon, x_intermed = self.modl(y.float(), mask=mask, return_steps=True)
        print('\tReconstructed')
        
        step_input = cplx.abs(x_intermed[nppc_position])[:, None, ...]
        step_mmse = cplx.abs(x_intermed[nppc_position - 1])[:, None, ...]
        w_mat = nppc_net(step_input, step_mmse)
        print('\tEstimated principal components')
        
        w_mat_ = w_mat.flatten(2)
        w_norms = w_mat_.norm(dim=2)
        w_hat_mat = w_mat_ / w_norms[:, :, None]
        
        x_true = cplx.abs(x_true.detach())[:, None, ...]
        err = (x_true - step_mmse).flatten(1)
                
        ## Normalizing by the error's norm
        ## -------------------------------
        err_norm = err.norm(dim=1)
        err = err / err_norm[:, None]
        w_norms = w_norms / err_norm[:, None]
        
        ## W hat loss
        ## ----------
        err_proj = torch.einsum('bki,bi->bk', w_hat_mat, err)
        reconst_err = 1 - err_proj.pow(2).sum(dim=1)
        
        ## W norms loss
        ## ------------
        second_moment_mse = (w_norms.pow(2) - err_proj.detach().pow(2)).pow(2)
        
        ## (Alon, ShimronLab) Data-consistency loss
        ## ----------------------------------------
        dc_loss = 0
        if self.nppc_params.dc_loss_lambda != 0:
            MFW = torch.einsum(
                'cije, edij -> ec',
                mask, abs(torch.fft.fftshift(torch.fft.fft2(w_mat[:, :, 0], axis=(1, 2)), axis=(1, 2)))
            ) / (2 * err.size()[1])           # c: Re/Im component, i,j: image dims, e: example, d: PC directions
            dc_loss = torch.linalg.norm(MFW) / (MFW.shape[0] * MFW.shape[1])
        
        second_moment_loss_lambda = -1 + 2 * nppc_step / second_moment_loss_grace
        second_moment_loss_lambda = max(min(second_moment_loss_lambda, 1) ,1e-6)
        second_moment_loss_lambda *= second_moment_loss_lambda
        
        loss = reconst_err.mean() + second_moment_loss_lambda * second_moment_mse.mean() + self.nppc_params.dc_loss_lambda * dc_loss
        return loss
        
    def reset(self, nppc_params: NPPCParams = None):
        self.__init__(self.nppc_params if nppc_params is None else nppc_params)


repeats = 4
batch_objectives = torch.zeros(repeats, dtype=torch.float32)


def get_batch_loss(data, nppc_position: int = 0):
    """
    
    :param data: 
    :param nppc_position:       Position of input to the NPPC (not position of the MMSE)
    :return: 
    """
    assert nppc_position < params.num_grad_steps, f'Chosen NPPC placement ({nppc_position}) is out of bounds, not in [0, {params.num_grad_steps} - 1]'
    
    # Obtain subsampled k-space, target image, k-space mask respectively
    y, x_true, mask = data
    x_true = x_true.to(device)
    mask = mask.to(device)
    y = y.to(device)
    
    with torch.no_grad():
        x_recon, x_intermed = single_MoDL(y.float(), mask=mask, return_steps=True)
        print('\tReconstructed')
        
    # Input to NPPC: The n-1-th iteration solution as input, and n-th iteration solution as the MMSE
    # mmse = cplx.abs(x_recon)[:, None, ...]
    step_input = cplx.abs(x_intermed[nppc_position])[:, None, ...]
    step_mmse = cplx.abs(x_intermed[nppc_position - 1])[:, None, ...]
    # w_mat = nppc_net(problem_input, mmse)
    w_mat = nppc_net(step_input, step_mmse)
    
    w_mat_ = w_mat.flatten(2)
    w_norms = w_mat_.norm(dim=2)
    w_hat_mat = w_mat_ / w_norms[:, :, None]
    
    x_true = cplx.abs(x_true.detach())[:, None, ...]
    err = (x_true - step_mmse).flatten(1)
    
    print('\tCalculated error')

    ## Normalizing by the error's norm
    ## -------------------------------
    err_norm = err.norm(dim=1)
    err = err / err_norm[:, None]
    w_norms = w_norms / err_norm[:, None]
    
    ## W hat loss
    ## ----------
    err_proj = torch.einsum('bki,bi->bk', w_hat_mat, err)
    reconst_err = 1 - err_proj.pow(2).sum(dim=1)
    
    ## W norms loss
    ## ------------
    second_moment_mse = (w_norms.pow(2) - err_proj.detach().pow(2)).pow(2)
    
    ## (Alon, ShimronLab) Data-consistency loss
    ## ----------------------------------------
    # print('\tCalculating DC loss')
    # MW = torch.einsum('cije, edij -> ec', mask_ifft.float(), w_mat[:, :, 0].float()) / (2 * err.size()[1])           # c: Re/Im component, i,j: image dims, e: example, d: PC directions
    MFW = torch.einsum(
        'cije, edij -> ec',
        mask, abs(torch.fft.fftshift(torch.fft.fft2(w_mat[:, :, 0], axis=(1, 2)), axis=(1, 2)))
    ) / (2 * err.size()[1])           # c: Re/Im component, i,j: image dims, e: example, d: PC directions
    dc_loss = torch.linalg.norm(MFW) / (MFW.shape[0] * MFW.shape[1])
    dc_loss_lambda = 5e0 #1e0
    
    second_moment_loss_lambda = -1 + 2 * nppc_step / second_moment_loss_grace
    second_moment_loss_lambda = max(min(second_moment_loss_lambda, 1) ,1e-6)
    second_moment_loss_lambda *= second_moment_loss_lambda
    
    return reconst_err.mean() + second_moment_loss_lambda * second_moment_mse.mean() + dc_loss_lambda * dc_loss


## DAPA 2: Run on the last step in MoDL
**Rationale**
1. This is where **x** enters the true solution space
2. This is the smallest problem where uncertainty still exists (we trust the previous steps in MoDL)
3. It is a much easier problem

In [10]:
from alon.brain_masking import BrainMasker

In [38]:
for iter, data in enumerate(train_loader):
    
    # Obtain subsampled k-space, target image, k-space mask respectively
    y, x_true, mask = data
    x_true = x_true.to(device)
    mask = mask.to(device)
    y = y.to(device)
        
    with torch.no_grad():
        x_recon, x_intermed = single_MoDL(y.float(), mask=mask, return_steps=True)
    
        # Input to NPPC: The n-1-th iteration solution as input, and n-th iteration solution as the MMSE
        mmse = cplx.abs(x_recon)[:, None, ...]
        problem_input = cplx.abs(x_intermed[-1])[:, None, ...]
        w_mat_orig = nppc_net(problem_input, mmse)
    
    """ Test - defining the error only on the masked image """
    # mask = BrainMasker()(mmse[0, 0])[0]#.flatten()
    # err *= mask[None]
    
    # w_mat = w_mat_orig * torch.tensor(BrainMasker()(mmse[0, 0])[0][None, None, None], dtype=torch.float32)
    w_mat = w_mat_orig
    """"""
    
    
    w_mat_ = w_mat.flatten(2)
    w_norms = w_mat_.norm(dim=2)
    w_hat_mat = w_mat_ / w_norms[:, :, None]
    
    x_true = cplx.abs(x_true.detach())[:, None, ...]
    err = (x_true - mmse).flatten(1)
    

    ## Normalizing by the error's norm
    ## -------------------------------
    err_norm = err.norm(dim=1)
    err = err / err_norm[:, None]
    w_norms = w_norms / err_norm[:, None]
    
    ## W hat loss
    ## ----------
    err_proj = torch.einsum('bki,bi->bk', w_hat_mat, err)
    reconst_err = 1 - err_proj.pow(2).sum(dim=1)
    
    ## W norms loss
    ## ------------
    second_moment_mse = (w_norms.pow(2) - err_proj.detach().pow(2)).pow(2)
    
    second_moment_loss_lambda = -1 + 2 * nppc_step / second_moment_loss_grace
    second_moment_loss_lambda = max(min(second_moment_loss_lambda, 1) ,1e-6)
    second_moment_loss_lambda *= second_moment_loss_lambda
    objective = reconst_err.mean() + second_moment_loss_lambda * second_moment_mse.mean()

    nppc_optimizer.zero_grad()
    objective.backward()
    nppc_optimizer.step()
    nppc_step += 1
    
    
    print(f'Loss: {objective.detach().item()}')
    
    torch.save(nppc_net, f'/home/alon_granek/PythonProjects/NPPC/modl_nppc_net3.pth')
        

		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.825690746307373
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.904462993144989
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.8065555095672607
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.8409633040428162
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.7346203327178955
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.752144455909729
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.8610484600067139
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.7396346926689148
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.8439711332321167
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.8512775897979736
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.735094428062439
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.8334081172943115
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.7661327123641968
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.7712322473526001
		Step 0/3...
		Step 1/3...


KeyboardInterrupt: 

## DAPA 2.1: On last MoDL step, but enforcing data consistency

It is hard not to train the posterior of the absolute value clean image given the absolute value sampling-noised image. That's why I do so in DAPA 2.

However, this degenerates the problem to a simple denoising problem, and so nothing makes the predicted PC subspace data-consistent, while it should be just that.

* Not entirely though. Since we train on the last MoDL iteration, the k-space interpolation is almost fully done. There is simply larger uncertainty where not measured vs where measured.

DAPA 2.1 attempts to continue learning on the absolute-values inverse problem, but ensure data-consistency on the PC subspace. That is in the hope that the only denoising done here is on the sampling noise.
 
We look for an inverse problem of the form

$$
\underset{\mathbf x}{\arg \min} \|\mathbf{MF}(\hat{\mathbf x} + \mathbf W \boldsymbol \alpha) - \mathbf y\|^2 + \lambda R(\hat{\mathbf x} + \mathbf W \boldsymbol \alpha)
$$

$\hat{\mathbf x}\quad$      MMSE estimate
$\mathbf W\quad$            Predicted PCs
$\boldsymbol \alpha\quad$   Coordinates with respect to PCs. Let any such coordinates.

Specifically for DC,
$$
\|\mathbf{MF}(\hat{\mathbf x} + \mathbf W \boldsymbol \alpha) - \mathbf y\|^2 \leq \|\mathbf{MF}\hat{\mathbf x} - \mathbf y\|^2 + \|\mathbf{MF}\mathbf W \boldsymbol \alpha\|^2
$$
Where the first term is already minimized by MoDL and is independent of PCs. We will attempt to add a $\|\mathbf{MW}\|^2$ loss.


**A major change is that we freeze the mask to properly train the NPPC for this inverse problem**

In [10]:
train_loader_given_mask = create_data_loaders(params, frozen_mask=mask)

	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002641.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_201_6003008.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_202_6000421.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_202_6000508.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_202_6000531.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_203_6000942.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_209_6001397.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXT1POST_200_6002033.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXT1POST_200_6002392.h5.npy...
	/mnt/c/Users/along/brain_multicoil

In [11]:
repeats = 4
batch_objectives = torch.zeros(repeats, dtype=torch.float32)


def get_batch_loss(data, nppc_position: int = 0):
    """
    
    :param data: 
    :param nppc_position:       Position of input to the NPPC (not position of the MMSE)
    :return: 
    """
    assert nppc_position < params.num_grad_steps, f'Chosen NPPC placement ({nppc_position}) is out of bounds, not in [0, {params.num_grad_steps} - 1]'
    
    # Obtain subsampled k-space, target image, k-space mask respectively
    y, x_true, mask = data
    x_true = x_true.to(device)
    mask = mask.to(device)
    y = y.to(device)
    
    with torch.no_grad():
        x_recon, x_intermed = single_MoDL(y.float(), mask=mask, return_steps=True)
        print('\tReconstructed')
        
    # Input to NPPC: The n-1-th iteration solution as input, and n-th iteration solution as the MMSE
    # mmse = cplx.abs(x_recon)[:, None, ...]
    step_input = cplx.abs(x_intermed[nppc_position])[:, None, ...]
    step_mmse = cplx.abs(x_intermed[nppc_position - 1])[:, None, ...]
    # w_mat = nppc_net(problem_input, mmse)
    w_mat = nppc_net(step_input, step_mmse)
    
    w_mat_ = w_mat.flatten(2)
    w_norms = w_mat_.norm(dim=2)
    w_hat_mat = w_mat_ / w_norms[:, :, None]
    
    x_true = cplx.abs(x_true.detach())[:, None, ...]
    err = (x_true - step_mmse).flatten(1)
    
    print('\tCalculated error')

    ## Normalizing by the error's norm
    ## -------------------------------
    err_norm = err.norm(dim=1)
    err = err / err_norm[:, None]
    w_norms = w_norms / err_norm[:, None]
    
    ## W hat loss
    ## ----------
    err_proj = torch.einsum('bki,bi->bk', w_hat_mat, err)
    reconst_err = 1 - err_proj.pow(2).sum(dim=1)
    
    ## W norms loss
    ## ------------
    second_moment_mse = (w_norms.pow(2) - err_proj.detach().pow(2)).pow(2)
    
    ## (Alon, ShimronLab) Data-consistency loss
    ## ----------------------------------------
    # print('\tCalculating DC loss')
    # MW = torch.einsum('cije, edij -> ec', mask_ifft.float(), w_mat[:, :, 0].float()) / (2 * err.size()[1])           # c: Re/Im component, i,j: image dims, e: example, d: PC directions
    MFW = torch.einsum(
        'cije, edij -> ec',
        mask, abs(torch.fft.fftshift(torch.fft.fft2(w_mat[:, :, 0], axis=(1, 2)), axis=(1, 2)))
    ) / (2 * err.size()[1])           # c: Re/Im component, i,j: image dims, e: example, d: PC directions
    dc_loss = torch.linalg.norm(MFW) / (MFW.shape[0] * MFW.shape[1])
    dc_loss_lambda = 5e0 #1e0
    
    second_moment_loss_lambda = -1 + 2 * nppc_step / second_moment_loss_grace
    second_moment_loss_lambda = max(min(second_moment_loss_lambda, 1) ,1e-6)
    second_moment_loss_lambda *= second_moment_loss_lambda
    
    return reconst_err.mean() + second_moment_loss_lambda * second_moment_mse.mean() + dc_loss_lambda * dc_loss

In [ ]:
""" Training """

for rep in range(20): #12):
    start = repeats * rep
    losses = list()
    for iter, data in enumerate(train_loader_given_mask, start=start):
        print(f'Iter {iter}...')
        losses.append(get_batch_loss(data))
        if iter - start == repeats - 1:
            break
    
    objective = torch.stack(losses).mean()
    
    # enum_lim = islice(enumerate(train_loader_given_mask, start=repeats * rep), repeats)
    # objective = torch.mean(torch.tensor([get_batch_loss(data) for iter, data in enum_lim]))
    nppc_optimizer.zero_grad()
    objective.backward() #retain_graph=True)
    nppc_optimizer.step()
    nppc_step += 1

    print(f'Loss: {objective.detach().item()}')

    torch.save(nppc_net, f'/home/alon_granek/PythonProjects/NPPC/modl_nppc_net_DCed2test.pth')


In [12]:
nppc_net = torch.load(
    # f'/home/alon_granek/PythonProjects/NPPC/modl_nppc_net3.pth'
    #f'/home/alon_granek/PythonProjects/NPPC/modl_nppc_net_DCed.pth'
    f'/home/alon_granek/PythonProjects/NPPC/modl_nppc_net_DCed2test.pth'
)

for iter, data in enumerate(train_loader_given_mask):
    
    # Obtain subsampled k-space, target image, k-space mask respectively
    y, x_true, mask = data
    x_true = x_true.to(device)
    mask = mask.to(device)
    y = y.to(device)
        
    with torch.no_grad():
        x_recon, x_intermed = single_MoDL(y.float(), mask=mask, return_steps=True)
    
    # Input to NPPC: The n-1-th iteration solution as input, and n-th iteration solution as the MMSE
    mmse = cplx.abs(x_recon)[:, None, ...]
    problem_input = cplx.abs(x_intermed[-1])[:, None, ...]
    w_mat = nppc_net(problem_input, mmse)
    
    break

		Step 0/3...
		Step 1/3...
		Step 2/3...


In [28]:
from alon.MoDLsinglechannel.demo_modl_singlechannel.utils.transforms import SenseModel_single

fig, ax = plt.subplots(1, 2, sharex=True, sharey=True)
# ax[0].imshow(cplx.abs(x_intermed[0])[1])
ax[0].imshow(problem_input[1, 0])
ax[1].imshow(mmse[1, 0])

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f0d53714220>

In [23]:
""" Analysis of PCs in k-space - as they should all be 0 where not measured """
from alon.fastmri_preprocess import fftc, ifftc
import matplotlib.pyplot as plt

c = 4
example = 1

y_c = fftc(w_mat[example, c, 0].detach().numpy())
log_y_c = 10 * np.log10(abs(y_c))
m = mask[0, ..., example].detach().numpy().astype(bool)

# plt.figure()
# plt.hist([log_y_c[m], log_y_c[~m]], histtype='step', label=['measured', 'not measured'], density=True, bins=50)
# plt.legend()


fig, ax = plt.subplots(1, 2)
ax[0].imshow(ifftc(y_c * (1 - m)).real, vmin=-0.1, vmax=0.1, cmap='RdBu')
ax[0].set_title('zeroed')
ax[1].imshow(w_mat[example, c, 0].detach().numpy(), vmin=-0.1, vmax=0.1, cmap='RdBu')
ax[1].set_title('raw')

<IPython.core.display.Javascript object>

Text(0.5, 1.0, 'raw')

In [16]:
mask.shape

torch.Size([2, 372, 372, 2])

In [31]:
import matplotlib.pyplot as plt
# fig, ax = plt.subplots(1, 5)
# for c in range(5):
#     ax[c].imshow(w_mat[0, c, 0], cmap='inferno', origin='lower', vmin=-0.5, vmax=0.5)
# fig.set_size_inches(13, 5)

c = 0
alphas = np.linspace(-0.1, 0.1, 5)
fig, ax = plt.subplots(1, len(alphas))
for i, alpha in enumerate(alphas):
    ax[i].imshow(mmse.detach()[0, 0] + alpha * w_mat.detach()[0, c, 0], cmap='Greys_r', origin='lower', vmin=0, vmax=1.5)
fig.set_size_inches(30, 7)
fig.tight_layout()
fig.savefig('/home/alon_granek/PythonProjects/NPPC/temp.png')

<IPython.core.display.Javascript object>

In [41]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# List of image file paths

# Initial image index
current_image_index = 0

c = 3
alphas = np.linspace(-1, 1, 4)

# Function to display the image at the current index
def display_image():
    img = mmse.detach()[0, 0] + alphas[current_image_index] * w_mat.detach()[0, c, 0] #mpimg.imread(image_files[current_image_index])
    ax.imshow(img, cmap='Greys_r', vmin=0, vmax=1.5, origin='lower')
    ax.set_title(f"Alpha {round(alphas[current_image_index], 2)}")
    fig.canvas.draw()

# Function to handle key press events
def on_key(event):
    global current_image_index
    if event.key == 'up':
        current_image_index = (current_image_index + 1) % alphas.size #len(image_files)
    elif event.key == 'down':
        current_image_index = (current_image_index - 1) % alphas.size #len(image_files)
    display_image()

# Create a figure and axis
fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.2)
fig.set_size_inches(10, 10)
fig.tight_layout()

# Display the initial image
display_image()

# Connect the key press event to the handler
fig.canvas.mpl_connect('key_press_event', on_key)

# Display the plot
plt.show()


<IPython.core.display.Javascript object>

In [16]:
""" Showing all PCs """

import matplotlib.pyplot as plt

example = 1

fig, ax = plt.subplots(1, 5)
for c in range(5):
    ax[c].set_title(f'PC {c + 1}')
    # ax[c].imshow(w_mat.detach()[example, c, 0] - w_mat.detach()[example, c, 0].mean(), vmin=-0.2, vmax=0.2, cmap='RdBu')
    ax[c].imshow(w_mat.detach()[example, c, 0], vmin=-0.2, vmax=0.2, cmap='RdBu')
fig.set_size_inches(20, 6)
# fig.set_size_inches(10, 5)
fig.tight_layout()
fig.savefig('/home/alon_granek/PythonProjects/NPPC/temp2.png')

<IPython.core.display.Javascript object>

In [15]:
from alon.fastmri_preprocess import fftc, ifftc

example = 1

fig, ax = plt.subplots(1, 5)
for c in range(5):
    ax[c].set_title(f'PC {c + 1}')
    ax[c].imshow(10 * np.log10(abs(fftc(w_mat.detach()[example, c, 0]))), cmap='magma')
fig.set_size_inches(20, 6)
# fig.set_size_inches(10, 5)
fig.tight_layout()
# fig.savefig('/home/alon_granek/PythonProjects/NPPC/pc kspace.png')

<IPython.core.display.Javascript object>

In [None]:
# Orthogonallity test

# for c in range(5):
#     w = w_mat[:, c]
#     dot = torch.einsum('ecij, ecij -> ec', mask_ifft[0].real.permute(2, 0, 1)[:, None, None][:, :, 0], w_mat[:, :, 0])
#     norm_w = torch.einsum('ecij, ecij -> ec', w_mat[:, :, 0], w_mat[:, :, 0])
#     norm_m = torch.einsum('ecij, ecij -> ec', mask_ifft[0].real.permute(2, 0, 1)[:, None, None][:, :, 0], mask_ifft[0].real.permute(2, 0, 1)[:, None, None][:, :, 0])
#     print(
#         f'{dot=}\n\n{norm_w=}\n\n{norm_m=}'
#     )
    
# mask_ifft.shape, w_mat.shape

from alon.MoDLsinglechannel.demo_modl_singlechannel.utils.flare_utils import torch_fft2c

c = 0
ex = 0

fig, ax = plt.subplots()
# ax.imshow(10 * np.log10(abs(torch.fft.fftshift(torch.fft.fft2(w_mat[ex, c, 0])).detach())))
# ax.imshow(10 * np.log10(abs(torch.fft.fftshift(torch.fft.fft2(mask_ifft[0].real.permute(2, 0, 1)[:, None, None][:, :, 0])).detach()))[ex, 0])
ax.imshow(10 * np.log10(abs(torch.fft.fft2(mask_ifft[0].real.permute(2, 0, 1)[:, None, None][:, :, 0])).detach())[ex, 0])
# ax.imshow(mask_ifft[0].real.permute(2, 0, 1)[:, None, None][:, :, 0].detach()[ex, 0], vmin=-1, vmax=1)


In [18]:
# Animation of adding or removing the PC
import matplotlib.pyplot as plt
import matplotlib.animation as animation

c = 2
example = 1

# Create a sample 3D array (e.g., 10 slices of 2D images)
alphas = np.linspace(-1, 1, 50)
image_series = mmse.detach()[example, 0][None] + alphas[:, None, None] * (w_mat.detach().numpy()[example, c, 0] - w_mat.detach().numpy()[example, c, 0].mean())[None]
# mask = BrainMasker()(mmse.detach()[0, 0].numpy())

# Set up the figure and axis
fig, ax = plt.subplots()
im = ax.imshow(image_series[0], cmap='Greys_r', vmin=0, vmax=1.5)
fig.set_size_inches(10, 10)

def update(frame):
    # Update the image data
    ax.set_title(round(alphas[frame], 2))
    im.set_array(image_series[frame]) # * mask)
    return [im]

# Create animation
ani = animation.FuncAnimation(fig, update, frames=alphas.size, repeat=True, interval=30)

# Display the animation
# plt.show()
FFwriter = animation.FFMpegWriter(fps=10)
ani.save('/home/alon_granek/PythonProjects/NPPC/ani3.avi', writer=FFwriter)

<IPython.core.display.Javascript object>

INFO:matplotlib.animation:Animation.save using <class 'matplotlib.animation.FFMpegWriter'>
INFO:matplotlib.animation:MovieWriter._run: running command: ffmpeg -f rawvideo -vcodec rawvideo -s 1000x1000 -pix_fmt rgba -r 10 -loglevel error -i pipe: -vcodec h264 -pix_fmt yuv420p -y /home/alon_granek/PythonProjects/NPPC/ani3.avi


**Test of data consistency**

In [34]:
from alon.fastmri_preprocess import fftc

c = 4 #0
example = 1
# alphas = np.linspace(-1, 1, 100)
# image_series = mmse.detach()[example, 0][None] + alphas[:, None, None] * w_mat.detach().numpy()[example, c, 0][None]

# Sample
sigma = 2 #0.05
n_samples = 200 #100
samples = sigma * np.random.normal(0, 1, [n_dirs, n_samples])
# sample_images = mmse.detach()[example, 0] + alphas[:, None] * w_mat.detach().numpy()[example, :, 0]
sample_images = mmse.detach()[example, 0] + np.einsum('cs, cij -> sij', samples, w_mat.detach().numpy()[example, :, 0])

# Check data consistency
y_nppc_samples = fftc(sample_images)
# dc = np.linalg.norm(mask[0][None, ..., example] * y_nppc_samples - (y[0] + 1j * y[1]).detach().numpy()[..., example], axis=(1, 2))
dc = np.linalg.norm(mask[0][None, ..., example] * y_nppc_samples - (y[0] + 1j * y[1]).detach().numpy()[..., example], axis=(1, 2))


In [35]:
dc_mmse = np.linalg.norm(mask[0][..., example] * mmse.detach()[example, 0] - (y[0] + 1j * y[1]).detach().numpy()[..., example])

plt.figure()
plt.hist(dc, bins=50)
plt.axvline(dc_mmse, color='black')

<IPython.core.display.Javascript object>

<matplotlib.lines.Line2D at 0x7f8652f366b0>

In [66]:
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True)
ax[0].imshow(sample_images[dc.argmin()], vmin=0)
ax[1].imshow(x_true[example, ..., 0], vmin=0)
fig.set_size_inches(13, 8)

<IPython.core.display.Javascript object>

In [48]:
c = 3


# Create a sample 3D array (e.g., 10 slices of 2D images)
alphas = np.linspace(-0.5, 0.5, 10)
image_series = mmse.detach()[0, 0][None] + alphas[:, None, None] * w_mat.detach().numpy()[0, c, 0][None]
mask = BrainMasker()(mmse.detach()[0, 0].numpy())

fig, ax = plt.subplots()
im = ax.imshow(image_series[0], cmap='Greys_r', vmin=0, vmax=1.5)

def update(frame):
    # Update the image data
    ax.set_title(round(alphas[frame], 2))
    im.set_array(image_series[frame]) # * mask)
    return [im]

for i in range(image_series.shape[0]):
    update(i)

fig.set_size_inches(10, 10)
plt.show()

<IPython.core.display.Javascript object>

In [13]:
fig.savefig('/home/alon_granek/PythonProjects/NPPC/PCs 2.png')

# DAPA 2.1.1: Posterior-to-posterior mapping via a point-to-posterior mapping

In [ ]:
#todo Train network per step. Inference: Get mean, PCs per step, then sample a sphere, get total mean, covariance, then advance. Check if converges

In [12]:
""" Training """

modl_steps = np.arange(0, params.num_grad_steps)

for step in modl_steps:
    print(f'TRAINING ON MODL GRADIENT STEP {step}-{step + 1}')
    for rep in range(15): #12):
        start = repeats * rep
        losses = list()
        for iter, data in enumerate(train_loader_given_mask, start=start):
            print(f'Iter {iter}...')
            losses.append(get_batch_loss(data, step))
            if iter - start == repeats - 1:
                break
        
        objective = torch.stack(losses).mean()        
        nppc_optimizer.zero_grad()
        objective.backward() #retain_graph=True)
        nppc_optimizer.step()
        nppc_step += 1
    
        print(f'Loss: {objective.detach().item()}')
    
        torch.save(nppc_net, f'/home/alon_granek/PythonProjects/NPPC/modl_nppc_net_DCed step {step}.pth')


TRAINING ON MODL GRADIENT STEP 0-1
Iter 0...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Iter 1...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Iter 2...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Iter 3...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Loss: 10.983609199523926
Iter 4...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Iter 5...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Iter 6...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Iter 7...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Loss: 6.298846244812012
Iter 8...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Iter 9...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Iter 10...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructe

KeyboardInterrupt: 


## DAPA 1: Run directly on the abs of input zero-filled and output

In [11]:
# from nppc import NPPCTrainer

# trainer = NPPCTrainer(
#     model=nppc_net,
#     batch_size=16,
#     max_chunk_size=8,
#     output_folder='./results/celeba_inpainting_eyes/nppc/',
#     max_benchmark_samples=256,
# )
# trainer.train(
#     n_steps=100, #20000,
#     log_every=20,
#     benchmark_every=None,
# )
from alon.fastmri_preprocess import ifftc, fftc

params_nppc = params
params_nppc.batch_size = 8
train_loader = create_data_loaders(params)


from IPython.display import display, clear_output

# fig, ax = plt.subplots()


nppc_objective_log = []
for iter, data in enumerate(train_loader):
    
    # Obtain subsampled k-space, target image, k-space mask respectively
    y_distorted, x_org, mask = data
    x_org = x_org.to(device)
    mask = mask.to(device)
    y_distorted = y_distorted.to(device)
    # Zero-filled image ("distorted" in NPPC terms)
    x_distorted = torch.tensor(abs(ifftc(torch.tensor(fftc(x_org[..., 0] + 1j * x_org[..., 1])) * mask[..., 0]))).to(device).float()[:, None, ...]
    # x_distorted = abs(ifftc(y_distorted[..., 0] + 1j * y_distorted[..., 1]))
    
    with torch.no_grad():
        x_restored = single_MoDL(y_distorted.float(), mask=mask)
        # x_distorted = x_distorted[None]
        
        # Moving the "sample in batch" axis to be the first
        # x_distorted = torch.permute(x_distorted, [3, 0, 1, 2])
        # x_restored = torch.permute(x_distorted, [3, 0, 1, 2])
        # 
        # print(x_distorted.shape)
        #w_mat = nppc_net(x_distorted.permute([3, 0, 1, 2]), x_restored.permute([3, 0, 1, 2]))
        
    # zero-filling of distorted    
    
    
    # x_distorted = cplx.abs(x_distorted.detach())[:, None, ...]
    x_restored = cplx.abs(x_restored.detach())[:, None, ...]
    print(f'{x_restored.shape=}\t\t{x_distorted.shape=}')
    # break
    w_mat = nppc_net(x_distorted, x_restored)

    w_mat_ = w_mat.flatten(2)
    w_norms = w_mat_.norm(dim=2)
    w_hat_mat = w_mat_ / w_norms[:, :, None]
    
    x_org = cplx.abs(x_org.detach())[:, None, ...]
    err = (x_org - x_restored).flatten(1)

    ## Normalizing by the error's norm
    ## -------------------------------
    err_norm = err.norm(dim=1)
    err = err / err_norm[:, None]
    w_norms = w_norms / err_norm[:, None]
    
    ## W hat loss
    ## ----------
    err_proj = torch.einsum('bki,bi->bk', w_hat_mat, err)
    reconst_err = 1 - err_proj.pow(2).sum(dim=1)
    
    ## W norms loss
    ## ------------
    second_moment_mse = (w_norms.pow(2) - err_proj.detach().pow(2)).pow(2)
    
    second_moment_loss_lambda = -1 + 2 * nppc_step / second_moment_loss_grace
    second_moment_loss_lambda = max(min(second_moment_loss_lambda, 1) ,1e-6)
    second_moment_loss_lambda *= second_moment_loss_lambda
    objective = reconst_err.mean() + second_moment_loss_lambda * second_moment_mse.mean()

    nppc_optimizer.zero_grad()
    objective.backward()
    nppc_optimizer.step()
    nppc_step += 1
    
    
    print(f'Loss: {objective.detach().item()}')
    
    
    if True: #nppc_step % 10 == 0:
        nppc_objective_log.append(objective.detach().item())

        torch.save(nppc_net, f'/home/alon_granek/PythonProjects/NPPC/modl_nppc_net.pth')

        # ax.clear()
        # ax.plot(nppc_objective_log)
        # display(fig)

	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002641.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_201_6003008.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_202_6000421.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_202_6000508.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_202_6000531.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_203_6000942.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_209_6001397.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXT1POST_200_6002033.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXT1POST_200_6002392.h5.npy...
	/mnt/c/Users/along/brain_multicoil


KeyboardInterrupt



In [22]:
import plotly.express as px
# from nppc.auxil import imgs_to_grid
from IPython.core.display import Markdown

# samples_list = np.random.RandomState(1).randint(0, len(test_set), 10)
t_list = torch.linspace(-3, 3, 21).to(device)


def imgs_to_grid(imgs, nrows=None, **make_grid_args):
    imgs = imgs.detach().cpu()
    if imgs.ndim == 5:
        nrow = imgs.shape[1]
        imgs = imgs.reshape(imgs.shape[0] * imgs.shape[1], imgs.shape[2], imgs.shape[3], imgs.shape[4])
    elif nrows is None:
        nrow = int(np.ceil(imgs.shape[0] ** 0.5))

    make_grid_args2 = dict(value_range=(0, 1), pad_value=1.)
    make_grid_args2.update(make_grid_args)
    img = torchvision.utils.make_grid(imgs, nrow=nrow, **make_grid_args2).clamp(0, 1)
    return img

def scale_img(x):
    return x / torch.abs(x).flatten(-3).max(-1)[0][..., None, None, None] / 1.5 + 0.5

def tensor_img_to_numpy(x):
    return x.detach().permute(-2, -1, -3).cpu().numpy()

def imshow(img, scale=1, **kwargs):
    if isinstance(img, torch.Tensor):
        img = tensor_img_to_numpy(img)
    img = img.clip(0, 1)

    fig = px.imshow(img, **kwargs).update_layout(
        height=img.shape[0] * scale,
        width=img.shape[1] * scale,
        margin=dict(t=0, b=0, l=0, r=0),
        xaxis_showticklabels=False,
        yaxis_showticklabels=False,
    )
    return fig


test_total = 2
test_count = 0
for iter, data in enumerate(train_loader):
    
    # Obtain subsampled k-space, target image, k-space mask respectively
    y_distorted, x_org, mask = data
    x_org = x_org.to(device)
    mask = mask.to(device)
    y_distorted = y_distorted.to(device)
    # Zero-filled image ("distorted" in NPPC terms)
    x_distorted = torch.tensor(abs(ifftc(torch.tensor(fftc(x_org[..., 0] + 1j * x_org[..., 1])) * mask[..., 0]))).to(device).float()[:, None, ...]
        
    with torch.no_grad():
        x_restored = single_MoDL(y_distorted.float()[0][None], mask=mask[0][None])
        x_restored = cplx.abs(x_restored.detach())[:, None, ...]
        w_mat = nppc_net(x_distorted[0][None], x_restored[0][None])
    
    display(Markdown('## Sample \# ...'))
    display(Markdown('### Original, Distorted, Restored:'))
    imshow(imgs_to_grid(torch.stack((cplx.abs(x_org)[0], x_distorted[0], x_restored[0].squeeze()))[None, :, 0]), scale=2).show()
    
    display(Markdown('### Principal directions:'))
    imgs = t_list[:, None, None, None, None] * w_mat + x_restored[0][None][None]
    imgs = torch.cat((scale_img(w_mat), imgs), dim=0)
    imgs = imgs.transpose(0, 1).contiguous()
    imshow(imgs_to_grid(imgs), scale=1.6).show()
    
    if test_count == test_total:
        break
    else:
        test_count += 1
    

		Step 0/3...
		Step 1/3...
		Step 2/3...


## Sample \# ...

### Original, Distorted, Restored:

RuntimeError: stack expects each tensor to be equal size, but got [372, 372] at entry 0 and [1, 372, 372] at entry 1

In [35]:
import matplotlib.pyplot as plt
# fig, ax = plt.subplots(1, 5)
# for c in range(5):
#     ax[c].imshow(w_mat[0, c, 0], cmap='inferno', origin='lower', vmin=-0.5, vmax=0.5)
# fig.set_size_inches(13, 5)

c = 0
alphas = np.linspace(-0.5, 0.5, 10)
fig, ax = plt.subplots(1, len(alphas))
for i, alpha in enumerate(alphas):
    ax[i].imshow(x_restored[0, 0] + alpha * w_mat[0, c, 0], cmap='inferno', origin='lower', vmin=0, vmax=1.5)
fig.set_size_inches(30, 7)
fig.tight_layout()
fig.savefig('/home/alon_granek/PythonProjects/NPPC/1st pc attempt.png')

<IPython.core.display.Javascript object>

In [37]:
#todo
#   Make sure NPPCs are data-consistent.
#   Unrolled NPPC net?
#       DC over the whole set (not only \hat{x})

import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 2)
ax[0].imshow(x_distorted[0, 0])
ax[1].imshow(x_restored[0, 0])

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7fcf4419afb0>