# Настройка среды

In [1]:
import os
import sys
sys.path.append(os.path.abspath("../../../.."))

In [2]:
import optuna
import torch
import numpy
import json
import pickle
from typing import Literal, Union, Any
from belashovplot import TiledPlot

from elements.propagators import FurrierPropagation
from elements.modulators import PhaseModulator, AmplitudeModulator
from elements.composition import CompositeModel
from elements.detectors import ClassificationDetectors
from elements.simple import IntensityToAmplitude, AmplitudeToIntensity

from utilities.filters import Gaussian, Window
from utilities.datasets import Dataset
from utilities.losses import Normalizable, Normalization, LossLinearCombination
import cluster

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from cluster import epochs, SelectedGPUs
ExcludedGPUs = []
GPUCount = len(SelectedGPUs.ids) - len(ExcludedGPUs)
GPUCount

8

In [4]:
import pandas
pandas.options.display.max_columns = None

# Класс настройщик вариации

In [5]:
class Variation:
    class Abstract:
        def __init__(self, name:str):
            self._name = name
        def __call__(self, trial:optuna.trial):
            raise NotImplementedError
    class Float(Abstract):
        def __init__(self, name:str, v0:float, v1:float=None):
            super().__init__(name)
            if v1 is None: v1 = v0
            self._limits = (v0, v1)
        def __call__(self, trial:optuna.trial):
            return trial.suggest_float(self._name, *self._limits)
    class Int(Abstract):
        def __init__(self, name:str, v0:int, v1:int=None):
            super().__init__(name)
            if v1 is None: v1 = v0
            self._limits = (v0, v1)
        def __call__(self, trial:optuna.trial):
            return trial.suggest_int(self._name, *self._limits)
    class Categorical(Abstract):
        def __init__(self, name:str, cats:Union[list[str],str], reds:Union[list[Any],Any]=None):
            super().__init__(name)
            if not isinstance(cats, list): cats = list(cats)
            if reds is not None and not isinstance(reds, list): reds = list(redirect)
            self._cats = cats
            self._reds = reds
        def __call__(self, trial:optuna.trial):
            val = trial.suggest_categorical(self._name, self._cats)
            if self._reds is not None:
                index = self._cats.index(val)
                val = self._reds[index]
            return val
    class Proportion(Abstract):
        def __init__(self, name:str, amount:int, nmap:list[str]=None):
            super().__init__(name)
            self._amount = amount
            self._nmap = nmap
        def __call__(self, trial:optuna.trial):
            coefficients = []
            for i in range(self._amount):
                addition = f'{i}' if self._nmap is None else self._nmap[i]
                coefficients.append(trial.suggest_float(f'{self._name}_{addition}', 0.0, 1.0))
            integral = sum(coefficients)
            coefficients = tuple([coefficient/integral for coefficient in coefficients])
            return coefficients
    
    def __class_getitem__(cls, key:Literal['float', 'int', 'cat']):
        if key == 'float': return Variation.Float
        if key == 'int':   return Variation.Int
        if key == 'cat':   return Variation.Categorical
        raise KeyError

# Настройки вариаций

In [6]:
wavelength = 500.0E-9
Delta = 2*wavelength # Варианты 0.5λ, 1.0λ, 2.0λ, 4.0λ, 8.0λ, 16.0λ

# Параметры системы
var_wavelength   = Variation.Float('wavelength', wavelength)
var_N            = Variation.Int('N', 500, 5000)
# var_length       = Variation.Float('length', 2.5E-3*4)
# var_pixels       = Variation.Int('pixels', 5000) #5000
var_distance     = Variation.Float('distance', 1.0E-3, 500.0E-3)
var_masks_amount = Variation.Int('masks', 1)
var_detectors_norm   = Variation.Categorical('detectors_norm', ['none', 'integral', 'minmax'])

# Параметры обучения
var_epochs           = Variation.Int('epochs', 5, 7) #5
var_batch_per_device = Variation.Int('batch_per_device', 1, 8)
var_dataset_scale    = Variation.Float('dataset_scale', 0.1, 1.0)
var_loss_proportion  = Variation.Proportion('loss_proportion', 2, ['ce', 'mse'])
var_ce_norm          = Variation.Categorical('ce_norm',  ['minmax','max','softmax', 'none'], [Normalization.Minmax(), Normalization.Max(), Normalization.Softmax(), None])
var_mse_norm         = Variation.Categorical('mse_norm', ['minmax','max','softmax', 'none'], [Normalization.Minmax(), Normalization.Max(), Normalization.Softmax(), None])
var_optimizer        = Variation.Categorical('optimizer', ['Adam', 'SGD', 'RMSProp'], [torch.optim.Adam, torch.optim.SGD, torch.optim.RMSprop])
var_learning_rate    = Variation.Float('learning_rate', 1.0E-7, 2.0)

# Не менять
var_devices          = Variation.Int('devices', GPUCount)

# Определение функции оптимизации

In [7]:
def objective(trial:optuna.trial):
    wavelength = var_wavelength(trial)
    N = var_N(trial)
    
    var_length = Variation.Float('length', N*Delta)
    var_pixels = Variation.Int('pixels', N) #5000
    
    length = var_length(trial)
    pixels = var_pixels(trial)
    distance = var_distance(trial)
    masks_amount = var_masks_amount(trial)
    detectors_norm = var_detectors_norm(trial)

    epochs = var_epochs(trial)
    devices = var_devices(trial)
    batch_per_device = var_batch_per_device(trial)
    batch_size = batch_per_device * devices
    dataset_scale = var_dataset_scale(trial)
    loss_proportion = var_loss_proportion(trial)
    ce_norm = var_ce_norm(trial)
    mse_norm = var_mse_norm(trial)
    optimizer = var_optimizer(trial)
    learning_rate = var_learning_rate(trial)

    print(f"Текущие параметры эксперимента: {trial.params}")
    
    propagation = FurrierPropagation(N, length, wavelength, 1.0, 0.0, distance, 0.4)
    phase_modulators = [PhaseModulator(N, length, pixels) for i in range(masks_amount)]
    amplitude_modulators = [AmplitudeModulator(N, length, pixels) for i in range(masks_amount)]
    elements = [propagation]
    for phase_modulator, amplitude_modulator in zip(phase_modulators, amplitude_modulators):
        elements += [phase_modulator, amplitude_modulator, propagation]
    spectral_filter = Window(centers=wavelength, sizes=300.0E-9)
    detectors_filter = Gaussian((length/50, length/50), (0,0))
    detectors = ClassificationDetectors(N, length, wavelength, 10, detectors_filter, spectral_filter)
    if detectors_norm == 'none':
        detectors.normalization.none()
    elif detectors_norm == 'integral':
        detectors.normalization.integral()
    elif detectors_norm == 'minmax':
        detectors.normalization.minmax()
    else: raise TypeError(f"Incorrect detectors normalization {detectors_norm}")
    model = CompositeModel(IntensityToAmplitude(), *elements, AmplitudeToIntensity(), detectors)
    
    dataset = Dataset('MNIST', batch_per_device, N, N, torch.float32, threads=GPUCount, preload=10)
    dataset.padding(surface_ratio=dataset_scale)
    loss_function = LossLinearCombination(Normalizable.CrossEntropy(ce_norm), Normalizable.MeanSquareError(mse_norm))
    loss_function.proportions(*loss_proportion)

    torch.cuda.empty_cache()
    SelectedGPUs.exclude(*ExcludedGPUs)
    mh, lh, cmh = cluster.epochs(epochs, 10, model, dataset, loss_function, optimizer, lr=learning_rate)

    accuracies = [100*numpy.sum(numpy.diagonal(confusion, 0)) / numpy.sum(confusion) for confusion in cmh]
    best_index = max(enumerate(accuracies), key=lambda x: x[1])[0]
    best_accuracy = accuracies[best_index]

    # model = pickle.dumps(mh[best_index-1]).hex()
    accuracy = json.dumps(best_accuracy)
    accuracies_history = json.dumps(accuracies)
    loss_history = json.dumps(numpy.concatenate(lh).tolist())
    confusion_matrixes = json.dumps(numpy.stack(cmh).tolist())

    # trial.set_user_attr('model', model)
    trial.set_user_attr('accuracy', accuracy)
    trial.set_user_attr('accuracies_history', accuracies_history)
    trial.set_user_attr('loss_history', loss_history)
    trial.set_user_attr('confusion_matrixes', confusion_matrixes)

    return best_accuracy
    
study = optuna.create_study(study_name="Default", storage="sqlite:///DNL.db", direction='maximize', load_if_exists=True)

[I 2024-10-07 12:29:23,254] Using an existing study with name 'Default' instead of creating a new one.


In [8]:
study.trials_dataframe().nlargest(20, 'value')

Unnamed: 0,number,value,datetime_start,datetime_complete,duration,params_N,params_batch_per_device,params_ce_norm,params_dataset_scale,params_detectors_norm,params_devices,params_distance,params_epochs,params_learning_rate,params_length,params_loss_proportion_ce,params_loss_proportion_mse,params_masks,params_mse_norm,params_optimizer,params_pixels,params_wavelength,user_attrs_accuracies_history,user_attrs_accuracy,user_attrs_confusion_matrixes,user_attrs_loss_history,state
164,164,97.653333,2024-10-05 10:49:57.281176,2024-10-05 12:18:14.755372,0 days 01:28:17.474196,5000,3,none,0.135857,none,8,0.482149,5,0.087106,0.01,0.979408,0.589449,1,minmax,RMSProp,5000,5e-07,"[9.035, 94.06, 95.71333333333334, 97.178333333...",97.65333333333334,"[[[0.0, 0.0, 0.0, 0.0, 0.0, 740.375, 0.0, 0.0,...","[2217464064.0, 25510354944.0, 8448890880.0, 24...",COMPLETE
151,151,97.583333,2024-10-04 09:31:43.417224,2024-10-04 11:44:33.119649,0 days 02:12:49.702425,5000,4,none,0.119024,none,8,0.478326,5,0.105052,0.01,0.293021,0.728026,1,minmax,RMSProp,5000,5e-07,"[9.035, 93.83666666666667, 96.015, 96.55333333...",97.58333333333331,"[[[0.0, 0.0, 0.0, 0.0, 0.0, 740.375, 0.0, 0.0,...","[955066880.0, 13313553408.0, 2981218304.0, 507...",COMPLETE
152,152,97.583333,2024-10-04 11:44:33.223791,2024-10-04 14:14:16.380811,0 days 02:29:43.157020,5000,4,none,0.117693,none,8,0.478427,5,0.095319,0.01,0.992287,0.716175,1,minmax,RMSProp,5000,5e-07,"[9.035, 93.83666666666667, 96.015, 96.55333333...",97.58333333333331,"[[[0.0, 0.0, 0.0, 0.0, 0.0, 740.375, 0.0, 0.0,...","[955066880.0, 13313553408.0, 2981218304.0, 507...",COMPLETE
161,161,97.44,2024-10-05 04:52:27.929562,2024-10-05 07:07:36.180469,0 days 02:15:08.250907,5000,3,none,0.142703,none,8,0.483557,5,0.091139,0.01,0.98778,0.607162,1,minmax,RMSProp,5000,5e-07,"[9.035, 91.99666666666667, 96.54, 97.121666666...",97.44,"[[[0.0, 0.0, 0.0, 0.0, 0.0, 740.375, 0.0, 0.0,...","[2222863104.0, 27598155776.0, 9274755072.0, 24...",COMPLETE
160,160,97.421667,2024-10-05 03:23:20.986505,2024-10-05 04:52:27.798845,0 days 01:29:06.812340,5000,3,none,0.145098,none,8,0.479877,5,0.090086,0.01,0.974575,0.59888,1,minmax,RMSProp,5000,5e-07,"[9.035, 94.875, 96.08666666666667, 96.70333333...",97.42166666666668,"[[[0.0, 0.0, 0.0, 0.0, 0.0, 740.375, 0.0, 0.0,...","[2188310528.0, 27997859840.0, 9448491008.0, 25...",COMPLETE
180,180,97.411667,2024-10-06 13:53:50.337082,2024-10-06 15:23:53.881853,0 days 01:30:03.544771,5000,3,none,0.136382,none,8,0.46939,5,0.122151,0.01,0.942358,0.588526,1,minmax,RMSProp,5000,5e-07,"[9.035, 94.02333333333333, 96.43833333333333, ...",97.41166666666666,"[[[0.0, 0.0, 0.0, 0.0, 0.0, 740.375, 0.0, 0.0,...","[2188608000.0, 27549810688.0, 11313104896.0, 1...",COMPLETE
182,182,97.411667,2024-10-07 05:23:49.556601,2024-10-07 10:18:09.928436,0 days 04:54:20.371835,4790,3,none,0.136335,none,8,0.468183,6,0.107295,0.01,0.937367,0.522439,1,minmax,RMSProp,4790,5e-07,"[9.035, 94.02333333333333, 96.43833333333333, ...",97.41166666666666,"[[[0.0, 0.0, 0.0, 0.0, 0.0, 740.375, 0.0, 0.0,...","[2188608000.0, 27549810688.0, 11313104896.0, 1...",COMPLETE
159,159,97.408333,2024-10-05 01:53:13.730805,2024-10-05 03:23:20.885968,0 days 01:30:07.155163,5000,3,none,0.146909,none,8,0.482887,5,0.083966,0.01,0.935809,0.598952,1,minmax,RMSProp,5000,5e-07,"[9.035, 92.155, 94.585, 96.68666666666667, 97....",97.40833333333332,"[[[0.0, 0.0, 0.0, 0.0, 0.0, 740.375, 0.0, 0.0,...","[2137466368.0, 26430613504.0, 8646926336.0, 25...",COMPLETE
163,163,97.391667,2024-10-05 08:36:30.287793,2024-10-05 10:49:57.173049,0 days 02:13:26.885256,5000,3,none,0.138247,none,8,0.48322,5,0.101378,0.01,0.999821,0.606717,1,minmax,RMSProp,5000,5e-07,"[9.035, 93.915, 96.17666666666666, 97.17, 97.3...",97.39166666666668,"[[[0.0, 0.0, 0.0, 0.0, 0.0, 740.375, 0.0, 0.0,...","[2234833664.0, 27626989568.0, 10069804032.0, 2...",COMPLETE
168,168,97.315,2024-10-05 17:34:59.708849,2024-10-05 19:03:16.830678,0 days 01:28:17.121829,5000,3,none,0.138075,none,8,0.476995,5,0.127077,0.01,0.892748,0.57215,1,minmax,RMSProp,5000,5e-07,"[9.035, 94.21666666666667, 96.37833333333333, ...",97.315,"[[[0.0, 0.0, 0.0, 0.0, 0.0, 740.375, 0.0, 0.0,...","[2161636864.0, 27020216320.0, 11846311936.0, 1...",COMPLETE


In [9]:
acc_hists = study.trials_dataframe().nlargest(20, 'value')['user_attrs_accuracies_history']
for i, acc_hist in enumerate(acc_hists):
    acc_hist = json.loads(acc_hist)
    print(f"model {i}: " + (', '.join([f"{round(acc,1)}" for acc in acc_hist]) if isinstance(acc_hist, list) else f"{acc_hist}"))

model 0: 9.0, 94.1, 95.7, 97.2, 97.7, 34.4
model 1: 9.0, 93.8, 96.0, 96.6, 97.6, 97.6, 6.7
model 2: 9.0, 93.8, 96.0, 96.6, 97.6, 97.6, 6.7
model 3: 9.0, 92.0, 96.5, 97.1, 97.4, 97.4, 17.9
model 4: 9.0, 94.9, 96.1, 96.7, 97.4, 24.6
model 5: 9.0, 94.0, 96.4, 97.1, 97.4, 31.2
model 6: 9.0, 94.0, 96.4, 97.1, 97.4, 31.2
model 7: 9.0, 92.2, 94.6, 96.7, 97.4, 9.7
model 8: 9.0, 93.9, 96.2, 97.2, 97.4, 97.4, 24.0
model 9: 9.0, 94.2, 96.4, 96.7, 97.3, 22.0
model 10: 9.0, 94.2, 96.4, 96.7, 97.3, 22.0
model 11: 9.0, 92.8, 96.2, 96.9, 97.2, 7.6
model 12: 9.0, 86.8, 96.0, 96.2, 97.2, 8.6
model 13: 9.0, 94.2, 96.1, 96.7, 97.2, 19.1
model 14: 9.0, 94.7, 96.2, 97.0, 97.1, 12.7
model 15: 9.0, 94.2, 96.0, 97.0, 97.1, 97.1, 19.2
model 16: 9.0, 94.2, 96.0, 97.0, 97.1, 97.1, 19.2
model 17: 9.0, 94.2, 96.0, 97.0, 97.1, 97.1, 19.2
model 18: 9.0, 95.2, 95.1, 95.7, 97.1, 97.1, 6.2
model 19: 9.0, 95.2, 95.1, 95.7, 97.1, 97.1, 6.2


# Оптимизация

In [None]:
study.optimize(objective, n_trials=10)

Текущие параметры эксперимента: {'wavelength': 5e-07, 'N': 4781, 'length': 0.004781, 'pixels': 4781, 'distance': 0.4675132915029152, 'masks': 1, 'detectors_norm': 'none', 'epochs': 5, 'devices': 8, 'batch_per_device': 3, 'dataset_scale': 0.14309876062090318, 'loss_proportion_ce': 0.9447587344040843, 'loss_proportion_mse': 0.5904876795163942, 'ce_norm': 'none', 'mse_norm': 'minmax', 'optimizer': 'RMSProp', 'learning_rate': 0.1277567442966427}
Training main thread PID is: 3486570
  4%|█▊                                      | 110/2500 [00:58<11:14,  3.54it/s]

In [None]:
study.optimize(objective, n_trials=10)

In [None]:
study.optimize(objective, n_trials=10)