In [1]:
%matplotlib qt

from sklearnex import patch_sklearn
patch_sklearn()
import pandas as pd
import matplotlib.pyplot as plt

df = pd.read_csv("../data/qbo_data.csv",index_col="run_id")
df = df.sort_values("qbo_period_std")


Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


## Constants

In [2]:
OBS_MEAN = 27.475168565819082
OBS_STD = 3.9921132486687365

OBS_AMP_MEAN = 26.68333333333334
OBS_AMP_STD = 2.7605353748060457

CURRENT_CW = 35
CURRENT_BT = 0.0043

In [3]:
import numpy as np
X = df[["Bt","cw"]].to_numpy()
y_point_error = df[["qbo_period_std","qbo_amplitude_std"]].to_numpy()
y = df[["qbo_periods","qbo_amplitude_mean"]].to_numpy()

mean_y = np.mean(y,axis=0)
std_y = np.std(y,axis=0)
y_norm = (y - mean_y)/std_y
y_point_error = y_point_error/std_y

In [4]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)




In [5]:
from sklearn.gaussian_process import GaussianProcessRegressor

gpr_period = GaussianProcessRegressor(alpha=(y_point_error[:,0])**2)
gpr_amplitude = GaussianProcessRegressor(alpha=(y_point_error[:,1])**2)
gpr_period.fit(X_scaled,y_norm[:,0])
gpr_amplitude.fit(X_scaled,y_norm[:,1])


GaussianProcessRegressor(alpha=array([0.        , 0.03547776, 0.03204514, 0.02129178, 0.03661212,
       0.09994381, 0.04752571, 0.02339978, 0.06869182, 0.04361435,
       0.02058369, 0.0330511 , 0.029952  , 0.02990572, 0.03962432,
       0.04181666, 0.02632595, 0.03740381, 0.11036499, 0.05486589,
       0.11988828, 0.01287535, 0.07012233, 0.04252602, 0.01810179,
       0.02959773, 0.03893115, 0.03611464, 0.02273579, 0.05...
       0.03482906, 0.06321747, 0.124354  , 0.07802362, 0.42698006,
       0.05888453, 0.17278043, 0.05400357, 0.50627278, 0.11065289,
       0.19951574, 0.0018386 , 0.24543738, 0.11524   , 0.13536967,
       0.05538973, 0.58399116, 0.46842607, 0.11437035, 0.37871673,
       0.05733359, 0.45538299, 0.07037729, 0.04718549, 0.07573725,
       0.08398727, 0.0340759 , 0.04177656, 0.04106591, 0.60142414,
       0.14445963, 0.09313113, 0.0329861 , 0.06683739, 0.13460253]))

In [6]:
import numpy as np
cw = np.linspace(5,70,1000)
Bt = np.ones(len(cw))*0.0043
x_samples = np.array([Bt,cw]).T
x_samples = scaler.transform(x_samples)



In [7]:
samples,samples_std = gpr_period.predict(x_samples,return_std=True)
samples = samples*std_y[0] + mean_y[0]

In [8]:
# import matplotlib.pyplot as plt
# fig,(ax1,ax2)= plt.subplots(nrows=1,ncols=2,figsize=(10,6))

# ax1.plot(cw,samples)
# ax1.fill_between(
#     cw,
#     samples - 1.96*samples_std*std_y[0],
#     samples + 1.96*samples_std*std_y[0],
#     color="tab:blue",
#     alpha=0.5,
#     label=r"95% confidence interval",
# )
# ax1.axhline(OBS_MEAN,color='black',label="True Observation")
# ax1.fill_between(cw,OBS_MEAN-OBS_STD,OBS_MEAN+OBS_STD,color='black',alpha=0.3)
# ax1.set_xlabel("Cw tropics")
# ax1.set_ylabel("QBO Period (months)")
# ax1.errorbar(X[:,1],y[:,0],yerr=y_point_error[:,0]*std_y[0],fmt='o',label="MiMA runs")
# ax1.set_ylim(10,40)
# ax1.legend()

# ax2.plot(cw,implausability)
# ax2.axhline(np.percentile(implausability,20),c='g',label=f"20th percentile implausability cut off ")
# ax2.legend()
# ax2.set_ylabel("1D Implausability")
# ax2.set_xlabel("cw")

### 3D Plots?

In [11]:
cw = np.linspace(5,70,100)
Bt = np.linspace(0.001,0.007,100)
X_Bt,Y_cw = np.meshgrid(Bt,cw)
x_samples = np.vstack((X_Bt.ravel(),Y_cw.ravel())).T
x_samples = scaler.transform(x_samples)


In [74]:
X_Bt[0]

array([0.001     , 0.00106061, 0.00112121, 0.00118182, 0.00124242,
       0.00130303, 0.00136364, 0.00142424, 0.00148485, 0.00154545,
       0.00160606, 0.00166667, 0.00172727, 0.00178788, 0.00184848,
       0.00190909, 0.0019697 , 0.0020303 , 0.00209091, 0.00215152,
       0.00221212, 0.00227273, 0.00233333, 0.00239394, 0.00245455,
       0.00251515, 0.00257576, 0.00263636, 0.00269697, 0.00275758,
       0.00281818, 0.00287879, 0.00293939, 0.003     , 0.00306061,
       0.00312121, 0.00318182, 0.00324242, 0.00330303, 0.00336364,
       0.00342424, 0.00348485, 0.00354545, 0.00360606, 0.00366667,
       0.00372727, 0.00378788, 0.00384848, 0.00390909, 0.0039697 ,
       0.0040303 , 0.00409091, 0.00415152, 0.00421212, 0.00427273,
       0.00433333, 0.00439394, 0.00445455, 0.00451515, 0.00457576,
       0.00463636, 0.00469697, 0.00475758, 0.00481818, 0.00487879,
       0.00493939, 0.005     , 0.00506061, 0.00512121, 0.00518182,
       0.00524242, 0.00530303, 0.00536364, 0.00542424, 0.00548

array([[-1.73387883, -1.7295298 ],
       [-1.69884069, -1.7295298 ],
       [-1.66380256, -1.7295298 ],
       ...,
       [ 1.66482008,  1.73161175],
       [ 1.69985821,  1.73161175],
       [ 1.73489634,  1.73161175]])

In [13]:
period,period_std = gpr_period.predict(x_samples,return_std=True)
period = period*std_y[0] + mean_y[0]
period = period.reshape((100,100))
period_std = (period_std*std_y[0]).reshape((100,100))


amplitude,amplitude_std = gpr_amplitude.predict(x_samples,return_std=True)
amplitude = amplitude*std_y[1] + mean_y[1]
amplitude = amplitude.reshape((100,100))
amplitude_std = (amplitude_std*std_y[1]).reshape((100,100))

In [14]:
implausability_period = np.abs(period - OBS_MEAN)/np.sqrt(OBS_STD**2 + period_std**2)
implausability_amplitude = np.abs(amplitude - OBS_AMP_MEAN)/np.sqrt(OBS_AMP_STD**2 + amplitude_std**2)
implausability = np.sqrt(implausability_amplitude**2 + implausability_period**2)

In [15]:
imp_cutoff = implausability < np.percentile(implausability,20)

In [44]:
fig,(ax1,ax2,ax3) = plt.subplots(1,3,figsize=(14,6))
CS = ax1.contour(X_Bt,Y_cw,amplitude,levels=25)
ax1.scatter(X[:,0],X[:,1],)
ax1.scatter(CURRENT_BT,CURRENT_CW)
ax1.clabel(CS, CS.levels, inline=True,  fontsize=10)
ax1.set_ylabel("Cw")
ax1.set_xlabel("Bt")

CS2 = ax2.contour(X_Bt,Y_cw,period,levels=25)
ax2.scatter(X[:,0],X[:,1],)
ax2.scatter(CURRENT_BT,CURRENT_CW)
ax2.clabel(CS2, CS2.levels, inline=True,  fontsize=10)
ax2.set_ylabel("Cw")
ax2.set_xlabel("Bt")

ax3.contour(X_Bt,Y_cw,implausability,levels=25)
ax3.contourf(X_Bt,Y_cw,imp_cutoff,alpha=0.3,cmap='RdYlGn')
ax3.scatter(X[:,0],X[:,1])
ax3.scatter(CURRENT_BT,CURRENT_CW)
ax3.set_ylabel("Cw")
ax3.set_xlabel("Bt")

Text(0.5, 0, 'Bt')

In [38]:
from mpl_toolkits.mplot3d import axes3d
ax = plt.figure(figsize=(10,10)).add_subplot(projection='3d')
ax.plot_surface(X_Bt,Y_cw,implausability,cmap='viridis')
ax.set_xlabel("Bt")
ax.set_ylabel("Cw")
ax.set_zlabel("Implausability")

Text(0.5, 0, 'Implausability')

## Calculate Decision Boundary + Sampling

In [49]:
y,x=np.random.uniform(low=(5,0.001),high=(70,0.007))

In [63]:
period

array([[25.59373668, 25.59188864, 25.58845833, ..., 22.09159004,
        22.01381678, 21.94097443],
       [25.55345946, 25.55023023, 25.54542964, ..., 21.99618334,
        21.91717106, 21.84319174],
       [25.51132435, 25.50672095, 25.5005612 , ..., 21.90166711,
        21.8214926 , 21.74644851],
       ...,
       [ 5.28086901,  5.06217567,  4.88221605, ..., 22.73822318,
        22.67466159, 22.62224192],
       [ 5.31481347,  5.09524433,  4.91417321, ..., 22.80277583,
        22.73590373, 22.68015452],
       [ 5.38040687,  5.16047696,  4.9787606 , ..., 22.86723771,
        22.7972376 , 22.7383281 ]])

In [19]:
def get_new_samples(new_sample_space,X,Y,n_samples=100):
    """
    MCMC sampling.
    Probably not very
    Dimensions of new_sample_space,X,Y, must match
    """
    ## Calculating bounding box of new_sample_space.
    y_where,x_where = np.where(new_sample_space)
    x_min,x_max,y_min,y_max = X_Bt[0,np.min(x_where)],X_Bt[0,np.max(x_where)],Y_cw[np.min(y_where),0],Y_cw[np.max(y_where),0]
    new_samples = []
    while len(new_samples) < n_samples:
        # Sample random point in original sample space 
        x,y = np.random.uniform(low=(x_min,y_min),high=(x_max,y_max))
        row = np.argmin(np.abs(X-x),axis=1)[1]
        col = np.argmin(np.abs(Y-y),axis=0)[0]
        if new_sample_space[row,col]:
            new_samples.append((x,y))
        # Check sample point meets plausability criterion
    return np.array(new_samples )


(array([[0.001     , 0.00106061, 0.00112121, ..., 0.00687879, 0.00693939,
         0.007     ],
        [0.001     , 0.00106061, 0.00112121, ..., 0.00687879, 0.00693939,
         0.007     ],
        [0.001     , 0.00106061, 0.00112121, ..., 0.00687879, 0.00693939,
         0.007     ],
        ...,
        [0.001     , 0.00106061, 0.00112121, ..., 0.00687879, 0.00693939,
         0.007     ],
        [0.001     , 0.00106061, 0.00112121, ..., 0.00687879, 0.00693939,
         0.007     ],
        [0.001     , 0.00106061, 0.00112121, ..., 0.00687879, 0.00693939,
         0.007     ]]),
 array([[ 5.        ,  5.        ,  5.        , ...,  5.        ,
          5.        ,  5.        ],
        [ 5.65656566,  5.65656566,  5.65656566, ...,  5.65656566,
          5.65656566,  5.65656566],
        [ 6.31313131,  6.31313131,  6.31313131, ...,  6.31313131,
          6.31313131,  6.31313131],
        ...,
        [68.68686869, 68.68686869, 68.68686869, ..., 68.68686869,
         68.68686869, 68