In [33]:
import numpy as np
import torch.nn as nn
import non_local

In [34]:
# initializing DnCNN batch normalization 
def init_BN(m, ksize=3, b_min=0.025):
    n = ksize**2 * m.num_features
    m.weight.data.normal_(0, np.sqrt(2. / (n)))
    m.weight.data[(m.weight.data > 0) & (m.weight.data <= b_min)] = b_min
    m.weight.data[(m.weight.data < 0) & (m.weight.data >= -b_min)] = -b_min
    m.weight.data = np.abs(m.weight.data)
    m.bias.data.zero_()
    m.momentum = 0.001

## CNN implementation
def cnn(cnn_opt):
    ksize = cnn_opt.get("kernel",3)
    padding = (ksize-1)//2
    cnn_bn = cnn_opt.get("bn",True)
    cnn_depth = cnn_opt.get("depth",0)
    cnn_channels = cnn_opt.get("features")
    cnn_outchannels = cnn_opt.get("nplanes_out",)
    chan_in = cnn_opt.get("nplanes_in")

    if cnn_depth == 0:
        cnn_outchannels=chan_in

    cnn_layers = []
    ReLU = nn.ReLU(inplace=True)

    for i in range(cnn_depth-1):
        cnn_layers.extend([
            nn.Conv2d(chan_in,cnn_channels, ksize, 1, padding, bias= (not cnn_bn)),
            nn.BatchNorm2d(cnn_channels) if cnn_bn else nn.Sequential(),
            ReLU
        ])
        chan_in = cnn_channels

    if cnn_depth > 0:
        cnn_layers.append(nn.Conv2d(chan_in,cnn_outchannels,ksize, 1, padding, bias=True))

    net = nn.Sequential(*cnn_layers)
    net.nplanes_out = cnn_outchannels
    net.nplanes_in = cnn_opt.get("nplanes_in")
    return net

In [35]:
# This is the N3Block which contains the embeddedding network and KNN selection rule:

#This takes a 2D image as an input
class N3Block(nn.Module):
    
    def __init__(self, nplanes_in, k, patchsize=10, stride=5,nl_match_window=15,temp_opt={}, embedcnn_opt={}):
        
        super().__init__()
        
        self.patchsize = patchsize
        self.stride = stride
        self.nplanes_in = nplanes_in
        self.nplanes_out = (k+1) * nplanes_in
        self.k = k
        self.reset_parameters()
        
        #Call embedding network using cnn as:
        embedcnn_opt["nplanes_in"] = nplanes_in
        self.embedcnn = cnn(embedcnn_opt)

        #Call Temperature tensor 
        with_temp = temp_opt.get("external_temp")
        if with_temp:
            tempcnn_opt = dict(**embedcnn_opt)
            tempcnn_opt["nplanes_out"] = 1
            self.tempcnn = cnn(tempcnn_opt)
        else:
            self.tempcnn = None

        #relaxed continuous KNN - non local processing
        indexer = lambda xe_patch,ye_patch: non_local.index_neighbours(xe_patch, ye_patch, nl_match_window, exclude_self=True)
        self.n3aggregation = non_local.N3Aggregation2D(indexing=indexer, k=k,patchsize=patchsize, stride=stride, temp_opt=temp_opt)
        

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, (nn.BatchNorm2d)):
                init_BN(m, kernelsize=3, b_min=0.025)
                
    def forward(self, x):
        
        #Store the input so that it can be sent to continuous nearest neighbours selection
        x_bck = x

        #Find ouput of embedding network 
        x_embedded = self.embedcnn(x)
        y_embedded = x_embedded

        if self.tempcnn is not None:
            log_temp = self.tempcnn(x)
        else:
            log_temp = None

        #Find final output using the continuous nearest neighbours selection
        y = self.n3aggregation(x_bck,x_embedded,y_embedded,log_temp=log_temp)
        return y

In [36]:
#Local Network - DnCNN
class DnCNN(nn.Module):

    def __init__(self,  nplanes_in, nplanes_out, features, kernel, depth, residual, bn):

        super().__init__()

        self.residual = residual
        self.nplanes_out = nplanes_out
        self.nplanes_in = nplanes_in
        self.kernelsize = kernel
        self.nplanes_residual = None

        ## building the hour-glass layered structure
        self.conv1 = nn.Conv2d(nplanes_in, features, kernel_size=kernel, stride=1, padding=kernel//2, bias=True)
        self.bn1 = nn.BatchNorm2d(features) if bn else nn.Sequential()
        self.relu = nn.ReLU(inplace=True)
        
        layers = []
        for i in range(depth-2):
            layers += [nn.Conv2d(features, features, kernel_size=kernel, stride=1, padding=kernel//2, bias=False),
                       nn.BatchNorm2d(features)  if bn else nn.Sequential(),
                       self.relu]
        self.layer1 = nn.Sequential(*layers)
        
        self.conv2 = nn.Conv2d(features, nplanes_out, kernel_size=kernel, stride=1, padding=kernel//2, bias=True)

        self.reset_parameters()

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, np.sqrt(2. / (n)))
            elif isinstance(m, nn.BatchNorm2d):
                init_BN(m, kernelsize=self.kernelsize, b_min=0.025)

    def forward(self, x):
        shortcut = x
        # applying the hour-glass layers on x in the following steps in order: conv1,relu,layers and conv2
        x = self.conv2(self.layer1(self.relu(self.conv1(x))))
        nplanes_residual = self.nplanes_residual or self.nplanes_in
        if self.residual:
            nshortcut = min(self.nplanes_in, self.nplanes_out, nplanes_residual)
            x[:,:nshortcut,:,:]= x[:,:nshortcut,:,:] + shortcut[:,:nshortcut,:,:]
        return x

In [37]:
#Final architecture, which combines Local(DnCNN) and Non-local(N3Block) processing to denoise image
class N3Net(nn.Module):
    def __init__(self, nplanes_in, nplanes_out, nplanes_interm, nblocks, block_opt, nl_opt, residual=False):
    
        super(N3Net, self).__init__()
        self.nplanes_in = nplanes_in
        self.nplanes_out = nplanes_out
        self.nblocks = nblocks
        self.residual = residual

        nin = nplanes_in
        cnns = []
        nls = []
        for i in range(nblocks-1):
            cnns.append(DnCNN(nin, nplanes_interm, **block_opt))
            nl = N3Block(nplanes_interm, **nl_opt)
            nin = nl.nplanes_out
            nls.append(nl)

        nout = nplanes_out
        cnns.append(DnCNN(nin, nout, **block_opt))

        self.nls = nn.Sequential(*nls)
        self.blocks = nn.Sequential(*cnns)

    def forward(self, x):
        shortcut = x
        for i in range(self.nblocks-1):
            x = self.blocks[i](x)
            x = self.nls[i](x)

        x = self.blocks[-1](x)
        
        if self.residual:
            nshortcut = min(self.nplanes_in, self.nplanes_out)
            x[:,:nshortcut,:,:] = x[:,:nshortcut,:,:] + shortcut[:,:nshortcut,:,:]

        return x