In [1]:
import numpy as np
import torch
from torchvision.transforms import v2

## Light Source

In [None]:
class resource:
  def __init__(self, freq, phase, amplitude):
    self.freq = freq
    self.phase = phase
    self.amplitude = amplitude




## Distances

In [None]:
class distance:
    def __init__(self,z1,z2, Dz):
        self.z1 = z1        #resource to sample distance (can be an array)
        self.z2 = z2        #sample to sensor distance   (can be an array)
        self.Dz = Dz        #distance between heights    (for multi-height phase retreival approach)

## Free Space

In [None]:
class free_space:
    def __init__(self, n, freq, numX, numY, z, flag):
        self.n = n
        self.freq = freq     #source frequency
        self.numX = numX     #number of samples in fx range (with respect to the shape of data)
        self.numY = numY     #number of samples in fy range (with respect to the shape of data)
        self.z = z           #distance of propagation
        self.flag = flag     #flag for forward or backward propagation

        def transfer(self):
          Lambda = 2*np.pi*freq
          fx = np.array(np.linspace(-1/Lambda, 1/Lambda, self.numX))
          fy = np.array(np.linspace(-1/Lambda, 1/Lambda, self.numY))
          grid_x, grid_y = np.meshgrid(fx, fy)
          fz = np.sqrt((1/(Lambda**2))*np.ones(grid_x.shape) - grid_x**2 - grid_y**2)
          H  = np.exp(1j*2*np.pi*n*flag*z*fz)
          return H


## Back Propagation

In [None]:
def Back_Propagation(z, freq, hologram):
  numX, numY = hologram.shape
  s1 = free_space(1, freq, numX, numY, z, 1)
  H = s1.transfer()
  return np.abs(np.fft.ifft2(np.fft.ifftshift(H*np.fft.fftshift(np.fft.fft2(hologram)))))


## Forward Propagation

In [None]:
def Forward_Propagation(z, freq, hologram):
  numX, numY = hologram.shape
  s1 = free_space(1, freq, numX, numY, z, -1)
  H = s1.transfer()
  return np.abs(np.fft.ifft2(np.fft.ifftshift(H*np.fft.fftshift(np.fft.fft2(hologram)))))

In [None]:
def amplitude_update(w1, w2, hologram, frame):
  '''
  weighed sum of measured hologram and calculated hologram with propagation

  '''
  return w1*hologram + w2*frame


In [None]:
def iterative(level, z, freq, hologram, w1, w2, E, treshold):

  '''
  level : number of different measured height
  z: the distance
  freq: frequency(ies) of the source
  hologram : matrix of measured holograms
  w1: weight for measured hologram in amplitude update
  w2: weight for calculated hologram in amplitude update
  E: error threshold
  treshold: treshold of the number of iterations

  '''
  frame = hologram[0]
  phase = 0
  error = 1000
  while(np.abs(error) < E or treshold == 0):
    for i in range(level):
      for f in freq:
         F = Forward_Propagation(z, f, frame)
         frame = amplitude_update(w1, w2, hologram[i+1], np.abs(F))*np.exp(1j*np.angle(F))

    for i in range(level-1, -1, -1):
      for f in freq:
         B = Back_Propagation(z, f, frame)
         frame = amplitude_update(w1, w2, hologram[i-1], np.abs(B))*np.exp(1j*np.angle(B))

    error = phase - np.angle(B)
    phase = np.angle(B)
    treshold -= 1

  return phase


## Super Resolution

In [None]:
seq = 9   #number of digitally shifted LR samples
step = 2  #pixels to shift
scale = 5 #down sampling factor
h, w = img.shape

angle = [0, 30, 60, 90, 120, 180]
for t in angle:
  transforms = v2.compose([
      v2.RandomAffine(degrees = 0, translate = (step, step)),
      v2.GaussianBlur(5, sigma=np.sqrt(2)),
      v2.Resize((h/scale, w/scale)),
      v2.RandomRotation(t),

      v2.RandomHorizontalFlip(p=0.5),
      v2.RandomVerticalFlip(p=0.5)


   ])
