### Imports

In [None]:
import os
import sys
import random
import json

In [None]:
import time

In [None]:
import numpy as np

In [None]:
import torch
from torch.utils.data import Dataset

In [None]:
from torch import nn

In [None]:
from torch.nn import functional

In [None]:
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode

In [None]:
from tqdm import tqdm

In [None]:
from datetime import datetime

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

plt.style.use('dark_background')
%matplotlib inline
# %config InlineBackend.figure_format = 'retina'

#### `svetlanna`

In [None]:
from svetlanna.units import ureg

In [None]:
from svetlanna import SimulationParameters
from svetlanna.parameters import ConstrainedParameter

In [None]:
from svetlanna import Wavefront
from svetlanna.transforms import ToWavefront

In [None]:
from svetlanna.elements import FreeSpace, Aperture, RectangularAperture, DiffractiveLayer
from svetlanna.setup import LinearOpticalSetup
from svetlanna.detector import Detector, DetectorProcessorClf

In [None]:
from svetlanna.visualization import show_stepwise_forward

In [None]:
from svetlanna.clerk import Clerk

#### `src`

In [None]:
# dataset of wavefronts
from src.wf_datasets import DatasetOfWavefronts

In [None]:
# training and evaluation loops
from src.clf_loops import onn_train_clf, onn_validate_clf

# Optical Neural Network

In that example notebook we will try to realize a simple architecture of an optical neural network from the article [[1]](https://www.science.org/doi/10.1126/science.aat8084).

### <span style="color:red"> Select the folder with results to load (TODO) </span>

In [None]:
# list of all saved models

DIR_RESULTS = 'results'

filepathes = []

for file in os.listdir(DIR_RESULTS):      
    filename = os.fsdecode(file)
    if os.path.isdir(os.path.join(DIR_RESULTS, filename)):
        filepathes.append(filename)

print(*sorted(filepathes), sep='\n')

In [None]:
SELECTED_EXP = 'exp_20-06-2025_14-48'  # TODO: select experiment folder from the list above!

In [None]:
RESULTS_FOLDER = f'{DIR_RESULTS}/{SELECTED_EXP}'  

In [None]:
with open(f'{RESULTS_FOLDER}/conditions.json') as json_file:
    LOADED_VARIABLES = json.load(json_file)

In [None]:
LOADED_VARIABLES

# 1. Simulation parameters


First of all we need to specify simulation parameters for our task: they includes wavelength $\lambda$ and a numerical mesh (in our case it corresponds to a neuron size).

**<span style="color:red">Sources to use:</span>**
[[1]](https://www.science.org/doi/10.1126/science.aat8084) and its [Supplementary Material](https://www.science.org/doi/suppl/10.1126/science.aat8084/suppl_file/aat8084-lin-sm-rev-3.pdf), [[2]](https://ieeexplore.ieee.org/abstract/document/8732486) 

In [None]:
c_const = 299_792_458  # [m / s]

In [None]:
working_wavelength = LOADED_VARIABLES['wavelength']  # [m]

In [None]:
# neuron size (square)
neuron_size = LOADED_VARIABLES['neuron_size']  # [m]
NEURON_SIZE = neuron_size

In [None]:
print('Specified parameters:')
# uncomment next two lines!
print(f'lambda = {working_wavelength * 1e6:.3f} um')
print(f'neuron size = {neuron_size * 1e6:.3f} um')

In [None]:
# an actual zone where weights will be updated during a training process
ALL_SIZE = LOADED_VARIABLES['mesh_size']  # for example (100, 100) neurons
USE_APERTURES = LOADED_VARIABLES['use_apertures']

In [None]:
if USE_APERTURES:
    # if we will add apertures we must specify the aperture size here!
    DETECTOR_SIZE = LOADED_VARIABLES['aperture_size']
else:
    DETECTOR_SIZE = ALL_SIZE

In [None]:
# number of neurons in simulation
x_layer_nodes = ALL_SIZE[1]
y_layer_nodes = ALL_SIZE[0]

In [None]:
# calculate physical size of each layer in [m]
x_layer_size_m = x_layer_nodes * neuron_size  # [m]
y_layer_size_m = y_layer_nodes * neuron_size

In [None]:
print(f'Layer size (in neurons): {x_layer_nodes} x {y_layer_nodes} = {x_layer_nodes * y_layer_nodes}')
print(f'Layer size (in cm): {x_layer_size_m * 1e2} x {y_layer_size_m * 1e2}')

In [None]:
# simulation parameters for the rest of the notebook!

SIM_PARAMS = SimulationParameters(
    axes={
        'W': torch.linspace(-x_layer_size_m / 2, x_layer_size_m / 2, x_layer_nodes),
        'H': torch.linspace(-y_layer_size_m / 2, y_layer_size_m / 2, y_layer_nodes),
        'wavelength': working_wavelength,  # monochromatic!
    }
)  # this is a custom object from our library `svetlanna`

# 2. Dataset preparation

## 2.1. [MNIST Dataset](https://www.kaggle.com/datasets/hojjatk/mnist-dataset)

Here we load dataset of images but we need to transform them to Wavefronts in order to use them for DNN training!

In [None]:
# initialize a directory for a dataset
MNIST_DATA_FOLDER = './data'  # folder to store data

In [None]:
# TRAIN (images)
mnist_train_ds = torchvision.datasets.MNIST(
    root=MNIST_DATA_FOLDER,
    train=True,  # for train dataset
    download=True,
)
# TEST (images)
mnist_test_ds = torchvision.datasets.MNIST(
    root=MNIST_DATA_FOLDER,
    train=False,  # for test dataset
    download=True,
)

## 2.2. Create Train and Test datasets of wavefronts

From [[2]](https://ieeexplore.ieee.org/abstract/document/8732486):

> Input objects were encoded in amplitude channel (MNIST) of the input plane and were illuminated with a uniform plane wave at a wavelength of $\lambda$ to match the conditions introduced in [[1]](https://www.science.org/doi/10.1126/science.aat8084) for all-optical classification.

So, we need to do an amplitude modulation of each image from the dataset!

**<span style="color:red">Comment:</span>**
We will see later what does "amplitude modulation" mean!

In [None]:
# select modulation type
MODULATION_TYPE = 'amp'  # using ONLY amplitude to encode each picture in a Wavefront!

### 2.2.1. Transformations of images to Wavefronts

In [None]:
resize_y = int(DETECTOR_SIZE[0] / 2)
resize_x = int(DETECTOR_SIZE[1] / 2)  # shape for transforms.Resize

# paddings along OY
pad_top = int((y_layer_nodes - resize_y) / 2)
pad_bottom = y_layer_nodes - pad_top - resize_y
# paddings along OX
pad_left = int((x_layer_nodes - resize_x) / 2)
pad_right = x_layer_nodes - pad_left - resize_x  # params for transforms.Pad

In [None]:
# compose all transforms!
image_transform_for_ds = transforms.Compose(
  [
      transforms.ToTensor(),
      transforms.Resize(
          size=(resize_y, resize_x),
          interpolation=InterpolationMode.NEAREST,
      ),
      transforms.Pad(
          padding=(
              pad_left,  # left padding
              pad_top,  # top padding
              pad_right,  # right padding
              pad_bottom  # bottom padding
          ),
          fill=0,
      ),  # padding to match sizes!
      ToWavefront(modulation_type=MODULATION_TYPE)  # <- select modulation type!!!
  ]  
)

### 2.2.2. Create Dataset objects for train and test datasets

In [None]:
# TRAIN dataset of WAVEFRONTS
mnist_wf_train_ds = DatasetOfWavefronts(
    init_ds=mnist_train_ds,  # dataset of images
    transformations=image_transform_for_ds,  # image transformation
    sim_params=SIM_PARAMS,  # simulation parameters
)
# TEST dataset of WAVEFRONTS
mnist_wf_test_ds = DatasetOfWavefronts(
    init_ds=mnist_test_ds,  # dataset of images
    transformations=image_transform_for_ds,  # image transformation
    sim_params=SIM_PARAMS,  # simulation parameters
)

# 3. Optical network

In [None]:
NUM_OF_DIFF_LAYERS = LOADED_VARIABLES['num_diff_layers']  # number of diffractive layers
FREE_SPACE_DISTANCE = LOADED_VARIABLES['free_space_distance']  # [m] - distance between difractive layers

In [None]:
print(f'Distance between layers is {FREE_SPACE_DISTANCE * 1e2:.3f} cm')

## 3.1. Architecture

### 3.1.1. Elements


In [None]:
MAX_PHASE = 2 * np.pi

In [None]:
FREESPACE_METHOD = 'AS'  # we use an angular spectrum method

In [None]:
INIT_PHASES = torch.ones(NUM_OF_DIFF_LAYERS) * np.pi

#### Functions that return single elements for further architecture

In [None]:
# TAKE A LOOK! CODE HERE IS READY
def get_const_phase_layer(
    sim_params: SimulationParameters,
    value: float, 
    max_phase=2 * torch.pi
):
    """
    Returns DiffractiveLayer with a constant phase mask.
    """
    x_nodes, y_nodes = sim_params.axes_size(axs=('W', 'H'))

    const_mask = torch.ones(size=(y_nodes, x_nodes)) * value
    
    return DiffractiveLayer(
        simulation_parameters=sim_params,
        mask=ConstrainedParameter(
            const_mask,
            min_value=0,
            max_value=max_phase
        ),  # HERE WE ARE USING CONSTRAINED PARAMETER! Phases are learnable!
    )

### <span style="color:red">3.1.2. List of Elements (!TODO!)</span>

Function to construct a list of elements to reproduce an architecture from [the extended article](https://ieeexplore.ieee.org/abstract/document/8732486):

In [None]:
# TODO: copy your function!!! or add it to src/ folder as a script and import the function!

In [None]:
def get_elements_list(
    num_layers,
    simulation_parameters,
    freespace_method,
    phase_values,
    apertures=False,
    aperture_size=(100, 100)
):
    # TODO: Copy your function!

    return elements_list

In [None]:
architecture_elements_list = get_elements_list(
    num_layers=NUM_OF_DIFF_LAYERS,
    simulation_parameters=SIM_PARAMS,
    freespace_method=FREESPACE_METHOD,
    phase_values=INIT_PHASES,
    apertures=USE_APERTURES,
    aperture_size=DETECTOR_SIZE
)

In [None]:
print(f'Number of elements in the system (including Detector): {len(architecture_elements_list)}')

### 3.1.3. Compose `LinearOpticalSetup`

In [None]:
def get_setup(simulation_parameters, apertures=False):
    """
    Returns an optical setup. Recreates all elements.
    """
    elements_list = get_elements_list(
        num_layers=NUM_OF_DIFF_LAYERS,
        simulation_parameters=SIM_PARAMS,
        freespace_method=FREESPACE_METHOD,
        phase_values=INIT_PHASES,
        apertures=apertures,
        aperture_size=DETECTOR_SIZE
    )  # recreate a list of elements

    return LinearOpticalSetup(elements=elements_list)

## 3.2. Detector processor

In [None]:
NUMBER_OF_CLASSES = 10  # TODO: how many classes do we have?

### 3.2.1. Detector mask

In [None]:
detector_segment_size = LOADED_VARIABLES['detector_segment_size']  # in neurons (int)
detector_segment_size_m = detector_segment_size * NEURON_SIZE  # in [m]

In [None]:
ZONES_ORDER = LOADED_VARIABLES['segments_order']  # TODO: specify the order

In [None]:
DETECTOR_MASK_LOADED = torch.load(f'{LOAD_FOLDER}/detector_mask.pt')

### 3.2.2. Detector processor

In [None]:
# create a DetectorProcessorOzcanClf object
DETECTOR_PROCESSOR = DetectorProcessorClf(
    simulation_parameters=SIM_PARAMS,
    num_classes=NUMBER_OF_CLASSES,
    segmented_detector=DETECTOR_MASK_LOADED,
)

#### To visualize detector zones (for further use)

In [None]:
ZONES_HIGHLIGHT_COLOR = 'w'
ZONES_LW = 0.5
selected_detector_mask = DETECTOR_PROCESSOR.segmented_detector.clone().detach()

In [None]:
def get_zones_patches(detector_mask):
    """
    Returns a list of patches to draw zones in final visualisation
    """
    zones_patches = []

    delta = 0.5
    
    for ind_class in range(NUMBER_OF_CLASSES):
        idx_y, idx_x = (detector_mask == ind_class).nonzero(as_tuple=True)
        
        zone_rect = patches.Rectangle(
            (idx_x[0] - delta, idx_y[0] - delta), 
            idx_x[-1] - idx_x[0] + 2 * delta, idx_y[-1] - idx_y[0] + 2 * delta, 
            linewidth=ZONES_LW, 
            edgecolor=ZONES_HIGHLIGHT_COLOR,
            facecolor='none'
        )
        
        zones_patches.append(zone_rect)

    return zones_patches

# 4. Necessary stuff

In [None]:
DEVICE = 'cpu'

In [None]:
train_bs = LOADED_VARIABLES['train_batch_size']  # a batch size for training set
val_bs = LOADED_VARIABLES['val_batch_size']

test_bs = 10

#### Train/Validation split

In [None]:
train_val_split_seed = LOADED_VARIABLES['train_val_seed']

In [None]:
# mnist_wf_train_ds
train_wf_ds, val_wf_ds = torch.utils.data.random_split(
    dataset=mnist_wf_train_ds,
    lengths=[55000, 5000],  # sizes from the article
    generator=torch.Generator().manual_seed(train_val_split_seed)  # for reproducibility
)

#### Loaders

In [None]:
train_wf_loader = torch.utils.data.DataLoader(
    train_wf_ds,
    batch_size=train_bs,
    shuffle=True,
    # num_workers=2,
    drop_last=False,
)

val_wf_loader = torch.utils.data.DataLoader(
    val_wf_ds,
    batch_size=val_bs,
    shuffle=False,
    # num_workers=2,
    drop_last=False,
)

test_wf_loader = torch.utils.data.DataLoader(
    mnist_wf_test_ds,
    batch_size=test_bs,
    shuffle=False,
    # num_workers=2,
    drop_last=False,
)  # data loader for a test MNIST data

#### Loss

In [None]:
loss_func_clf = nn.CrossEntropyLoss()
loss_func_name = 'CE loss'

# 5. Load model weights and estimate perfomance

## 5.1. Loading of saved results

### 5.1.1. Learning curves

In [None]:
losses_data = np.genfromtxt(
    f'{RESULTS_FOLDER}/training_curves.csv',
    delimiter=','
)

In [None]:
NUM_EPOCHS = LOADED_VARIABLES['number_of_epochs']
(train_epochs_losses, val_epochs_losses, train_epochs_acc, val_epochs_acc) = losses_data[1:, :].T

In [None]:
# learning curves plot
fig, axs = plt.subplots(1, 2, figsize=(10, 3))

axs[0].plot(range(1, NUM_EPOCHS + 1), np.array(train_epochs_losses), label='train')
axs[0].plot(range(1, NUM_EPOCHS + 1), np.array(val_epochs_losses), linestyle='dashed', label='validation')

axs[1].plot(range(1, NUM_EPOCHS + 1), train_epochs_acc, label='train')
axs[1].plot(range(1, NUM_EPOCHS + 1), val_epochs_acc, linestyle='dashed', label='validation')

axs[0].set_ylabel(loss_func_name)
axs[0].set_xlabel('Epoch')
axs[0].legend()

axs[1].set_ylabel('Accuracy')
axs[1].set_xlabel('Epoch')
axs[1].legend()

plt.show()

### 5.1.2. Weights of a model

In [None]:
# init setup to load weights
optical_setup_loaded = get_setup(SIM_PARAMS, LOADED_VARIABLES['use_apertures'])
# LOAD WEIGHTS for the model
optical_setup_loaded.net.load_state_dict(torch.load(f'{LOAD_FOLDER}/optical_net.pth'))

### 5.1.3. Trained phase masks visualization

In [None]:
n_cols = NUM_OF_DIFF_LAYERS  # number of columns for DiffractiveLayer's masks visualization
n_rows = 1

# plot wavefronts phase
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3, n_rows * 3.2))
ind_diff_layer = 0

cmap = 'gist_stern'  # 'gist_stern' 'rainbow'

for ind_layer, layer in enumerate(optical_setup_loaded.net):
    if isinstance(layer, DiffractiveLayer):  # plot masks for Diffractive layers
        if n_rows > 1:
            ax_this = axs[ind_diff_layer // n_cols][ind_diff_layer % n_cols]
        else:
            ax_this = axs[ind_diff_layer % n_cols]

        ax_this.set_title(f'{ind_diff_layer + 1}. DiffractiveLayer')

        trained_mask = layer.mask.detach()
        
        ax_this.imshow(         
            trained_mask, cmap=cmap,
            vmin=0, vmax=MAX_PHASE
        )
        ind_diff_layer += 1

        # select only a part within apertures!
        # x_frame = (x_layer_nodes - DETECTOR_SIZE[1]) / 2
        # y_frame = (y_layer_nodes - DETECTOR_SIZE[0]) / 2
        # ax_this.set_xlim([x_frame, x_layer_nodes - x_frame])
        # ax_this.set_ylim([y_frame, y_layer_nodes - y_frame])

plt.show()

## 5.2. Calculate metrics on test set for the loaded model

Checking if the loaded model works correctly!

In [None]:
test_losses_1, _, test_accuracy_1 = onn_validate_clf(
    optical_setup_loaded.net,  # optical network with loaded weights
    test_wf_loader,  # dataloader of training set
    detector_processor,  # detector processor
    loss_func_clf,
    device=DEVICE,
    show_process=True,
)  # evaluate the model

print(
    'Results after training on TEST set:\n' + 
    f'\t{loss_func_name} : {np.mean(test_losses_1):.6f}\n' +
    f'\tAccuracy : {(test_accuracy_1 * 100):>0.1f} %'
)

## 5.3. Example of classification (propagation through the setup)

### 5.3.1. Select a sample to propagate

In [None]:
# plot an image
# '1' - 3214, good
# '4' - 6152, good
# '6' - 123, good
# '8' - 128, good
# '0' - 3, good
ind_test = 123
cmap = 'hot'

fig, axs = plt.subplots(1, 2, figsize=(2 * 3, 3))

test_wavefront, test_target = mnist_wf_test_ds[ind_test]

axs[0].set_title(f'intensity (id={ind_test})')
axs[0].imshow(test_wavefront.intensity, cmap=cmap)

axs[1].set_title(f'phase')
axs[1].imshow(
    test_wavefront.phase, cmap=cmap,
    vmin=0, vmax=2 * torch.pi
)

plt.show()

In [None]:
# propagation of the example through the trained network
setup_scheme, test_wavefronts = optical_setup_loaded.stepwise_forward(test_wavefront)

### 5.3.2. Detector picture (enlarged)

In [None]:
# create a figure with subplots
fig, ax_this = plt.subplots(1, 1, figsize=(3, 3.2))

# Detector output (not a wavefront!)
ax_this.set_title('Detector Intensity')
ax_this.imshow(
    test_wavefronts[-1].detach().numpy(), cmap='hot',
    # vmin=0, vmax=1  # uncomment to make the same limits
)

for zone in get_zones_patches(DETECTOR_MASK_LOADED):
    # add zone's patches to the axis
    ax_this.add_patch(zone)

# select only a part within apertures! uncomment if needed
# x_frame = (x_layer_nodes - DETECTOR_SIZE[1]) / 2
# y_frame = (y_layer_nodes - DETECTOR_SIZE[0]) / 2

# plt.axis([x_frame, x_layer_nodes - x_frame, y_layer_nodes - y_frame, y_frame])

plt.show()

In [None]:
# get probabilities of an example classification
test_probas = DETECTOR_PROCESSOR.forward(test_wavefronts[-1])
# Comment: forward() method is from DetectorProcessorClf
#          p_i = I(detector_i) / sum_j(I(detector_j))
# Comment: It's another output than for batch_forward, that was used during training!

assert np.isclose(test_probas.sum().item(), 1)

for label, prob in enumerate(test_probas[0]):
    print(f'{label} : {prob * 100:.2f}%')