### Imports

In [None]:
import os
import sys
import random

In [None]:
import time

In [None]:
import numpy as np

In [None]:
import torch
from torch.utils.data import Dataset

In [None]:
from torch import nn

In [None]:
from torch.nn import functional

In [None]:
import torchvision
import torchvision.transforms as transforms

In [None]:
from torchvision.transforms import InterpolationMode

In [None]:
# our library
from svetlanna import SimulationParameters
from svetlanna.parameters import ConstrainedParameter

In [None]:
# our library
from svetlanna import Wavefront
from svetlanna import elements
from svetlanna.setup import LinearOpticalSetup
from svetlanna.detector import Detector, DetectorProcessorClf

In [None]:
from svetlanna.transforms import ToWavefront

In [None]:
# datasets of wavefronts
from src.wf_datasets import DatasetOfWavefronts
from src.wf_datasets import WavefrontsDatasetSimple

In [None]:
from tqdm import tqdm

In [None]:
from datetime import datetime

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

plt.style.use('dark_background')
%matplotlib inline

# 1. Simulation Parameters

## 1.1 Defining simulation parameters

In [None]:
working_frequency = 0.4 * 1e12 # [Hz]
c_const = 299_792_458  # [m / s]
working_wavelength = c_const / working_frequency  # [m]

# neuron size (square)
neuron_size = 0.53 * working_wavelength  # [m]

DETECTOR_SIZE = (1024, 1024)
# an actual zone where weights will be updated during a training process

# number of neurons in simulation
x_layer_nodes = DETECTOR_SIZE[1] * 1
y_layer_nodes = DETECTOR_SIZE[0] * 1
# Comment: Same size as proposed!

# physical size of each layer [cm]
x_layer_size_m = x_layer_nodes * neuron_size  # [m]
y_layer_size_m = y_layer_nodes * neuron_size  # [m]

In [None]:
print(f'lambda = {working_wavelength * 1e6:.3f} um')
print(f'neuron size = {neuron_size * 1e6:.3f} um')
print(f'Layer size (in neurons): {x_layer_nodes} x {y_layer_nodes} = {x_layer_nodes * y_layer_nodes}')
print(f'Layer size (in cm): {x_layer_size_m * 1e2} x {y_layer_size_m * 1e2}')

## 1.2 Creation of the grid(i.e. numerical mesh)

In [None]:
# simulation parameters for the rest of the notebook
SIM_PARAMS = SimulationParameters(
    axes={
        'W': torch.linspace(-x_layer_size_m / 2, x_layer_size_m / 2, x_layer_nodes),
        'H': torch.linspace(-y_layer_size_m / 2, y_layer_size_m / 2, y_layer_nodes),
        'wavelength': working_wavelength,  # only one wavelength!
    }
)

# 2. Dataset preparation (Data Engineer)

## 2.1. [MNIST Dataset](https://www.kaggle.com/datasets/hojjatk/mnist-dataset): loading and conversion to wavefronts

In [None]:
# initialize a directory for a dataset
MNIST_DATA_FOLDER = './data'  # folder to store data

### 2.1.1. Load Train and Test datasets of images

In [None]:
# TRAIN (images)
mnist_train_ds = torchvision.datasets.MNIST(
    root=MNIST_DATA_FOLDER,
    train=True,  # for train dataset
    download=False,
)

print(f'Train data: {len(mnist_train_ds)}')

In [None]:
# TEST (images)
mnist_test_ds = torchvision.datasets.MNIST(
    root=MNIST_DATA_FOLDER,
    train=False,  # for test dataset
    download=False,
)

print(f'Test data : {len(mnist_test_ds)}')

### 2.1.2. Detector

 `DetectorProcessor` in our library is used to process an information on detector. For example, for the current task `DetectorProcessor` must return only 10 values (1 value per 1 class).

 

 Let's define the “Detector” object: 

In [None]:
import src.detector_segmentation as detector_segmentation

In [None]:
number_of_classes = 10

detector_segment_size = 22 * working_wavelength

# size of each segment in neurons
x_segment_nodes = int(detector_segment_size / neuron_size)
y_segment_nodes = int(detector_segment_size / neuron_size)
# each segment of size = (y_segment_nodes, x_segment_nodes)

y_boundary_nodes = y_segment_nodes * 9
x_boundary_nodes = x_segment_nodes * 9


# This mask will be used to generate a target image for each number
DETECTOR_MASK = detector_segmentation.squares_mnist(
    y_boundary_nodes, x_boundary_nodes,  # size of a detector or an aperture (in the middle of detector)
    SIM_PARAMS
)
# Target image: zeros are everywhere except the necessary zone responsible for the label!

### 2.1.3. Conversion images to wavefronts

In [None]:
# select modulation type
MODULATION_TYPE = 'amp'  # using ONLY amplitude to encode each picture in a Wavefront!

resize_y = int(DETECTOR_SIZE[0] / 3)
resize_x = int(DETECTOR_SIZE[1] / 3)  # shape for transforms.Resize

# paddings along OY
pad_top = int((y_layer_nodes - resize_y) / 2)
pad_bottom = y_layer_nodes - pad_top - resize_y
# paddings along OX
pad_left = int((x_layer_nodes - resize_x) / 2)
pad_right = x_layer_nodes - pad_left - resize_x  # params for transforms.Pad

# compose all transforms!
image_transform_for_ds = transforms.Compose(
  [
      transforms.ToTensor(),
      transforms.Resize(
          size=(resize_y, resize_x),
          interpolation=InterpolationMode.NEAREST,
      ),
      transforms.Pad(
          padding=(
              pad_left,  # left padding
              pad_top,  # top padding
              pad_right,  # right padding
              pad_bottom  # bottom padding
          ),
          fill=0,
      ),  # padding to match sizes!
      ToWavefront(modulation_type=MODULATION_TYPE)  # <- select modulation type!!!
  ]
)

**<span style="color:red">Comment.</span>** Here `dataset.getitem()` will return a pair of a `Wavefront`, where a number encoded, and a target label (a number from 0 to 9). During the training process we will use MSE loss and we will generate a target detector picture based on a detector zones (will be initialized later in 3.1.3). 

In [None]:
# TRAIN dataset of WAVEFRONTS
mnist_wf_train_ds = DatasetOfWavefronts(
    init_ds=mnist_train_ds,  # dataset of images
    transformations=image_transform_for_ds,  # image transformation
    sim_params=SIM_PARAMS,  # simulation parameters
    target='detector',
    detector_mask=DETECTOR_MASK
)

# TEST dataset of WAVEFRONTS
mnist_wf_test_ds = DatasetOfWavefronts(
    init_ds=mnist_test_ds,  # dataset of images
    transformations=image_transform_for_ds,  # image transformation
    sim_params=SIM_PARAMS,  # simulation parameters
    target='detector',
    detector_mask=DETECTOR_MASK
)

# 3. Optical network

Let's create a neural network that will consist of 512 diffraction layers. 5 layers will be trained: 2 at the beginning, 2 at the end and one in the middle. 

> Distance between layers will be is set to be $40$ $\lambda$

## 3.1 Neural network parameters

In [None]:
NUM_OF_DIFF_LAYERS = 5  # number of diffractive layers that will be trained
FREE_SPACE_DISTANCE = 40 * working_wavelength  # [m] - distance between difractive layers
print(f'Distance between layers is {FREE_SPACE_DISTANCE * 1e2:.3f} cm')

MAX_PHASE = 2 * torch.pi  # max phase for phase masks

FREESPACE_METHOD = 'AS'  # we use another method in contrast to [2]!!!

INIT_PHASE = torch.pi  # initial values for phase masks

## 3.2 Architecture

**<span style="color:red">Comment:</span>**
Here we are using a default `ConstrainedParameter` which is using the sigmoid function to limit a parameter range.

**<span style="color:red">Comment:</span>** Setup ends with `Detector` that returns an output tensor of intensities for each input `Wavefront`.

In [None]:
def set_setup(
    total_number_of_layers: int,
    number_of_layers_at_the_beginning: int,
):

    global FREE_SPACE_DISTANCE, MAX_PHASE, FREESPACE_METHOD, INIT_PHASE
    global SIM_PARAMS

    elements_list = []

    free_space = elements.FreeSpace(
        simulation_parameters=SIM_PARAMS,
        distance=FREE_SPACE_DISTANCE,
        method=FREESPACE_METHOD
    )

    x_nodes, y_nodes = SIM_PARAMS.axes_size(axs=('W', 'H'))
    const_mask = torch.ones(size=(y_nodes, x_nodes)) * INIT_PHASE

    trainable_diffractive_layer = elements.DiffractiveLayer(
        simulation_parameters=SIM_PARAMS,
        mask=ConstrainedParameter(
            const_mask,
            min_value=0,
            max_value=MAX_PHASE
        ),
    )

    untrained_diffractive_layer = elements.DiffractiveLayer(
        simulation_parameters=SIM_PARAMS,
        mask=const_mask,  # HERE WE ARE DON'T USE CONSTRAINED PARAMETER!
    )

    elements_list.append(free_space)

    for _ in range(2):
        elements_list.append(trainable_diffractive_layer)
        elements_list.append(free_space)

    for _ in range(number_of_layers_at_the_beginning):
        elements_list.append(untrained_diffractive_layer)
        elements_list.append(free_space)

    elements_list.append(trainable_diffractive_layer)
    elements_list.append(free_space)

    for _ in range(total_number_of_layers - number_of_layers_at_the_beginning):
        elements_list.append(untrained_diffractive_layer)
        elements_list.append(free_space)

    for _ in range(2):
        elements_list.append(trainable_diffractive_layer)
        elements_list.append(free_space)

     # add Detector in the end of the system!
    elements_list.append(
        Detector(
            simulation_parameters=SIM_PARAMS,
            func='intensity'  # detector that returns intensity
        )
    )

    return LinearOpticalSetup(elements=elements_list)

In [None]:
NUM_OF_DIFF_LAYERS_NO_TRAIN = 507
NUM_OF_DIFF_LAYERS_BEGINNING = 253

In [None]:
optical_setup = set_setup(
    total_number_of_layers=NUM_OF_DIFF_LAYERS_NO_TRAIN,
    number_of_layers_at_the_beginning=NUM_OF_DIFF_LAYERS_BEGINNING
)

## 3.3 Detector processor

In [None]:
CALCULATE_ACCURACIES = True

# create a DetectorProcessorOzcanClf object
if CALCULATE_ACCURACIES:
    detector_processor = DetectorProcessorClf(
        simulation_parameters=SIM_PARAMS,
        num_classes=number_of_classes,
        segmented_detector=DETECTOR_MASK,
    )
else:
    detector_processor = None

## 3.4 Stuff for training

In [None]:
train_bs = 128  # a batch size for training set
val_bs = 64  # a batch size for validation set

LR = 1e-3  # learning rate

loss_func_clf = nn.MSELoss()  # by default: reduction='mean'
loss_func_name = 'MSE'

In [None]:
def get_adam_optimizer(net):
    return torch.optim.Adam(
        params=net.parameters(),  # NETWORK PARAMETERS!
        lr=LR
    )

In [None]:
# mnist_wf_train_ds
train_wf_ds, val_wf_ds = torch.utils.data.random_split(
    dataset=mnist_wf_train_ds,
    lengths=[55000, 5000],  # sizes from the article
    generator=torch.Generator().manual_seed(178)  # for reproducibility
)

train_wf_loader = torch.utils.data.DataLoader(
    train_wf_ds,
    batch_size=train_bs,
    shuffle=True,
    # num_workers=2,
    drop_last=False,
)

val_wf_loader = torch.utils.data.DataLoader(
    val_wf_ds,
    batch_size=val_bs,
    shuffle=False,
    # num_workers=2,
    drop_last=False,
)

test_wf_loader = torch.utils.data.DataLoader(
    mnist_wf_test_ds,
    batch_size=10,
    shuffle=False,
    # num_workers=2,
    drop_last=False,
)  # data loader for a test MNIST data

## 3.5 Training and evaluation loops

In [None]:
def onn_train_mse(
    optical_net, wavefronts_dataloader,
    detector_processor_clf,  # DETECTOR PROCESSOR needed for accuracies only!
    loss_func, optimizer,
    device='cuda', show_process=False
):
    """
    Function to train `optical_net` (classification task)
    ...

    Parameters
    ----------
        optical_net : torch.nn.Module
            Neural Network composed of Elements.
        wavefronts_dataloader : torch.utils.data.DataLoader
            A loader (by batches) for the train dataset of wavefronts.
        detector_processor_clf : DetectorProcessorClf
            A processor of a detector image for a classification task, that returns `probabilities` of classes.
        loss_func :
            Loss function for a multi-class classification task.
        optimizer: torch.optim
            Optimizer...
        device : str
            Device to computate on...
        show_process : bool
            Flag to show (or not) a progress bar.

    Returns
    -------
        batches_losses : list[float]
            Losses for each batch in an epoch.
        batches_accuracies : list[float]
            Accuracies for each batch in an epoch.
        epoch_accuracy : float
            Accuracy for an epoch.
    """
    optical_net.train()  # activate 'train' mode of a model
    batches_losses = []  # to store loss for each batch
    batches_accuracies = []  # to store accuracy for each batch

    correct_preds = 0
    size = 0

    for batch_wavefronts, batch_targets in tqdm(
        wavefronts_dataloader,
        total=len(wavefronts_dataloader),
        desc='train', position=0,
        leave=True, disable=not show_process
    ):  # go by batches
        # batch_wavefronts - input wavefronts, batch_labels - labels
        batch_size = batch_wavefronts.size()[0]

        batch_wavefronts = batch_wavefronts.to(device)
        batch_targets = batch_targets.to(device)

        optimizer.zero_grad()

        # forward of an optical network
        detector_output = optical_net(batch_wavefronts)

        # calculate loss for a batch
        loss = loss_func(detector_output, batch_targets)

        loss.backward()
        optimizer.step()

        # ACCURACY
        if CALCULATE_ACCURACIES:
            # process a detector image
            batch_labels = detector_processor_clf.batch_forward(batch_targets).argmax(1)
            batch_probas = detector_processor_clf.batch_forward(detector_output)

            batch_correct_preds = (
                batch_probas.argmax(1) == batch_labels
            ).type(torch.float).sum().item()

            correct_preds += batch_correct_preds
            size += batch_size

        # accumulate losses and accuracies for batches
        batches_losses.append(loss.item())
        if CALCULATE_ACCURACIES:
            batches_accuracies.append(batch_correct_preds / batch_size)
        else:
            batches_accuracies.append(0.)

    if CALCULATE_ACCURACIES:
        epoch_accuracy = correct_preds / size
    else:
        epoch_accuracy = 0.

    return batches_losses, batches_accuracies, epoch_accuracy

In [None]:
def onn_validate_mse(
    optical_net, wavefronts_dataloader,
    detector_processor_clf,  # DETECTOR PROCESSOR NEEDED!
    loss_func,
    device='cuda', show_process=False
    ):
    """
    Function to validate `optical_net` (classification task)
    ...

    Parameters
    ----------
        optical_net : torch.nn.Module
            Neural Network composed of Elements.
        wavefronts_dataloader : torch.utils.data.DataLoader
            A loader (by batches) for the train dataset of wavefronts.
        detector_processor_clf : DetectorProcessorClf
            A processor of a detector image for a classification task, that returns `probabilities` of classes.
        loss_func :
            Loss function for a multi-class classification task.
        device : str
            Device to computate on...
        show_process : bool
            Flag to show (or not) a progress bar.

    Returns
    -------
        batches_losses : list[float]
            Losses for each batch in an epoch.
        batches_accuracies : list[float]
            Accuracies for each batch in an epoch.
        epoch_accuracy : float
            Accuracy for an epoch.
    """
    optical_net.eval()  # activate 'eval' mode of a model
    batches_losses = []  # to store loss for each batch
    batches_accuracies = []  # to store accuracy for each batch

    correct_preds = 0
    size = 0

    for batch_wavefronts, batch_targets in tqdm(
        wavefronts_dataloader,
        total=len(wavefronts_dataloader),
        desc='validation', position=0,
        leave=True, disable=not show_process
    ):  # go by batches
        # batch_wavefronts - input wavefronts, batch_labels - labels
        batch_size = batch_wavefronts.size()[0]

        batch_wavefronts = batch_wavefronts.to(device)
        batch_targets = batch_targets.to(device)

        with torch.no_grad():
            detector_outputs = optical_net(batch_wavefronts)
            # calculate loss for a batch
            loss = loss_func(detector_outputs, batch_targets)

        # ACCURACY
        if CALCULATE_ACCURACIES:
            # process a detector image
            batch_labels = detector_processor_clf.batch_forward(batch_targets).argmax(1)
            batch_probas = detector_processor_clf.batch_forward(detector_outputs)

            batch_correct_preds = (
                batch_probas.argmax(1) == batch_labels
            ).type(torch.float).sum().item()

            correct_preds += batch_correct_preds
            size += batch_size

        # accumulate losses and accuracies for batches
        batches_losses.append(loss.item())
        if CALCULATE_ACCURACIES:
            batches_accuracies.append(batch_correct_preds / batch_size)
        else:
            batches_accuracies.append(0.)

    if CALCULATE_ACCURACIES:
        epoch_accuracy = correct_preds / size
    else:
        epoch_accuracy = 0.

    return batches_losses, batches_accuracies, epoch_accuracy

# 4. Training of the optical network

## 4.1. Before training: transferring objects to the GPU

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

DEVICE

In [None]:
optical_setup.net = optical_setup.net.to(DEVICE)
SIM_PARAMS = SIM_PARAMS.to(DEVICE)
detector_processor = detector_processor.to(DEVICE)

## 4.2. Training

In [None]:
n_epochs = 2
print_each = 2  # print each n'th epoch info

scheduler = None  # sheduler for a lr tuning during training


# Linc optimizer to a recreated net!
optimizer_clf = get_adam_optimizer(optical_setup.net)

In [None]:
train_epochs_losses = []
val_epochs_losses = []  # to store losses of each epoch

train_epochs_acc = []
val_epochs_acc = []  # to store accuracies

torch.manual_seed(98)  # for reproducability?

for epoch in range(n_epochs):
    if (epoch == 0) or ((epoch + 1) % print_each == 0) or (epoch == n_epochs - 1):
        print(f'Epoch #{epoch + 1}: ', end='')
        show_progress = True
    else:
        show_progress = False

    # TRAIN
    start_train_time = time.time()  # start time of the epoch (train)
    train_losses, _, train_accuracy = onn_train_mse(
        optical_setup.net,  # optical network composed
        train_wf_loader,  # dataloader of training set
        detector_processor,  # detector processor
        loss_func_clf,
        optimizer_clf,
        device=DEVICE,
        show_process=show_progress,
    )  # train the model
    mean_train_loss = np.mean(train_losses)

    if (epoch == 0) or ((epoch + 1) % print_each == 0) or (epoch == n_epochs - 1):  # train info
        print('Training results')
        print(f'\t{loss_func_name} : {mean_train_loss:.6f}')
        if CALCULATE_ACCURACIES:
            print(f'\tAccuracy : {(train_accuracy*100):>0.1f} %')
        print(f'\t------------   {time.time() - start_train_time:.2f} s')

    # VALIDATION
    start_val_time = time.time()  # start time of the epoch (validation)
    val_losses, _, val_accuracy = onn_validate_mse(
        optical_setup.net,  # optical network composed in 3.
        val_wf_loader,  # dataloader of validation set
        detector_processor,  # detector processor
        loss_func_clf,
        device=DEVICE,
        show_process=show_progress,
    )  # evaluate the model
    mean_val_loss = np.mean(val_losses)

    if (epoch == 0) or ((epoch + 1) % print_each == 0) or (epoch == n_epochs - 1):  # validation info
        print('Validation results')
        print(f'\t{loss_func_name} : {mean_val_loss:.6f}')
        if CALCULATE_ACCURACIES:
            print(f'\tAccuracy : {(val_accuracy*100):>0.1f} %')
        print(f'\t------------   {time.time() - start_val_time:.2f} s')

    if scheduler:
        scheduler.step(mean_val_loss)

    # save losses
    train_epochs_losses.append(mean_train_loss)
    val_epochs_losses.append(mean_val_loss)
    # seve accuracies
    train_epochs_acc.append(train_accuracy)
    val_epochs_acc.append(val_accuracy)

## 4.3. Learning curves (MSELoss and Accuracy)

In [None]:
# learning curve
fig, axs = plt.subplots(1, 2, figsize=(10, 3))

axs[0].plot(range(1, n_epochs + 1), np.array(train_epochs_losses) * 1e3, label='train')
axs[0].plot(range(1, n_epochs + 1), np.array(val_epochs_losses) * 1e3, linestyle='dashed', label='validation')

axs[1].plot(range(1, n_epochs + 1), train_epochs_acc, label='train')
axs[1].plot(range(1, n_epochs + 1), val_epochs_acc, linestyle='dashed', label='validation')

axs[0].set_ylabel(loss_func_name + r' $\times 10^3$')
axs[0].set_xlabel('Epoch')
axs[0].legend()

axs[1].set_ylabel('Accuracy')
axs[1].set_xlabel('Epoch')
axs[1].legend()

plt.show()

In [None]:
# array with all losses
# TODO: make with PANDAS?
all_lasses_header = ','.join([
    f'{loss_func_name.split()[0]}_train', f'{loss_func_name.split()[0]}_val',
    'accuracy_train', 'accuracy_val'
])
all_losses_array = np.array(
    [train_epochs_losses, val_epochs_losses, train_epochs_acc, val_epochs_acc]
).T

## 4.4. Trained phase masks

In [None]:
# Индексы объектов, которые нужно визуализировать
target_indices = {1, 3, 511, 1021, 1023}

# Определяем количество колонок и строк для визуализации
n_cols = len(target_indices)  # Количество колонок равно числу целевых индексов
n_rows = 1

# Создаем фигуру для визуализации
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5.2, n_rows * 4.6))

cmap = 'rainbow'  # Цветовая карта для визуализации
count = 1
# Перебираем слои в optical_setup
for ind_layer, layer in enumerate(optical_setup.net.to(torch.device("cpu"))):
    if ind_layer in target_indices and isinstance(layer, elements.DiffractiveLayer):


        # Определяем текущий subplot
        ax_this = axs[list(target_indices).index(ind_layer)]

        # Добавляем заголовок с индексом слоя
        ax_this.set_title(f'DiffractiveLayer {count}')
        count += 1

        # Получаем mask для визуализации
        mask_to_visualize = layer.mask.detach()

        # Визуализируем mask
        im = ax_this.imshow(
            mask_to_visualize, cmap=cmap,
            vmin=0, vmax=MAX_PHASE
        )
        x_frame = (x_layer_nodes - DETECTOR_SIZE[1]) / 2
        y_frame = (y_layer_nodes - DETECTOR_SIZE[0]) / 2
        ax_this.set_xlim([x_frame, x_layer_nodes - x_frame])
        ax_this.set_ylim([y_frame, y_layer_nodes - y_frame])

        cbar = fig.colorbar(im, ax=ax_this, orientation='vertical', fraction=0.046, pad=0.04)
        cbar.set_label('Mask Value')

plt.show()

## 4.5 Saving the results

#### Pathes

In [None]:
RESULTS_FOLDER = f'models/reproduced_results/MNIST_MSE_Ozcan_2018-2020_GPU_{512}_DL_{DETECTOR_SIZE[0]}x{DETECTOR_SIZE[1]}_grid'

if not os.path.exists(RESULTS_FOLDER):
    os.makedirs(RESULTS_FOLDER)

In [None]:
# filepath to save the model
model_filepath = f'{RESULTS_FOLDER}/optical_setup_net_gpu.pth'
# filepath to save losses
losses_filepath = f'{RESULTS_FOLDER}/training_curves_gpu.csv'

#### Saving model weights and learning curves

In [None]:
# saving model
torch.save(optical_setup.net.state_dict(), model_filepath)

In [None]:
# saving losses
np.savetxt(
    losses_filepath, all_losses_array,
    delimiter=',', header=all_lasses_header, comments=""
)

# 5. Load saved weights for the model

In [None]:
RESULTS_FOLDER = f'models/reproduced_results/MNIST_MSE_Ozcan_2018-2020_GPU_{512}_DL_{DETECTOR_SIZE[0]}x{DETECTOR_SIZE[1]}_grid'

load_model_filepath = f'{RESULTS_FOLDER}/optical_setup_net_gpu.pth'

In [None]:
# setup to load weights
ozcan_optical_setup_loaded = set_setup(
    total_number_of_layers=NUM_OF_DIFF_LAYERS_NO_TRAIN,
    number_of_layers_at_the_beginning=NUM_OF_DIFF_LAYERS_BEGINNING
)

# LOAD WEIGHTS
ozcan_optical_setup_loaded.net.load_state_dict(torch.load(load_model_filepath))

## 5.1. Calculate metrics on test set for the loaded model

Checking if the loaded model works correctly!

In [None]:
test_losses_1, _, test_accuracy_1 = onn_validate_mse(
    ozcan_optical_setup_loaded.net.to(torch.device("cuda")),  # optical network with loaded weights
    test_wf_loader,  # dataloader of training set
    detector_processor,  # detector processor
    loss_func_clf,
    device=DEVICE,
    show_process=True,
)  # evaluate the model

print(
    'Results after training on TEST set:\n' +
    f'\t{loss_func_name} : {np.mean(test_losses_1):.6f}\n' +
    f'\tAccuracy : {(test_accuracy_1 * 100):>0.1f} %'
)