In [None]:
# Imports
import numpy as np 
%matplotlib inline
import matplotlib.pyplot as plt 
# from scipy.interpolate import interp1d, interp2d 
# import scipy.io as sio 
from scipy.signal import max_len_seq 
# from imageio import imread
import torch
# print(torch.__file__)
# print(torch.version.cuda)
# print(torch.cuda.is_available())
# # import torchvision
import torchvision.transforms as transforms
import PIL.Image as Image
# # from skimage.transform import resize
import matplotlib.colors as pltc
import math
# import pickle

In [None]:
def greyscale(x):
    r, g, b = x[0,:,:], x[1,:,:], x[2,:,:]
    return 0.299 * r + 0.587 * g + 0.114 * b

In [None]:
def psnr(img1, img2):
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    mse = np.mean((img1 - img2)**2)
    if mse == 0:
        return float('inf')
    return 20 * math.log10(255.0 / math.sqrt(mse))

In [None]:
def color_range(tens):
  out = tens.detach().numpy()
  out = out/np.amax(out)
  out = np.clip(out, 0, 1); # Converting to correct color range
  return out

In [None]:
numScenePix = 256;  # Changing 256 to 64
numMaskPix = 63; unitMask = 30; 
numSensorPix = 512; unitSensor = 8;
d = 4000;

midScenePix = int(np.floor(numScenePix/2)) 
midSensorPix = int(np.floor(numSensorPix/2))-1  

z2a = lambda z: (1-d/z) 
a2z = lambda a: d/(1-a) 
crop_sensor = lambda x: x[midScenePix: midScenePix+numSensorPix, midScenePix:midScenePix+numSensorPix]
crop_img = lambda x: x[midSensorPix: midSensorPix+numScenePix, midSensorPix: midSensorPix+numScenePix]

depthList = z2a(np.array([50e4])) # 50cm away

# define functions for interpolation
locSensor = np.linspace(-numSensorPix*unitSensor/2 + unitSensor/2, numSensorPix*unitSensor/2 - unitSensor/2, numSensorPix) 
locMask = np.linspace(-numMaskPix*unitMask/2 + unitMask/2, numMaskPix*unitMask/2 - unitMask/2, numMaskPix) 

# Function needed for Generating PSF
def generateInterpMatrix(depthList, locSensor, locMask):
    numDepth = depthList.size
    interpMatrix = np.zeros([numSensorPix, numMaskPix, numDepth])
    for di in np.arange(numDepth):
        locInterp = depthList[di]*locSensor
        hi = np.minimum(len(locMask) - 1, np.searchsorted(locMask, locInterp, 'right'))
        lo = np.maximum(0, hi - 1)
        interpMatrix[np.arange(len(lo)), lo, di] = 1 - (locInterp - locMask[lo])/unitMask
        interpMatrix[np.arange(len(hi)), hi, di] = 1 - (locMask[hi] - locInterp)/unitMask
        rowsCast = np.where(np.logical_or(locInterp<locMask[0], locInterp>locMask[-1] ))
        interpMatrix[rowsCast, :, di] = 0
    return interpMatrix        

In [None]:
# create mask pattern
maskVec = max_len_seq(int(np.log2(numMaskPix+1)))[0].reshape(numMaskPix,1)
maskPattern = maskVec @ maskVec.T 

# generate interpolation matrices for multiple depths
interpMatrix = generateInterpMatrix(depthList, locSensor, locMask)[:,:,0]
psf = interpMatrix @ maskPattern @ interpMatrix.T # The Point Spread Function

plt.figure(1)
plt.subplot(121); plt.imshow(maskPattern); plt.title('mask pattern'); plt.axis('off');
plt.subplot(122); plt.imshow(psf); plt.title('psf'); plt.axis('off');

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# read image
im = Image.open('Images/Jpeg/InNature.jpg').convert('L')  # Original Image
im = np.array(im)
print(im.shape)

# forward model 
fullLength = numSensorPix + numScenePix - 1
psf_fft = np.fft.fft2(psf, s=[fullLength,fullLength])
img_fft = np.fft.fft2(im, s=[fullLength,fullLength])
y_fft = psf_fft * img_fft # Ax
y = crop_sensor(np.fft.ifft2(y_fft).real) # sensor measurements 

# add gaussian noise
std = 0.2
y = y + np.random.normal(loc=0, scale=std, size=(numSensorPix,numSensorPix)) # + mu
# im = im - 1
plt.figure(1)
plt.subplot(231);plt.imshow(psf); plt.title('psf'); plt.axis('off');
plt.subplot(232);plt.imshow(im, cmap='gray'); plt.title('image'); plt.axis('off'); # Displaying original image
plt.subplot(233);plt.imshow(y); plt.title('sensor measurements'); plt.axis('off');
plt.subplot(234);plt.imshow(np.log(np.fft.fftshift(np.abs(psf_fft)))); plt.title('psf fft'); plt.axis('off'); # What is the fft? Feed-Forward...?
plt.subplot(235);plt.imshow(np.log(np.fft.fftshift(np.abs(img_fft)))); plt.title('img fft'); plt.axis('off'); # Fourier Transform -> Convolution
plt.subplot(236);plt.imshow(np.log(np.fft.fftshift(np.abs(y_fft)))); plt.title('measurements fft'); plt.axis('off');

In [None]:
# inverse 
# Original Weiner Filter Implementation
# snr = signal to noise ratio. Represented as 'K' in the original formula
lambda_snr = np.sqrt(std) # regularization parameter for noise reduction ---- > Fine tune this
xhat_fft = (np.conjugate(psf_fft) * np.fft.fft2(y,s=[fullLength,fullLength] )) / (np.abs(psf_fft)**2 + lambda_snr )
xhat = crop_img(np.fft.fftshift(np.fft.ifft2(xhat_fft).real)) 

plt.figure(2)
plt.subplot(121);plt.imshow(im, cmap='gray');plt.title('original image'); plt.axis('off');
plt.subplot(122);plt.imshow(xhat, cmap='gray');plt.title('reconstructed image');plt.axis('off');

# Gradient Descent

- Each Iteration is: $x_k = x_{k-1} - A^T(Ax_{k-1} - y)$, where k is the number of iterations

In [None]:
# crop_fft = lambda x: x[0:numScenePix, 0:numScenePix]
# A = lambda x: np.fft.ifft2(psf_fft * np.fft.fft2(x, s=np.array([fullLength,fullLength]),norm='ortho')).real
# AT = lambda y: crop_fft(np.fft.ifft2(np.conjugate(psf_fft) * np.fft.fft2(y,norm='ortho'))).real # Doesn't return an image

In [None]:
# yn = A(im)
# plt.subplot(121); plt.imshow(yn); plt.title('yn'); plt.axis('off')
# plt.subplot(122); plt.imshow(A(im)); plt.title('A(im)'); plt.axis('off') # Sensor Measurements.

In [None]:
# # Loss Function
# def loss(z): # Z is the reconstructed image
#   _in = A(z) - yn
#   _in = torch.from_numpy(_in)
#   return torch.linalg.norm(_in)**2 # ||AG(z) - y||2^2

# print(loss(im)) 

In [None]:
# Reconstruction Error
# def recon_err(z):
#   _in = z - im
#   _in = torch.from_numpy(_in)
#   return torch.linalg.norm(_in)**2 

# print(recon_err(xhat))

In [None]:
# Gradient
# def grad(z, yn):
#   return AT(A(z) - yn) # A^T(A(z) - y)
  # return torch.from_numpy(AT(A(G(z)) - yn))

In [None]:
# Our goal was to make bad reconstructions better, 
# So let's make up a "bad reconstruction", and try and make it better
# init = np.random.normal(loc=0, scale=0.35, size=(256,256))
# init = np.random.rand(256,256)
# init = np.random.rand(64,64)
# plt.imshow(init); plt.axis('off');

In [None]:
# Example of how the Inital Image 'improves' each iteration (here is showing just 1)
# test_z = init - (0.01 * grad(init, yn))
# plt.imshow(test_z); plt.axis('off');

In [None]:
# def gd(sensor_measurements, learning_rate=0.01, max_iterations=100, tol=2000): # Starting with a simple 5 iterations
#   it = 1; relative_cost = float('inf'); 
#   loss_hist = np.zeros(max_iterations) # 
#   z_hist = np.zeros((max_iterations, im.shape[0], im.shape[1])) # Contains all previous reconstructions
#   # print(z_hist.shape)
#   # print(max_iterations)
#   z_hist[0][:][:] = init
#   loss_hist[0] = loss(z_hist[0])

#   while (it < max_iterations or relative_cost < tol):
#     # print(it)
#     z_k_1 = z_hist[it-1][:][:]
#     z_k = z_k_1 - (learning_rate * grad(z_k_1, sensor_measurements)) # x_k = x_{k-1} - grad(x_{k-1})
#     z_hist[it][:][:] = z_k
#     curr_loss = loss(z_hist[it]) # L
#     loss_hist[it] = curr_loss
#     relative_cost = recon_err(z_k) # compared to Tolerance
#     it += 1
#   return loss_hist, z_hist

In [None]:
# l_h, z_h = gd(yn)

In [None]:
# print(loss(z_h[0]))
# plt.plot(l_h)

In [None]:
# for i in range(0,100,5):
#   plt.imshow(z_h[i]) 
#   plt.pause(0.1)

In [None]:
# plt.imshow(z_h[99] - z_h[0]); plt.colorbar();
# print(z_h[0].min())
# print(z_h[0].max())

In [None]:
# print(l_h)

In [None]:
# plt.subplot(121); plt.imshow(z_h[0]); plt.axis('off');
# plt.subplot(122); plt.imshow(z_h[99]); plt.axis('off');
# print(loss(A(z_h[99])))

In [None]:
# img = z_h[99]
# plt.imsave('out.jpg', img)

# Load in a Training Dataset  

In [None]:
# Load in a Pretrained Model
import torch.hub
import torch.nn as nn
import torch.nn.functional as F

### Re-write Functions to use Torch instead of Numpy

In [None]:
# Possibly need to change these values ... 
crop_fft = lambda x: x[0:numScenePix, 0:numScenePix]
psf_fft = torch.tensor(psf_fft)
A = lambda x: torch.fft.ifft2(psf_fft * torch.fft.fft2(x, s=([fullLength,fullLength]), norm='ortho')).real
AT = lambda y: crop_fft(torch.fft.ifft2(torch.conj(psf_fft) * torch.fft.fft2(y, norm='ortho'))).real 

##### Writing A(x) for RGB Images

In [None]:
def A(x): 
    if len(x.size()) == 3:
        out = torch.zeros(767,767,3); # print("Shape is 3D")
        for i in range(0,3): # Input is shape [H,W,3]
            curr_channel = x[:,:,i]
            out[:,:,i] = torch.fft.ifft2(psf_fft * torch.fft.fft2(curr_channel, s=([fullLength,fullLength]), norm='ortho')).real
    else: 
        out = torch.fft.ifft2(psf_fft * torch.fft.fft2(x, s=([fullLength,fullLength]), norm='ortho')).real 
    return out

In [None]:
# img = Image.open('Images/Jpeg/InNature.jpg')  # Test Image
# img = np.array(img)
# img = torch.from_numpy(img)
# print(img.shape);
# # plt.imshow(img[:,:,0]); plt.axis('off'); 

# fig1, ax = plt.subplots(1,2,figsize=(10,5))
# one_dim = A(torch.from_numpy(im))
# three_dim = A(img)
# three_dim_disp = three_dim.numpy(); three_dim_disp = three_dim_disp/np.amax(three_dim_disp); three_dim_disp = np.clip(three_dim_disp, 0, 1); # Converting to correct color 
# ax[0].imshow(one_dim); ax[0].set_title('Grey A(x)');
# ax[1].imshow(three_dim_disp); ax[1].set_title('RGB A(x)');

In [None]:
# crop = lambda x: x[383:639, 383:639]
# xhat_fft = (np.conjugate(psf_fft) * np.fft.fft2(one_dim,s=[fullLength,fullLength] )) / (np.abs(psf_fft)**2 + lambda_snr )
# xhat = crop(np.fft.fftshift(np.fft.ifft2(xhat_fft).real))
# plt.subplot(121); plt.imshow(xhat, cmap='gray'); plt.title('Gray A(x)'); plt.axis('off'); 

# # --------------------------------------------------------
# # RGB Version
# xhat_fft_rgb = torch.zeros(767,767,3);
# xhat_rgb = torch.zeros(256,256,3);
# for i in range(0,3):
#     xhat_fft_rgb[:,:,i] = (torch.conj(psf_fft) * torch.fft.fft2(three_dim[:,:,i],s=[fullLength,fullLength] )) / (torch.abs(psf_fft)**2)
#     xhat_rgb[:,:,i] = crop(torch.fft.fftshift(torch.fft.ifft2(xhat_fft_rgb[:,:,i]).real))

# xhat_rgb = color_range(xhat_rgb) # Color Correction 
# plt.subplot(122); plt.imshow(xhat_rgb); plt.title('RGB A(x)'); plt.axis('off'); 

## Defining DCGAN 

### Load DCGAN


Option 1: DCGAN trained on FashionGen \\
Option 2: Progressive GAN trained on celebAHQ-256

In [None]:
use_gpu = True if torch.cuda.is_available() else False
print(use_gpu)
# model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', pretrained=True, useGPU=use_gpu)
model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub','PGAN', model_name='celebAHQ-256', pretrained=True, useGPU=use_gpu)

Option 3: StyleGan2 trained on FFHQ 
  - Requires `ffhq.pkl`

In [None]:
# torch.enable_grad()

# with open('D:/UCR/Research/Zips/ffhq.pkl', 'rb') as f:
#     netG = pickle.load(f)['G_ema']

# torch.set_grad_enabled(True)
# # Force Values to Floating Point in order to preserve CPU usage
# import functools
# netG.forward = functools.partial(netG.forward, force_fp32=True)
# torch.set_default_tensor_type('torch.cuda.FloatTensor')
# print(netG)

In [None]:
# nc = 1 # Number of channels in the training images. For color images this is 3
nz = 512 # nz = 120 # Size of z latent vector (i.e. size of generator input)
# ngf = 64 # Size of feature maps in generator
# ndf = 64 # Size of feature maps in discriminator

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
print(device)

In [None]:
# When we want to reclaim memory
# !pip install GPUtil
import gc
import GPUtil
gc.collect()
torch.cuda.empty_cache() # And then empty it

### Define Generator(z)
- We define the Generator function when we already have pre-trained weights

In [None]:
# Define Generator for Self-Trained DCGAN
# class Generator(nn.Module):
#     def __init__(self, ngpu):
#         super(Generator, self).__init__()
#         self.ngpu = ngpu
#         self.main = nn.Sequential(
#             # input is Z, going into a convolution
#             nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
#             nn.BatchNorm2d(ngf * 8),
#             nn.ReLU(True),
#             # state size. (ngf*8) x 4 x 4
#             nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(ngf * 4),
#             nn.ReLU(True),
#             # state size. (ngf*4) x 8 x 8
#             nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(ngf * 2),
#             nn.ReLU(True),
#             # state size. (ngf*2) x 16 x 16
#             nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(ngf),
#             nn.ReLU(True),
#             # state size. (ngf) x 32 x 32
#             nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
#             nn.Tanh()
#             # state size. (nc) x 64 x 64
#         )


#     def forward(self, input):
#         return self.main(input)

In [None]:
# Define Generator for CIFAR-100 DCGAN
# class Generator(nn.Module):
#     def __init__(self, ngpu, nc=1, nz=100, ngf=64):
#         super(Generator, self).__init__()
#         self.ngpu = ngpu
#         self.main = nn.Sequential(
#             # input is Z, going into a convolution
#             nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
#             nn.BatchNorm2d(ngf * 8),
#             nn.ReLU(True),
#             # state size. (ngf*8) x 4 x 4
#             nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(ngf * 4),
#             nn.ReLU(True),
#             # state size. (ngf*4) x 8 x 8
#             nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(ngf * 2),
#             nn.ReLU(True),
#             # state size. (ngf*2) x 16 x 16
#             nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(ngf),
#             nn.ReLU(True),
#             nn.ConvTranspose2d(    ngf,      nc, kernel_size=1, stride=1, padding=0, bias=False),
#             nn.Tanh()
#         )

#     def forward(self, input):
#         if input.is_cuda and self.ngpu > 1:
#             output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
#         else:
#             output = self.main(input)
#         return output

In [None]:
# netG = Generator(ngpu).to(device)
# # netG.load_state_dict(torch.load('D:/UCR/Research/DCGAN/generator.pth')) # CelebA Self-Trained
# # netG.load_state_dict(torch.load('D:/UCR/Research/lineCam/netG_epoch_299.pth', map_location=torch.device('cpu'))) # CIFAR-Pretrained
# netG.eval()

### Define G(z) 

In [None]:
 # Size of the Model 
GPUtil.showUtilization()

In [None]:
def G(z): # Input to the Model is a [N,nz,1,1], s.t. nz = latent vector
    out = model.netG(z) # out = netG(z)
    # out = F.interpolate(out, scale_factor=1/4, recompute_scale_factor=True) # Rescaling
    # print(out.dtype)
    # print(out.shape)  
    return out # Output [1,3,64,64]

# ---------------------------------------------------------
# Testing out the G(z) Function
# with torch.no_grad():
#   # i = G(torch.randn([1,nz]).cuda())[0] # ([3, 256, 256])
#   i = G(torch.randn([1,nz]))[0] # ([3, 256, 256])
#   print(i.shape) # Confirm correct shape
#   id = color_range(i.cpu())
#   plt.figure(figsize=(7,7))
#   plt.imshow(id.transpose(1,2,0)); plt.axis('off');
# del i
# GPUtil.showUtilization()

In [None]:
# Want to confirm torch.no_grad() blocks VRAM consumption
# tests = torch.zeros(100,3,256,256)
# with torch.no_grad():
#   for i in range(0,100):
#     tests[i,:,:,:] = G(torch.randn([1,nz]))[0]
# GPUtil.showUtilization()

### Sub-Sampling

#### Archived Sampling Methods

In [None]:
# def Slices(dim1, dim2, slc, col=False):
#       w_sigma = torch.zeros((3,dim1,dim2))
#   for i in range(0,3):
#     w_sigma[i,0:dim1:slc] = 1
#   if col == True:
#         w_sigma = torch.transpose(w_sigma, 1, 2)
#   return w_sigma

# def even_prct_chunks(dim1, dim2, percentage=0.1, num_chunks=2, col=False):
#   w_sigma = torch.zeros((3,dim1, dim2))
#   num_rows = int(w_sigma.shape[1] * percentage) # 153
#   if(num_rows < num_chunks) or (num_chunks <= 1): 
#     print("Error, Percentage or num_chunks is too Low!")

#   id = dim1 / num_chunks # <---- 
#   chk_sz = num_rows / num_chunks
#   for i in range(0, num_chunks):
#     for j in range(0,3):
#       start = int(id*i)
#       end = int(start + chk_sz)
#       # if j==0:
#       #   print(start,end)
#       w_sigma[j][start:end] = 1
#   if col == True:
#         w_sigma = torch.transpose(w_sigma, 1, 2)
#   return w_sigma

# Random Distribution - Chooses a % of Random Rows 
# def rand_dst(dim1, dim2, percentage, col=False):
#   w_sigma = torch.zeros((3,dim1,dim2))
#   num_rows = int(w_sigma.shape[1] * percentage)

#   for i in range(0, num_rows):
#     idx = rand_roll(w_sigma.shape[1] - 1)
#     if w_sigma[:,idx].sum() == 0:
#       for j in range(0,3):
#         w_sigma[j,idx] = 1
#   if col == True:
#         w_sigma = torch.transpose(w_sigma, 1, 2)
#   return w_sigma

#### Random Sampling Methods 

In [None]:
import random

# Random Roll - Helper Function to give value between 0 and x 
def rand_roll(x):
    return random.randint(0,x)

# Roll Between - Helper Function to give value between a start and end
def roll_btwn(strt, end):
    return random.randint(strt,end)
  
# Random Pixels - Chooses a % of Random Pixels
def rand_pxl(dim1, dim2, prct, HER=False):
      w_sigma = torch.zeros((dim1,dim2,3))
      num_pixels = int((dim1 * dim2) * prct);
      
      if HER == True:
            for i in range(0,num_pixels):
                  x = roll_btwn(120, 632) # placing a specific 512
                  y = roll_btwn(120, 632)
                  for j in range(0,3):
                        w_sigma[x,y,j] = 1          
      else:
            for i in range(0,num_pixels):
                  x = rand_roll(dim1 - 1)
                  y = rand_roll(dim2 - 1)
                  for j in range(0,3):
                        w_sigma[x,y,j] = 1
      return w_sigma


test2 = rand_pxl(767,767, 0.25, True);

plt.figure(figsize=(15,15))
plt.subplot(121); plt.imshow(test2.cpu()); plt.axis('off'); plt.title('Random Pixels');

### Define Parameter Z

- The low-dimensional space vector will follow the format required from the DCGAN

In [None]:
# z = torch.nn.Parameter(torch.randn(1,nz).cuda()) # Randomly initialize a value for 'z'

### Create a New Test Image

In [None]:
# Testing Random Image Generation --------------------- 
with torch.no_grad():
  test_image = G(torch.randn(1,nz).cuda())[0] # ([3,256,256])
ti = color_range(test_image.cpu())
plt.subplot(121); plt.imshow(ti.transpose(1,2,0)); plt.axis('off'); plt.title("Test (Gen'd) Image");

# Testing A(x) -------------------------
test = A(test_image.cpu().permute(1,2,0))
test_sm = color_range(test.cpu())
plt.subplot(122); plt.imshow(test_sm); plt.axis('off'); plt.title('Sensor Measurements')

# Init yn
yn = test.detach(); # (767,767,3)
GPUtil.showUtilization()

In [None]:
# plt.imsave('TestImage.png', ti.transpose(1,2,0))
original_im = test_image.cpu().permute(1,2,0)

### OR Load it from Memory

In [None]:
celeb = Image.open('D:/UCR/Research/TestImage.png').convert('RGB')   # Test Image from CelebAHQ-256 Dataset
plt.imshow(celeb); plt.axis('off');
transform = transforms.ToTensor()
test_image = transform(celeb); print(test_image.shape); # ([3,256,256])

In [None]:
# Load the saved tensor
with torch.no_grad():
    plt.figure(figsize=(9,9))
    original_im = test_image.cpu().permute(1,2,0)
    li = color_range(test_image.cpu()) # (3,256,256) # li = original loaded image, color corrected
    # plt.subplot(121); plt.imshow(li.transpose(1,2,0)); plt.axis('off'); plt.title("Loaded Test (Gen'd) Image");
    plt.subplot(121); plt.imshow(test_image.cpu().permute(1,2,0)); plt.axis('off'); plt.title("Loaded Test (Gen'd) Image");

    # Testing A(x) ---------------------
    load_sm = A(test_image.cpu().permute(1,2,0))
    disp_sm = color_range(load_sm.cpu())
    plt.subplot(122); plt.imshow(disp_sm); plt.axis('off'); plt.title('Loaded Sensor Measurements')

    # Init yn
    yn = load_sm.detach(); # [767,767,3]
    GPUtil.showUtilization();

In [None]:
plt.imsave('SM.png', disp_sm)

In [None]:
# Lets see if we can sub sample the Sensor Measurements of the Diffuser
# w_sigma = torch.ones((767,767,3))
w_sigma = rand_pxl(767,767,0.25,True)

# print(w_sigma.shape) # Confirm correct shape
plt.subplot(121); plt.imshow(w_sigma[:,:,0].cpu(), cmap='Accent'); plt.axis('off'); plt.title('sampling matrix');

# Create sensor_measurements_sparse
# print(yn.shape, w_sigma.shape)
y_sparse = w_sigma * yn;
y_s_d = color_range(y_sparse.cpu())
plt.subplot(122); plt.imshow(y_s_d); plt.axis('off'); plt.title('y_sparse');
GPUtil.showUtilization()

In [None]:
plt.imsave('SpareSM.png', y_s_d)
plt.imsave('Sampling.png', w_sigma[:,:,0].cpu(), cmap='Accent')

In [None]:
# -- Loss Function --
def loss_net(z):
  val = G(z)[0]; # Produces an Image, Shape (3,H=64,W=64)
  # plt.subplot(121); plt.imshow(val.cpu().detach().permute(1,2,0)) # Showing val (Generated Image)
  # plt.subplot(122); plt.imshow(A(val.cpu().detach().permute(1,2,0))) # Showing A(val) (Calculated Sensor Measurements for the Image)
  # return torch.linalg.norm(A(val.permute(1,2,0)) - yn)**2 # ||AG(z) - y||2^2
  return torch.linalg.norm((w_sigma * A(val.cpu().permute(1,2,0))) - y_sparse)**2 # ||AG(z) - y||2^2

# -- Reconstruction Error --
def recon_err_net(z):
    u = G(z)[0] # z will be the actual latent vector 'z'
    # plt.imshow(u.cpu().detach().permute(1,2,0)) # Showing val (Generated Image)
    # print(test_image.shape)
    return torch.mean((u.cpu().permute(1,2,0) - original_im)**2) # ||G(z_hat) - x*||2^2

# -- Random Initialization --
# -- Initalize a better 'z' by minimizing loss --
# def rand_init():
#       num_samples = 10
#       sample_list = torch.zeros(num_samples,1,nz) # List of randomly initialized samples
#       loss_list = torch.zeros(num_samples) # List of loss values

#       for it in range(0,10):
#         if(it % 5 == 0):
#           GPUtil.showUtilization()
#           print(it)
#         sample = torch.randn(1,nz) # Generate a random sample
#         curr_loss = loss_net(sample) # Calculate loss w/ sample
#         sample_list[it] = sample  # Add sample to list
#         loss_list[it] = curr_loss # Min loss corresponds to values at this index
          
#       min_loss = torch.min(loss_list)
#       min_loss_idx = torch.argmin(loss_list)
#       print(min_loss, "Index " + str(min_loss_idx))
#       new_z = sample_list[min_loss_idx]  
#       return new_z
    
with torch.no_grad():
  # -- Testing Loss Function (using no grad)
  Loss_test = loss_net(torch.randn(1,nz).cuda())
  print(Loss_test)

  # -- Testing Reconstruction Error Function (using no grad)
  re_test = recon_err_net(torch.randn(1,nz).cuda())
  print(re_test)

z = torch.nn.Parameter(torch.randn(1,nz).cuda(), requires_grad=True) # Randomly initialize a  # Randomly initialize a value for 'z'

# Define an Optimizer


In [None]:
import torch.optim as optim
import cv2
import os
# optimizer = optim.SGD([z], lr=0.001) # Stochastic Gradient Descent 
optimizer = optim.Adam([z], lr=0.1) # Adam Optimizer
# optimizer = optim.RMSprop([z], lr=0.00001) # RMSprop Optimizer # Lower sampling rate
lmda = lambda x: 0.1
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmda)
total_iterations = 2000 # Use 10000 ?
max_iters = 500

In [None]:
def recon_init(z):
    r_hist = torch.zeros(total_iterations)
    images = torch.zeros(3,256,256,total_iterations)
    tol = True
    it = 0; epoch = 1;

    for epoch in range(1,5):
        print("Starting Epoch %d" % epoch);
        while(it < (epoch * max_iters) and tol == True):
            if(it % 100 == 0):
                print("Iteration: %d" % it)
            # GPUtil.showUtilization()
            optimizer.zero_grad()
            
            R = recon_err_net(z) # z - sensor measurements
            torch.autograd.backward(R, retain_graph=False, create_graph=False) 
            optimizer.step()
            
            with torch.no_grad(): 
                gen_img = G(z) 
                images[:,:,:,it] = gen_img
                r_hist[it] = R
            # if(r_hist[it] < 1000):
            #     tol = False
            #     print("Iteration: %d" % it)
            #     r_hist[it:] = R   
            it += 1
            del R, gen_img
    
    epoch += 1
    scheduler.step()
    print('Done!')
    
    return z, r_hist, images

In [None]:
def loss_init(z):
  loss_history = torch.zeros(max_iterations)
  rec_history = torch.zeros(max_iterations)
  img_list = torch.zeros(3,256,256,max_iterations)
  tol = True
  it = 0

  while (it < max_iterations and tol == True):
    if(it % 100 == 0):
      print("Iteration: %d" % it)
      # GPUtil.showUtilization()
   
    L = loss_net(z) # z - sensor measurements
    torch.autograd.backward(L, retain_graph=False, create_graph=False) 
    optimizer.step()
    optimizer.zero_grad()
    with torch.no_grad():
      gen_img = G(z) 
      img_list[:,:,:,it] = gen_img
      R = recon_err_net(z) # Reconstruction Error
    rec_history[it] = R
    loss_history[it] = L
    if(loss_history[it] < 1000):
          tol = False
          print("Iteration: %d" % it)
          loss_history[it:] = L # Fix Loss and Rec Graphs
          rec_history[it:] = R
    it += 1
    del L, R, gen_img

  print('Done!')

  return z, loss_history, rec_history, img_list 

# Currently 12:05 Execution Time via COLAB
# Currently 3:53/7:33 Execution Time via LOCAL MCHN

In [None]:
opt_z, r_hist, images = recon_init(z)
# epochs = 2;
# Optimized_Z = torch.zeros(epochs,1,nz)

# for e in range(0,int(epochs)):
#     Optimized_Z[e], = recon_init(z)

# min_z = torch.min(Optimized_Z,0)
# min_z_idx = torch.argmin(Optimized_Z, 0)
# Opt_z = Optimized_Z[min_z_idx]
# r_hist = R_History[min_z_idx]
# images = Image_List[min_z_idx,:,:,:,:]

#       print(min_loss, "Index " + str(min_loss_idx))
#       new_z = sample_list[min_loss_idx]  
#       return new_z

In [None]:
plt.semilogy(r_hist[0:].detach().numpy()); plt.title("Reconstruction Error History")

In [86]:
for i in range(0,2000):
    img = images[:,:,:,i]
    disp_img = color_range(img); display = color_range(img);
    disp_img = 255 * (disp_img - disp_img.min()) / (disp_img.max() - disp_img.min()); disp_img = np.array(disp_img, np.float32);
    cv2.imwrite('Images/GIF_Images/image' + "0" + str(i) + '.png', cv2.cvtColor(disp_img.transpose(1,2,0), cv2.COLOR_RGB2BGR))
    # plt.imshow(display.transpose(1,2,0)); plt.axis('off'); plt.title('Iteration '+ str(i))
    # plt.pause(0.1)

In [None]:
optimizer = optim.RMSprop([opt_z], lr=0.001) # RMSprop Optimizer # Lower sampling rate
max_iterations = 2000 # Use 10000 ?
final_z, loss_history, rec_history, img_list = loss_init(opt_z)

In [None]:
# print(loss_history[0:].detach().numpy())
plt.figure(figsize=(15,5))
plt.subplot(121); plt.semilogy(loss_history[0:].detach().numpy()); plt.title("Loss History")
plt.subplot(122); plt.plot(rec_history[0:].detach().numpy()); plt.title("Reconstruction Error History")
print(loss_history[-1]) # Get last loss
# plt.savefig('D:/UCR/Research/Images/figs/history' + str(10) + '%' + 'test5' + '.png')


#### Observing Change from 0-100

- The first ~200 iterations have the greatest change, so let's examine those for now. 

In [None]:
# Save the first 200 images, because they show the greatest change.

# for i in range(0,200,5):
#     img = img_list[:,:,:,i]
#     disp_img = color_range(img); display = color_range(img);
#     disp_img = 255 * (disp_img - disp_img.min()) / (disp_img.max() - disp_img.min()); disp_img = np.array(disp_img, np.float32);
    # cv2.imwrite('Images/GIF_Images/image' + str(i) + '.png', cv2.cvtColor(disp_img.transpose(1,2,0), cv2.COLOR_RGB2BGR))
    # plt.imshow(img_list[:,:,:,i].detach().numpy().transpose(1,2,0)); plt.axis('off'); plt.title('Iteration '+ str(i))
    # plt.pause(0.1)

# Can save the img_list using first_100 = np.zeros(3,256,256,100); first_100[:,:,;,j] = img_list[:,:,:,i]; np.save(first_100.detach().numpy());

In [None]:
# Then save more intermediate images, since the change is not as great

for i in range(0,2000,100):
    img = img_list[:,:,:,i]
    disp_img = color_range(img); display = color_range(img);
    # disp_img = 255 * (disp_img - disp_img.min()) / (disp_img.max() - disp_img.min()); disp_img = np.array(disp_img, np.float32);
    # cv2.imwrite('D:/UCR/Research/Images/GIF_Images/image' + str(i) + '.png', cv2.cvtColor(disp_img.transpose(1,2,0), cv2.COLOR_RGB2BGR))
    plt.imshow(display.transpose(1,2,0)); plt.axis('off'); plt.title('Iteration '+ str(i))
    plt.pause(0.1)

In [None]:
# Display Images

# for i in range(0,2000,200):
#     plt.imshow(img_list[:,:,:,i].detach().numpy().transpose(1,2,0)); plt.axis('off'); plt.title('Iteration '+ str(i))
#     plt.pause(0.1)

In [None]:
recon = G(z)[0] # Optimized Reconstructed Image (Tensor, uses Grad)
recon_norm = color_range(recon.cpu()) # Reconstruction converted to Color Range
# start = img_list[:,:,:,0]; start = color_range(start); # First Iteration
start = images[:,:,:,0]; start = color_range(start); # First Iteration
# Calculate PSNR -- Using 2 Floats in this Calculation
# R = 'max int value of depth'
# ti = li # 
# print(ti.shape, recon_norm.shape)
nl = cv2.PSNR(recon_norm, ti, R=1) # Noise Level 
nl2 = psnr(recon_norm, ti)
print(nl, nl2)

plt.figure(figsize=(15,15))
plt.subplot(131); plt.imshow(ti.transpose(1,2,0)); plt.axis('off'); plt.title('Original')
plt.subplot(132); plt.imshow(recon_norm.transpose(1,2,0)); plt.axis('off'); plt.title('Optimized PSNR: %f' % nl);
plt.subplot(133); plt.imshow(start.transpose(1,2,0)); plt.axis('off'); plt.title('Iteration '+ str(0));

# plt.savefig('D:/UCR/Research/Images/figs/fig' + str(10) + '%' + 'test5' + '.png')
# plt.savefig('D:/UCR/Research/Images/figs/fig' + str(25) + '_col' + '.png')

# print(recon.max(), recon.min())

In [None]:
plt.imsave('OptImg.png', recon_norm.transpose(1,2,0))
plt.imsave('It0.png', start.transpose(1,2,0))

In [None]:
# diff = recon_norm.transpose(1,2,0) - ti.transpose(1,2,0)
# plt.imshow(diff)
recon_A_gt = A(recon.cpu().detach().permute(1,2,0))
recon_A = color_range(recon_A_gt)
# print(recon_A.shape)
# print(disp_sm.shape)
plt.figure(figsize=(15,15))
plt.subplot(131); plt.imshow(disp_sm); plt.axis('off'); plt.title('Original SM');
plt.subplot(132); plt.imshow(recon_A); plt.axis('off'); plt.title('Optimized SM');
plt.subplot(133); plt.imshow(disp_sm - recon_A); plt.axis('off'); plt.title('Diff');

In [None]:
# Original Image
image_fft_rgb = torch.zeros(767,767,3);
image_rgb = torch.zeros(256,256,3);
for i in range(0,3):
    image_fft_rgb[:,:,i] = (torch.conj(psf_fft) * torch.fft.fft2(yn[:,:,i],s=[fullLength,fullLength] )) / (torch.abs(psf_fft)**2)
    image_rgb[:,:,i] = crop(torch.fft.fftshift(torch.fft.ifft2(image_fft_rgb[:,:,i]).real))

image_rgb = color_range(image_rgb) # Color Correction 
plt.subplot(121); plt.imshow(image_rgb); plt.title('Original'); plt.axis('off'); 

# Optimized
recon_fft_rgb = torch.zeros(767,767,3);
recon_rgb = torch.zeros(256,256,3);
for i in range(0,3):
    recon_fft_rgb[:,:,i] = (torch.conj(psf_fft) * torch.fft.fft2(recon_A_gt[:,:,i],s=[fullLength,fullLength] )) / (torch.abs(psf_fft)**2)
    recon_rgb[:,:,i] = crop(torch.fft.fftshift(torch.fft.ifft2(recon_fft_rgb[:,:,i]).real))

recon_rgb = color_range(recon_rgb) # Color Correction 
plt.subplot(122); plt.imshow(recon_rgb); plt.title('Recon'); plt.axis('off'); 

In [88]:
from PIL import Image
import glob
import cv2

frame_sz = (256,256)
out = cv2.VideoWriter('transforms.mp4', cv2.VideoWriter_fourcc(*'DIVX'), 60, frame_sz)
for filename in sorted(glob.glob("D:/UCR/Research/Images/GIF_Images/transforms/*.png"), key=os.path.getmtime):
    print(filename)
    img = cv2.imread(filename)
    out.write(img)
out.release()
# fp_in = "D:/UCR/Research/Images/GIF_Images/transforms/image*.png"
# fp_out = "D:/UCR/Research/transforms.gif"

# imgs = [Image.open(f) for f in glob.glob(fp_in)]
# img = imgs[0]
# img.save(fp=fp_out, format='GIF', append_images=imgs, duration=500, save_all=True, loop=1)

D:/UCR/Research/Images/GIF_Images/transforms\image00.png
D:/UCR/Research/Images/GIF_Images/transforms\image01.png
D:/UCR/Research/Images/GIF_Images/transforms\image02.png
D:/UCR/Research/Images/GIF_Images/transforms\image03.png
D:/UCR/Research/Images/GIF_Images/transforms\image04.png
D:/UCR/Research/Images/GIF_Images/transforms\image05.png
D:/UCR/Research/Images/GIF_Images/transforms\image06.png
D:/UCR/Research/Images/GIF_Images/transforms\image07.png
D:/UCR/Research/Images/GIF_Images/transforms\image08.png
D:/UCR/Research/Images/GIF_Images/transforms\image09.png
D:/UCR/Research/Images/GIF_Images/transforms\image010.png
D:/UCR/Research/Images/GIF_Images/transforms\image011.png
D:/UCR/Research/Images/GIF_Images/transforms\image012.png
D:/UCR/Research/Images/GIF_Images/transforms\image013.png
D:/UCR/Research/Images/GIF_Images/transforms\image014.png
D:/UCR/Research/Images/GIF_Images/transforms\image015.png
D:/UCR/Research/Images/GIF_Images/transforms\image016.png
D:/UCR/Research/Images/G

In [None]:
# np.set_printoptions(threshold=np.inf)
# sm = disp_sm
# print(sm.shape)
# plt.imshow(sm[121:632,121:632,:])

# rows, cols = np.nonzero(sm[:,:,0] > 0.0)
# row2, col2 = np.nonzero(sm[:,:,1] > 0.0)
# row3, col3 = np.nonzero(sm[:,:,2] > 0.0)
# nrg = np.zeros((767,767,3))

# # with open('Energy.txt', 'a') as fp:
# #     fp.write(str(sm))

# for r, c in zip(rows, cols):
#     nrg[r][c][0] = 1
# for r, c in zip(row2, col2):
#     nrg[r][c][1] = 1
# for r, c in zip(row3, col3):
#     nrg[r][c][2] = 1
# plt.imshow(nrg[121:632,121:632,:]); plt.colorbar();

11 - 4 - 21
#### TODO
- ~~PSNR Statistics? - Possible built in functions for this~~ \\
- ~~Save the first 100 Images (since they show the most dramatic change)~~ \\
- Make forward model more realistic \\
  - Different types of sampling patterns \\
  - Reconstruction Side, **try different initializations for (z). Solve this problem by drawing different values for (z)** 
  - Does quality degrade when going down from lower and lower amounts of measurements 


Reconstruction Error - Finds an image that matches observation


#### Next Steps 

How many measurements are needed to recover a single image?
See what is out there in terms of video generative models, or some other ways to reconstruct video. 



Showing not just a line sensor, but a distributed sensor, see if video works for that.

PULSE
Demarcus 
ILO

