In [None]:
import tensorflow.keras as k
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
import neural_net_project_auxiliary as nn
DM = nn.DataManipulator
import plotly.graph_objs as go
from ipywidgets import interact, interactive, FloatSlider, VBox
import ipywidgets as widgets
from IPython.display import display, clear_output
%matplotlib widget

In [None]:
os.chdir("/")
for root, dirs, files in os.walk('.'):
    if 'models' in dirs:
        found_dir_path = os.path.join(root, 'models')
os.chdir(found_dir_path)

model_abs_77K = tf.keras.models.load_model('final_abs_spectra_from_params_7_pigs_dis_120_77K', compile=False)
model_abs_290K = tf.keras.models.load_model('final_abs_spectra_from_params_7_pigs_dis_120_290K', compile=False)
model_array = np.array([[model_abs_77K, model_abs_290K]])

In [None]:
os.chdir("/")
for root, dirs, files in os.walk('.'):
    if 'statistical_props' in dirs:
        found_dir_path = os.path.join(root, 'statistical_props')
os.chdir(found_dir_path)

par_means, par_stds, spec_means, spec_stds,  = np.empty((2,2), dtype=object), np.empty((2,2), dtype=object),\
                                               np.empty((2,2), dtype=object), np.empty((2,2), dtype=object)
for ss, spec in enumerate(['abs', 'fl']):
    for tt, temp in enumerate(['77K', '290K']): 
        par_means[ss, tt] = np.load(f'params_mean_7pigs_dis120_{spec}_{temp}.npy')
        spec_means[ss, tt] = np.load(f'spectra_mean_7pigs_dis120_{spec}_{temp}.npy')
        par_stds[ss, tt] = np.load(f'params_std_7pigs_dis120_{spec}_{temp}.npy')
        spec_stds[ss, tt] = np.load(f'spectra_std_7pigs_dis120_{spec}_{temp}.npy')

In [None]:
os.chdir("/")
for root, dirs, files in os.walk('.'):
    if 'statistical_props' in dirs:
        found_dir_path = os.path.join(root, 'statistical_props')
os.chdir(found_dir_path)
os.chdir("../")
freq_axis_abs = np.load('abs_frequency_axis.npy')
freq_axis_fl = np.load('abs_frequency_axis.npy')
Ham_abs = np.load('FMO_Ham.npy')
Ham_fl = np.load('FMO_Ham.npy')
dipole_moments = np.load('FMO_3eni_dmoms.npy')
site_dip_factor = dipole_moments@dipole_moments.T
site_dip_factor_max = np.amax(np.abs(site_dip_factor))
site_dip_factor /= site_dip_factor_max
number_of_pigments  = 7


def make_sliders():
    return [
        widgets.FloatSlider(
            value=1.0,
            min=0.0,
            max=2.0,
            step=0.1,
            description=f"S{i+1}",
            continuous_update=True
        )
        for i in range(7)
    ]

fig1 = go.FigureWidget(
    data=[go.Scatter(x=freq_axis_abs, y=np.zeros_like(freq_axis_abs), mode="lines")],
    layout=go.Layout(title="Absorption spectra")
)

sliders1 = make_sliders()

def update_fig1(change=None):
    energies_abs = [s.value for s in sliders1]
    

    Ham_abs[np.diag_indices_from(Ham)] = energies_abs
    exciton_ens, SS = np.linalg.eigh(Ham_abs)     
    gamma_aamm = (SS**2).T@(SS**2)
    C_coeffs = SS.flatten()                        
    ex_dip_factor = np.diag((SS.T)@site_dip_factor@SS)
    params_abs = np.concatenate((exciton_ens,\
                         (gamma_aamm[np.triu_indices_from(gamma_aamm)]).flatten(),\
                           ex_dip_factor, C_coeffs))
    
    params_290_abs = DM.stat_normalise(params_abs,  mean=par_means[0,1], std=par_stds[0,1])
    predicted_spec_290_abs = DM.stat_denormalise(model_array[0, 1].predict(np.array([params_290_abs]), verbose=0)[0], mean=spec_means[0,1], std=spec_stds[0,1])   

    params_77_abs = DM.stat_normalise(params_abs,  mean=par_means[0,0], std=par_stds[0,0])
    predicted_spec_77_abs = DM.stat_denormalise(model_array[0, 0].predict(np.array([params_77_abs]), verbose=0)[0], mean=spec_means[0,0], std=spec_stds[0,0])    

    
    with fig1.batch_update():
        fig1.data[0].y = predicted_spec_290_abs

for s in sliders1:
    s.observe(update_fig1, names="value")

update_fig1()


fig2 = go.FigureWidget(
    data=[go.Scatter(x=freq_axis_fl, y=np.zeros_like(freq_axis_fl), mode="lines")],
    layout=go.Layout(title="Fluorescence spectra")
)

sliders2 = make_sliders()

def update_fig2(change=None):
    energies_fl = [s.value for s in sliders2]
    Ham_fl[np.diag_indices_from(Ham_fl)] = energies_fl
    exciton_ens, SS = np.linalg.eigh(Ham_abs)     
    gamma_aamm = (SS**2).T@(SS**2)
    C_coeffs = SS.flatten()                        
    ex_dip_factor = np.diag((SS.T)@site_dip_factor@SS)
    params_fl = np.concatenate((exciton_ens,\
                         (gamma_aamm[np.triu_indices_from(gamma_aamm)]).flatten(),\
                           ex_dip_factor, C_coeffs))
    params_290_fl = DM.stat_normalise(params_fl,  mean=par_means[1,1], std=par_stds[1,1])
    predicted_spec_290_fl = DM.stat_denormalise(model_array[1, 1].predict(np.array([params_290_fl]), verbose=0)[0], mean=spec_means[1,1], std=spec_stds[1,1])   

    params_77_fl = DM.stat_normalise(params_fl,  mean=par_means[1,0], std=par_stds[1,0])
    predicted_spec_77_fl = DM.stat_denormalise(model_array[1, 0].predict(np.array([params_77_fl]), verbose=0)[0], mean=spec_means[1,0], std=spec_stds[1,0])    
   
    with fig2.batch_update():
        fig2.data[0].y = predicted_spec_290_fl

for s in sliders2:
    s.observe(update_fig2, names="value")

update_fig2()

display(
    widgets.HBox([
        widgets.VBox(sliders1 + [fig1]),
        widgets.VBox(sliders2 + [fig2])
    ])
)