In [None]:
!pip install --quiet pytorch-lightning>=1.4

In [None]:
## Credits to https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial12/Autoregressive_Image_Modeling.html

import os
import math
import numpy as np
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

In [None]:
DATASET_PATH = '../data/'

pl.seed_everything(42)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda:0')

In [None]:
def discretize(sample):
    return (sample*255).to(torch.long)

transform = transforms.Compose([transforms.ToTensor(), discretize])

main_data = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)
train_dataset = [x for x in main_data if x[1] == 7]
pl.seed_everything(42)
train_set, val_set = torch.utils.data.random_split(train_dataset, [int(0.8*len(train_dataset)), len(train_dataset)-int(0.8*len(train_dataset))])

test_set = MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)
test_set = [x for x in test_set if x[1] == 7]

train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True, drop_last=True, pin_memory=True, num_workers=2)
val_loader = data.DataLoader(val_set, batch_size=128, shuffle=False, drop_last=False, num_workers=2)
test_loader = data.DataLoader(test_set, batch_size=128, shuffle=False, drop_last=False, num_workers=2)

In [None]:
class MaskedConvolution(nn.Module):
    def __init__(self, c_in, c_out, mask, **kwargs):
        super().__init__()
        kernel_size = (mask.shape[0], mask.shape[1])
        dilation = 1 if 'dilation' not in kwargs else kwargs['dilation']
        padding = tuple([dilation*(kernel_size[i]-1)//2 for i in range(2)])

        self.conv = nn.Conv2d(c_in, c_out, kernel_size, padding=padding, **kwargs)

        self.register_buffer('mask', mask[None, None])

    def forward(self, x):
        self.conv.weight.data *= self.mask
        return self.conv(x)


In [None]:
class VerticalStackConvolution(MaskedConvolution):
    def __init__(self, c_in, c_out, kernel_size=3, mask_center=False, **kwargs):
        mask = torch.ones(kernel_size, kernel_size)
        mask[kernel_size//2+1:, :] = 0

        if mask_center:
            mask[kernel_size//2:, :] = 0

        super().__init__(c_in, c_out, mask, **kwargs)

class HorizontalStackConvolution(MaskedConvolution):
    def __init__(self, c_in, c_out, kernel_size=3, mask_center=False, **kwargs):
        mask = torch.ones(1, kernel_size)
        mask[0, kernel_size//2+1:] = 0

        if mask_center:
            mask[0, kernel_size//2:] = 0

        super().__init__(c_in, c_out, mask, **kwargs)


In [None]:
class GatedMaskedConv(nn.Module):

    def __init__(self, c_in, **kwargs):
        """
        Gated Convolution block implemented the computation graph shown above.
        """
        super().__init__()
        self.conv_vert = VerticalStackConvolution(c_in, c_out=2*c_in, **kwargs)
        self.conv_horiz = HorizontalStackConvolution(c_in, c_out=2*c_in, **kwargs)
        self.conv_vert_to_horiz = nn.Conv2d(2*c_in, 2*c_in, kernel_size=1, padding=0)
        self.conv_horiz_1x1 = nn.Conv2d(c_in, c_in, kernel_size=1, padding=0)

    def forward(self, v_stack, h_stack):
        # Vertical stack (left)
        v_stack_feat = self.conv_vert(v_stack)
        v_val, v_gate = v_stack_feat.chunk(2, dim=1)
        v_stack_out = torch.tanh(v_val) * torch.sigmoid(v_gate)

        # Horizontal stack (right)
        h_stack_feat = self.conv_horiz(h_stack)
        h_stack_feat = h_stack_feat + self.conv_vert_to_horiz(v_stack_feat)
        h_val, h_gate = h_stack_feat.chunk(2, dim=1)
        h_stack_feat = torch.tanh(h_val) * torch.sigmoid(h_gate)
        h_stack_out = self.conv_horiz_1x1(h_stack_feat)
        h_stack_out = h_stack_out + h_stack

        return v_stack_out, h_stack_out

In [None]:
class PixelCNN(pl.LightningModule):

    def __init__(self, c_in, c_hidden):
        super().__init__()
        self.save_hyperparameters()
        self.pred_list = []
        self.prob_list = torch.zeros((28*28, 256))
        self.nll_list = []
        self.conv_vstack = VerticalStackConvolution(c_in, c_hidden, mask_center=True)
        self.conv_hstack = HorizontalStackConvolution(c_in, c_hidden, mask_center=True)
        self.conv_layers = nn.ModuleList([
            GatedMaskedConv(c_hidden),
            GatedMaskedConv(c_hidden, dilation=2),
            GatedMaskedConv(c_hidden),
            GatedMaskedConv(c_hidden, dilation=4),
            GatedMaskedConv(c_hidden),
            GatedMaskedConv(c_hidden, dilation=2),
            GatedMaskedConv(c_hidden)
        ])
        self.conv_out = nn.Conv2d(c_hidden, c_in * 256, kernel_size=1, padding=0)

        self.example_input_array = train_set[0][0][None]

    def forward(self, x):

        x = (x.float() / 255.0) * 2 - 1
        v_stack = self.conv_vstack(x)
        h_stack = self.conv_hstack(x)
        for layer in self.conv_layers:
            v_stack, h_stack = layer(v_stack, h_stack)
        out = self.conv_out(F.elu(h_stack))

        out = out.reshape(out.shape[0], 256, out.shape[1]//256, out.shape[2], out.shape[3])
        return out

    def calc_likelihood(self, x):
        pred = self.forward(x)
        nll = F.cross_entropy(pred, x, reduction='none')
        bpd = nll.mean(dim=[1,2,3]) * np.log2(np.exp(1))
        return bpd.mean()

    @torch.no_grad()
    def sample(self, img_shape, img=None):

        counter = 0
        if img is None:
            img = torch.zeros(img_shape, dtype=torch.long).to(device) - 1
        for h in tqdm(range(img_shape[2]), leave=False):
            for w in range(img_shape[3]):
                for c in range(img_shape[1]):
                    if (img[:,c,h,w] != -1).all().item():
                        continue

                    pred = self.forward(img[:,:,:h+1,:])
                    self.pred_list.append(pred)
                    probs = F.softmax(pred[:,:,c,h,w], dim=-1)
                    self.prob_list[counter] = probs
                    counter += 1
                    img[:,c,h,w] = torch.multinomial(probs, num_samples=1).squeeze(dim=-1)
        return self.pred_list, self.prob_list, img

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=False)
        return [optimizer], [{'scheduler': scheduler, 'monitor':'val_bpd'}]

    def training_step(self, batch, batch_idx):
        loss = self.calc_likelihood(batch[0])
        self.log('train_bpd', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.calc_likelihood(batch[0])
        self.log('val_bpd', loss)

    def test_step(self, batch, batch_idx):
        loss = self.calc_likelihood(batch[0])
        self.log('test_bpd', loss)

# Denoising Model

In [None]:
from tqdm import tqdm
import time
from torchvision.utils import save_image

class test_model_1():

    def __init__(self):
        torch.manual_seed(142)
        self.ARmodel = PixelCNN(c_in=1, c_hidden=64).to(device)
        self.ARmodel.load_state_dict(torch.load('saved_model.pt'))
        self.trainable_params = []
        self.automatic_optimization = False

    def forward(self, curr_est, x):
        log_loss = self.compute_likelihood(curr_est)
        MSEloss = nn.MSELoss()
        mse_loss = ((curr_est-x)**2).mean()
        return mse_loss, log_loss

    def compute_likelihood(self, x):
        torch.manual_seed(142)
        x = x.unsqueeze(0)
        logits = self.ARmodel(x).squeeze()
        logits = logits.permute((1,2,0))
        exp_logits = torch.exp(logits)
        probs = exp_logits/torch.sum(exp_logits, dim=2, keepdim=True)
        log_l = 0.0
        x = x.squeeze()
        log_l = -1.0*torch.log(self.interpolate(probs, x))
        return log_l

    def training_loop(self, x, sigma, max_iterations=10000, lr=0.01, sigma_w=30.0):

        '''
        The main training loop function. Choose whichever curr_est you wish to have
        and comment out the rest.
        '''
        curr_est = x.detach().clone().requires_grad_() # The noisy image as input
        #curr_est = torch.zeros_like(x).requires_grad_() # Blank image as input
        #curr_est = torch.randn(x.shape).requires_grad_() # Random noise as input
        log_loss_list = []
        total_loss_list = []
        mse_loss_list = []
        optimizer = torch.optim.Adam([curr_est], lr=lr)
        for i in range(max_iterations):

            if i%500 == 0 and i!=0:
                self.save_curr_est(i, curr_est)
            start = time.time()
            optimizer.zero_grad()
            mse_loss, log_loss = self.forward(curr_est, x)
            total_loss = mse_loss/(sigma_w*sigma*sigma) + log_loss      ### The main loss equation
            total_loss.backward()
            optimizer.step()
            end = time.time()
            print('[INFO] epoch ' + str(i) + ': MSEloss = ' + str(mse_loss.item()) + ' | Log loss = ' + str(log_loss.item()))
            print(end-start)
            log_loss_list.append(log_loss.item())
            mse_loss_list.append(mse_loss.item())
            total_loss_list.append(total_loss.item())
        return curr_est, log_loss_list, mse_loss_list, total_loss_list

    def print_trainable_params(self):
        print("Trainable Parameters:")
        for name, param in self.named_parameters():
            if param.requires_grad:
                print(name)

    def configure_optimizers(self):
        pass

    def save_curr_est(self, epoch, curr_est):
        print('Saving image at epoch: ' + str(epoch))
        save_image(curr_est, '/content/epoch'+str(epoch)+'.png')

    def interpolate(self, dist, x):
        x_ = x.reshape((28*28,))
        x_ = torch.clamp(x_, 0.0001, 255-0.0001)
        dist_ = dist.reshape((28*28, 256))
        ceil_q = torch.ceil(x_).detach().long()
        alpha = ceil_q - x_
        floor_q = torch.floor(x_).detach().long()


        q_floor = dist_[torch.arange(x_.shape[0]), floor_q]
        q_ceil = dist_[torch.arange(x_.shape[0]), ceil_q]
        q_y = alpha * q_floor + (1 - alpha) * q_ceil
        return q_y.mean()


In [None]:
denoising_model = test_model_1()
training_set = np.zeros((len(test_set), 28, 28))
for i in range(len(test_set)):
    training_set[i] = torch.squeeze(test_set[i][0])
pl.seed_everything(142)

training_set = torch.from_numpy(training_set).float()

i = 1 #1, 10, 100 the images chosen for testing.
sigma = 50.0   ### Standard deviation of noise added

x = training_set[i] + torch.randn((28,28))*sigma
x = np.clip(x, 0.0, 255.0)     ### Make sure that the pixel values are within bounds

plt.imshow(x.squeeze(), cmap='gray')
plt.axis('off')
plt.show()
plt.imshow(training_set[i].squeeze(), cmap='gray')
plt.axis('off')
plt.show()

denoising_model.compute_likelihood(training_set[i].unsqueeze(0)), denoising_model.compute_likelihood(x.unsqueeze(0))

In [None]:
denoising_model = test_model_1()
x = x.to(device)
output = denoising_model.training_loop(x.unsqueeze(0), sigma, 5, 0.1, sigma_w=30.0)

In [None]:
out = output[0].cpu().detach().numpy()
out[0] = np.clip(out[0], 0, 255)
plt.imshow(out[0], vmin=0, vmax=255, cmap='gray')
plt.show()
plt.imshow(training_set[i], vmin=0, vmax=255, cmap='gray')
plt.show()
plt.imshow(x, vmin=0, vmax=255, cmap='gray')
plt.show()

In [None]:
fig, ax = plt.subplots(1,3, figsize=(15, 5))

ax[0].plot(output[1])
ax[0].set_title('log loss')
ax[1].plot(output[2])
ax[1].set_title('MSE loss')
ax[2].plot(output[3])
ax[2].set_title('total loss')
plt.show()

In [None]:
mse_clean = ((training_set[i]-out[0])**2).mean()

In [None]:
### PSNR calculation

psnr = 10 * torch.log10((255*255) / mse_clean)
print(psnr.item())

In [None]:
### SSIM calculation

from skimage.metrics import structural_similarity as compare_ssim

A = training_set[i].cpu().numpy()
B = out[0]

(score, diff) = compare_ssim(A, B, full=True)
print(score)