<div style="text-align: center; padding-top: 30px; padding-bottom: 10px;">

<h1 style="font-size: 2.8em; font-weight: 600; margin-bottom: 0.2em;">
Global Neural Network Model
</h1>

<p style="font-size: 1.2em; color: gray; font-style: italic; margin-top: 0;">
This notebook visualises the main results from the paper.
</p>

</div>

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import os
from models.global_model.model_functions.helper_functions.prepare_data import Prepare
from utils import create_pred_input
from scipy.interpolate import griddata
from models import MultivariateModelGlobal as Model       


################################################## User Input ##################################################
n_models = 5 
lr = 0.001                      # Learning rate
min_delta = 1e-6               # Tolerance for optimization
patience = 50                   # Patience for early stopping
verbose = 2                     # Verbosity mode for optimization
###############################################################################################################


2026-02-10 09:57:25.068227: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-10 09:57:25.301806: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-10 09:57:26.005250: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-10 09:57:26.690910: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1770717447.349447   15849 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770717447.53

## 1. Data preperation

In [2]:

#load data and create prediction grid
data=pd.read_excel('../data/MainData.xlsx')
growth, precip, temp = Prepare(data, n_countries=196, time_periods=63)
x_train = {0:temp, 1:precip}

#summary statics for standardisation
mean_temp=np.nanmean(data["TempPopWeight"])
std_temp=np.nanstd(data["TempPopWeight"])
mean_precip=np.nanmean(data["PrecipPopWeight"])
std_precip=np.nanstd(data["PrecipPopWeight"])

pred_input, T, P= create_pred_input(mc=False, mean_T=mean_temp, std_T=std_temp, mean_P=mean_precip, std_P=std_precip)


## 2. Model overview

In [4]:

results = dict(np.load(f'../results/paper/weights/global_model/results.npy', 
                       allow_pickle=True).item())

results={k: v for k,v in results.items() if v is not None}
top_models = sorted(results, key=lambda node: results[node][2])[0:n_models]

#print the corresponding holdout loss for all  the top models
print(f'Top models {top_models}, Holdout: {[results[model][2] for model in top_models]}')

Top models [(32, 2), (16, 2), (8, 2), (2, 2), (4, 4)], Holdout: [np.float64(0.004287374671548605), np.float64(0.0042893365025520325), np.float64(0.004292081110179424), np.float64(0.00429602712392807), np.float64(0.004297249484807253)]


## 3. Prediction on Temperature-Precipitation grid

In [6]:

# --- Build surfaces for each of the top-n models ---------------------
model_surfaces = []
for idx, node in enumerate(top_models[0:n_models], 1):
    
    # instantiate and load your model
    factory = Model(node, x_train, growth, dropout=0, country_trends=False, dynamic_model=False, within_transform=True, add_fe=False)
    factory.Depth=len(node)
    model=factory.get_model()
    
    #load parameters
    weight_file = f'../results/paper/weights/global_model/parameters/{node}.weights.h5'
    model.load_params(weight_file)

    #make prediciions on grid
    pred_flat = model.model_visual.predict([pred_input]).reshape(-1,)
    Growth = pred_flat.reshape(T.shape)

    opacity = 0.3
    surf = go.Surface(
        x=T, y=P/1000, z=Growth, #ensure that the surfaces are meassured in meters instead of milimeters
        colorscale='Cividis',
        opacity=0.85,
        showscale=False,
        name=f'Model {node}'
    )
    model_surfaces.append(surf)
 

#calculate the average surface
z=np.mean([surf.z for surf in model_surfaces], axis=0).reshape(T.shape)

mean_surface = go.Surface(
        x=T, y=P/1000, z=z,
        colorscale='Cividis',
        opacity=0.85,
        showscale=False,
        name='mean_surface'
    )




[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 960ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 93ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 154ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 104ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 90ms/step


## 4. Data visualization

In [7]:

plot_data=[mean_surface]
fig= go.Figure(data=plot_data)

## 3-d surface plot
fig.update_layout(
        autosize=True,
       
        margin=dict(
            l=0,
            r=0,
            b=0,
            t=0,
        ),

            scene=dict(
                xaxis_title='Temperature (°C)',
                yaxis_title='Precipitation (m)',
                zaxis=dict(title=dict(text="Δ ln(GDP)"),range=[-0.3, 0.3]),
                camera=dict(eye=dict(x=1.738, y=-1.780, z=0.589))
                
            ),
            
            legend=dict(
                bgcolor='rgba(255,255,255,0.7)',
                bordercolor='black',
                borderwidth=1
            ),
            font=dict(
            size=10
        )
        )
fig.show()
