In [4]:
import torch
import torch.nn as nn
import torch.nn.init as init

In [19]:
import cmath
import math

In [5]:
class DnCNN(nn.Module):
    def __init__(self, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        layers = []

        layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(depth-2):
            layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
            layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
        self._initialize_weights()

    def forward(self, x):
        y = x
        out = self.dncnn(x)
        return y-out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.orthogonal_(m.weight)
                print('init weight')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)

In [20]:
def createFourierMatrix(k,n):
    i = cmath.sqrt(-1)
    val = cmath.exp(-2*cmath.pi*i/n)
    p = (k-1)/2
    q = (k+1)/2
    F = torch.zeros(n*n,k*k)
    F = F.type(torch.complex64)

    f = torch.zeros(n,1);
    f = f.type(torch.complex64)
    f_u = torch.zeros(n*n,1);
    f_u = f_u.type(torch.complex64)
    for u in range(n):
        index = torch.arange(u*n,(u+1)*n)
        f_u[u*n:(u+1)*n]=val**u
        f[u]=val**u;

    f_v = f.repeat(n,1);
    for u in range(k):
        for v in range(k):
            a=0
            b=0
            if(u<=p):
                a = n-p+u;
            else:
                a = u-p;


            if(v<=p):
                b = n-p+v;
            else:
                b = v-p;

            F[:,(u*k+v)]=((torch.pow(f_u,(a)))*(torch.pow(f_v,(b)))).flatten();

    return F

In [22]:
def zeroPad2DMatrix(layer_wt,n):
    k = layer_wt.size()[3]
    p = (k-1)/2
    q = (k+1)/2
    I = torch.eye(n);
    ind1 = torch.arange(0,p)
    ind2 = torch.arange(p,k)
    ind3 = torch.arange(k,n)
    indices = torch.cat((ind2,ind3,ind1))
    indices=indices.type(torch.int64)
    perm = I[indices];
    perm_mat = perm.unsqueeze(0).unsqueeze(0)
    pad_left = 0
    pad_right = n - k
    pad_top = 0
    pad_bottom = n - k
    # Apply padding
    padded_wt = torch.nn.functional.pad(layer_wt, (pad_left, pad_right, pad_top, pad_bottom))
    perm_mat_tr = torch.transpose(perm_mat,2,3)
    padded_final = torch.matmul(perm_mat,torch.matmul(padded_wt,perm_mat_tr))
    return padded_final

In [23]:
def computeLayerLipschitzFourier(layer_wt,n):
    layer_wt_padded = zeroPad2DMatrix(layer_wt,n)
    layer_pf=torch.fft.fft2(layer_wt_padded)
    layer_fperm = torch.permute(layer_pf,(2,3,0,1))
    sing = torch.linalg.svdvals(layer_fperm)
    lip = torch.max(torch.abs(sing))
    return lip

In [6]:
path = "/home/kunallab/sayan/DnCNN/DnCNN/TrainingCodes/dncnn_pytorch/models/DnCNN_sigma25/model.pth"

In [7]:
model = torch.load(path)



In [8]:
net = DnCNN()

init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight


In [9]:
print(net)

DnCNN(
  (dncnn): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (3): BatchNorm2d(64, eps=0.0001, momentum=0.95, affine=True, track_running_stats=True)
    (4): ReLU(inplace=True)
    (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (6): BatchNorm2d(64, eps=0.0001, momentum=0.95, affine=True, track_running_stats=True)
    (7): ReLU(inplace=True)
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (9): BatchNorm2d(64, eps=0.0001, momentum=0.95, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (12): BatchNorm2d(64, eps=0.0001, momentum=0.95, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): Conv2d(64, 64, kernel_size=(3,

In [27]:
lip =1 
for layer in net.dncnn:
    if isinstance(layer,nn.Conv2d):
        #print(layer.in_channels)
        lip = lip*computeLayerLipschitzFourier(layer.weight,40)

print(lip)

tensor(51170.0820, grad_fn=<MulBackward0>)


In [11]:
net.dncnn[0]==nn.conv2D()

AttributeError: module 'torch.nn' has no attribute 'conv2D'