In [2]:
#Standard imports
import os
import math
import torch
import random
import numpy as np
from PIL import Image
import torch.nn as nn
from tqdm import tqdm
from IPython import embed
from typing import Optional
import matplotlib.pyplot as plt

#PL imports
import torchmetrics
import pytorch_lightning as pl
from torchvision import transforms
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader, random_split, RandomSampler

#For PL warnings
import warnings
warnings.filterwarnings("ignore")


In [5]:
class Propagation_Layer(torch.nn.Module):
    '''
    A class for implementing the angular spectrum method (ASM) of wavefront propagation 
    for neural networks. Attributes used during training are registered as buffers per 
    PytorchLightning documentation.

    ... 

    Attributes
    ----------
    params : dict
        Dictionary of propagation parameters to fill out most of other attributes.
    
    batch_size : int, tensor
        Batch sized used during training.

    distance : float, tensor
        Propagation distance between layers.

    wavelength : float, tensor
        Wavelength of illumination light.

    Nx : int, tensor
        Number of DOE/SLM 'pixel elements' in the x direction.
    
    Ny : int, tensor
        Number of DOE/SLM 'pixel elements' in the y direction.

    extent_x : float, tensor
        Length of DOE/SLM in the x direction.

    extent_y : float, tensor
        Length of DOE/SLM in the y direction. 
   
    pixel_pitch : float, tensor
        Pitch (distance between) of DOE/SLM pixels.  


    Methods
    -------
    init_layer()
        Initializes the propagation plane and propagation kernel for the ASM.

    forward(wavefront, distance=None)
        Performs the ASM and returns the resulting wavefront.
    '''
    def __init__(self, params, distance):
        super().__init__()
        self.distance = distance
        self.params = params['propagator']
        self.batch_size = params['lightning']['batch_size']
        self.register_buffer('wavelength', torch.tensor(self.params['wavelength']))
        #DOE/SLM Parameters
        self.doe_params = self.params['pluto']
        self.register_buffer('Nx', torch.tensor(self.doe_params['Nx']))
        self.register_buffer('Ny', torch.tensor(self.doe_params['Ny']))
        self.register_buffer('extent_x', torch.tensor(self.doe_params['x_extent']))   
        self.register_buffer('extent_y', torch.tensor(self.doe_params['y_extent']))   
        self.register_buffer('pixel_pitch', torch.tensor(self.doe_params['pitch']))
        #Initialize the propagation layer. 
        self.init_layer()
 
    def init_layer(self):
        self.register_buffer('x', torch.linspace(-self.extent_x / 2, self.extent_x / 2, self.Nx))
        self.register_buffer('y', torch.linspace(-self.extent_y / 2, self.extent_y / 2, self.Ny))

        xx,yy = torch.meshgrid(self.x, self.y)
        self.register_buffer('xx', xx.clone())
        self.register_buffer('yy', yy.clone())
 
        self.register_buffer('kx', torch.linspace(
            -math.pi * torch.div(self.Nx, 2, rounding_mode='floor') / (self.extent_x / 2), 
            math.pi * torch.div(self.Nx, 2, rounding_mode='floor') / (self.extent_x / 2),
            self.Nx
        ))
        self.register_buffer('ky', torch.linspace(
            -math.pi * torch.div(self.Ny, 2, rounding_mode='floor') / (self.extent_y / 2), 
            math.pi * torch.div(self.Ny, 2, rounding_mode='floor') / (self.extent_y / 2),
            self.Ny
        ))
        kxx, kyy =  torch.meshgrid(self.kx, self.ky)
        
        self.register_buffer('kxx', kxx.clone())
        self.register_buffer('kyy', kyy.clone()) 
        self.register_buffer('kz', torch.sqrt((2 * math.pi / self.wavelength) ** 2 - self.kxx ** 2 - self.kyy ** 2))

    def forward(self, wavefront, distance = None):
        '''
        Parameters
        ----------
        wavefront : float, tensor (Batch, Channel(2), Width, Height)
            Input wavefront to the layer. Channels are for the amplitude and phase of the
            wavefront respectively. 

        distance : float, tensor
            Distance of propation from current layer to the next. 
        '''

        if distance is not None:
            self.distance = torch.tensor(distance)

        amplitude = wavefront[:,0,:,:]
        phase = wavefront[:,1,:,:]
        E = amplitude[:,:,:] * torch.exp(1j * phase[:,:,:])
        
        #Get the angular spectrum at the current plane
        fft_c = torch.fft.fft2(E)
        c = torch.fft.fftshift(fft_c)

        #Multiply the angular spectrum by the propagation transfer function
        c_z = c * torch.exp(1j * self.kz * self.distance).to(c.device)

        #Compute the new wavefront
        E = torch.fft.ifft2(torch.fft.ifftshift(c_z))
        amplitude = torch.abs(E)
        phase = torch.angle(E)
     
 
        amplitude = torch.unsqueeze(torch.abs(E), 1)
        phase = torch.unsqueeze(torch.angle(E), 1)

        return torch.cat((amplitude, phase) , 1)
 

In [None]:
class Phase_Optimization(pl.LightningModule):
    def __init__(self, params):
        self.Nx = torch.tensor(params['propagator']['pluto']['Nx'])
        self.Ny = torch.tensor(params['propagator']['pluto']['Ny'])
        self.distance = torch.tensor(params['propagator']['distance'])
        self.wavelength = torch.tensor(params['propagator']['wavelength'])

        #Layers
        self.init_diffractive_layers(params)
        self.init_propagation_layers(params) 
        
        
    def init_diffractive_layers(self, params):
        params = params['model']
        num_layers = params['num_layers']

        self.initial_phases = [torch.nn.Parameter(torch.from_numpy(2 * np.pi * np.random.rand(1,1,self.Nx, self.Ny))) for _ in range(num_layers)]
        
        self.initial_amplitudes = [torch.nn.Parameter(torch.from_numpy(np.ones((1,1,self.Nx, self.Ny)))) for _ in range(num_layers)]
        
        for i in range(num_layers):
            self.initial_amplitudes[i].requires_grad = False
            self.register_parameter(f"phase_{i}", self.initial_phases[i])
            self.register_parameter(f"amplitude_{i}", self.initial_amplitudes[i])
            
    def init_propagation_layers(self, params):
        num_layers = params['model']['num_layers']
        self.prop_layers = [Propagation_Layer(params, self.distance) for _ in range(num_layers)]
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

    def criterion(self, x, y):
        loss = torch.nn.functional.mse_loss(x, y)
        return(loss)
    
    #======================================
    # Dataset Things
    #======================================
    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.dataset = utils.custom_cgh_dataset()
    
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)
    
    