In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import numpy as np
import matplotlib.pyplot as plt
%matplotlib qt
import sys
sys.path.append('D:\\Box\\dev\\deep_bayesian_recon')
import fastMRI.data.transforms as transforms
import sigpy as sp
import sigpy.plot
import sigpy.mri
from tqdm.notebook import tqdm
from generator_models import calculate_kl
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = "TRUE"

In [2]:
class ConvBlock(nn.Module):
    """
    A Convolutional Block that consists of two convolution layers each followed by
    instance normalization, LeakyReLU activation and dropout.
    """

    def __init__(self, in_chans, out_chans, drop_prob, norm_layer=nn.BatchNorm2d, norm_momentum=0.1,
                 activation=nn.ReLU):
        """
        Args:
            in_chans (int): Number of channels in the input.
            out_chans (int): Number of channels in the output.
            drop_prob (float): Dropout probability.
        """
        super().__init__()

        self.in_chans = in_chans
        self.out_chans = out_chans
        self.drop_prob = drop_prob

        self.layers = nn.Sequential(
            nn.Conv1d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
            norm_layer(out_chans, momentum=norm_momentum),
            nn.ReLU(True) if activation == nn.ReLU else activation(),
            nn.Dropout(drop_prob),
            nn.Conv1d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
            norm_layer(out_chans, momentum=norm_momentum),
            nn.ReLU(True) if activation == nn.ReLU else activation(),
            nn.Dropout(drop_prob)
        )

    def forward(self, input):
        """
        Args:
            input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]

        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """
        return self.layers(input)

    def __repr__(self):
        return f'ConvBlock(in_chans={self.in_chans}, out_chans={self.out_chans}, ' \
            f'drop_prob={self.drop_prob})'


class TransposeConvBlock(nn.Module):
    """
    A Transpose Convolutional Block that consists of one convolution transpose layers followed by
    instance normalization and LeakyReLU activation.
    """

    def __init__(self, in_chans, out_chans, drop_prob, norm_layer=nn.BatchNorm2d, norm_momentum=0.1,
                 activation=nn.ReLU):
        """
        Args:
            in_chans (int): Number of channels in the input.
            out_chans (int): Number of channels in the output.
        """
        super().__init__()

        self.in_chans = in_chans
        self.out_chans = out_chans

        self.layers = nn.Sequential(
            nn.ConvTranspose1d(in_chans, out_chans, kernel_size=2, stride=2, bias=False),
            norm_layer(out_chans, momentum=norm_momentum),
            nn.ReLU(True) if activation == nn.ReLU else activation(),
            nn.Dropout(drop_prob),
        )


    def forward(self, input):
        """
        Args:
            input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]

        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """
        return self.layers(input)

    def __repr__(self):
        return f'ConvBlock(in_chans={self.in_chans}, out_chans={self.out_chans})'


class UnetModel1d(nn.Module):
    """
    PyTorch implementation of a U-Net model. Adapted from fastMRI GitHub

    This is based on:
        Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-net: Convolutional networks
        for biomedical image segmentation. In International Conference on Medical image
        computing and computer-assisted intervention, pages 234–241. Springer, 2015.
    """

    def __init__(self, img_shape, latent_dim, in_chans, out_chans, chans, num_channels, num_pool_layers, drop_prob,
                 latent_shape=(0,),
                 num_samples=1,
                 norm_layer=nn.BatchNorm2d,
                 norm_momentum=0.1, max_likelihood=False,
                 dip_mc_mode=False,
                 dip_mc_dropout=False):
        """
        Args:
            in_chans (int): Number of channels in the input to the U-Net model.
            out_chans (int): Number of channels in the output to the U-Net model.
            chans (int): Number of output channels of the first convolution layer.
            num_pool_layers (int): Number of down-sampling and up-sampling layers.
            drop_prob (float): Dropout probability.
        """
        super().__init__()

        self.in_chans = in_chans
        self.out_chans = out_chans
        self.chans = chans
        self.num_pool_layers = num_pool_layers
        self.drop_prob = drop_prob
        self.latent_dim = latent_dim
        self.img_shape = img_shape
        self.latent_shape = img_shape if list(latent_shape) == [0] else latent_shape
        
        if dip_mc_dropout is False:
            z_grad = True if dip_mc_mode is False else False
            self.z_mean = nn.Parameter(torch.zeros([self.latent_dim, *self.latent_shape]), requires_grad=z_grad)
            self.z_var = nn.Parameter(0.1 * torch.ones([self.latent_dim, *self.latent_shape]), requires_grad=z_grad)
        else:
        # DIP with Dropout
            assert drop_prob > 0, f"For DIP with MC Dropout, drop_prob needs to be >0. Got {drop_prob}"
            self.z_mean = nn.Parameter(torch.normal(torch.zeros([self.latent_dim, *self.latent_shape]), 
                                                    1e-2 * torch.ones([self.latent_dim, *self.latent_shape])),
                                       requires_grad=False)
            self.z_var = nn.Parameter(1e-24 * torch.ones([self.latent_dim, *self.latent_shape]), requires_grad=False)


        # Upsample to image shape for U-net to process
#         self.latent_upconv = nn.ConvTranspose1d(self.latent_dim, self.in_chans, kernel_size=self.img_shape)

        self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob, norm_layer=norm_layer,
                                                           norm_momentum=norm_momentum)])
        ch = chans
        for i in range(num_pool_layers - 1):
            self.down_sample_layers += [ConvBlock(ch, ch * 2, drop_prob, norm_layer=norm_layer,
                                                  norm_momentum=norm_momentum)]
            ch *= 2
        self.conv = ConvBlock(ch, ch * 2, drop_prob, norm_layer=norm_layer,
                              norm_momentum=norm_momentum)

        coil_latent_dim = ch * 2
        ngf = chans

        self.up_conv = nn.ModuleList()
        self.up_transpose_conv = nn.ModuleList()
        for i in range(num_pool_layers - 1):
            self.up_transpose_conv += [TransposeConvBlock(ch * 2, ch, drop_prob, norm_layer=norm_layer,
                                                          norm_momentum=norm_momentum)]
            self.up_conv += [ConvBlock(ch * 2, ch, drop_prob, norm_layer=norm_layer,
                                       norm_momentum=norm_momentum)]
            ch //= 2

        self.up_transpose_conv += [TransposeConvBlock(ch * 2, ch, drop_prob, norm_layer=norm_layer,
                                                      norm_momentum=norm_momentum)]
        self.up_conv += [
            nn.Sequential(
                ConvBlock(ch * 2, ch, drop_prob=0., norm_layer=norm_layer,
                          norm_momentum=norm_momentum),
                nn.Conv1d(ch, self.out_chans, kernel_size=1, stride=1),
            )]

        self.num_channels = num_channels

        self.num_samples = num_samples
        self.rng = nn.Parameter(torch.zeros([self.num_samples, self.latent_dim, *self.latent_shape]), requires_grad=False)
        # self.rng = nn.Parameter(torch.zeros([self.num_samples, self.in_chans, 1, 1]))
        self.rng.requires_grad = False

        self.max_likelihood = max_likelihood

        if self.max_likelihood:
            self.z_var.requires_grad = False
            self.z_mean.requires_grad = False
            self.z_mean.random_()  # Initialize with a sample from a Normal distribution

    def forward(self):
        """
        Args:
            input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]

        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """

        if self.max_likelihood:
            z = self.z_mean.unsqueeze(0)
        else:
            z = self.z_mean + self.z_var ** 0.5 * self.rng.normal_()
        # input = F.interpolate(z, size=self.img_shape, mode='bilinear', align_corners=False)
#         input = self.latent_upconv(z)
        input = z

        stack = []
        output = input

        # Apply down-sampling layers
        for i, layer in enumerate(self.down_sample_layers):
            output = layer(output)
            stack.append(output)
            output = F.avg_pool1d(output, kernel_size=2, stride=2, padding=0)

        output = self.conv(output)

        # Apply up-sampling layers
        for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv):
            downsample_layer = stack.pop()
            output = transpose_conv(output)

            # Reflect pad on the right/botton if needed to handle odd input dimensions.
            padding = [0, 0]
            if output.shape[-1] != downsample_layer.shape[-1]:
                padding[1] = 1 # Padding right
            if sum(padding) != 0:
                output = F.pad(output, padding, "reflect")

            output = torch.cat([output, downsample_layer], dim=1)
            output = conv(output)

        return z, output

    def infer(self):
        """
        Args:
            input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]

        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """

        if self.max_likelihood:
            z = self.z_mean.unsqueeze(0)
        else:
            z = self.z_mean + self.z_var ** 0.5 * self.rng[0:1].normal_()
        # input = F.interpolate(z, size=self.img_shape, mode='bilinear', align_corners=False)
#         input = self.latent_upconv(z)
        input = z

        stack = []
        output = input

        # Apply down-sampling layers
        for i, layer in enumerate(self.down_sample_layers):
            output = layer(output)
            stack.append(output)
            output = F.avg_pool1d(output, kernel_size=2, stride=2, padding=0)

        output = self.conv(output)


        # Apply up-sampling layers
        for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv):
            downsample_layer = stack.pop()
            output = transpose_conv(output)

            # Reflect pad on the right/botton if needed to handle odd input dimensions.
            padding = [0, 0]
            if output.shape[-1] != downsample_layer.shape[-1]:
                padding[1] = 1  # Padding right
            if sum(padding) != 0:
                output = F.pad(output, padding, "reflect")

            output = torch.cat([output, downsample_layer], dim=1)
            output = conv(output)
        
        return z, output

In [3]:
class ForwardModelCoilEstimated(nn.Module):
    def __init__(self, noise_sigma, num_channels, img_shape, mask, maximum_likelihood=False, 
                 no_noise_f_reg=False,
                 device='cpu',
                 n_mps=1):
        super(ForwardModelCoilEstimated, self).__init__()

        self.num_channels = num_channels
        self.img_shape = img_shape
        self.device = device
        self.mask = mask.to(device)
        self.n_mps = n_mps
        self.rng = torch.zeros([self.n_mps,
                                self.num_channels,
                                self.img_shape[0],
                                2]).to(device)
        self.rng.requires_grad = False
        self.maximum_likelihood = maximum_likelihood
        self.no_noise_f_reg = no_noise_f_reg


    def forward(self, image, coil_est, noise_sigma):
        if self.maximum_likelihood or self.no_noise_f_reg:
            y = transforms.fft1(transforms.complex_mul(coil_est, image.unsqueeze(2)))
            y = torch.sum(y, dim=1)  # Reduce Soft SENSE dim
        else:
            y = transforms.fft1(transforms.complex_mul(coil_est, image.unsqueeze(2)))
            y = torch.sum(y, dim=1)  # Reduce Soft SENSE dim
            y = y + noise_sigma * self.rng.normal_()

        return y[self.mask.expand_as(y)].reshape(image.shape[0], self.num_channels, -1)
    

class ForwardModel(nn.Module):
    def __init__(self, img_shape, num_channels, device, maximum_likelihood=False, no_noise_f_reg=False):
        super(ForwardModel, self).__init__()
        self.img_shape = img_shape
        self.num_channels = num_channels
        self.rng = torch.zeros([1,
                                self.num_channels,
                                self.img_shape[0],
                                1]).to(device)
        self.rng.requires_grad = False
        self.maximum_likelihood = maximum_likelihood
        self.no_noise_f_reg = no_noise_f_reg
        
        
    def forward(self, data, noise_sigma):
        if self.maximum_likelihood or self.no_noise_f_reg:
            y = data + 0.0 * self.rng.normal_()
        else:
            y = data + noise_sigma * self.rng.normal_()
        return y.reshape(data.shape[0], self.num_channels, -1)

In [4]:
class OptimizationModel(nn.Module):
    def __init__(self, y_data, forward_model, generative_model, img_shape, latent_dim, noise_sigma, device,
                 num_samples=10, max_likelihood=False, coil_init=False, existing_mps=None, l2_reg=0,
                 n_mps=1,
                 noise_estimation=True):
        # Latent variables: A, P
        # Parameters: lamda (prior precision on A), P_prior_mean, P_prior_var
        # Variational parameters: P_params
        # Observed variables: y
        super(OptimizationModel, self).__init__()
        self.device = device
        self.img_shape = img_shape + (1,)
        self.log_noise_sigma = nn.Parameter(torch.log(torch.tensor(noise_sigma)), requires_grad=noise_estimation)  # Noise standard deviation
        self.y_data = y_data.to(self.device)  # Store current observed data

        # Priors for latent variable z
        self.z_prior_mean = torch.zeros([latent_dim])
        self.z_prior_var = torch.ones([latent_dim])

        # Specify generative network
        # self.g = NNGenerator(latent_dim=latent_dim, img_shape=img_shape)
        self.g = generative_model

        # Specify computational parameters
        self.latent_dim = latent_dim  # Number of latent variables
        self.num_samples = num_samples  # Number of samples to approximate expectations
        self.N = np.prod(self.img_shape)  # Number of voxels in image

        # Specify forward model
        self.f = forward_model.to(device)
        self.f.eval()
        self.f.device = device
        for param in self.f.parameters():
            param.to(device)
            param.requires_grad = False

        self.max_likelihood = max_likelihood
        self.coil_init = coil_init

        self.existing_mps = existing_mps

        self.l2_reg = l2_reg
        self.n_mps = n_mps

        self.sigma_eps = 1e-9

        self.scale_factor = nn.Parameter(torch.tensor(0.0))
        # self.scale_factor = nn.Parameter(torch.zeros([1, self.n_mps, self.g.num_channels // self.n_mps, 1, 1, 1]))

    def forward(self):
        # Get image estimate
        z, x = self.g()
        x = x.reshape(self.num_samples, 1, self.n_mps, self.img_shape[0])
        x = x.permute(0, 2, 3, 1)

        # Get noise sigma from log-noise-sigma
        self.noise_sigma = torch.exp(self.log_noise_sigma)

        x = x * torch.exp(self.scale_factor)
        y = self.f(x, self.noise_sigma)

        if self.max_likelihood:
            # Calculate loss
            mse = torch.sum((self.y_data.expand_as(y) - y)**2)
            loss = mse + self.l2_reg * (torch.sum(x ** 2))
            log_likelihood = torch.tensor([0.], requires_grad=False)
            kl = torch.tensor([0.], requires_grad=False)
        else:
            # Calculate loss

            mse = torch.sum(
                torch.mean((self.y_data.expand_as(y) - y) ** 2, dim=0))

            log_likelihood = -0.5 * torch.numel(y) * (
                torch.log(torch.tensor(2 * np.pi).to(self.device)) + 2 * (self.log_noise_sigma + self.sigma_eps)) \
                             - 0.5 * (torch.exp(-2 * (self.log_noise_sigma + self.sigma_eps))) * mse

            kl = calculate_kl(self.g.z_mean, torch.sqrt(torch.abs(self.g.z_var)), 0.0, 1.0)

            reg = self.l2_reg * (torch.sum(x**2))

            if hasattr(self.g, 'calculate_elbo_entropy'):
                kl += self.g.calculate_elbo_entropy()
                # qz_entropy += qw

            ELBO = log_likelihood - kl # log_z_prior - qz_entropy
            loss = -ELBO + reg

        return loss, log_likelihood, kl, mse

In [5]:
# Generate synthetic data
img_shape = (256,)
num_channels = 8
num_samples = 1
n_mps = 1

t = np.linspace(-1, 1, img_shape[0])
x_in = np.cos(2 * np.pi * 1.5 * t).astype(np.float32) * 0.5
x_in[(t > 0.5) & (t <= 1)] = 2*t[(t > 0.5) & (t <= 1)] - 1
x_in[(t < -0.5) & (t >= -1)] = 2*t[(t < -0.5) & (t >= -1)] + 1
x_in -= x_in.min()

plt.figure()
plt.plot(x_in)
plt.tight_layout()
plt.savefig(f'x_input.png')

In [7]:
x_tensor = transforms.to_tensor(x_in)
x_tensor = x_tensor.resize(1, 1, img_shape[0], 1)



In [11]:
# Set random seed
seed = 34433
np.random.seed(seed=seed)
torch.manual_seed(seed)

noise_sigma = 0.1
mask = torch.ones([img_shape[0], 2], dtype=torch.bool)
device = 'cpu'

f = ForwardModel(img_shape=img_shape, num_channels=num_channels, device=device)
y_tensor = f.forward(x_tensor, noise_sigma)

In [12]:
plt.figure()
plt.plot(y_tensor.squeeze().detach().cpu().numpy().T, '.')
plt.tight_layout()
plt.savefig(f'y_data_noise_{noise_sigma}.png')

In [13]:
# MLE estimator
x_inverse = torch.mean(y_tensor, dim=1)
plt.figure()
plt.plot(x_inverse.T)

[<matplotlib.lines.Line2D at 0x1bb24ee3c50>]

# Per iteration error calculation

In [29]:
chans, num_pool_layers = [32, 4]
chans = int(chans)
num_pool_layers = int(num_pool_layers)
latent_dim = 16
in_chans = latent_dim
out_chans = 1
drop_prob = 0.

infer_num_samples = 128

max_iter = 10000
calc_error_iter = 50

device = 'cuda'
l2_reg = 0


weight_decay = 1e-2
lr = 1e-4

seed = 34433

mean_scale = torch.mean(y_tensor)
y_tensor_scaled = y_tensor - mean_scale
std_scale = torch.std(y_tensor_scaled)
y_tensor_scaled = y_tensor_scaled / std_scale

noise_sigma_init = noise_sigma / std_scale

ref_data = x_tensor

modes = ['dip', 'dip_mc_inference', 'dip_mc_dropout', 'dip_mc_z_mc_dropout', 'dnlinv_no_dropout', 'dnlinv_dropout']
#modes = ['dip', 'dnlinv_no_dropout', 'dnlinv_dropout']
# modes = ['dip', 'dip_mc_z_mc_dropout']
mode_error = {}
x_est = {}
mode_losses = {}
x_mc_est = {}

for mode in modes:
    print(f"Running estimation with {mode}")
    if mode == 'dip':
        maximum_likelihood = True
        dip_mc_mode=False
        dip_mc_dropout=False
        noise_estimation=False
        no_noise_f_reg = True
        drop_prob = 0.
    elif mode == 'dip_mc_inference':
        maximum_likelihood = False
        dip_mc_mode=True
        dip_mc_dropout=False
        noise_estimation=False
        no_noise_f_reg = True
        drop_prob = 0.
    elif mode == 'dip_mc_dropout':
        maximum_likelihood = False
        dip_mc_mode=False
        dip_mc_dropout=True
        noise_estimation=False
        no_noise_f_reg = True
        drop_prob = 0.1
    elif mode == 'dip_mc_z_mc_dropout':
        maximum_likelihood = False
        dip_mc_mode=True
        dip_mc_dropout=True
        noise_estimation=False
        no_noise_f_reg = True
        drop_prob = 0.1
    elif mode == 'dnlinv_no_dropout':
        maximum_likelihood = False
        dip_mc_mode=False
        dip_mc_dropout=False
        noise_estimation=True
        no_noise_f_reg = False
        drop_prob = 0.0
    elif mode == 'dnlinv_dropout':
        maximum_likelihood = False
        dip_mc_mode=False
        dip_mc_dropout=False
        noise_estimation=True
        no_noise_f_reg = False
        drop_prob = 0.1

    num_samples = 16 if maximum_likelihood is False else 1
    
    np.random.seed(seed=seed)
    torch.manual_seed(seed)

    g = UnetModel1d(img_shape, latent_dim, in_chans, out_chans, chans, num_channels, num_pool_layers, drop_prob,
             num_samples=num_samples,
             norm_layer=nn.InstanceNorm1d,
             norm_momentum=0.1, max_likelihood=maximum_likelihood,
             dip_mc_mode=dip_mc_mode,
             dip_mc_dropout=dip_mc_dropout)



    f = ForwardModel(img_shape=img_shape, num_channels=num_channels, device=device, maximum_likelihood=maximum_likelihood,
                    no_noise_f_reg=no_noise_f_reg)

    my_model = OptimizationModel(y_data=y_tensor_scaled, forward_model=f, generative_model=g, img_shape=img_shape,
                                         latent_dim=latent_dim,
                                         noise_sigma=noise_sigma_init, device=device, num_samples=num_samples,
                                         max_likelihood=maximum_likelihood, existing_mps=None,
                                         l2_reg=l2_reg, n_mps=n_mps,
                                noise_estimation=noise_estimation)
    my_model = my_model.to(device)


    optim = torch.optim.AdamW([{'params': my_model.g.parameters(), 'weight_decay': weight_decay},
                                              #{'params': my_model.log_noise_sigma},
                                              {'params': my_model.scale_factor}],
                                             lr=lr)

    # Run optimization

    err = []
    losses = np.zeros(max_iter)
    with tqdm(total=max_iter, desc='Optimize q(z)') as pbar:
        for i in range(max_iter):
            loss, log_likelihood, kl, mse = my_model.forward()
            
            optim.zero_grad()
            loss.backward(retain_graph=False)
            optim.step()

            pbar.set_postfix(iteration=i, loss=loss.item(), neg_log_likelihood=-log_likelihood.item(),
                             kl=kl.item(), sse=mse.item())
            pbar.update()


            # Do inference for error calculation
            if (i % calc_error_iter) == 0:
                # Do inference
                with torch.no_grad():
                    if maximum_likelihood:
                        z, x = my_model.g()
                        x = x.reshape(num_samples, 1, n_mps, img_shape[0])
                        x = x.permute(0, 2, 3, 1).contiguous()
                    else:
                        # Generate estimate from Monte-Carlo
                        n_samples = infer_num_samples
                        x = []
                        for i in range(n_samples):
                            z, x_mc = my_model.g.infer()
                            x.append(x_mc.to('cpu'))

                        x = torch.cat(x, dim=0).reshape(n_samples, 1, n_mps, img_shape[0]).permute(0, 2, 3, 1).contiguous()
                    x *= torch.exp(my_model.scale_factor.cpu())
                    x *= std_scale
                    x += mean_scale
                    x_mean = torch.mean(x, dim=0)

                    error = x_mean - ref_data.resize(img_shape[0], 1).type_as(x)
                    err.append(error.detach().cpu().numpy().flatten())

            losses[i] = loss.item()
            # Get error per few iters
    err = np.stack(err, axis=0)

    mode_error[mode] = err
    x_est[mode] = x_mean
    mode_losses[mode] = losses
    x_mc_est[mode] = x


Running estimation with dip


  del sys.path[0]


Optimize q(z):   0%|          | 0/10000 [00:00<?, ?it/s]



Running estimation with dip_mc_inference


Optimize q(z):   0%|          | 0/10000 [00:00<?, ?it/s]

Running estimation with dip_mc_dropout


Optimize q(z):   0%|          | 0/10000 [00:00<?, ?it/s]

Running estimation with dip_mc_z_mc_dropout


Optimize q(z):   0%|          | 0/10000 [00:00<?, ?it/s]

Running estimation with dnlinv_no_dropout


Optimize q(z):   0%|          | 0/10000 [00:00<?, ?it/s]

Running estimation with dnlinv_dropout


Optimize q(z):   0%|          | 0/10000 [00:00<?, ?it/s]

In [30]:
reference = (x_tensor).resize(img_shape[0], 1).numpy().flatten()
x_gaussian_mean_est = (x_inverse.T).detach().squeeze().numpy().flatten()
mse = np.sqrt(np.mean((x_gaussian_mean_est - reference)**2))
psnr_mean_gaussian = 10 * np.log10(np.max(reference ** 2) / mse)
print(f'PSNR MLE: {psnr_mean_gaussian}')

PSNR MLE: 30.924830436706543


In [31]:
reference = (x_tensor).resize(img_shape[0], 1).numpy().flatten()

plt.figure()
for mode in modes:
    mse = np.sqrt(np.mean(mode_error[mode]**2, axis=(1)))
    psnr = 10 * np.log10(np.max(reference ** 2) / mse)
    nrmse = np.sqrt(mse) / np.sqrt(np.mean(reference**2))
    plt.plot(range(0, max_iter, calc_error_iter)[:len(psnr)], psnr)

plt.title('PSNR vs iterations')
plt.ylabel('PSNR (dB)')
plt.xlabel('iterations')
plt.axhline(psnr_mean_gaussian, linestyle='--', color='m', linewidth=1.0)
plt.legend(modes + ['MLE'])
plt.tight_layout()
plt.savefig(f'PSNR_iterations_{noise_sigma}.png')

In [32]:
ref = (x_tensor).resize(img_shape[0], 1).squeeze()
plt.figure(figsize=(16, 9))
for idx, mode in enumerate(modes):
    est = (x_est[mode]).detach().cpu().numpy().squeeze()
    plt.subplot(2,len(modes),idx+1)
    plt.plot(est, linewidth=2.0)
    plt.title(mode)
    plt.plot(ref, '--', linewidth=1.0)
    plt.subplot(2,len(modes),idx+1+len(modes))
    plt.plot(ref-est)
    plt.title(f"{mode} error")
# plt.legend(modes)
plt.tight_layout()
plt.savefig(f'reconstructions_{noise_sigma}.png')

