# Optimization of instrument designs

In [None]:
from definitions import *
import numpy as np
import pandas as pd
import os 
from util import *
from instrument import *
import matplotlib.pyplot as plt

In [None]:
target_interval = (10e-9, 5000e-9)
C_log = 0.1
C_linear = 0.0001

# L_max = 20.0 # m
d_1 = 0.3 # m
d_2 = 0.3 # m
epsilon = 0.01 # m
def normalize_solution(instance):    

    # Swap if the precession component lengths are the wrong way around
    if instance[2] < instance[3]:
        t = instance[2]
        instance[2] = instance[3]
        instance[3] = t
    min_dist = (d_1 + d_2) / 2 + epsilon
    prec_dist = instance[2] - instance[3]
    if prec_dist <= min_dist:
        delta = min_dist - prec_dist
        instance[2] += delta/2
        instance[3] -= delta/2
    # Limit the sample position to be straight after the second precession device with some margin given by epsilon
    instance[1] = min(instance[1], instance[3] - d_2 / 2 - epsilon)
    for i in range(len(instance)):
        instance[i] = round(instance[i], 4)

    return instance

def optimize_instrument(type='wsp', PG=True, L_min = 0.37, L_max = 10):
    # If a PG is used, change permitted L0 range
    if PG:
        quality = 0.01
        # Parameters are L0, L_s, L_1, L_2
        param_space = [{'low': 2.0, 'high': 4.4}, {'low': L_min, 'high': L_max}, {'low': L_min + (d_1 + d_2) / 2, 'high': L_max - d_1/2}, {'low': L_min + d_2/2, 'high': L_max - (d_1 + d_2) / 2}]
        monochrom_name = 'PG'
    else:
        quality = 0.1
        param_space = [{'low': 8.0, 'high': 12.0}, {'low': L_min, 'high': L_max}, {'low': L_min + (d_1 + d_2) / 2, 'high': L_max - d_1/2}, {'low': L_min + d_2/2, 'high': L_max - (d_1 + d_2) / 2}]
        monochrom_name = 'VS'

    def instrument_from_solution(solution, type='wsp'):
        L0 = solution[0] * 1e-10
        L_s = solution[1]
        L_1 = solution[2]
        L_2 = solution[3]
        match type:
            case 'wsp':
                theta_0 = np.deg2rad(45)
                By_min = 0.1e-3
                By_max = 63e-3
            case 'iso':
                theta_0 = np.deg2rad(45)
                By_min = 0.1e-3
                By_max = 15e-3
            case 'foil':
                theta_0 = tune_foil(L0)
                By_min = 0.3e-3
                By_max = 30e-3
        instr = Instrument('', '', type, L0, quality * L0 / FWHM_factor, theta_0, By_min, By_max, L_s, L_1, L_2)
        return instr

    def fitness_func(solution):
        instr = instrument_from_solution(solution, type)
        delta_range = instr.delta_range()
        fitness = log_overlap_percentage(delta_range, target_interval) * C_log + overlap_percentage(delta_range, target_interval) * C_linear
        return fitness
    N = 50000
    N_genes = 4
    # Generate the random array
    population = np.zeros((N, N_genes))

    for i, param in enumerate(param_space):
        population[:, i] = np.random.uniform(low=param['low'], high=param['high'], size=N)
    fitnesses = np.zeros(N)
    for j in range(0,N):
        population[j,:] = normalize_solution(population[j,:])
        fitnesses[j] = fitness_func(population[j,:])
    # print(fitnesses)
    best_id = np.argmax(fitnesses)
    best_sol = population[best_id, :]
    best_instr = instrument_from_solution(best_sol, type)

    best_instr.name = f'{type.upper()} {monochrom_name}'
    return best_instr

instrs = []
for type in ['foil', 'wsp', 'iso']:
    print(f"==========Best {type} instrument compatible with pyroletic graphite monochromator==========")
    pg_instr = optimize_instrument(type=type, PG = True, L_max=5.0)
    delta_range = pg_instr.delta_range()
    print(str(pg_instr))
    instrs.append(pg_instr)
for type in ['foil', 'wsp', 'iso']:
    print(f"==========Best {type} instrument compatible with velocity selector monochromator==========")
    vs_instr = optimize_instrument(type=type, PG = False, L_max=5.0)
    delta_range = vs_instr.delta_range()
    print(str(vs_instr))
    instrs.append(vs_instr)

In [None]:
for instr in instrs:
    print(instr)

In [None]:
for i in [0,1]:
    print(f"====TABLE {i}======")
    for instr in instrs:
        # print(instr)
        delta_max_field = instr.delta_max_B_field()
        delta_max_env = instr.delta_max_envelope()
        delta_max_ten_samples = instr.delta_max_sampling()
        # print(F"Max delta ideal sampling (10 samples per period) (f_0 = {round(f_ten_samples*1e-3)}mm^-1: {round(delta_max_ten_samples * 1e9,2)}nm")
        maxes = [(delta_max_ten_samples, 'sampling'), (delta_max_env, 'envelope'), (delta_max_field, 'precession devices') ]
        (delta_max, max_name) = min(maxes)
        delta_min_field = instr.delta_min_B_field()
        delta_min_single_period = instr.delta_min_detector()
        mins = [(delta_min_single_period, 'detector size'), (delta_min_field, 'precession devices')]
        (delta_min, min_name) = max(mins)
        # print(min_max)
        Q_max = instr.Q_max()
        # print(f"{a}")
        r = lambda x: round(x * 1e9,2)
        if i==0:
            # label, L0, DL, L1, L2, LS
            print(f"{instr.name} & {round(instr.L0 * 1e10,2)} & {round(instr.DL * FWHM_factor * 1e10,3)} & {round(instr.L_1, 3)} & {round(instr.L_2, 3)} & {round(instr.L_s, 3)} \\\\")
        else:
            print(f"{instr.name} & {round(Q_max * 1e-10, 5)} & {r(delta_min)} & {r(delta_max)} \\\\")
    