# Training linear nets

In [7]:
from __future__ import print_function
import matplotlib.pyplot as plt
#%matplotlib notebook

import os

import warnings
warnings.filterwarnings('ignore')

from include import *
from PIL import Image
import PIL

import numpy as np
import torch
import torch.optim
from torch.autograd import Variable

GPU = True
if GPU == True:
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    dtype = torch.cuda.FloatTensor
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    print("num GPUs",torch.cuda.device_count())
else:
    dtype = torch.FloatTensor

num GPUs 0


In [1]:
def rep_error_deep_decoder(img_np,net,net_input,convert2ycbcr=False):
    '''
    mse obtained by representing img_np with the deep decoder
    '''
    output_depth = img_np.shape[0]
    if output_depth == 3 and convert2ycbcr:
        img = rgb2ycbcr(img_np)
    else:
        img = img_np
    img_var = np_to_var(img).type(dtype)
    
    rnd = 500
    numit = 2500
    rn = 0.005
    mse_n, mse_t, ni, net = fit( num_channels=num_channels,
                        reg_noise_std=rn,
                        net_input=net_input.type(dtype),        
                        reg_noise_decayevery = rnd,
                        num_iter=numit,
                        #LR=0.005,
                        LR=0.05,
                        img_noisy_var=img_var,
                        net=net,
                        img_clean_var=img_var,
                        find_best=False,
                        OPTIMIZER='SGD',
                               )
    out_img = net(ni.type(dtype)).data.cpu().numpy()[0]
    if output_depth == 3 and convert2ycbcr:
        out_img = ycbcr2rgb(out_img)
    return psnr(out_img,img_np), out_img, num_param(net)


def myimgshow(plt,img):
    if(img.shape[0] == 1):
        plt.imshow(np.clip(img[0],0,1),cmap='Greys',interpolation='none')
    else:
        plt.imshow(np.clip(img.transpose(1, 2, 0),0,1),interpolation='none')    
        
def comparison(img_np,net,net_input,convert2ycbcr=False):
    # compute representations
    psnrv, out_img_np, nparms = rep_error_deep_decoder(img_np,net=net,net_input=net_input,convert2ycbcr=convert2ycbcr)
    nchannels = img_np.shape[0]
    
    print("Compression factor: ", np.prod( img_np.shape ) / nparms )
    # plot results
    fig = plt.figure(figsize = (15,15)) # create a 5 x 5 figure 
    
    ax1 = fig.add_subplot(131)
    myimgshow(ax1,img_np) 
    ax1.set_title('Original image')
    ax1.axis('off')
    
    ax2 = fig.add_subplot(132)
    myimgshow(ax2,out_img_np)
    ax2.set_title( "Deep-Decoder representation, PSNR: %.2f" % psnrv )
    ax2.axis('off')
    #save_np_img(img_np,"exp_comp_orig.png")
    #save_np_img(out_img_np,"exp_comp_dd.png")
    
    plt.axis('off')
    fig.show()
    
def plot_kernels(tensor):
    if not len(tensor.shape)==4:
        raise Exception("assumes a 4D tensor")
    num_kernels = tensor.shape[0]
    fig = plt.figure(figsize=(tensor.shape[0],tensor.shape[1]))
    for i in range(tensor.shape[0]):
        for j in range(tensor.shape[1]):
            ax1 = fig.add_subplot(tensor.shape[0],tensor.shape[1],1+i*tensor.shape[0]+j)
            ax1.imshow(tensor[i][j])
            ax1.axis('off')
            ax1.set_xticklabels([])
            ax1.set_yticklabels([])

    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    plt.show()
    
def apply_until(net_input,net,n = 100):
    # applies function by funtion of a network
    for i,fun in enumerate(net):
        if i>=n:
            break
        if i==0:
            out = fun(net_input.type(dtype))
        else:
            out = fun(out)
    print(i, "last func. applied:", net[i-1])
    if n == 0:
        return net_input
    else:
        return out

def plot_tensor(out,nrows=8):
    imgs = [img for img in out.data.cpu().numpy()[0]]
    fig = plot_image_grid(imgs,nrows=nrows)
    plt.show()

## Individual batch norm implementation

In [3]:
def pure_batch_norm(X, gamma, beta, eps = 1e-5):
    if len(X.shape) not in [4]:
        raise ValueError('only supports 2dconv')
    # extract the dimensions
    N, C, H, W = X.shape
    # mini-batch mean
    mean = torch.mean(X, axis=(0, 2, 3))
    # mini-batch variance
    variance = torch.mean((X - mean.reshape((1, C, 1, 1))) ** 2, axis=(0, 2, 3))
    # normalize
    X_hat = (X - mean.reshape((1, C, 1, 1))) * 1.0 / torch.sqrt(variance.reshape((1, C, 1, 1)) + eps)
    # scale and shift
    return gamma.reshape((1, C, 1, 1)) * X_hat + beta.reshape((1, C, 1, 1))



class ChannelNormalization(torch.nn.Module):
    def __init__(self, D_in,mode="BN"):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        self.mode = mode
        requires_grad=True
        super(ChannelNormalization, self).__init__()
        #self.gamma = Variable(torch.ones(D_in),requires_grad=requires_grad).type(dtype)
        #self.beta = Variable(torch.zeros(D_in),requires_grad=requires_grad).type(dtype)
        self.gamma = nn.Parameter(torch.ones(D_in))
        self.beta = nn.Parameter(torch.zeros(D_in))

        
    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        
        #return (x - torch.mean(x))/(torch.std(x) + 0.00000001 ) *self.gamma[0] + self.beta[0]
        
        #return pure_batch_norm(x, self.gamma, self.beta, eps = 1e-5)
        
        xx = Variable(torch.zeros(x.shape)).type(dtype)
        for i,(g,b) in enumerate(zip(self.gamma,self.beta)):
            #print(i, x[0][i].shape)
            if self.mode == "BN":
                #xx[:,i] = (x[:,i] - torch.mean(x[:,i]))/(torch.std(x[:,i]) + 0.00000001 ) *g + b
                #print(torch.mean(x[0][i]))
                xx[0][i] = (x[0][i] - torch.mean(x[0][i]))/torch.sqrt( torch.var(x[0][i]) + 0.00001 ) * g + b
                #xx[0][i] = (x[0][i] - torch.mean(x[0]))/(torch.std(x[0]) + 0.00000001 ) *g + b
            elif self.mode == "mult":
                xx[:,i] = x[:,i]*g + b
            elif self.mode == "non-learned":
                xx[0][i] = (x[0][i] - torch.mean(x[0][i]))/torch.sqrt( torch.var(x[0][i]) + 0.00001 )
            elif self.mode == "center":
                xx[:,i] = (x[:,i] - torch.mean(x[:,i]))
            elif self.mode == "normalize+bias":
                xx[0][i] = (x[0][i])/torch.sqrt( torch.var(x[0][i]) + 0.00001 ) * g + b
            elif self.mode == "only+bias":
                xx[:,i] = x[:,i] + b
            elif self.mode == "onlycenter+bias":
                xx[:,i] = (x[:,i] - torch.mean(x[:,i]))*g + b
            elif self.mode == "almostcenter":
                xx[:,i] = ( (x[:,i] - torch.mean(x[:,i])) + torch.mean(x[:,i]) / torch.from_numpy( np.sqrt( np.array([np.prod(x[0][i].shape)]) ) ).float().type(dtype)) /torch.sqrt( torch.var(x[0][i]) + 0.00001 ) + b
            elif self.mode == "center+scale":
                #xx[:,i] =  (x[:,i] - torch.mean(x[:,i])) / (torch.max( x[:,i] - torch.mean(x[:,i]) )  - torch.min( x[:,i] - torch.mean(x[:,i]) ) )
                xx[:,i] =  (x[:,i] - torch.mean(x[:,i])) / (torch.max( torch.abs(x[:,i] - torch.mean(x[:,i]) ) ) + 0.00001)
            elif self.mode == "center+mean_scale":
                xx[:,i] =  (x[:,i] - torch.mean(x[:,i])) / torch.mean(torch.abs( x[:,i] - torch.mean(x[:,i]) ) )
            elif self.mode == "noise":
                #   noise = Variable(ins.data.new(ins.size()).normal_(mean, stddev))
                #no = torch.norm(x[0][i].data)
                #sigma = 0.1*no
                #noise = Variable( torch.randn( x[0][i].size() ) * sigma ).type(dtype)
                #noise = Variable( x[0][i].data.new(x[0][i].size()).normal_(0, sigma))
                xx[0][i] = x[0][i] / ( torch.norm(x[0][i])* np.prod(x[0][i].shape) + 0.00001  ) + b
                #xx[0][i] = x[0][i] / torch.norm( x[0][i] )  + noise + b
                #xx[0][i] = (x[0][i] - torch.mean(x[0][i])) / torch.sqrt( torch.var(x[0][i]) + 0.00001 ) + b
            #elif self.mode == "center_then_normalize": # stupid; is the same as BN
            #    center =  (x[0][i] - torch.mean(x[0][i]))
            #    xx[0][i] = center / torch.sqrt( torch.var(center) + 0.00001 ) + b
            elif self.mode == "scale":
                z =  x[0][i]
                xx[0][i] = z / ( torch.norm(z) + 0.00001 )*g + b
            else:
                raise ValueError('not an option for channel normalization.')
        return xx

In [4]:
#m = ChannelNormalization(4)
#m.parameters()
#print('m', list(m.parameters()))

def init_normal(m):
    if type(m) == nn.Conv2d:
        nn.init.xavier_uniform(m.weight)

# use the modules apply function to recursively apply the initialization
#rand_net.apply(init_normal)

## Decoder with channel normalization

In [5]:
def conv(in_f, out_f, kernel_size, stride=1, pad='zero',bias=False):
    padder = None
    to_pad = int((kernel_size - 1) / 2)
    if pad == 'reflection':
        padder = nn.ReflectionPad2d(to_pad)
        to_pad = 0
  
    convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias)

    layers = filter(lambda x: x is not None, [padder, convolver])
    return nn.Sequential(*layers)

def decnet(
        num_output_channels=3, 
        num_channels_up=[128]*5, 
        filter_size_up=1,
        act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 
        mode = "mult",
        ):
    
    num_channels_up = num_channels_up + [num_channels_up[-1]]
    n_scales = len(num_channels_up) 
    
    model = nn.Sequential()

    
    for i in range(len(num_channels_up)-1):
        model.add(conv( num_channels_up[i], num_channels_up[i+1],  filter_size_up, 1, pad='reflection'))        
        #model.add(act_fun)
        #model.add(nn.BatchNorm2d( num_channels_up[i+1] ,affine=True))
        model.add(ChannelNormalization(num_channels_up[i+1],mode=mode))
        #model.add(nn.BatchNorm2d( num_channels_up[i+1] ,affine=True,track_running_stats=False,momentum=0))                 
    
    model.add(conv( num_channels_up[-1], num_output_channels, 1, pad='reflection',bias=True))
    model.add(nn.Sigmoid())
    
    return model

In [6]:
path = './test_data/'
img_name = "phantom256"
img_path = path + img_name + ".png"
img_pil = Image.open(img_path)
img_np = pil_to_np(img_pil)
img_np = img_np / np.max(img_np)
img_var = np_to_var(img_np).type(dtype)
print(img_var.shape)

RuntimeError: cuda runtime error (38) : no CUDA-capable device is detected at /opt/conda/conda-bld/pytorch_1544199946412/work/aten/src/THC/THCGeneral.cpp:51

# Experiment: Effect of normalization on norm of gradients

The experiments shows that for this simple setup of only one convlayer and a linear problem, normalization efficiently avoids vanishing gradients.

In [None]:
num_channels = [1]*10
output_depth = 1
shape = [1,num_channels[0], img_np.shape[1], img_np.shape[2]]
print("shape: ", shape)
net_input = Variable(torch.zeros(shape))
net_input.data.uniform_()
net_input.data *= 1./10

output_depth = img_np.shape[0] # number of output channels
net = decnet(output_depth,num_channels,filter_size_up=51,mode="center+scale").type(dtype)

net.apply(init_normal)

mse_n, mse_t, ni, net, out_grads = fit( num_channels=num_channels,
                        reg_noise_std=0.0,
                        net_input=net_input.type(dtype),        
                        reg_noise_decayevery = 500,
                        num_iter=1000,
                        #LR=0.005,
                        LR=0.05,
                        img_noisy_var=img_var,
                        net=net,
                        img_clean_var=img_var,
                        find_best=False,
                        OPTIMIZER='SGD',
                        output_gradients=True,
                               )

for i,g in enumerate(out_grads):
    plt.semilogy(g,label=str(i)) 
plt.legend()
plt.show()

In [None]:
num_channels = [1]*10
output_depth = 1
shape = [1,num_channels[0], img_np.shape[1], img_np.shape[2]]
print("shape: ", shape)
net_input = Variable(torch.zeros(shape))
net_input.data.uniform_()
net_input.data *= 1./10

output_depth = img_np.shape[0] # number of output channels
net = decnet(output_depth,num_channels,filter_size_up=51,mode="center+mean_scale").type(dtype)

net.apply(init_normal)

mse_n, mse_t, ni, net, out_grads = fit( num_channels=num_channels,
                        reg_noise_std=0.0,
                        net_input=net_input.type(dtype),        
                        reg_noise_decayevery = 500,
                        num_iter=1000,
                        #LR=0.005,
                        LR=0.05,
                        img_noisy_var=img_var,
                        net=net,
                        img_clean_var=img_var,
                        find_best=False,
                        OPTIMIZER='SGD',
                        output_gradients=True,
                               )

for i,g in enumerate(out_grads):
    plt.semilogy(g,label=str(i)) 
plt.legend()
plt.show()

In [None]:
for m in net.modules():
    if isinstance(m, ChannelNormalization):
        print(m.beta.data,m.gamma.data)
#        print(torch.norm(m.weight.data).cpu())#p='fro')
        #print(m.weights.data)

In [None]:
print(net)

In [None]:
num_channels = [1]*10
output_depth = 1
shape = [1,num_channels[0], img_np.shape[1], img_np.shape[2]]
print("shape: ", shape)
net_input = Variable(torch.zeros(shape))
net_input.data.uniform_()
net_input.data *= 1./10

output_depth = img_np.shape[0] # number of output channels
net = decnet(output_depth,num_channels,filter_size_up=51,mode="BN").type(dtype)

net.apply(init_normal)

mse_n, mse_t, ni, net, out_grads = fit( num_channels=num_channels,
                        reg_noise_std=0.0,
                        net_input=net_input.type(dtype),        
                        reg_noise_decayevery = 500,
                        num_iter=1000,
                        #LR=0.005,
                        LR=0.05,
                        img_noisy_var=img_var,
                        net=net,
                        img_clean_var=img_var,
                        find_best=False,
                        OPTIMIZER='SGD',
                        output_gradients=True,
                               )

for i,g in enumerate(out_grads):
    plt.semilogy(g,label=str(i)) 
plt.legend()
plt.show()

In [None]:
num_channels = [1]*5
output_depth = 1
shape = [1,num_channels[0], img_np.shape[1], img_np.shape[2]]
print("shape: ", shape)
net_input = Variable(torch.zeros(shape))
net_input.data.uniform_()
net_input.data *= 1./10

output_depth = img_np.shape[0] # number of output channels
net = decnet(output_depth,num_channels,filter_size_up=51,mode="only+bias").type(dtype)
net.apply(init_normal)

mse_n, mse_t, ni, net, out_grads = fit( num_channels=num_channels,
                        reg_noise_std=0.0,
                        net_input=net_input.type(dtype),        
                        reg_noise_decayevery = 500,
                        num_iter=1000,
                        #LR=0.005,
                        LR=0.05,
                        img_noisy_var=img_var,
                        net=net,
                        img_clean_var=img_var,
                        find_best=False,
                        OPTIMIZER='SGD',
                        output_gradients=True,
                               )

In [None]:
for i,g in enumerate(out_grads):
    plt.semilogy(g,label=str(i)) 
plt.legend()
plt.show()

In [None]:
num_channels = [1]*10
output_depth = 1
shape = [1,num_channels[0], img_np.shape[1], img_np.shape[2]]
print("shape: ", shape)
net_input = Variable(torch.zeros(shape))
net_input.data.uniform_()
net_input.data *= 1./10

output_depth = img_np.shape[0] # number of output channels
net = decnet(output_depth,num_channels,filter_size_up=51,mode="BN").type(dtype)

net.apply(init_normal)

mse_n, mse_t, ni, net, out_grads = fit( num_channels=num_channels,
                        reg_noise_std=0.0,
                        net_input=net_input.type(dtype),        
                        reg_noise_decayevery = 500,
                        num_iter=1000,
                        #LR=0.005,
                        LR=0.05,
                        img_noisy_var=img_var,
                        net=net,
                        img_clean_var=img_var,
                        find_best=False,
                        OPTIMIZER='SGD',
                        output_gradients=True,
                               )

for i,g in enumerate(out_grads):
    plt.semilogy(g,label=str(i)) 
plt.legend()
plt.show()

In [None]:
for i,g in enumerate(out_grads):
    plt.semilogy(g,label=str(i)) 
plt.legend()
plt.show()

## Linear net, with batch norm

In [None]:
num_channels = [8]*15
output_depth = img_np.shape[0] # number of output channels
net = decodernw(output_depth,num_channels_up=num_channels,upsample_mode='none',act_fun=None,filter_size_up=9).type(dtype)

width = img_np.shape[1]
height = img_np.shape[2]
shape = [1,num_channels[0], width, height]
print("shape: ", shape)
net_input = Variable(torch.zeros(shape))
net_input.data.uniform_()
net_input.data *= 1./10
comparison(img_np,net,net_input)

## Linear net, without batch norm

In [None]:
num_channels = [8]*15
output_depth = img_np.shape[0] # number of output channels
net = decodernw(output_depth,num_channels_up=num_channels,upsample_mode='none',bn=False,act_fun=None,filter_size_up=9).type(dtype)

width = img_np.shape[1]
height = img_np.shape[2]
shape = [1,num_channels[0], width, height]
print("shape: ", shape)
net_input = Variable(torch.zeros(shape))
net_input.data.uniform_()
net_input.data *= 1./10
comparison(img_np,net,net_input)

## SGD

In [None]:
num_channels = [8]*15
output_depth = img_np.shape[0] # number of output channels
net = decodernw(output_depth,num_channels_up=num_channels,upsample_mode='none',bn=False,act_fun=None,filter_size_up=9).type(dtype)

width = img_np.shape[1]
height = img_np.shape[2]
shape = [1,num_channels[0], width, height]
print("shape: ", shape)
net_input = Variable(torch.zeros(shape))
net_input.data.uniform_()
net_input.data *= 1./10
comparison(img_np,net,net_input)

In [None]:
num_channels = [8]*15
shape = [1,num_channels[0], img_np.shape[1], img_np.shape[2]]
print("shape: ", shape)
net_input = Variable(torch.zeros(shape))
net_input.data.uniform_()
net_input.data *= 1./10

In [None]:
output_depth = img_np.shape[0] # number of output channels
net = decnet(output_depth,num_channels,filter_size_up=9,mode="BN").type(dtype)

mse_n, mse_t, ni, net = fit( num_channels=num_channels,
                        reg_noise_std=0.0,
                        net_input=net_input.type(dtype),        
                        reg_noise_decayevery = 500,
                        num_iter=100,
                        #LR=0.005,
                        LR=0.05,
                        img_noisy_var=img_var,
                        net=net,
                        img_clean_var=img_var,
                        find_best=False,
                        OPTIMIZER='SGD',
                        #output_gradient=True,
                               )

In [None]:
#print(list(net.parameters()))

In [None]:
#for m in net.modules():
#    if isinstance(m, ChannelNormalization):
#        print("parameters:")
#        print(m.gamma)
#        print(m.beta)
        #print(m.weights.data)

# Test with different BN variants

In [None]:
shape = [1,num_channels[0], img_np.shape[1], img_np.shape[2]]
print("shape: ", shape)
net_input = Variable(torch.zeros(shape))
net_input.data.uniform_()
net_input.data *= 1./10

num_channels = [8]*15
output_depth = img_np.shape[0] # number of output channels

## Regular BN

In [None]:
output_depth = img_np.shape[0] # number of output channels
net = decnet(output_depth,num_channels,filter_size_up=9,mode="BN").type(dtype)

comparison(img_np,net,net_input)

## Reparameterization

In [None]:
num_channels = [8]*15
output_depth = img_np.shape[0] # number of output channels
net = decnet(output_depth,num_channels,filter_size_up=9,mode="mult").type(dtype)

comparison(img_np,net,net_input)

## Center and only learn bias

In [None]:
net = decnet(output_depth,num_channels,filter_size_up=9,mode="center+bias").type(dtype)
comparison(img_np,net,net_input)

## Only center and bias

In [None]:
num_channels = [8]*15
net = decnet(output_depth,num_channels,filter_size_up=9,mode="onlycenter+bias").type(dtype)
comparison(img_np,net,net_input)

## Normalize and bias

In [None]:
net = decnet(output_depth,num_channels,filter_size_up=9,mode="normalize+bias").type(dtype)
comparison(img_np,net,net_input)

In [None]:
#for m in net.modules():
#    if isinstance(m, nn.Conv2d):
#        print(m.weight.data.shape)
#        print(torch.norm(m.weight.data).cpu())#p='fro')
        #print(m.weights.data)

In [None]:
net = decnet(output_depth,num_channels,filter_size_up=9,mode="only+bias").type(dtype)
comparison(img_np,net,net_input)

## Scale

In [None]:
net = decnet(output_depth,num_channels,filter_size_up=9,mode="scale").type(dtype)
comparison(img_np,net,net_input)

# Single channel

In [None]:
num_channels = [1]*10
shape = [1,num_channels[0], img_np.shape[1], img_np.shape[2]]
print("shape: ", shape)
net_input = Variable(torch.zeros(shape))
net_input.data.uniform_()
net_input.data *= 1./10
net = decnet(output_depth,num_channels,filter_size_up=81,mode="BN").type(dtype)
comparison(img_np,net,net_input)

In [None]:
num_channels = [1]*10
shape = [1,num_channels[0], img_np.shape[1], img_np.shape[2]]
print("shape: ", shape)
net_input = Variable(torch.zeros(shape))
net_input.data.uniform_()
net_input.data *= 1./10
net = decnet(output_depth,num_channels,filter_size_up=51,mode="only+bias").type(dtype)
comparison(img_np,net,net_input)

In [None]:
num_channels = [1]*1
shape = [1,num_channels[0], img_np.shape[1], img_np.shape[2]]
print("shape: ", shape)
net_input = Variable(torch.zeros(shape))
net_input.data.uniform_()
net_input.data *= 1./10
net = decnet(output_depth,num_channels,filter_size_up=51,mode="only+bias").type(dtype)
comparison(img_np,net,net_input)

In [None]:
for m in net.modules():
    if isinstance(m, nn.Conv2d):
        #print(m.weight.data.shape)
        print(torch.norm(m.weight.data).cpu())#p='fro')
        #print(m.weights.data)

In [None]:
output_depth = img_np.shape[0] # number of output channels
net = decnet(output_depth,num_channels,filter_size_up=51,mode="BN").type(dtype)

mse_n, mse_t, ni, net = fit( num_channels=num_channels,
                        reg_noise_std=0.0,
                        net_input=net_input.type(dtype),        
                        reg_noise_decayevery = 500,
                        num_iter=1000,
                        #LR=0.005,
                        LR=0.05,
                        img_noisy_var=img_var,
                        net=net,
                        img_clean_var=img_var,
                        find_best=False,
                        OPTIMIZER='SGD',
                        output_gradients=True,
                               )

print(net)
plt.semilogy(mse_n)
plt.show()

In [None]:
print(net)
for p in net.parameters():
    print(len(p.shape))

## Different initialization

In [None]:
print(net)

for m in net.modules():
    if isinstance(m, nn.Conv2d):
        print(m.weight.data)

## Norms of gradients

In [None]:
p = [x for x in net.parameters() ]
img_var = np_to_var(img_np).type(dtype)
mse = torch.nn.MSELoss()
out = net(net_input.type(dtype))
loss = mse(out, img_var)
loss.backward()
for p in list(filter(lambda p: p.grad is not None, net.parameters())):
    print(p.grad.data.norm(2).item())

In [None]:
print(net)

In [None]:
num_channels = [1]*2
shape = [1,num_channels[0], img_np.shape[1], img_np.shape[2]]
print("shape: ", shape)
net_input = Variable(torch.zeros(shape))
net_input.data.uniform_()
net_input.data *= 1./10
net = decnet(output_depth,num_channels,filter_size_up=51,mode="BN").type(dtype)
comparison(img_np,net,net_input)

In [None]:
num_channels = [1]*30
output_depth = img_np.shape[0] # number of output channels
net = decodernw(output_depth,num_channels_up=num_channels,upsample_mode='none',act_fun=None,filter_size_up=19).type(dtype)

width = img_np.shape[1]
height = img_np.shape[2]
shape = [1,num_channels[0], width, height]
print("shape: ", shape)
net_input = Variable(torch.zeros(shape))
net_input.data.uniform_()
net_input.data *= 1./10
#comparison(img_np,net,net_input)

mse_n, mse_t, ni, net = fit( num_channels=num_channels,
                        reg_noise_std=0.0,
                        net_input=net_input.type(dtype),        
                        reg_noise_decayevery = 500,
                        num_iter=20000,
                        #LR=0.005,
                        LR=0.05,
                        img_noisy_var=img_var,
                        net=net,
                        img_clean_var=img_var,
                        find_best=False,
                        OPTIMIZER='adam',
                               )

In [None]:
num_channels = [4]*1
net = decnet(3,num_channels).type(dtype)
print(net)


shape = [1,num_channels[0], img_np.shape[1], img_np.shape[2]]
print("shape: ", shape)
net_input = Variable(torch.zeros(shape))
net_input.data.uniform_()
net_input.data *= 1./10
net(net_input.type(dtype))

In [None]:
A = np.array( [[[[1,1],[1,1]], [[2,2],[2,2]] ] , [[[1,1],[1,1]], [[2,2],[2,2]] ] ] )
print(A.shape)
print(A)
c = np.array([1,2])
B = c * A 
print(B.shape)
print(A[:,:,:]*c)

print(A[0,0])

In [None]:
A = np.array( [ [[[1,1],[1,1]], [[2,2],[2,2]]] , [[[1,1],[1,1]], [[2,2],[2,2]]] ] )
print(A.shape,A[0,0])
c = np.array([1,2])

for i,(a,b) in enumerate(zip(c,c)):
    print("a: ", A[0,i])
    A[:,i] = A[:,i]*a + b
    
print(A)

In [None]:
class Model1(nn.Module):
    def __init__(self):
        super(Model1, self).__init__()
        self.multp = Variable(torch.rand(1), requires_grad=True)

class Model2(nn.Module):
    def __init__(self):
        super(Model2, self).__init__()
        self.multp = nn.Parameter(torch.rand(1)) # requires_grad is True by default for Parameter

m1 = Model1()
m2 = Model2()

print('m1', list(m1.parameters()))
print('m2', list(m2.parameters()))

In [None]:
def