In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import gc

# from google.colab import drive
# from tqdm import test

# drive.mount('/content/drive')

In [None]:
class DownSampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels,  kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
        )

        self.maxpool = nn.MaxPool2d(2)


    def forward(self, x):
        features = self.double_conv(x)
        down_sampled = self.maxpool(features)
        return features, down_sampled

class UpSampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, add_channels):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            #nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            #nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(inplace=True)
        )
        self.feat = nn.Sequential(
            nn.Conv2d(add_channels + out_channels, out_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
        )

    def forward(self, x1, x2=None):

        x1 = self.up(x1)
        x1 = self.double_conv(x1)

        if x2 is not None:
          
            diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
            diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

            x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

            x = torch.cat([x2, x1], dim=1)
        else:
            x = x1

        feat = self.feat(x)
        return feat

class PooledSkip(nn.Module):
    def __init__(self, output_spatial_size):
        super().__init__()

        self.output_spatial_size = output_spatial_size

    def forward(self, x):
        global_avg_pooling = x.mean((2,3), keepdim=True) #self.gap(x)
        #print('gap shape:' , global_avg_pooling.shape)
        return global_avg_pooling.repeat(1, 1, self.output_spatial_size,self.output_spatial_size)


In [None]:
def epoch_train(loader, clf, criterion, opt):
    clf.train(True)

    loss = 0.0
    accuracy = 0.0
    n_observations = 0
    for images, labels in loader:
      
      inputs, labels = images.cuda(), labels.cuda()
      opt.zero_grad()

      outputs = clf(inputs)
      loss_i = criterion(outputs, labels)
      loss_i.backward()
      opt.step()

      loss += loss_i.item()
      accuracy += torch.sum(labels == outputs.argmax(dim=-1))
      n_observations += labels.shape[0]

    loss /= len(loader)
    accuracy /= n_observations

    return loss, accuracy

def epoch_test(loader, clf, criterion):
    clf.eval()

    loss = 0.0
    accuracy = 0.0
    n_observations = 0
    for images, labels in loader:
      
      inputs, labels = images.cuda(), labels.cuda()

      outputs = clf(inputs)
      loss_i = criterion(outputs, labels)

      loss += loss_i.item()
      accuracy += torch.sum(labels == outputs.argmax(dim=-1))
      n_observations += labels.shape[0]

    loss /= len(loader)
    accuracy /= n_observations

    return loss, accuracy

def train(train_loader, test_loader, clf, criterion, opt, n_epochs=50):
    for epoch in tqdm(range(n_epochs)):
        train_loss, train_acc = epoch_train(train_loader, clf, criterion, opt)
        test_loss, test_acc = epoch_test(test_loader, clf, criterion)

        print(f'[Epoch {epoch + 1}] train loss: {train_loss:.3f}; train acc: {train_acc:.2f}; ' + 
              f'test loss: {test_loss:.3f}; test acc: {test_acc:.2f}')

In [None]:
class DeblurNN(nn.Module):
    def __init__(self, K=9, blur_kernel_size=33, bilinear=False,
                 no_softmax=False):
        super(DeblurNN, self).__init__()

        self.no_softmax = no_softmax
        if no_softmax:
            print('Softmax is not being used')

        self.inc_rgb = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
        )

        self.inc_gray = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
        )
        self.blur_kernel_size = blur_kernel_size
        self.K=K

        self.down1 = DownSampleBlock(64, 64)
        self.down2 = DownSampleBlock(64, 128)
        self.down3 = DownSampleBlock(128, 256)
        self.down4 = DownSampleBlock(256, 512)
        self.down5 = DownSampleBlock(512, 1024)
        self.feat = nn.Sequential(
            nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
        )
        self.up1 = UpSampleBlock(1024,1024, 512, bilinear)
        self.up2 = UpSampleBlock(512,512, 256, bilinear)
        self.up3 = UpSampleBlock(256,256, 128, bilinear)
        self.up4 = UpSampleBlock(128,128, 64, bilinear)
        self.up5 = UpSampleBlock(64,64, 64, bilinear)

        self.masks_end = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(64, K, kernel_size=3, padding=1),
            nn.Softmax(dim=1),
        )

        self.feat5_gap = PooledSkip(2)
        self.feat4_gap = PooledSkip(4)  
        self.feat3_gap = PooledSkip(8)  
        self.feat2_gap = PooledSkip(16)  
        self.feat1_gap = PooledSkip(32) 

        self.kernel_up1 = UpSampleBlock(1024,1024, 512, bilinear)
        self.kernel_up2 = UpSampleBlock(512,512, 256, bilinear)
        self.kernel_up3 = UpSampleBlock(256,256, 256, bilinear)
        self.kernel_up4 = UpSampleBlock(256,128, 128, bilinear)
        self.kernel_up5 = UpSampleBlock(128,64, 64, bilinear)
        if self.blur_kernel_size>33:
            self.kernel_up6 = UpSampleBlock(64, 0, 64, bilinear)

        self.kernels_end = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=2, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(64, K, kernel_size=3, padding=1)
            #nn.Conv2d(128, K*self.blur_kernel_size*self.blur_kernel_size, kernel_size=8),
        )
        self.kernel_softmax = nn.Softmax(dim=2)

    def forward(self, x):
        #Encoder
        if x.shape[1]==3:
            x1 = self.inc_rgb(x)
        else:
            x1 = self.inc_gray(x)
        x1_feat, x2 = self.down1(x1)
        x2_feat, x3 = self.down2(x2)
        x3_feat, x4 = self.down3(x3)
        x4_feat, x5 = self.down4(x4)
        x5_feat, x6 = self.down5(x5)
        x6_feat = self.feat(x6)

        #k = self.kernel_network(x3)
        feat6_gap = x6_feat.mean((2,3), keepdim=True) #self.feat6_gap(x6_feat)
        #print('x6_feat: ', x6_feat.shape,'feat6_gap: ' , feat6_gap.shape)
        feat5_gap = self.feat5_gap(x5_feat)
        #print('x5_feat: ', x5_feat.shape,'feat5_gap: ' , feat5_gap.shape)
        feat4_gap = self.feat4_gap(x4_feat)
        #print('x4_feat: ', x4_feat.shape,'feat4_gap: ' , feat4_gap.shape)
        feat3_gap = self.feat3_gap(x3_feat)
        #print('x3_feat: ', x3_feat.shape,'feat3_gap: ' , feat3_gap.shape)
        feat2_gap = self.feat2_gap(x2_feat)
        #print('x2_feat: ', x2_feat.shape,'feat2_gap: ' , feat2_gap.shape)
        feat1_gap = self.feat1_gap(x1_feat)
        #print(feat5_gap.shape, feat4_gap.shape)
        k1 = self.kernel_up1(feat6_gap, feat5_gap)
        #print('k1 shape', k1.shape)
        k2 = self.kernel_up2(k1, feat4_gap)
        #print('k2 shape', k2.shape)
        k3 = self.kernel_up3(k2, feat3_gap)
        #print('k3 shape', k3.shape)
        k4 = self.kernel_up4(k3, feat2_gap)
        #print('k4 shape', k4.shape)
        k5 = self.kernel_up5(k4, feat1_gap)

        if self.blur_kernel_size==65:
            k6 = self.kernel_up6(k5)
            k = self.kernels_end(k6)
        else:
            k = self.kernels_end(k5)
        N, F, H, W = k.shape  # H and W should be one
        k = k.view(N, self.K, self.blur_kernel_size * self.blur_kernel_size)

        if self.no_softmax:
            k = F.leaky_relu(k)
            #suma = k5.sum(2, keepdim=True)
            #k = k5 / suma
        else:
            k = self.kernel_softmax(k)

        k = k.view(N, self.K, self.blur_kernel_size, self.blur_kernel_size)

        #Decoder
        x7 = self.up1(x6_feat, x5_feat)
        x8 = self.up2(x7, x4_feat)
        x9 = self.up3(x8, x3_feat)
        x10 = self.up4(x9, x2_feat)
        x11 = self.up5(x10, x1_feat)
        logits = self.masks_end(x11)

        return  k, logits

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def loss(kernel, coeff, k_reference, img_input, img_ref, norma):
    #Constants
    n_i = (0.1**0.5)*torch.rand(1).to(device) #Noise
    gamma = 2.2
    a = 50
    K = 33
    K_half = K // 2
    
    #Variables preparations

    v_i_gt = img_input #Initially blured image
    v_i = torch.zeros((1, 3, img_ref.shape[2], img_ref.shape[3])).to(device) #Reblured image

    #Sharp image
    u_i = torch.zeros((img_ref.shape[0], img_ref.shape[1], 
                       img_ref.shape[2] + K - 1, img_ref.shape[3] + K - 1)).to(device)
    u_i[:, :, K_half : u_i.shape[2] - K_half, K_half : u_i.shape[3] - K_half] = img_ref[:, :, :, :].to(device)
    
    #Losses
    Loss_reblur = torch.zeros((1, 3, img_ref.shape[2], img_ref.shape[3])).to(device)
    Loss_kernel = torch.zeros((1, 3, img_ref.shape[2], img_ref.shape[3])).to(device)

    for i in range(K_half, img_ref.shape[2] - K_half):
        for j in range(K_half, img_ref.shape[3] - K_half):
            print(i, j)
            torch.cuda.empty_cache()
            gc.collect()
            i_v = i - K_half
            j_v = j - K_half
            
            weight = 1 #1 pixel per kernel
            #weight = 1 / number_of_pixels_per_kernel

            #Calculate Reblur Kernel for pixel ij
            Ker = torch.sum(kernel * coeff[:, :, i_v, j_v].unsqueeze(-1).unsqueeze(-1), dim = 1)
            Kerij = Ker.repeat(3, 3, 1, 1)
            # print(Ker.shape)

            #Calculate Reblur Loss
            R = v_i_gt[:, :, i_v, j_v] - 1 / a * torch.log(1 + torch.exp(a * (v_i_gt[:, :, i_v, j_v] - 1)))
            # print(R.shape)
            u_nn = u_i[:, :, i - K_half : i + K_half + 1, j - K_half : j + K_half + 1]
            v_i[:, :, i_v, j_v] = R * (F.conv2d(u_nn, Kerij)[0, :, 0, 0] + n_i) ** (1 / gamma)
            # print(v_i[:, :, i_v, j_v])

            Loss_reblur[:, :, i_v, j_v] = weight * (v_i[:, :, i_v, j_v] - v_i_gt[:, :, i_v, j_v]) ** 2
            print(Loss_reblur[:, :, i_v, j_v])
            ##Calculate Kernel Loss
            Ker = Ker.unsqueeze(0)

            Loss_kernel[:, :, i_v, j_v] = weight * torch.linalg.matrix_norm(Ker - k_reference, ord = norma)
            print(Loss_kernel[:, :, i_v, j_v])

    return torch.sum(Loss_reblur), torch.sum(Loss_kernel)


In [2]:
import torch
a = torch.rand(1, 1, 2, 3)
print(a[:, :, 1, 1])

tensor([[0.1442]])
