In [1]:
%matplotlib notebook
import os, sys
import logging
import random
import h5py
import shutil
import time
import argparse
import numpy as np
import sigpy.plot as pl
import torch
import sigpy as sp
import torchvision
from torch import optim
from tensorboardX import SummaryWriter
from torch.nn import functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib
# import custom libraries
from utils import transforms as T
from utils import subsample as ss
from utils import complex_utils as cplx
from utils.resnet2p1d import generate_model
from utils.flare_utils import roll
# import custom classes
from utils.datasets import SliceData
from subsample_fastmri import MaskFunc
from MoDL_single import UnrolledModel
import argparse

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
%load_ext autoreload
%autoreload 0



In [2]:
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

In [3]:
class DataTransform:
    """
    Data Transformer for training unrolled reconstruction models.
    """

    def __init__(self, mask_func, args, use_seed=False):
        self.mask_func = mask_func
        self.use_seed = use_seed
        self.rng = np.random.RandomState()

    # def __call__(self, kspace, target, slice):
    #     im_lowres = abs(sp.ifft(sp.resize(sp.resize(kspace,(640,24)),(640,372))))
    #     magnitude_vals = im_lowres.reshape(-1)
    #     k = int(round(0.05 * magnitude_vals.shape[0]))
    #     scale = magnitude_vals[magnitude_vals.argsort()[::-1][k]]
    #     kspace = kspace/scale
    #     target = target/scale
    #     # Convert everything from numpy arrays to tensors
    #     kspace_torch = cplx.to_tensor(kspace).float()   
    #     target_torch = cplx.to_tensor(target).float()   
    #     mask_slice = np.ones((640,372))
    #     mk1 = self.mask_func((1,1,372,2))[0,0,:,0]
    #     knee_masks = mask_slice*mk1
    #     mask_torch = torch.tensor(knee_masks[...,None]).float()
    #     mask2 = sp.mri.poisson((640,372), 4, calib=(70, 56), dtype=float, crop_corner=True, return_density=False, seed=0, max_attempts=6, tol=0.1)
    #     mask_torch = torch.stack([torch.tensor(mask2).float(),torch.tensor(mask2).float()],dim=2)
    #     kspace_torch = kspace_torch*mask_torch
    # 
    #     return kspace_torch,target_torch,mask_torch
    
    def normalization(self, image: np.ndarray):
        norm_percentile = 95
        vals = image.reshape(-1)
        n_taken = int(round((1 - norm_percentile * 1e-2) * vals.shape[0]))
        scale = vals[vals.argsort()[::-1][n_taken]]
        return scale
    
    def __call__(self, kspace, target):
        """
        The forward model.
        
        :param kspace: 
        :param target: 
        :return: 
        """
        # Normalize
        scale = self.normalization(target)
        kspace /= scale
        target /= scale
        
        # Convert to torch tensors
        kspace_torch = cplx.to_tensor(kspace).float()   
        target_torch = cplx.to_tensor(target).float() 
        
        # k-space masking
        mask = sp.mri.poisson(
            img_shape=(372, 372),
            accel=6, #4,
            calib=(56, 56),
            dtype=float, crop_corner=True, return_density=False, seed=0, max_attempts=6, tol=0.1
        )
        mask_torch = torch.stack([torch.tensor(mask).float(),torch.tensor(mask).float()],dim=2)
        kspace_masked = kspace_torch * mask_torch
        
        return kspace_masked, target_torch, mask_torch

In [4]:
def create_datasets(args, frozen_mask: torch.Tensor = None):
    """
    
    :param args: 
    :param frozen_mask:    Use a pre-determined sampling mask. If left None, generating a new one.
    :return: 
    """
    if frozen_mask is None:
        # Generate k-t undersampling masks
        train_mask = MaskFunc([0.08],[4])
    else:
        train_mask = torch.clone(mask)
    
    train_data = SliceData(
        root=str(args.data_path),
        transform=DataTransform(train_mask, args),
        sample_rate=1
    )
    return train_data

def create_data_loaders(args, frozen_mask: torch.Tensor = None):
    train_data = create_datasets(args, frozen_mask)
    print(len(train_data))
#     print(train_data[0])

    train_loader = DataLoader(
        dataset=train_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )
    return train_loader

def build_optim(args, params):
    optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
    return optimizer

In [5]:
#Hyper parameters
params = Namespace()
params.data_path = '/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train' #"../../single_channel_data/train/"
params.batch_size = 2 #4
params.num_grad_steps = 3 #4
params.num_cg_steps = 8
params.share_weights = True
params.modl_lamda = 0.05
params.lr = 0.0001
params.weight_decay = 0
params.lr_step_size = 500
params.lr_gamma = 0.5
params.epoch = 21

In [6]:
train_loader = create_data_loaders(params, frozen_mask = None)

	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002641.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_201_6003008.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_202_6000421.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_202_6000508.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_202_6000531.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_203_6000942.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_209_6001397.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXT1POST_200_6002033.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXT1POST_200_6002392.h5.npy...
	/mnt/c/Users/along/brain_multicoil

In [7]:
single_MoDL = UnrolledModel(params).to(device)
optimizer = build_optim(params, single_MoDL.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, params.lr_step_size, params.lr_gamma)
criterion = nn.MSELoss()

shared weights


In [None]:
from pathlib import Path
import matplotlib.pyplot as plt

for epoch in range(params.epoch):
    print(f'Epoch {epoch}')
    single_MoDL.train()
    avg_loss = 0.

    for iter, data in enumerate(train_loader):
        print(f'{iter=}...')
        
        try:
            # if True:
            input,target,mask = data
            input = input.to(device)
            target = target.to(device)
            mask = mask.to(device)
            
            # fig, ax = plt.subplots(1, 2)
            # ax[0].imshow(10 * np.log10(np.maximum(abs(input)[0, ..., 0], 1e-8)), cmap='inferno')
            # ax[1].imshow(target[0, ..., 0], cmap='inferno')
            # plt.show()
            # raise ValueError
        except:
            continue
        # print(f'\tDone data prep. {input.shape=} {target.shape=}')
        im_out = single_MoDL(input.float(),mask=mask)
        loss = criterion(im_out,target)
        # print(f'\tFound loss')
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f'\tDone backprop')
        avg_loss = 0.99 * avg_loss + 0.01 * loss.item() if iter > 0 else loss.item()
        print(f'\tInstant loss: {loss.item()}. \t\tAvg loss: {avg_loss}')
        if iter % 20 == 0:
            logging.info(
                f'Epoch = [{epoch:3d}/{params.epoch:3d}] '
                f'Iter = [{iter:4d}/{len(train_loader):4d}] '
                f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g}'
            )
    #Saving the model
    exp_dir = '/home/alon_granek/PythonProjects/NPPC/alon/checkpoints2' #"L2_checkpoints_poisson_x4/"
    Path(exp_dir).mkdir(exist_ok=True)
    torch.save(
        {
            'epoch': epoch,
            'params': params,
            'model': single_MoDL.state_dict(),
            'optimizer': optimizer.state_dict(),
            'exp_dir': exp_dir
        },
        f=os.path.join(exp_dir, 'model_%d.pt'%(epoch))
    )

In [7]:
single_MoDL = UnrolledModel(params).to(device)
single_MoDL.load_state_dict(
    torch.load(
        '/home/alon_granek/PythonProjects/NPPC/alon/checkpoints2/model_0.pt'
    )['model']
)
single_MoDL.eval()

shared weights


UnrolledModel(
  (resnets): ModuleList(
    (0-2): 3 x UNet(
      (inc): inconv(
        (conv): double_conv(
          (conv): Sequential(
            (0): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (5): ReLU(inplace=True)
          )
        )
      )
      (down1): down(
        (mpconv): Sequential(
          (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          (1): double_conv(
            (conv): Sequential(
              (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
           

In [9]:
for iter, data in enumerate(train_loader):
    print(f'{iter=}...')
    
    # if True:
    input,target,mask = data
    input = input.to(device)
    target = target.to(device)
    mask = mask.to(device)
    
    # im_out = single_MoDL(input.float(),mask=mask)
    im_out, images = single_MoDL(input.float(), mask=mask, return_steps=True)
    
    break


sample = 0

recon_image = cplx.abs(im_out.detach())
gt_image = cplx.abs(target.detach())

from alon.fastmri_preprocess import ifftc, fftc
import matplotlib.pyplot as plt

in_kspace = input.detach()
in_image = torch.tensor(abs(ifftc(torch.tensor(fftc(gt_image)) * mask[..., 0]))).to(device)  #abs(ifftc(in_kspace[..., 0]) + 1j * ifftc(in_kspace[..., 1]))

iter=0...
		Step 0/3...
		Step 1/3...
		Step 2/3...


In [22]:
fig, ax = plt.subplots(1, 3)
ax[0].set_title('Target')
ax[0].imshow(torch.flipud(gt_image[sample]), cmap='Greys_r', vmin=0, vmax=1.2)
ax[1].set_title('6x accelerated zero-filled')
ax[1].imshow(torch.flipud(in_image[sample]), cmap='Greys_r', vmin=0, vmax=1.2)
ax[2].set_title('Reconstructed')
ax[2].imshow(torch.flipud(recon_image[sample]), cmap='Greys_r', vmin=0, vmax=1.2)
fig.set_size_inches(13, 5)
plt.tight_layout()
fig.show()

<IPython.core.display.Javascript object>

In [20]:
""" Attempt in rough skull stripping """

from alon.brain_masking import BrainMasker

mask = BrainMasker()(recon_image[sample])[0]


In [21]:

fig, ax = plt.subplots()
ax.imshow(mask)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f8586cbdf30>

In [11]:
import matplotlib.pyplot as plt
sample = 0

fig, ax = plt.subplots(1, 2 + len(images) + 1)
ax[0].set_title('Target')
ax[0].imshow(torch.flipud(gt_image[sample]), cmap='Greys_r', vmin=0, vmax=1.2)
ax[1].set_title('6x accelerated zero-filled')
ax[1].imshow(torch.flipud(in_image[sample]), cmap='Greys_r', vmin=0, vmax=1.2)
for i in range(2, 2 + len(images)):
    ax[i].set_title(f'Recon step {i - 2}')
    ax[i].imshow(torch.flipud(cplx.abs(images[i - 2][sample].detach())), cmap='Greys_r', vmin=0, vmax=1.2)
ax[-1].imshow(torch.flipud(cplx.abs(im_out[sample].detach())), cmap='Greys_r', vmin=0, vmax=1.2)
fig.set_size_inches(30, 7)
plt.tight_layout()
plt.savefig('/home/alon_granek/PythonProjects/NPPC/MoDL recon steps 3.png')
fig.show()

<IPython.core.display.Javascript object>

# Connecting MoDL to NPPC

**Naive approach:** Use the whole unrolled network for PC estimation. The PC subspace is expected to be data consistent, as there is no variance on the projection of k-space one-hot vectors corresponding to the measured k's.

In [8]:
class UNet(nn.Module):
    def __init__(
            self,
            in_channels=1,
            out_channels=1,
            channels_list=(32, 64, 128, 256),
            bottleneck_channels=512,
            min_channels_decoder=64,
            n_groups=8,
        ):

        super().__init__()
        ch = in_channels

        ## Encoder
        ## =======
        self.encoder_blocks = nn.ModuleList([])
        ch_hidden_list = []

        layers = []
        layers.append(nn.ZeroPad2d(2))
        ch_ = channels_list[0]
        layers.append(nn.Conv2d(ch, ch_, 3, padding=1))
        ch = ch_
        self.encoder_blocks.append(nn.Sequential(*layers))
        ch_hidden_list.append(ch)

        for i_level in range(len(channels_list)):
            ch_ = channels_list[i_level]
            downsample = i_level != 0

            layers = []
            if downsample:
                layers.append(nn.MaxPool2d(2))
            layers.append(nn.Conv2d(ch, ch_, 3, padding=1))
            ch = ch_
            layers.append(nn.GroupNorm(n_groups, ch))
            layers.append(nn.LeakyReLU(0.1))
            self.encoder_blocks.append(nn.Sequential(*layers))
            ch_hidden_list.append(ch)

        ## Bottleneck
        ## ==========
        ch_ = bottleneck_channels
        layers = []
        layers.append(nn.Conv2d(ch, ch_, 3, padding=1))
        ch = ch_
        layers.append(nn.GroupNorm(n_groups, ch))
        layers.append(nn.LeakyReLU(0.1))
        layers.append(nn.Conv2d(ch, ch, 3, padding=1))
        layers.append(nn.GroupNorm(n_groups, ch))
        layers.append(nn.LeakyReLU(0.1))
        self.bottleneck = nn.Sequential(*layers)

        ## Decoder
        ## =======
        self.decoder_blocks = nn.ModuleList([])
        for i_level in reversed(range(len(channels_list))):
            ch_ = max(channels_list[i_level], min_channels_decoder)
            downsample = i_level != 0
            ch = ch + ch_hidden_list.pop()
            layers = []

            layers.append(nn.Conv2d(ch, ch_, 3, padding=1))
            ch = ch_
            layers.append(nn.GroupNorm(n_groups, ch))
            layers.append(nn.LeakyReLU(0.1))
            if downsample:
                layers.append(nn.Upsample(scale_factor=2, mode='nearest'))
            self.decoder_blocks.append(nn.Sequential(*layers))

        ch = ch + ch_hidden_list.pop()
        ch_ = channels_list[0]
        layers = []
        layers.append(nn.Conv2d(ch, out_channels, 1))
        layers.append(nn.ZeroPad2d(-2))
        self.decoder_blocks.append(nn.Sequential(*layers))

    def forward(self, x):
        h = []
        for block in self.encoder_blocks:
            x = block(x)
            h.append(x)

        x = self.bottleneck(x)
        for block in self.decoder_blocks:
            x = torch.cat((x, h.pop()), dim=1)
            x = block(x)
        return x

In [9]:
from nppc import PCWrapper, networks

n_dirs = 5

# base_net = networks.ResUNet(
#     in_channels=2, #pre_out_channels + self.x_shape[0],
#     out_channels=1 + n_dirs, #self.x_shape[0] * n_dirs,
#     channels_list=(64, 64, 128, 128, 256, 256),
#     bottleneck_channels=512,
#     downsample_list=(False, True, True, True, True, True),
#     attn_list=(False, False, False, False, True, False),
#     n_blocks=2,
#     n_groups=8,
#     attn_heads=1,
# )

# base_net = networks.UNet(
#     in_channels=2, #pre_out_channels + self.x_shape[0],
#     out_channels=1, #self.x_shape[0] * n_dirs,
#     channels_list=(32, 64, 128),
#     bottleneck_channels=256,
#     n_blocks=1,
#     n_blocks_bottleneck=2,
#     min_channels_decoder=64,
# )


from alon.MoDLsinglechannel.demo_modl_singlechannel.utils.transforms import SenseModel_single
from alon.MoDLsinglechannel.demo_modl_singlechannel.MoDL_single import Operator


# Initialize frozen sampling mask
for iter, data in enumerate(train_loader):
    mask = data[-1].to(device)
    mask_ifft = torch.fft.ifft2(torch.fft.ifftshift(mask, axis=(1, 2)), axis=(1, 2))
    # Sense = Operator(SenseModel_single(weights=mask))
    # delta_1d = abs(np.arange(mask.shape[1], dtype=float) - mask.shape[1] / 2.0) <= 1
    # delta = delta_1d[None, :] * delta_1d[:, None]
    # mask_ifft = Sense.adjoint(np.ones_like(data[0])) #delta[None, ..., None])
    break

nppc_net = PCWrapper(UNet(in_channels=1 + 1, out_channels=1 * n_dirs), n_dirs=n_dirs) #, mask_ifft=mask_ifft) #, mask=1) #mask)
# nppc_net = PCWrapper(base_net, n_dirs=n_dirs)

#nppc_net.__dict__.update({'ddp', Namespace(size=1)})
nppc_net.__setattr__('ddp', Namespace(size=1))

nppc_net.to(device)
nppc_net.train()
nppc_optimizer = torch.optim.Adam(
    nppc_net.parameters(), lr=1e-4, #lr=1e-5, #lr=1e-4,
    betas=(0.9, 0.999)
)
nppc_step = 0



restoration_n_steps = 4000 #1000 #3000
nppc_n_steps = 3000 #3000
batch_size = 8 #2 #64 #128 #64 #128       #256

second_moment_loss_lambda = 1e0 #0.2 #1e0
second_moment_loss_grace = 500

## DAPA 2: Run on the last step in MoDL
**Rationale**
1. This is where **x** enters the true solution space
2. This is the smallest problem where uncertainty still exists (we trust the previous steps in MoDL)
3. It is a much easier problem

In [10]:
from alon.brain_masking import BrainMasker

In [38]:
for iter, data in enumerate(train_loader):
    
    # Obtain subsampled k-space, target image, k-space mask respectively
    y, x_true, mask = data
    x_true = x_true.to(device)
    mask = mask.to(device)
    y = y.to(device)
        
    with torch.no_grad():
        x_recon, x_intermed = single_MoDL(y.float(), mask=mask, return_steps=True)
    
    # Input to NPPC: The n-1-th iteration solution as input, and n-th iteration solution as the MMSE
    mmse = cplx.abs(x_recon)[:, None, ...]
    problem_input = cplx.abs(x_intermed[-1])[:, None, ...]
    w_mat_orig = nppc_net(problem_input, mmse)
    
    """ Test - defining the error only on the masked image """
    # mask = BrainMasker()(mmse[0, 0])[0]#.flatten()
    # err *= mask[None]
    
    # w_mat = w_mat_orig * torch.tensor(BrainMasker()(mmse[0, 0])[0][None, None, None], dtype=torch.float32)
    w_mat = w_mat_orig
    """"""
    
    
    w_mat_ = w_mat.flatten(2)
    w_norms = w_mat_.norm(dim=2)
    w_hat_mat = w_mat_ / w_norms[:, :, None]
    
    x_true = cplx.abs(x_true.detach())[:, None, ...]
    err = (x_true - mmse).flatten(1)
    

    ## Normalizing by the error's norm
    ## -------------------------------
    err_norm = err.norm(dim=1)
    err = err / err_norm[:, None]
    w_norms = w_norms / err_norm[:, None]
    
    ## W hat loss
    ## ----------
    err_proj = torch.einsum('bki,bi->bk', w_hat_mat, err)
    reconst_err = 1 - err_proj.pow(2).sum(dim=1)
    
    ## W norms loss
    ## ------------
    second_moment_mse = (w_norms.pow(2) - err_proj.detach().pow(2)).pow(2)
    
    second_moment_loss_lambda = -1 + 2 * nppc_step / second_moment_loss_grace
    second_moment_loss_lambda = max(min(second_moment_loss_lambda, 1) ,1e-6)
    second_moment_loss_lambda *= second_moment_loss_lambda
    objective = reconst_err.mean() + second_moment_loss_lambda * second_moment_mse.mean()

    nppc_optimizer.zero_grad()
    objective.backward()
    nppc_optimizer.step()
    nppc_step += 1
    
    
    print(f'Loss: {objective.detach().item()}')
    
    torch.save(nppc_net, f'/home/alon_granek/PythonProjects/NPPC/modl_nppc_net3.pth')
        

		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.825690746307373
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.904462993144989
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.8065555095672607
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.8409633040428162
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.7346203327178955
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.752144455909729
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.8610484600067139
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.7396346926689148
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.8439711332321167
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.8512775897979736
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.735094428062439
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.8334081172943115
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.7661327123641968
		Step 0/3...
		Step 1/3...
		Step 2/3...
Loss: 0.7712322473526001
		Step 0/3...
		Step 1/3...


KeyboardInterrupt: 

## DAPA 2.1: On last MoDL step, but enforcing data consistency

It is hard not to train the posterior of the absolute value clean image given the absolute value sampling-noised image. That's why I do so in DAPA 2.

However, this degenerates the problem to a simple denoising problem, and so nothing makes the predicted PC subspace data-consistent, while it should be just that.

* Not entirely though. Since we train on the last MoDL iteration, the k-space interpolation is almost fully done. There is simply larger uncertainty where not measured vs where measured.

DAPA 2.1 attempts to continue learning on the absolute-values inverse problem, but ensure data-consistency on the PC subspace. That is in the hope that the only denoising done here is on the sampling noise.
 
We look for an inverse problem of the form

$$
\underset{\mathbf x}{\arg \min} \|\mathbf{MF}(\hat{\mathbf x} + \mathbf W \boldsymbol \alpha) - \mathbf y\|^2 + \lambda R(\hat{\mathbf x} + \mathbf W \boldsymbol \alpha)
$$

$\hat{\mathbf x}\quad$      MMSE estimate
$\mathbf W\quad$            Predicted PCs
$\boldsymbol \alpha\quad$   Coordinates with respect to PCs. Let any such coordinates.

Specifically for DC,
$$
\|\mathbf{MF}(\hat{\mathbf x} + \mathbf W \boldsymbol \alpha) - \mathbf y\|^2 \leq \|\mathbf{MF}\hat{\mathbf x} - \mathbf y\|^2 + \|\mathbf{MF}\mathbf W \boldsymbol \alpha\|^2
$$
Where the first term is already minimized by MoDL and is independent of PCs. We will attempt to add a $\|\mathbf{MW}\|^2$ loss.


**A major change is that we freeze the mask to properly train the NPPC for this inverse problem**

In [10]:
train_loader_given_mask = create_data_loaders(params, frozen_mask=mask)

	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002641.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_201_6003008.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_202_6000421.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_202_6000508.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_202_6000531.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_203_6000942.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_209_6001397.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXT1POST_200_6002033.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXT1POST_200_6002392.h5.npy...
	/mnt/c/Users/along/brain_multicoil

In [11]:
repeats = 4
batch_objectives = torch.zeros(repeats, dtype=torch.float32)


def get_batch_loss(data):
    # Obtain subsampled k-space, target image, k-space mask respectively
    y, x_true, mask = data
    x_true = x_true.to(device)
    mask = mask.to(device)
    y = y.to(device)
    
    with torch.no_grad():
        x_recon, x_intermed = single_MoDL(y.float(), mask=mask, return_steps=True)
        print('\tReconstructed')
        
    # Input to NPPC: The n-1-th iteration solution as input, and n-th iteration solution as the MMSE
    mmse = cplx.abs(x_recon)[:, None, ...]
    problem_input = cplx.abs(x_intermed[-2])[:, None, ...]
    w_mat = nppc_net(problem_input, mmse)
    
    w_mat_ = w_mat.flatten(2)
    w_norms = w_mat_.norm(dim=2)
    w_hat_mat = w_mat_ / w_norms[:, :, None]
    
    x_true = cplx.abs(x_true.detach())[:, None, ...]
    err = (x_true - mmse).flatten(1)
    
    print('\tCalculated error')

    ## Normalizing by the error's norm
    ## -------------------------------
    err_norm = err.norm(dim=1)
    err = err / err_norm[:, None]
    w_norms = w_norms / err_norm[:, None]
    
    ## W hat loss
    ## ----------
    err_proj = torch.einsum('bki,bi->bk', w_hat_mat, err)
    reconst_err = 1 - err_proj.pow(2).sum(dim=1)
    
    ## W norms loss
    ## ------------
    second_moment_mse = (w_norms.pow(2) - err_proj.detach().pow(2)).pow(2)
    
    ## (Alon, ShimronLab) Data-consistency loss
    ## ----------------------------------------
    # print('\tCalculating DC loss')
    # MW = torch.einsum('cije, edij -> ec', mask_ifft.float(), w_mat[:, :, 0].float()) / (2 * err.size()[1])           # c: Re/Im component, i,j: image dims, e: example, d: PC directions
    MFW = torch.einsum(
        'cije, edij -> ec',
        mask, abs(torch.fft.fftshift(torch.fft.fft2(w_mat[:, :, 0], axis=(1, 2)), axis=(1, 2)))
    ) / (2 * err.size()[1])           # c: Re/Im component, i,j: image dims, e: example, d: PC directions
    dc_loss = torch.linalg.norm(MFW) / (MFW.shape[0] * MFW.shape[1])
    dc_loss_lambda = 5e0 #1e0
    
    second_moment_loss_lambda = -1 + 2 * nppc_step / second_moment_loss_grace
    second_moment_loss_lambda = max(min(second_moment_loss_lambda, 1) ,1e-6)
    second_moment_loss_lambda *= second_moment_loss_lambda
    
    return reconst_err.mean() + second_moment_loss_lambda * second_moment_mse.mean() + dc_loss_lambda * dc_loss



for rep in range(20): #12):
    start = repeats * rep
    losses = list()
    for iter, data in enumerate(train_loader_given_mask, start=start):
        print(f'Iter {iter}...')
        losses.append(get_batch_loss(data))
        if iter - start == repeats - 1:
            break
    
    objective = torch.stack(losses).mean()
    
    # enum_lim = islice(enumerate(train_loader_given_mask, start=repeats * rep), repeats)
    # objective = torch.mean(torch.tensor([get_batch_loss(data) for iter, data in enum_lim]))
    nppc_optimizer.zero_grad()
    objective.backward() #retain_graph=True)
    nppc_optimizer.step()
    nppc_step += 1

    print(f'Loss: {objective.detach().item()}')

    torch.save(nppc_net, f'/home/alon_granek/PythonProjects/NPPC/modl_nppc_net_DCed2test.pth')


Iter 0...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Iter 1...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Iter 2...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Iter 3...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Loss: 9.691699028015137
Iter 4...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Iter 5...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Iter 6...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Iter 7...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Loss: 5.855975151062012
Iter 8...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Iter 9...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Iter 10...
		Step 0/3...
		Step 1/3...
		Step 2/3...
	Reconstructed
	Calculated error
Iter 11...
		Ste

KeyboardInterrupt: 

In [12]:
nppc_net = torch.load(
    # f'/home/alon_granek/PythonProjects/NPPC/modl_nppc_net3.pth'
    #f'/home/alon_granek/PythonProjects/NPPC/modl_nppc_net_DCed.pth'
    f'/home/alon_granek/PythonProjects/NPPC/modl_nppc_net_DCed2test.pth'
)

for iter, data in enumerate(train_loader_given_mask):
    
    # Obtain subsampled k-space, target image, k-space mask respectively
    y, x_true, mask = data
    x_true = x_true.to(device)
    mask = mask.to(device)
    y = y.to(device)
        
    with torch.no_grad():
        x_recon, x_intermed = single_MoDL(y.float(), mask=mask, return_steps=True)
    
    # Input to NPPC: The n-1-th iteration solution as input, and n-th iteration solution as the MMSE
    mmse = cplx.abs(x_recon)[:, None, ...]
    problem_input = cplx.abs(x_intermed[-1])[:, None, ...]
    w_mat = nppc_net(problem_input, mmse)
    
    break

		Step 0/3...
		Step 1/3...
		Step 2/3...


In [28]:
from alon.MoDLsinglechannel.demo_modl_singlechannel.utils.transforms import SenseModel_single

fig, ax = plt.subplots(1, 2, sharex=True, sharey=True)
# ax[0].imshow(cplx.abs(x_intermed[0])[1])
ax[0].imshow(problem_input[1, 0])
ax[1].imshow(mmse[1, 0])

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f0d53714220>

In [23]:
""" Analysis of PCs in k-space - as they should all be 0 where not measured """
from alon.fastmri_preprocess import fftc, ifftc
import matplotlib.pyplot as plt

c = 4
example = 1

y_c = fftc(w_mat[example, c, 0].detach().numpy())
log_y_c = 10 * np.log10(abs(y_c))
m = mask[0, ..., example].detach().numpy().astype(bool)

# plt.figure()
# plt.hist([log_y_c[m], log_y_c[~m]], histtype='step', label=['measured', 'not measured'], density=True, bins=50)
# plt.legend()


fig, ax = plt.subplots(1, 2)
ax[0].imshow(ifftc(y_c * (1 - m)).real, vmin=-0.1, vmax=0.1, cmap='RdBu')
ax[0].set_title('zeroed')
ax[1].imshow(w_mat[example, c, 0].detach().numpy(), vmin=-0.1, vmax=0.1, cmap='RdBu')
ax[1].set_title('raw')

<IPython.core.display.Javascript object>

Text(0.5, 1.0, 'raw')

In [16]:
mask.shape

torch.Size([2, 372, 372, 2])

In [31]:
import matplotlib.pyplot as plt
# fig, ax = plt.subplots(1, 5)
# for c in range(5):
#     ax[c].imshow(w_mat[0, c, 0], cmap='inferno', origin='lower', vmin=-0.5, vmax=0.5)
# fig.set_size_inches(13, 5)

c = 0
alphas = np.linspace(-0.1, 0.1, 5)
fig, ax = plt.subplots(1, len(alphas))
for i, alpha in enumerate(alphas):
    ax[i].imshow(mmse.detach()[0, 0] + alpha * w_mat.detach()[0, c, 0], cmap='Greys_r', origin='lower', vmin=0, vmax=1.5)
fig.set_size_inches(30, 7)
fig.tight_layout()
fig.savefig('/home/alon_granek/PythonProjects/NPPC/temp.png')

<IPython.core.display.Javascript object>

In [41]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# List of image file paths

# Initial image index
current_image_index = 0

c = 3
alphas = np.linspace(-1, 1, 4)

# Function to display the image at the current index
def display_image():
    img = mmse.detach()[0, 0] + alphas[current_image_index] * w_mat.detach()[0, c, 0] #mpimg.imread(image_files[current_image_index])
    ax.imshow(img, cmap='Greys_r', vmin=0, vmax=1.5, origin='lower')
    ax.set_title(f"Alpha {round(alphas[current_image_index], 2)}")
    fig.canvas.draw()

# Function to handle key press events
def on_key(event):
    global current_image_index
    if event.key == 'up':
        current_image_index = (current_image_index + 1) % alphas.size #len(image_files)
    elif event.key == 'down':
        current_image_index = (current_image_index - 1) % alphas.size #len(image_files)
    display_image()

# Create a figure and axis
fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.2)
fig.set_size_inches(10, 10)
fig.tight_layout()

# Display the initial image
display_image()

# Connect the key press event to the handler
fig.canvas.mpl_connect('key_press_event', on_key)

# Display the plot
plt.show()


<IPython.core.display.Javascript object>

In [16]:
""" Showing all PCs """

import matplotlib.pyplot as plt

example = 1

fig, ax = plt.subplots(1, 5)
for c in range(5):
    ax[c].set_title(f'PC {c + 1}')
    # ax[c].imshow(w_mat.detach()[example, c, 0] - w_mat.detach()[example, c, 0].mean(), vmin=-0.2, vmax=0.2, cmap='RdBu')
    ax[c].imshow(w_mat.detach()[example, c, 0], vmin=-0.2, vmax=0.2, cmap='RdBu')
fig.set_size_inches(20, 6)
# fig.set_size_inches(10, 5)
fig.tight_layout()
fig.savefig('/home/alon_granek/PythonProjects/NPPC/temp2.png')

<IPython.core.display.Javascript object>

In [15]:
from alon.fastmri_preprocess import fftc, ifftc

example = 1

fig, ax = plt.subplots(1, 5)
for c in range(5):
    ax[c].set_title(f'PC {c + 1}')
    ax[c].imshow(10 * np.log10(abs(fftc(w_mat.detach()[example, c, 0]))), cmap='magma')
fig.set_size_inches(20, 6)
# fig.set_size_inches(10, 5)
fig.tight_layout()
# fig.savefig('/home/alon_granek/PythonProjects/NPPC/pc kspace.png')

<IPython.core.display.Javascript object>

In [31]:
# Orthogonallity test

# for c in range(5):
#     w = w_mat[:, c]
#     dot = torch.einsum('ecij, ecij -> ec', mask_ifft[0].real.permute(2, 0, 1)[:, None, None][:, :, 0], w_mat[:, :, 0])
#     norm_w = torch.einsum('ecij, ecij -> ec', w_mat[:, :, 0], w_mat[:, :, 0])
#     norm_m = torch.einsum('ecij, ecij -> ec', mask_ifft[0].real.permute(2, 0, 1)[:, None, None][:, :, 0], mask_ifft[0].real.permute(2, 0, 1)[:, None, None][:, :, 0])
#     print(
#         f'{dot=}\n\n{norm_w=}\n\n{norm_m=}'
#     )
    
# mask_ifft.shape, w_mat.shape

from alon.MoDLsinglechannel.demo_modl_singlechannel.utils.flare_utils import torch_fft2c

c = 0
ex = 0

fig, ax = plt.subplots()
# ax.imshow(10 * np.log10(abs(torch.fft.fftshift(torch.fft.fft2(w_mat[ex, c, 0])).detach())))
# ax.imshow(10 * np.log10(abs(torch.fft.fftshift(torch.fft.fft2(mask_ifft[0].real.permute(2, 0, 1)[:, None, None][:, :, 0])).detach()))[ex, 0])
ax.imshow(10 * np.log10(abs(torch.fft.fft2(mask_ifft[0].real.permute(2, 0, 1)[:, None, None][:, :, 0])).detach())[ex, 0])
# ax.imshow(mask_ifft[0].real.permute(2, 0, 1)[:, None, None][:, :, 0].detach()[ex, 0], vmin=-1, vmax=1)


<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f08fae609d0>

In [18]:
# Animation of adding or removing the PC
import matplotlib.pyplot as plt
import matplotlib.animation as animation

c = 2
example = 1

# Create a sample 3D array (e.g., 10 slices of 2D images)
alphas = np.linspace(-1, 1, 50)
image_series = mmse.detach()[example, 0][None] + alphas[:, None, None] * (w_mat.detach().numpy()[example, c, 0] - w_mat.detach().numpy()[example, c, 0].mean())[None]
# mask = BrainMasker()(mmse.detach()[0, 0].numpy())

# Set up the figure and axis
fig, ax = plt.subplots()
im = ax.imshow(image_series[0], cmap='Greys_r', vmin=0, vmax=1.5)
fig.set_size_inches(10, 10)

def update(frame):
    # Update the image data
    ax.set_title(round(alphas[frame], 2))
    im.set_array(image_series[frame]) # * mask)
    return [im]

# Create animation
ani = animation.FuncAnimation(fig, update, frames=alphas.size, repeat=True, interval=30)

# Display the animation
# plt.show()
FFwriter = animation.FFMpegWriter(fps=10)
ani.save('/home/alon_granek/PythonProjects/NPPC/ani3.avi', writer=FFwriter)

<IPython.core.display.Javascript object>

INFO:matplotlib.animation:Animation.save using <class 'matplotlib.animation.FFMpegWriter'>
INFO:matplotlib.animation:MovieWriter._run: running command: ffmpeg -f rawvideo -vcodec rawvideo -s 1000x1000 -pix_fmt rgba -r 10 -loglevel error -i pipe: -vcodec h264 -pix_fmt yuv420p -y /home/alon_granek/PythonProjects/NPPC/ani3.avi


**Test of data consistency**

In [34]:
from alon.fastmri_preprocess import fftc

c = 4 #0
example = 1
# alphas = np.linspace(-1, 1, 100)
# image_series = mmse.detach()[example, 0][None] + alphas[:, None, None] * w_mat.detach().numpy()[example, c, 0][None]

# Sample
sigma = 2 #0.05
n_samples = 200 #100
samples = sigma * np.random.normal(0, 1, [n_dirs, n_samples])
# sample_images = mmse.detach()[example, 0] + alphas[:, None] * w_mat.detach().numpy()[example, :, 0]
sample_images = mmse.detach()[example, 0] + np.einsum('cs, cij -> sij', samples, w_mat.detach().numpy()[example, :, 0])

# Check data consistency
y_nppc_samples = fftc(sample_images)
# dc = np.linalg.norm(mask[0][None, ..., example] * y_nppc_samples - (y[0] + 1j * y[1]).detach().numpy()[..., example], axis=(1, 2))
dc = np.linalg.norm(mask[0][None, ..., example] * y_nppc_samples - (y[0] + 1j * y[1]).detach().numpy()[..., example], axis=(1, 2))


In [35]:
dc_mmse = np.linalg.norm(mask[0][..., example] * mmse.detach()[example, 0] - (y[0] + 1j * y[1]).detach().numpy()[..., example])

plt.figure()
plt.hist(dc, bins=50)
plt.axvline(dc_mmse, color='black')

<IPython.core.display.Javascript object>

<matplotlib.lines.Line2D at 0x7f8652f366b0>

In [66]:
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True)
ax[0].imshow(sample_images[dc.argmin()], vmin=0)
ax[1].imshow(x_true[example, ..., 0], vmin=0)
fig.set_size_inches(13, 8)

<IPython.core.display.Javascript object>

In [48]:
c = 3


# Create a sample 3D array (e.g., 10 slices of 2D images)
alphas = np.linspace(-0.5, 0.5, 10)
image_series = mmse.detach()[0, 0][None] + alphas[:, None, None] * w_mat.detach().numpy()[0, c, 0][None]
mask = BrainMasker()(mmse.detach()[0, 0].numpy())

fig, ax = plt.subplots()
im = ax.imshow(image_series[0], cmap='Greys_r', vmin=0, vmax=1.5)

def update(frame):
    # Update the image data
    ax.set_title(round(alphas[frame], 2))
    im.set_array(image_series[frame]) # * mask)
    return [im]

for i in range(image_series.shape[0]):
    update(i)

fig.set_size_inches(10, 10)
plt.show()

<IPython.core.display.Javascript object>

In [13]:
fig.savefig('/home/alon_granek/PythonProjects/NPPC/PCs 2.png')

# DAPA 2.1.1: Posterior-to-posterior mapping via a point-to-posterior mapping

In [ ]:
#todo Train network per step. Inference: Get mean, PCs per step, then sample a sphere, get total mean, covariance, then advance. Check if converges.


## DAPA 1: Run directly on the abs of input zero-filled and output

In [11]:
# from nppc import NPPCTrainer

# trainer = NPPCTrainer(
#     model=nppc_net,
#     batch_size=16,
#     max_chunk_size=8,
#     output_folder='./results/celeba_inpainting_eyes/nppc/',
#     max_benchmark_samples=256,
# )
# trainer.train(
#     n_steps=100, #20000,
#     log_every=20,
#     benchmark_every=None,
# )
from alon.fastmri_preprocess import ifftc, fftc

params_nppc = params
params_nppc.batch_size = 8
train_loader = create_data_loaders(params)


from IPython.display import display, clear_output

# fig, ax = plt.subplots()


nppc_objective_log = []
for iter, data in enumerate(train_loader):
    
    # Obtain subsampled k-space, target image, k-space mask respectively
    y_distorted, x_org, mask = data
    x_org = x_org.to(device)
    mask = mask.to(device)
    y_distorted = y_distorted.to(device)
    # Zero-filled image ("distorted" in NPPC terms)
    x_distorted = torch.tensor(abs(ifftc(torch.tensor(fftc(x_org[..., 0] + 1j * x_org[..., 1])) * mask[..., 0]))).to(device).float()[:, None, ...]
    # x_distorted = abs(ifftc(y_distorted[..., 0] + 1j * y_distorted[..., 1]))
    
    with torch.no_grad():
        x_restored = single_MoDL(y_distorted.float(), mask=mask)
        # x_distorted = x_distorted[None]
        
        # Moving the "sample in batch" axis to be the first
        # x_distorted = torch.permute(x_distorted, [3, 0, 1, 2])
        # x_restored = torch.permute(x_distorted, [3, 0, 1, 2])
        # 
        # print(x_distorted.shape)
        #w_mat = nppc_net(x_distorted.permute([3, 0, 1, 2]), x_restored.permute([3, 0, 1, 2]))
        
    # zero-filling of distorted    
    
    
    # x_distorted = cplx.abs(x_distorted.detach())[:, None, ...]
    x_restored = cplx.abs(x_restored.detach())[:, None, ...]
    print(f'{x_restored.shape=}\t\t{x_distorted.shape=}')
    # break
    w_mat = nppc_net(x_distorted, x_restored)

    w_mat_ = w_mat.flatten(2)
    w_norms = w_mat_.norm(dim=2)
    w_hat_mat = w_mat_ / w_norms[:, :, None]
    
    x_org = cplx.abs(x_org.detach())[:, None, ...]
    err = (x_org - x_restored).flatten(1)

    ## Normalizing by the error's norm
    ## -------------------------------
    err_norm = err.norm(dim=1)
    err = err / err_norm[:, None]
    w_norms = w_norms / err_norm[:, None]
    
    ## W hat loss
    ## ----------
    err_proj = torch.einsum('bki,bi->bk', w_hat_mat, err)
    reconst_err = 1 - err_proj.pow(2).sum(dim=1)
    
    ## W norms loss
    ## ------------
    second_moment_mse = (w_norms.pow(2) - err_proj.detach().pow(2)).pow(2)
    
    second_moment_loss_lambda = -1 + 2 * nppc_step / second_moment_loss_grace
    second_moment_loss_lambda = max(min(second_moment_loss_lambda, 1) ,1e-6)
    second_moment_loss_lambda *= second_moment_loss_lambda
    objective = reconst_err.mean() + second_moment_loss_lambda * second_moment_mse.mean()

    nppc_optimizer.zero_grad()
    objective.backward()
    nppc_optimizer.step()
    nppc_step += 1
    
    
    print(f'Loss: {objective.detach().item()}')
    
    
    if True: #nppc_step % 10 == 0:
        nppc_objective_log.append(objective.detach().item())

        torch.save(nppc_net, f'/home/alon_granek/PythonProjects/NPPC/modl_nppc_net.pth')

        # ax.clear()
        # ax.plot(nppc_objective_log)
        # display(fig)

	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_200_6002641.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_201_6003008.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_202_6000421.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_202_6000508.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_202_6000531.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_203_6000942.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXFLAIR_209_6001397.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXT1POST_200_6002033.h5.npy...
	/mnt/c/Users/along/brain_multicoil_train_batch_0/multicoil_train/file_brain_AXT1POST_200_6002392.h5.npy...
	/mnt/c/Users/along/brain_multicoil


KeyboardInterrupt



In [22]:
import plotly.express as px
# from nppc.auxil import imgs_to_grid
from IPython.core.display import Markdown

# samples_list = np.random.RandomState(1).randint(0, len(test_set), 10)
t_list = torch.linspace(-3, 3, 21).to(device)


def imgs_to_grid(imgs, nrows=None, **make_grid_args):
    imgs = imgs.detach().cpu()
    if imgs.ndim == 5:
        nrow = imgs.shape[1]
        imgs = imgs.reshape(imgs.shape[0] * imgs.shape[1], imgs.shape[2], imgs.shape[3], imgs.shape[4])
    elif nrows is None:
        nrow = int(np.ceil(imgs.shape[0] ** 0.5))

    make_grid_args2 = dict(value_range=(0, 1), pad_value=1.)
    make_grid_args2.update(make_grid_args)
    img = torchvision.utils.make_grid(imgs, nrow=nrow, **make_grid_args2).clamp(0, 1)
    return img

def scale_img(x):
    return x / torch.abs(x).flatten(-3).max(-1)[0][..., None, None, None] / 1.5 + 0.5

def tensor_img_to_numpy(x):
    return x.detach().permute(-2, -1, -3).cpu().numpy()

def imshow(img, scale=1, **kwargs):
    if isinstance(img, torch.Tensor):
        img = tensor_img_to_numpy(img)
    img = img.clip(0, 1)

    fig = px.imshow(img, **kwargs).update_layout(
        height=img.shape[0] * scale,
        width=img.shape[1] * scale,
        margin=dict(t=0, b=0, l=0, r=0),
        xaxis_showticklabels=False,
        yaxis_showticklabels=False,
    )
    return fig


test_total = 2
test_count = 0
for iter, data in enumerate(train_loader):
    
    # Obtain subsampled k-space, target image, k-space mask respectively
    y_distorted, x_org, mask = data
    x_org = x_org.to(device)
    mask = mask.to(device)
    y_distorted = y_distorted.to(device)
    # Zero-filled image ("distorted" in NPPC terms)
    x_distorted = torch.tensor(abs(ifftc(torch.tensor(fftc(x_org[..., 0] + 1j * x_org[..., 1])) * mask[..., 0]))).to(device).float()[:, None, ...]
        
    with torch.no_grad():
        x_restored = single_MoDL(y_distorted.float()[0][None], mask=mask[0][None])
        x_restored = cplx.abs(x_restored.detach())[:, None, ...]
        w_mat = nppc_net(x_distorted[0][None], x_restored[0][None])
    
    display(Markdown('## Sample \# ...'))
    display(Markdown('### Original, Distorted, Restored:'))
    imshow(imgs_to_grid(torch.stack((cplx.abs(x_org)[0], x_distorted[0], x_restored[0].squeeze()))[None, :, 0]), scale=2).show()
    
    display(Markdown('### Principal directions:'))
    imgs = t_list[:, None, None, None, None] * w_mat + x_restored[0][None][None]
    imgs = torch.cat((scale_img(w_mat), imgs), dim=0)
    imgs = imgs.transpose(0, 1).contiguous()
    imshow(imgs_to_grid(imgs), scale=1.6).show()
    
    if test_count == test_total:
        break
    else:
        test_count += 1
    

		Step 0/3...
		Step 1/3...
		Step 2/3...


## Sample \# ...

### Original, Distorted, Restored:

RuntimeError: stack expects each tensor to be equal size, but got [372, 372] at entry 0 and [1, 372, 372] at entry 1

In [35]:
import matplotlib.pyplot as plt
# fig, ax = plt.subplots(1, 5)
# for c in range(5):
#     ax[c].imshow(w_mat[0, c, 0], cmap='inferno', origin='lower', vmin=-0.5, vmax=0.5)
# fig.set_size_inches(13, 5)

c = 0
alphas = np.linspace(-0.5, 0.5, 10)
fig, ax = plt.subplots(1, len(alphas))
for i, alpha in enumerate(alphas):
    ax[i].imshow(x_restored[0, 0] + alpha * w_mat[0, c, 0], cmap='inferno', origin='lower', vmin=0, vmax=1.5)
fig.set_size_inches(30, 7)
fig.tight_layout()
fig.savefig('/home/alon_granek/PythonProjects/NPPC/1st pc attempt.png')

<IPython.core.display.Javascript object>

In [37]:
#todo
#   Make sure NPPCs are data-consistent.
#   Unrolled NPPC net?
#       DC over the whole set (not only \hat{x})

import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 2)
ax[0].imshow(x_distorted[0, 0])
ax[1].imshow(x_restored[0, 0])

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7fcf4419afb0>