<div style="text-align: center; padding-top: 30px; padding-bottom: 10px;">

<h1 style="font-size: 2.8em; font-weight: 600; margin-bottom: 0.2em;">
Global Neural Network Model
</h1>

<p style="font-size: 1.2em; color: gray; font-style: italic; margin-top: 0;">
This notebook visualises the main results from the paper.
</p>

</div>

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import os
from models.global_model.model_functions.helper_functions.prepare_data import Prepare
from utils import create_pred_input
from scipy.interpolate import griddata
from models import MultivariateModelGlobal as Model       
import yaml


################################################## User Input ##################################################
n_models = 5 
###############################################################################################################

#load the config file
with open(f'../results/paper/weights/global_model//config.yaml', "r") as file:
    cfg = yaml.safe_load(file)
print(cfg)



## 1. Data preperation

In [None]:

#load data and create prediction grid
data=pd.read_excel('../data/MainData.xlsx')
growth, precip, temp, stats = Prepare(data, data_source=cfg["data_source"])
x_train = {0:temp, 1:precip}

#prediction grid for plotting 
pred_input, T, P= create_pred_input(mc=False, mean_T=stats["mean_temp"], std_T=stats["std_temp"], mean_P=stats["mean_precip"], std_P=stats["std_precip"])


## 2. Model overview

In [None]:

results = dict(np.load(f'../results/paper/weights/global_model/results.npy', 
                       allow_pickle=True).item())

results={k: v for k,v in results.items() if v is not None}
top_models = sorted(results, key=lambda node: results[node][2])[0:n_models]

#print the corresponding holdout loss for all  the top models
print(f'Top models {top_models}, Holdout: {[results[model][2] for model in top_models]}')

## 3. Prediction on Temperature-Precipitation grid

In [None]:

# --- Build surfaces for each of the top-n models ---------------------
model_surfaces = []
for idx, node in enumerate(top_models[0:n_models], 1):
    
    # instantiate and load your model
    factory = Model(node, cfg, x_train, growth)
    factory.Depth=len(node)
    model=factory.get_model()
    
    #load parameters
    weight_file = f'../results/paper/weights/global_model/parameters/{node}.weights.h5'
    model.load_params(weight_file)

    #make prediciions on grid
    pred_flat = model.model_visual.predict([pred_input]).reshape(-1,)
    Growth = pred_flat.reshape(T.shape)

    opacity = 0.3
    surf = go.Surface(
        x=T, y=P/1000, z=Growth, #ensure that the surfaces are meassured in meters instead of milimeters
        colorscale='Cividis',
        opacity=0.85,
        showscale=False,
        name=f'Model {node}'
    )
    model_surfaces.append(surf)
 

#calculate the average surface
z=np.mean([surf.z for surf in model_surfaces], axis=0).reshape(T.shape)

mean_surface = go.Surface(
        x=T, y=P/1000, z=z,
        colorscale='Cividis',
        opacity=0.85,
        showscale=False,
        name='mean_surface'
    )




## 4. Data visualization

In [None]:

plot_data=[mean_surface]
fig= go.Figure(data=plot_data)

## 3-d surface plot
fig.update_layout(
        autosize=True,
       
        margin=dict(
            l=0,
            r=0,
            b=0,
            t=0,
        ),

            scene=dict(
                xaxis_title='Temperature (°C)',
                yaxis_title='Precipitation (m)',
                zaxis=dict(title=dict(text="Δ ln(GDP)"),range=[-0.3, 0.3]),
                camera=dict(eye=dict(x=1.738, y=-1.780, z=0.589))
                
            ),
            
            legend=dict(
                bgcolor='rgba(255,255,255,0.7)',
                bordercolor='black',
                borderwidth=1
            ),
            font=dict(
            size=10
        )
        )
fig.show()
