In [1]:
import random

from IPython.display import Markdown

import numpy as np
import tqdm.notebook as tqdm
import plotly.graph_objects as go
import plotly.express as px
import torch
import torch.nn as nn
import torchvision

In [2]:
device = 'cuda:0'

restoration_n_steps = 3000
nppc_n_steps = 3000
batch_size = 256

second_moment_loss_lambda = 1e0
second_moment_loss_grace = 500

mask = torch.zeros((1, 28, 28)).to(device)
mask[:, :20, :] = 1.

n_dirs = 5

random_seed = 42

In [3]:
random.seed(random_seed)
np.random.seed(random_seed)
torch.random.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)

torch.backends.cudnn.benchmark = True

## Axiliary functions

In [4]:
def sample_to_width(x, width=1580, padding_size=2):
    n_samples = min((width - padding_size) // (x.shape[-1] + padding_size), x.shape[0])
    indices = np.linspace(0, x.shape[0] - 1, n_samples).astype(int)
    return x[indices]

def imgs_to_grid(imgs, nrows=None, **make_grid_args):
    imgs = imgs.detach().cpu()
    if imgs.ndim == 5:
        nrow = imgs.shape[1]
        imgs = imgs.reshape(imgs.shape[0] * imgs.shape[1], imgs.shape[2], imgs.shape[3], imgs.shape[4])
    elif nrows is None:
        nrow = int(np.ceil(imgs.shape[0] ** 0.5))

    make_grid_args2 = dict(value_range=(0, 1), pad_value=1.)
    make_grid_args2.update(make_grid_args)
    img = torchvision.utils.make_grid(imgs, nrow=nrow, **make_grid_args2).clamp(0, 1)
    return img

def scale_img(x):
    return x / torch.abs(x).flatten(-3).max(-1)[0][..., None, None, None] / 1.5 + 0.5

def tensor_img_to_numpy(x):
    return x.detach().permute(-2, -1, -3).cpu().numpy()

def imshow(img, scale=1, **kwargs):
    if isinstance(img, torch.Tensor):
        img = tensor_img_to_numpy(img)
    img = img.clip(0, 1)

    fig = px.imshow(img, **kwargs).update_layout(
        height=img.shape[0] * scale,
        width=img.shape[1] * scale,
        margin=dict(t=0, b=0, l=0, r=0),
        xaxis_showticklabels=False,
        yaxis_showticklabels=False,
    )
    return fig


class LoopLoader():
    def __init__(self, dataloader, size):
        self.dataloader = dataloader
        self.size = size

    def __len__(self):
        return self.size

    def __iter__(self):
        i = 0
        while (i < self.size):
            for x in self.dataloader:
                if (i >= self.size):
                    break
                yield x
                i += 1

## Dataset

In [5]:
train_set = torchvision.datasets.MNIST(root='./', download=True, train=True, transform=torchvision.transforms.ToTensor())
test_set = torchvision.datasets.MNIST(root='./', train=False, transform=torchvision.transforms.ToTensor())

dataloader = torch.utils.data.DataLoader(
    test_set,
    batch_size=batch_size,
    shuffle=True,
)
test_batch = next(iter(dataloader))

## Network

In [6]:
class UNet(nn.Module):
    def __init__(
            self,
            in_channels=1,
            out_channels=1,
            channels_list=(32, 64, 128, 256),
            bottleneck_channels=512,
            min_channels_decoder=64,
            n_groups=8,
        ):

        super().__init__()
        ch = in_channels

        ## Encoder
        ## =======
        self.encoder_blocks = nn.ModuleList([])
        ch_hidden_list = []

        layers = []
        layers.append(nn.ZeroPad2d(2))
        ch_ = channels_list[0]
        layers.append(nn.Conv2d(ch, ch_, 3, padding=1))
        ch = ch_
        self.encoder_blocks.append(nn.Sequential(*layers))
        ch_hidden_list.append(ch)

        for i_level in range(len(channels_list)):
            ch_ = channels_list[i_level]
            downsample = i_level != 0

            layers = []
            if downsample:
                layers.append(nn.MaxPool2d(2))
            layers.append(nn.Conv2d(ch, ch_, 3, padding=1))
            ch = ch_
            layers.append(nn.GroupNorm(n_groups, ch))
            layers.append(nn.LeakyReLU(0.1))
            self.encoder_blocks.append(nn.Sequential(*layers))
            ch_hidden_list.append(ch)

        ## Bottleneck
        ## ==========
        ch_ = bottleneck_channels
        layers = []
        layers.append(nn.Conv2d(ch, ch_, 3, padding=1))
        ch = ch_
        layers.append(nn.GroupNorm(n_groups, ch))
        layers.append(nn.LeakyReLU(0.1))
        layers.append(nn.Conv2d(ch, ch, 3, padding=1))
        layers.append(nn.GroupNorm(n_groups, ch))
        layers.append(nn.LeakyReLU(0.1))
        self.bottleneck = nn.Sequential(*layers)

        ## Decoder
        ## =======
        self.decoder_blocks = nn.ModuleList([])
        for i_level in reversed(range(len(channels_list))):
            ch_ = max(channels_list[i_level], min_channels_decoder)
            downsample = i_level != 0
            ch = ch + ch_hidden_list.pop()
            layers = []

            layers.append(nn.Conv2d(ch, ch_, 3, padding=1))
            ch = ch_
            layers.append(nn.GroupNorm(n_groups, ch))
            layers.append(nn.LeakyReLU(0.1))
            if downsample:
                layers.append(nn.Upsample(scale_factor=2, mode='nearest'))
            self.decoder_blocks.append(nn.Sequential(*layers))

        ch = ch + ch_hidden_list.pop()
        ch_ = channels_list[0]
        layers = []
        layers.append(nn.Conv2d(ch, out_channels, 1))
        layers.append(nn.ZeroPad2d(-2))
        self.decoder_blocks.append(nn.Sequential(*layers))

    def forward(self, x):
        h = []
        for block in self.encoder_blocks:
            x = block(x)
            h.append(x)

        x = self.bottleneck(x)
        for block in self.decoder_blocks:
            x = torch.cat((x, h.pop()), dim=1)
            x = block(x)
        return x

def gram_schmidt(x):
    x_shape = x.shape
    x = x.flatten(2)

    x_orth = []
    proj_vec_list = []
    for i in range(x.shape[1]):
        w = x[:, i, :]
        for w2 in proj_vec_list:
            w = w - w2 * torch.sum(w * w2, dim=-1, keepdim=True)
        w_hat = w.detach() / w.detach().norm(dim=-1, keepdim=True)

        x_orth.append(w)
        proj_vec_list.append(w_hat)

    x_orth = torch.stack(x_orth, dim=1).view(*x_shape)
    return x_orth


class RestorationWrapper(nn.Module):
    def __init__(self, net, mask):
        super().__init__()

        self.net = net
        self.mask = mask

    def forward(self, x):
        x_in = x

        x = (x - 0.5) / 0.2
        x = self.net(x)
        x = (x * 0.2) + 0.5

        x = x_in + x * self.mask
        return x


class PCWrapper(nn.Module):
    def __init__(self, net, n_dirs, mask):
        super().__init__()

        self.net = net
        self.n_dirs = n_dirs
        self.mask = mask

    def forward(self, x_distorted, x_restored):
        x = torch.cat((x_distorted, x_restored), dim=1)

        x = (x - 0.5) / 0.2
        w_mat = self.net(x)
        w_mat = w_mat * 0.2

        w_mat = w_mat.unflatten(1, (self.n_dirs, w_mat.shape[1] // self.n_dirs))
        w_mat = w_mat.flatten(0, 1)
        w_mat = w_mat * self.mask
        w_mat = w_mat.unflatten(0, (w_mat.shape[0] // self.n_dirs, self.n_dirs))

        w_mat = gram_schmidt(w_mat)
        return w_mat

restoration_net = RestorationWrapper(UNet(), mask=mask)
restoration_net.to(device)
restoration_net.train()
restoration_optimizer = torch.optim.Adam(restoration_net.parameters(), lr=1e-4, betas=(0.9, 0.999))
restoration_step = 0

nppc_net = PCWrapper(UNet(in_channels=1 + 1, out_channels=1 * n_dirs), n_dirs=n_dirs, mask=mask)
nppc_net.to(device)
nppc_net.train()
nppc_optimizer = torch.optim.Adam(nppc_net.parameters(), lr=1e-4, betas=(0.9, 0.999))
nppc_step = 0

## Train restoration model

In [7]:
dataloader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
)

restoration_objective_log = []
for batch in tqdm.tqdm(LoopLoader(dataloader, restoration_n_steps)):
    x_org = batch[0].to(device)
    x_distorted = x_org * (1 - mask)

    x_restored = restoration_net(x_distorted)
    err = x_org - x_restored
    objective = err.pow(2).flatten(1).mean()

    restoration_optimizer.zero_grad()
    objective.backward()
    restoration_optimizer.step()
    restoration_step += 1

    if restoration_step % 100:
        restoration_objective_log.append(objective.detach().item())


  0%|          | 0/3000 [00:00<?, ?it/s]

In [8]:
go.Figure(data=[go.Scatter(mode='lines', x=np.arange(len(restoration_objective_log)) * 100, y=restoration_objective_log)],
    layout=go.Layout(yaxis_title='Objective', xaxis_title='step', height=400, width=550, margin=dict(t=0, b=20, l=20, r=0)),
).show()

x_org = test_batch[0].to(device)
x_distorted = x_org * (1 - mask)
with torch.no_grad():
    x_restored = restoration_net(x_distorted)
err = x_org - x_restored

display(Markdown('### Original image:'))
imshow(imgs_to_grid(sample_to_width(x_org, width=780)[None]), scale=2).show()

display(Markdown('### Distorted image:'))
imshow(imgs_to_grid(sample_to_width(x_distorted, width=780)[None]), scale=2).show()

display(Markdown('### Restored image:'))
imshow(imgs_to_grid(sample_to_width(x_restored, width=780)[None]), scale=2).show()

display(Markdown('### Error:'))
imshow(imgs_to_grid(sample_to_width(scale_img(err), width=780)[None]), scale=2).show()

### Original image:

### Distorted image:

### Restored image:

### Error:

## Train NPPC model

In [9]:
nppc_objective_log = []
for batch in tqdm.tqdm(LoopLoader(dataloader, nppc_n_steps)):
    x_org = batch[0].to(device)
    x_distorted = x_org * (1 - mask)
    with torch.no_grad():
        x_restored = restoration_net(x_distorted)

    w_mat = nppc_net(x_distorted, x_restored)

    w_mat_ = w_mat.flatten(2)
    w_norms = w_mat_.norm(dim=2)
    w_hat_mat = w_mat_ / w_norms[:, :, None]

    err = (x_org - x_restored).flatten(1)

    ## Normalizing by the error's norm
    ## -------------------------------
    err_norm = err.norm(dim=1)
    err = err / err_norm[:, None]
    w_norms = w_norms / err_norm[:, None]

    ## W hat loss
    ## ----------
    err_proj = torch.einsum('bki,bi->bk', w_hat_mat, err)
    reconst_err = 1 - err_proj.pow(2).sum(dim=1)

    ## W norms loss
    ## ------------
    second_moment_mse = (w_norms.pow(2) - err_proj.detach().pow(2)).pow(2)

    second_moment_loss_lambda = -1 + 2 * nppc_step / second_moment_loss_grace
    second_moment_loss_lambda = max(min(second_moment_loss_lambda, 1) ,1e-6)
    second_moment_loss_lambda *= second_moment_loss_lambda
    objective = reconst_err.mean() + second_moment_loss_lambda * second_moment_mse.mean()

    nppc_optimizer.zero_grad()
    objective.backward()
    nppc_optimizer.step()
    nppc_step += 1

    if nppc_step % 100:
        nppc_objective_log.append(objective.detach().item())

  0%|          | 0/3000 [00:00<?, ?it/s]

In [10]:
## Plots
go.Figure(data=[go.Scatter(mode='lines', x=np.arange(len(restoration_objective_log)) * 100, y=restoration_objective_log)],
    layout=go.Layout(yaxis_title='Objective', xaxis_title='step', height=400, width=550, margin=dict(t=0, b=20, l=20, r=0)),
).show()

x_org = test_batch[0].to(device)
x_distorted = x_org * (1 - mask)
with torch.no_grad():
    x_restored = restoration_net(x_distorted)
    w_mat = nppc_net(x_distorted, x_restored)
err = x_org - x_restored

display(Markdown('### Original image:'))
imshow(imgs_to_grid(sample_to_width(x_org, width=780)[None]), scale=2).show()

display(Markdown('### Distorted image:'))
imshow(imgs_to_grid(sample_to_width(x_distorted, width=780)[None]), scale=2).show()

display(Markdown('### Restored image:'))
imshow(imgs_to_grid(sample_to_width(x_restored, width=780)[None]), scale=2).show()

display(Markdown('### Error:'))
imshow(imgs_to_grid(sample_to_width(scale_img(err), width=780)[None]), scale=2).show()

display(Markdown('### PCs:'))
imshow(imgs_to_grid(sample_to_width(scale_img(w_mat), width=780).transpose(0, 1).contiguous()), scale=2).show()

### Original image:

### Distorted image:

### Restored image:

### Error:

### PCs:

## Results

In [14]:
samples_list = np.random.RandomState(1).randint(0, len(test_set), 10)
t_list = torch.linspace(-3, 3, 21).to(device)

for i in samples_list:
    x_org = test_set[i][0][None].to(device)
    x_distorted = x_org * (1 - mask)
    with torch.no_grad():
        x_restored = restoration_net(x_distorted)
        w_mat = nppc_net(x_distorted, x_restored)
    err = x_org - x_restored

    display(Markdown('## Sample \# ' + f'{i}'))
    display(Markdown('### Original, Distorted, Restored:'))
    imshow(imgs_to_grid(torch.stack((x_org, x_distorted, x_restored))[None, :, 0]), scale=2).show()

    display(Markdown('### Principal direstions:'))
    imgs = t_list[:, None, None, None, None] * w_mat + x_restored[None]
    imgs = torch.cat((scale_img(w_mat), imgs), dim=0)
    imgs = imgs.transpose(0, 1).contiguous()
    imshow(imgs_to_grid(imgs), scale=2).show()

## Sample \# 235

### Original, Distorted, Restored:

### Principal direstions:

## Sample \# 5192

### Original, Distorted, Restored:

### Principal direstions:

## Sample \# 905

### Original, Distorted, Restored:

### Principal direstions:

## Sample \# 7813

### Original, Distorted, Restored:

### Principal direstions:

## Sample \# 2895

### Original, Distorted, Restored:

### Principal direstions:

## Sample \# 5056

### Original, Distorted, Restored:

### Principal direstions:

## Sample \# 144

### Original, Distorted, Restored:

### Principal direstions:

## Sample \# 4225

### Original, Distorted, Restored:

### Principal direstions:

## Sample \# 7751

### Original, Distorted, Restored:

### Principal direstions:

## Sample \# 3462

### Original, Distorted, Restored:

### Principal direstions: