### Imports...

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import random

%env CUDA_VISIBLE_DEVICES=0
import torch
import torchvision as tv
from torchvision import datasets, transforms

from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

#import torch_dip_utils as utils
import utils
import math

env: CUDA_VISIBLE_DEVICES=0


### Set up Hyperparameters, network filter and I/O sizes, and waveform parameters

In [2]:
#set up hyperparameters, net input/output sizes, and whether the problem is compressed sensing

LR = 1e-3 # learning rate
MOM = 0.9 # momentum
NUM_ITER = 100 # number iterations
WD = 1e-4 # weight decay for l2-regularization

Z_NUM = 32 # input seed
NGF = 64 # number of filters per layer
ALEX_BATCH_SIZE = 1 # batch size of gradient step
nc = 1 #num channels in the net I/0

#choose the number of samples and periods in the training waveform
WAVE_SIZE = 1024
WAVE_PERIODS = 2

### Choose whether the problem is compressed sensing or DIP

In [3]:
compressed = False

if compressed:
    num_measurements = 64
else:
    num_measurements = WAVE_SIZE

### Use CUDA if Possible

In [4]:
CUDA = torch.cuda.is_available()
print(CUDA)

#save the correct datatype depending on CPU or GPU execution
if CUDA : 
    dtype = torch.cuda.FloatTensor  
    print(torch.cuda.device(0))
else:
    dtype = torch.FloatTensor
    print("NO DEVICES")

True
<torch.cuda.device object at 0x7faef80c9da0>


### Create and plot the training and reference waveforms

In [5]:
#Produces a sinusoid with optional additive gaussian noise distributed (mean, std)
def get_sinusoid(num_samples, num_periods, noisy=True, std=0.1, mean=0):
    
    Fs = num_samples
    x = np.arange(num_samples)
    
    y = np.sin(2*np.pi * num_periods * x / Fs)
    
    if noisy:
        y += (std * np.random.randn(num_samples)) + mean
    
    return y

### Util function for normalizing noisy wave range to [-1,1] and renormalizing back to native range

In [6]:
def get_stats(x):
    a = np.min(x)
    b = np.max(x)
    mu = (a+b)/2.0
    sigma = (b-a)/2.0
    return [mu, sigma]

def normalise(x, mu, sigma):
    return (x-mu)/sigma

def renormalise(x, mu, sigma):
    return x*sigma + mu

### Prepare waveform for net training

In [7]:
#get the proper MSE loss based on the datatype
mse = torch.nn.MSELoss().type(dtype)

### Define the network architecture

In [8]:
class DCGAN(nn.Module):
    def __init__(self, nz, ngf=64, output_size=1024, nc=1, num_measurements=64):
        super(DCGAN, self).__init__()
        self.nc = nc
        self.output_size = output_size

        #Deconv Layers: (in_channels, out_channels, kernel_size, stride, padding, bias = false)
        #Inputs: R^(N x Cin x Lin), Outputs: R^(N, Cout, Lout) s.t. Lout = (Lin - 1)*stride - 2*padding + kernel_size
        
        self.conv1 = nn.ConvTranspose1d(nz, ngf, 16, 1, 0, bias=False)
        self.bn1 = nn.BatchNorm1d(ngf)
        #LAYER 1: input: (random) zϵR^(nzx1), output: x1ϵR^(2048x16) (channels x length) 
        
        self.conv2 = nn.ConvTranspose1d(ngf, ngf, 4, 2, 1, bias=False)
        self.bn2 = nn.BatchNorm1d(ngf)
        #LAYER 2: input: x1ϵR^(2048x16), output: x2ϵR^(1024x32) (channels x length) 
        
        self.conv3 = nn.ConvTranspose1d(ngf, ngf, 4, 2, 1, bias=False)
        self.bn3 = nn.BatchNorm1d(ngf)
        #LAYER 3: input: x2ϵR^(1024x32), output: x3ϵR^(512x64) (channels x length) 
        
        self.conv4 = nn.ConvTranspose1d(ngf, ngf, 4, 2, 1, bias=False)
        self.bn4 = nn.BatchNorm1d(ngf)
        #LAYER 4: input: x3ϵR^(512x64), output: x4ϵR^(256x128) (channels x length) 
        
        self.conv5 = nn.ConvTranspose1d(ngf, ngf, 4, 2, 1, bias=False)
        self.bn5 = nn.BatchNorm1d(ngf)
        #LAYER 5: input: x4ϵR^(256x128), output: x5ϵR^(128x256) (channels x length) 
        
        self.conv6 = nn.ConvTranspose1d(ngf, ngf, 4, 2, 1, bias=False)
        self.bn6 = nn.BatchNorm1d(ngf)
        #LAYER 6: input: x5ϵR^(128x256), output: x6ϵR^(64x512) (channels x length) 
        
        self.conv7 = nn.ConvTranspose1d(ngf, nc, 4, 2, 1, bias=False) #output is image
        #LAYER 7: input: x6ϵR^(64x512), output: (sinusoid) G(z,w)ϵR^(1x1024) (channels x length) 
        
        self.fc = nn.Linear(output_size*nc,num_measurements, bias=False) #output is A; measurement matrix
        # each entry should be drawn from a Gaussian (random noisy measurements)
        # don't compute gradient of self.fc! memory issues
   
    def forward(self, x):
        input_size = x.size()
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        x = F.relu(self.bn6(self.conv6(x)))
        x = F.tanh(self.conv7(x))
       
        return x
   
    def measurements(self, x):
        # this gives the image - make it a single row vector of appropriate length
        y = self.forward(x).view(1,-1)
        y = y.cpu()
        
        #pass thru FC layer - returns A*image
        meas = self.fc(y)
        
        if CUDA:
            return meas.cuda()
        else:
            return meas

### Train the network while tracking loss vs. reference wave

In [9]:
num_samples = 30
period_list = [0.5, 1, 2, 4, 8, 16, 32, 64]

mse_log = (1e6)*np.ones((len(period_list), num_samples))
iter_log = (1e6)*np.ones((len(period_list), num_samples))

for pd in range(len(period_list)):
    for j in range(num_samples):
        
        print(pd, " ", j)
        
        # get a DCGAN that outputs images of size WAVE_SIZE
        net = DCGAN(Z_NUM,NGF,WAVE_SIZE,nc,num_measurements) # initialize network
        net.fc.requires_grad = False

        if CUDA: # move network to GPU if available
            net.cuda()

        if compressed:
            net.fc.weight.data = (1/math.sqrt(1.0*num_measurements)) * torch.randn(num_measurements, WAVE_SIZE*nc)
        else:
            net.fc.weight.data = torch.eye(num_measurements)

        allparams = [x for x in net.parameters()] #specifies which to compute gradients of
        allparams = allparams[:-1] # get rid of last item in list (fc layer) because it's memory intensive

        z = Variable(torch.zeros(ALEX_BATCH_SIZE*Z_NUM).type(dtype).view(ALEX_BATCH_SIZE,Z_NUM,1))
        z.data.normal_().type(dtype)

        # Define optimizer
        optim = torch.optim.RMSprop(allparams,lr=LR,momentum=MOM, weight_decay=WD)
        
        y0 = get_sinusoid(num_samples = WAVE_SIZE, num_periods = period_list[pd], noisy=True)
        y0_denoised = get_sinusoid(num_samples = WAVE_SIZE, num_periods = period_list[pd], noisy=False)
        
        MU = get_stats(y0)[0]
        SIGMA = get_stats(y0)[1]

        y = torch.Tensor(y0)
        y = normalise(y, MU, SIGMA)
        y = Variable(y.type(dtype))
        
        measurements = Variable(torch.mm(y.cpu().data.view(ALEX_BATCH_SIZE,-1),net.fc.weight.data.permute(1,0)),requires_grad=False) 

        if CUDA: # move measurements to GPU if possible
            measurements = measurements.cuda()
        
        for i in range(NUM_ITER):
            optim.zero_grad() # clears graidents of all optimized variables
            out = net(z) # produces wave (in form of data tensor) i.e. G(z,w)
    
            loss = mse(net.measurements(z),measurements) # calculate loss between AG(z,w) and Ay
         
            # DCGAN output is in [-1,1]. Renormalise to [0,1] before plotting
            wave = renormalise(out, MU, SIGMA).data[0].cpu().numpy()[0,:] 

            cur_mse = np.mean((y0_denoised - wave)**2)
            
            if (cur_mse <= mse_log[pd][j]):
                mse_log[pd][j] = cur_mse
                iter_log[pd][j] = i
    
            loss.backward()
            optim.step()

0   0




0   1
0   2
0   3
0   4
0   5
0   6
0   7
0   8
0   9
0   10
0   11
0   12
0   13
0   14
0   15
0   16
0   17
0   18
0   19
0   20
0   21
0   22
0   23
0   24
0   25
0   26
0   27
0   28
0   29
1   0
1   1
1   2
1   3
1   4
1   5
1   6
1   7
1   8
1   9
1   10
1   11
1   12
1   13
1   14
1   15
1   16
1   17
1   18
1   19
1   20
1   21
1   22
1   23
1   24
1   25
1   26
1   27
1   28
1   29
2   0
2   1
2   2
2   3
2   4
2   5
2   6
2   7
2   8
2   9
2   10
2   11
2   12
2   13
2   14
2   15
2   16
2   17
2   18
2   19
2   20
2   21
2   22
2   23
2   24
2   25
2   26
2   27
2   28
2   29
3   0
3   1
3   2
3   3
3   4
3   5
3   6
3   7
3   8
3   9
3   10
3   11
3   12
3   13
3   14
3   15
3   16
3   17
3   18
3   19
3   20
3   21
3   22
3   23
3   24
3   25
3   26
3   27
3   28
3   29
4   0
4   1
4   2
4   3
4   4
4   5
4   6
4   7
4   8
4   9
4   10
4   11
4   12
4   13
4   14
4   15
4   16
4   17
4   18
4   19
4   20
4   21
4   22
4   23
4   24
4   25
4   26
4   27
4   28
4   29
5   0


In [10]:
print(iter_log)

[[38. 30. 32. 30. 33. 50. 47. 62. 37. 39. 48. 42. 48. 52. 33. 46. 41. 49.
  35. 36. 33. 30. 29. 37. 41. 27. 34. 42. 29. 29.]
 [61. 55. 79. 54. 47. 61. 51. 56. 53. 65. 47. 50. 94. 53. 57. 51. 66. 62.
  50. 97. 64. 69. 53. 59. 60. 58. 52. 47. 60. 43.]
 [61. 42. 63. 62. 54. 47. 54. 53. 54. 69. 57. 62. 61. 64. 56. 52. 57. 50.
  51. 70. 65. 62. 59. 53. 63. 51. 60. 69. 60. 54.]
 [72. 51. 43. 46. 46. 63. 56. 55. 70. 53. 56. 66. 49. 45. 59. 59. 69. 57.
  57. 61. 60. 45. 44. 57. 54. 61. 55. 60. 39. 60.]
 [53. 56. 63. 53. 55. 48. 72. 47. 65. 54. 50. 46. 44. 79. 79. 51. 66. 60.
  53. 54. 77. 88. 64. 52. 60. 78. 62. 60. 57. 69.]
 [52. 74. 48. 59. 60. 51. 60. 70. 65. 70. 64. 60. 59. 61. 39. 58. 81. 53.
  62. 47. 65. 64. 73. 67. 54. 51. 89. 37. 49. 67.]
 [54. 64. 64. 56. 39. 71. 55. 56. 52. 55. 52. 43. 39. 53. 52. 51. 58. 61.
  48. 38. 64. 61. 39. 59. 54. 46. 66. 67. 51. 49.]
 [55. 57. 50. 57. 71. 58. 33. 50. 39. 35. 70. 45. 55. 72. 40. 47. 50. 47.
  65. 50. 41. 55. 61. 57. 47. 50. 50. 51. 56. 40.]]

In [13]:
mean_mse = np.around(np.mean(mse_log, axis = 1), 5)
mean_iter = np.around(np.mean(iter_log, axis = 1), 3)
std_iter = np.around(np.std(iter_log, axis = 1), 3)

In [14]:
print("Periodicities: ", period_list)
print("Mean MSE: ", mean_mse)
print("Mean best iteration: ", mean_iter)
print("STD best iteration: ", std_iter)

Periodicities:  [0.5, 1, 2, 4, 8, 16, 32, 64]
Mean MSE:  [0.00132 0.00124 0.00136 0.00133 0.00131 0.00108 0.00088 0.00089]
Mean best iteration:  [38.633 59.133 57.833 55.6   60.5   60.3   53.9   51.8  ]
STD best iteration:  [ 8.404 12.23   6.558  8.2   11.114 11.181  8.677  9.772]
