# Optical Neural Network

The following code is a framework for building an optical neural network.

        Optical Processing Steps:
        1. Input data → DMD input light pattern (Choose the input)
        2. First propagation: light diffracts through free space
        3. Phase modulation: multiply by exp(i*φ) where φ is learnable
        4. Second propagation: modulated light travels to detector
        5. Intensity detection: |E|² (lose phase information)  
        6. Detector readout: bin intensities into grid
        7. Linear processing: traditional neural network layer that maps the detector bins to what we are try to predict (This can be changed to arg-max function for classification)

The magic here happens with the Phase Modulation layer which allows us to shift the phase at each pixel. This phase shift(0-2pi) at each pixel is what we optimize through gradient descent or other means. 

SLM_pixels_array = phase_mask = [(0-2π), (0-2π), ..., (0-2π)]

Gradient descent, adam, or whatever optimizer you want seeks to choose the best phase shift on the slm to predict the outcome at the detector. Much like optimizing for a hologram, where you create a phase mask for those same slm pixels that creates the best hologram at your desired space( instead of the detector layer).

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class OpticalNeuralNetwork(nn.Module):
    """
    Optical Neural Network that simulates light propagation and phase modulation.
    
    The network performs computation through simulated optical physics:
    1. Input → SLM (Spatial Light Modulator)
    2. Light propagates through free space (diffracts)  
    3. Programmable phase modulation (learnable)
    4. Light propagates again to detector
    5. Intensity measurement → final linear layer
    """
    
    def __init__(self, *, slm_size, border_size, z_distance,
                 wavelength, dx, dy, grid_shape=(8,8), output_dim=1):
        super().__init__()
        
        # Device setup
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        # Physical parameters
        self.slm_size, self.border = slm_size, border_size  # SLM dimensions and border
        self.pad = border_size                              # Padding for propagation
        self.z   = z_distance                               # Propagation distance (meters)
        self.wl  = wavelength                               # Light wavelength (meters)
        self.dx, self.dy = dx, dy                           # Pixel pitch (meters)
        
        # Calculate effective size including padding
        eff_x = slm_size[0] + 2*self.pad
        eff_y = slm_size[1] + 2*self.pad
        
        # Wave number: k = 2π/λ
        k = 2*np.pi/self.wl
        
        # Spatial frequency grids (cycles per meter)
        # These represent the angular spectrum - different propagation angles
        kx = 2*np.pi*torch.fft.fftshift(torch.fft.fftfreq(eff_x,d=self.dx)).to(self.device)
        ky = 2*np.pi*torch.fft.fftshift(torch.fft.fftfreq(eff_y,d=self.dy)).to(self.device)
        
        # Create 2D frequency meshgrid
        KX, KY = torch.meshgrid(kx, ky, indexing='ij')
        
        # Calculate z-component of wave vector: k_z = √(k² - k_x² - k_y²)
        # This determines how much phase each frequency accumulates during propagation
        KZ = torch.sqrt(torch.clamp(k**2 - KX**2 - KY**2, min=0.0))
        
        # Angular Spectrum transfer function: H(k_x,k_y) = exp(i * k_z * z)
        # This is exact (not paraxial approximation) - more accurate than Fresnel method
        self.register_buffer("propagation_kernel", torch.exp(1j*KZ*self.z))
        
        # LEARNABLE PARAMETERS:
        # 1. Phase modulation - this is the main "weight" of the optical network
        #    Each pixel can apply a phase shift from 0 to 2π (or beyond)
        self.modulation = nn.Parameter(torch.zeros(self.slm_size, device=self.device))
        
        # 2. Final calibration layer - traditional neural network readout
        self.grid_rows, self.grid_cols = grid_shape
        self.calib = nn.Linear(self.grid_rows*self.grid_cols, output_dim)

    def _propagate(self, field):
        """
        Simulates light propagation through free space using Angular Spectrum Method.
        
        Angular Spectrum Method:
        1. Decomposes the light field into plane waves (angular spectrum)
        2. Each plane wave propagates with transfer function H(kx,ky) = exp(i*kz*z)
        3. where kz = √(k² - kx² - ky²) is the exact z-component of wave vector
        
        This is more accurate than Fresnel approximation, especially for:
        - Short propagation distances
        - Large numerical apertures  
        - Wide-angle diffraction
        
        Steps:
        1. Add padding to avoid edge effects
        2. FFT to frequency domain (decompose into plane waves)
        3. Multiply by Angular Spectrum transfer function
        4. IFFT back to spatial domain
        5. Remove padding
        """
        # Add reflective padding to simulate larger optical table
        if self.pad > 0:
            field = F.pad(field, (self.pad,)*4, mode='reflect')
        
        # Transform to frequency domain - this gives us the angular spectrum
        # Each frequency component represents a plane wave traveling at angle θ
        # where sin(θ) = kx/k (for x-direction)
        freq = torch.fft.fftshift(torch.fft.fft2(field, dim=(-2,-1)), dim=(-2,-1))
        
        # Apply Angular Spectrum propagation - each plane wave accumulates phase
        # High spatial frequencies (steep angles) accumulate more phase per distance
        out = freq * self.propagation_kernel
        
        # Transform back to spatial domain
        efld = torch.fft.ifft2(torch.fft.ifftshift(out, dim=(-2,-1)), dim=(-2,-1))
        
        # Remove padding to get back to original SLM size
        if self.pad > 0:
            x0, y0 = self.pad, self.pad
            x1, y1 = x0 + self.slm_size[0], y0 + self.slm_size[1]
            efld = efld[..., x0:x1, y0:y1]
        
        return efld

    def forward(self, x):
        """
        Forward pass through the optical neural network.
        
        Args:
            x: Input tensor that gets mapped to the SLM (by external function)
        
        Returns:
            Processed output after optical computation and linear readout
        
        Optical Processing Steps:
        1. Input data → DMD input light pattern (Choose the input)
        2. First propagation: light diffracts through free space
        3. Phase modulation: multiply by exp(i*φ) where φ is learnable
        4. Second propagation: modulated light travels to detector
        5. Intensity detection: |E|² (lose phase information)  
        6. Detector readout: bin intensities into grid
        7. Linear processing: traditional neural network layer that maps the detector bins to what we are try to predict (This can be changed to arg-max function for classification)
        """
        x = x.to(self.device)
        
        # Step 1: Map input to SLM (assumes external map_to_slm function)
        # This creates the initial complex-valued light field
        slm = map_to_slm(x, self.slm_size, self.border)
        
        # Step 2: First propagation - input light diffracts
        # Light spreads out and creates interference patterns
        E1 = self._propagate(slm)
        
        # Step 3: PHASE MODULATION - The key optical computation!
        # Multiply by exp(i*φ) where φ is the learned phase pattern
        # This is like a programmable hologram or spatial light modulator
        # 
        # Physical meaning: E_out = E_in * exp(i*φ)
        # - Magnitude |E| unchanged
        # - Phase arg(E) shifted by φ at each pixel
        # - Creates focusing, defocusing, beam steering, etc.
        M = E1 * torch.exp(1j*self.modulation)
        
        # Step 4: Second propagation - modulated light travels to detector
        # The phase modulation creates new interference patterns
        E2 = self._propagate(M)
        
        # Step 5: Intensity detection - photodetectors measure |E|²
        # Phase information is lost here (as in real optical systems)
        intensity = E2.abs()**2
        
        # Step 6: Detector array - bin intensity into discrete measurements
        # This simulates a CCD camera or photodiode array
        pooled = F.adaptive_avg_pool2d(intensity, (self.grid_rows, self.grid_cols))
        pooled = pooled.reshape(x.size(0), -1)
        
        # Step 7: Final linear processing - traditional neural network
        # Maps detector readings to desired outputs
        return self.calib(pooled)