# Настройка среды

In [None]:
import os
import sys
sys.path.append(os.path.abspath("../../../.."))

In [None]:
import optuna
import torch
import numpy
import json
import pickle
from typing import Literal, Union, Any
from belashovplot import TiledPlot

from elements.propagators import FurrierPropagation
from elements.modulators import PhaseModulator, AmplitudeModulator
from elements.composition import CompositeModel
from elements.detectors import ClassificationDetectors
from elements.simple import IntensityToAmplitude, AmplitudeToIntensity

from utilities.filters import Gaussian, Window
from utilities.datasets import Dataset
from utilities.losses import Normalizable, Normalization, LossLinearCombination
import cluster

In [None]:
from cluster import epochs, SelectedGPUs
SelectedGPUs.exclude(1)
GPUCount = len(SelectedGPUs.ids)
SelectedGPUs.ids

# Класс настройщик вариации

In [None]:
class Variation:
    class Abstract:
        def __init__(self, name:str):
            self._name = name
        def __call__(self, trial:optuna.trial):
            raise NotImplementedError
    class Float(Abstract):
        def __init__(self, name:str, v0:float, v1:float=None):
            super().__init__(name)
            if v1 is None: v1 = v0
            self._limits = (v0, v1)
        def __call__(self, trial:optuna.trial):
            return trial.suggest_float(self._name, *self._limits)
    class Int(Abstract):
        def __init__(self, name:str, v0:int, v1:int=None):
            super().__init__(name)
            if v1 is None: v1 = v0
            self._limits = (v0, v1)
        def __call__(self, trial:optuna.trial):
            return trial.suggest_int(self._name, *self._limits)
    class Categorical(Abstract):
        def __init__(self, name:str, cats:Union[list[str],str], reds:Union[list[Any],Any]=None):
            super().__init__(name)
            if not isinstance(cats, list): cats = list(cats)
            if reds is not None and not isinstance(reds, list): reds = list(redirect)
            self._cats = cats
            self._reds = reds
        def __call__(self, trial:optuna.trial):
            val = trial.suggest_categorical(self._name, self._cats)
            if self._reds is not None:
                index = self._cats.index(val)
                val = self._reds[index]
            return val
    class Proportion(Abstract):
        def __init__(self, name:str, amount:int, nmap:list[str]=None):
            super().__init__(name)
            self._amount = amount
            self._nmap = nmap
        def __call__(self, trial:optuna.trial):
            coefficients = []
            for i in range(self._amount):
                addition = f'{i}' if self._nmap is None else self._nmap[i]
                coefficients.append(trial.suggest_float(f'{self._name}_{addition}', 0.0, 1.0))
            integral = sum(coefficients)
            coefficients = tuple([coefficient/integral for coefficient in coefficients])
            return coefficients
    
    def __class_getitem__(cls, key:Literal['float', 'int', 'cat']):
        if key == 'float': return Variation.Float
        if key == 'int':   return Variation.Int
        if key == 'cat':   return Variation.Categorical
        raise KeyError

# Настройки вариаций

In [None]:
# Параметры системы
var_wavelength   = Variation.Float('wavelength', 500.0E-9)
var_N            = Variation.Int('N', 5000) #5000
var_length       = Variation.Float('length', 2.5E-3)
var_pixels       = Variation.Int('pixels', 5000) #5000
var_distance     = Variation.Float('distance', 1.0E-3, 500.0E-3)
var_masks_amount = Variation.Int('masks', 1)
var_detectors_norm   = Variation.Categorical('detectors_norm', ['none', 'integral', 'minmax'])

# Параметры обучения
var_epochs           = Variation.Int('epochs', 5) #5
var_batch_per_device = Variation.Int('batch_per_device', 2, 8)
var_dataset_scale    = Variation.Float('dataset_scale', 0.1, 1.0)
var_loss_proportion  = Variation.Proportion('loss_proportion', 2, ['ce', 'mse'])
var_ce_norm          = Variation.Categorical('ce_norm',  ['minmax','max','softmax'], [Normalization.Minmax(), Normalization.Max(), Normalization.Softmax()])
var_mse_norm         = Variation.Categorical('mse_norm', ['minmax','max','softmax'], [Normalization.Minmax(), Normalization.Max(), Normalization.Softmax()])
var_optimizer        = Variation.Categorical('optimizer', ['Adam', 'SGD', 'RMSProp'], [torch.optim.Adam, torch.optim.SGD, torch.optim.RMSprop])
var_learning_rate    = Variation.Float('learning_rate', 1.0E-7, 10.0)

# Не менять
var_devices          = Variation.Int('devices', GPUCount)

# Определение функции оптимизации

In [None]:
def objective(trial:optuna.trial):
    wavelength = var_wavelength(trial)
    N = var_N(trial)
    length = var_length(trial)
    pixels = var_pixels(trial)
    distance = var_distance(trial)
    masks_amount = var_masks_amount(trial)
    detectors_norm = var_detectors_norm(trial)

    epochs = var_epochs(trial)
    devices = var_devices(trial)
    batch_size = var_batch_per_device(trial) * devices
    dataset_scale = var_dataset_scale(trial)
    loss_proportion = var_loss_proportion(trial)
    ce_norm = var_ce_norm(trial)
    mse_norm = var_mse_norm(trial)
    optimizer = var_optimizer(trial)
    learning_rate = var_learning_rate(trial)

    print(f"Текущие параметры эксперимента: {trial.params}")
    
    propagation = FurrierPropagation(N, length, wavelength, 1.0, 0.0, distance, 0.4)
    phase_modulators = [PhaseModulator(N, length, pixels) for i in range(masks_amount)]
    amplitude_modulators = [AmplitudeModulator(N, length, pixels) for i in range(masks_amount)]
    elements = [propagation]
    for phase_modulator, amplitude_modulator in zip(phase_modulators, amplitude_modulators):
        elements += [phase_modulator, amplitude_modulator, propagation]
    spectral_filter = Window(centers=wavelength, sizes=300.0E-9)
    detectors_filter = Gaussian((length/10, length/10), (0,0))
    detectors = ClassificationDetectors(N, length, wavelength, 10, detectors_filter, spectral_filter)
    if detectors_norm == 'none':
        detectors.normalization.none()
    elif detectors_norm == 'integral':
        detectors.normalization.integral()
    elif detectors_norm == 'minmax':
        detectors.normalization.minmax()
    else: raise TypeError(f"Incorrect detectors normalization {detectors_norm}")
    model = CompositeModel(IntensityToAmplitude(), *elements, AmplitudeToIntensity(), detectors)
    
    dataset = Dataset('MNIST', batch_size, N, N, torch.float32, threads=8, preload=10)
    dataset.padding(surface_ratio=dataset_scale)
    loss_function = LossLinearCombination(Normalizable.CrossEntropy(ce_norm), Normalizable.MeanSquareError(mse_norm))
    loss_function.proportions(*loss_proportion)

    mh, lh, cmh = cluster.epochs(epochs, 10, model, dataset, loss_function, optimizer, lr=learning_rate)

    accuracies = [100*numpy.sum(numpy.diagonal(confusion, 0)) / numpy.sum(confusion) for confusion in cmh]
    best_index = max(enumerate(accuracies), key=lambda x: x[1])[0]
    best_accuracy = accuracies[best_index]

    
    model = pickle.dumps(mh[best_index-1]).hex()
    accuracy = json.dumps(best_accuracy)
    accuracies_history = json.dumps(best_accuracy)
    loss_history = json.dumps(numpy.concatenate(lh).tolist())
    confusion_matrixes = json.dumps(numpy.stack(cmh).tolist())

    trial.set_user_attr('model', model)
    trial.set_user_attr('accuracy', accuracy)
    trial.set_user_attr('accuracies_history', accuracies_history)
    trial.set_user_attr('loss_history', loss_history)
    trial.set_user_attr('confusion_matrixes', confusion_matrixes)

    return best_accuracy
    
study = optuna.create_study(study_name="Default", storage="sqlite:///D2NN.db", direction='maximize', load_if_exists=True)

# Оптимизация

In [None]:
study.optimize(objective, n_trials=10)

In [None]:
study.optimize(objective, n_trials=10)

In [None]:
study.optimize(objective, n_trials=10)