In [None]:
from mssm.models import *
from mssmViz.sim import *
from mssmViz.plot import *
import pickle
import copy
import os
import time
from src.utils import GAMLSSGENSMOOTHFamily,llk_gamm_fun,init_lambda

size_conv = 2.54
single_width = 6/size_conv
double_width = 12/size_conv
full_width = 19/size_conv

should_plot = True # Whether plots should be generated to visually inspect simulation results

try:
    os.makedirs("./results")
except:
    print("Result directories already exist.")

try:
    os.makedirs("./results/sim")
    os.makedirs("./results/data")
    os.makedirs("./results/plots")
    os.makedirs("./results/data/sim1")
    os.makedirs("./results/data/sim2")
    os.makedirs("./results/data/sim3")
    os.makedirs("./results/data/sim1/plots")
    os.makedirs("./results/data/sim3/plots")
    os.makedirs("./results/data/sim4")
    os.makedirs("./results/data/sim5")
except:
    print("Result sub-directory already exist.")

In [None]:
############################# Simulation 3 #############################
n_sim = 100

sim_fams = [GAUMLSS([Identity(),LOG()]),GAMMALS([LOG(),LOG()])]
mod_fams = [GAUMLSS([Identity(),LOGb(-0.01)]),GAMMALS([LOG(),LOGb(-0.01)])]
init_fams = [Gaussian(),Gamma()]
fam_names = ["GAULS", "GAMMALS"]

for fam_name, mod_fam, init_fam, sim_fam in zip(fam_names,mod_fams,init_fams,sim_fams):
     
     # Set up storage for current sim
     eta_mses = np.zeros((n_sim,2,2))
     n_lam_updt = np.zeros((n_sim,2))
     Failures = np.zeros((n_sim,2))
     timings = np.zeros((n_sim,2))

     mod_fam.init_lambda = init_lambda
     
     iterator = tqdm(range(n_sim),desc="Simulating",leave=True)
     for sim_i in iterator:

        sim_dat = sim12(5000,c=0,seed=sim_i,family=sim_fam,n_ranef=20)
        sim_dat.to_csv(f'./results/data/sim3/sim_size:{n_sim}_fam:{fam_name}_set:{sim_i}.csv',index=False)

        # We need to model the mean: \mu_i = \alpha + f(x0)
        sim_formula_m = Formula(lhs("y"),
                            [i(),f(["x0"]),f(["x1"]),fs(["x0"],rf="x4")],
                            data=sim_dat)

        # and the standard deviation as well: log(\sigma_i) = \alpha + f(x0)
        sim_formula_sd = Formula(lhs("y"),
                            [i(),f(["x2"]),f(["x3"])],
                            data=sim_dat)

        # Collect both formulas
        sim_formulas = [sim_formula_m,sim_formula_sd]

        sim_i_failed = [False, False]

        ############################# Fit model with efs #############################
        try:
            efs_model = GAMMLSS(copy.deepcopy(sim_formulas),copy.deepcopy(mod_fam))
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                start_time1 = time.process_time()
                efs_model.fit(seed=sim_i,max_outer=200,max_inner=500,min_inner=500,method="LU/Chol",progress_bar=False,extend_lambda=False,should_keep_drop=False,repara=True,n_cores=1,prefit_grad=False)
                end_time1 = time.process_time()
        except:
            sim_i_failed[0] = True
            efs_model = None

        ############################# Fit models via likelihood/gradient only #############################
        try:
            gsmm_fam2 = GAMLSSGENSMOOTHFamily(2,copy.deepcopy(mod_fam.links),llk_gamm_fun,copy.deepcopy(mod_fam))
            gsmm_fam2.init_lambda = init_lambda
            gsmm_fam2.init_coef = lambda models: np.array([1e-4 for _ in range(models[0].formula.n_coef + models[1].formula.n_coef)]).reshape(-1,1)     

            gsmm_model2 = GSMM(formulas=copy.deepcopy(sim_formulas),family=gsmm_fam2)

            # Fit with qEFS update without initialization
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                bfgs_opt={"gtol":1e-9,
                              "ftol":1e-9,
                              "maxcor":30,
                              "maxls":200,
                              "maxfun":1e7}
                
                start_time_gsmm2 = time.process_time()
                gsmm_model2.fit(init_coef=None,method='qEFS',extend_lambda=False,
                                control_lambda=False,max_outer=200,max_inner=500,min_inner=500,
                                seed=sim_i,qEFSH='SR1',overwrite_coef=False,qEFS_init_converge=False,prefit_grad=True,
                                progress_bar=False,repara=True,n_cores=1,**bfgs_opt)
                end_time_gsmm2 = time.process_time()

            # Get linear predictors for gsmm_model
            split_coef = np.split(gsmm_model2.overall_coef,gsmm_model2.coef_split_idx)
            Xs = gsmm_model2.get_mmat()
            gsmm_model2.overall_preds = [Xs[xi]@split_coef[xi].reshape(-1,1) for xi in range(len(Xs))]

            gsmm_model2.info.eps = 0 # Not used but set to None, which messes up loop below
        except:
            sim_i_failed[1] = True
            gsmm_model2 = None

        ######################################## Collect MSEs ####################################
        models = [efs_model,gsmm_model2]
        start_times = [start_time1,start_time_gsmm2]
        end_times = [end_time1,end_time_gsmm2]

        if should_plot:
            fig = plt.figure(figsize=(full_width,1.25*single_width),layout='constrained')
            axs = fig.subplots(2,5,gridspec_kw=dict(wspace=0.05,hspace=0.01))

        for mi, model in enumerate(models):

            if sim_i_failed[mi]:
                print(f"Model {mi+1} failed at {sim_i}")
                Failures[sim_i,mi] = 1
                n_lam_updt[sim_i,mi] = np.nan
                eta_mses[sim_i,mi] = np.nan
                continue

            if should_plot:
                plot(model,dist_par=0,axs=axs[mi][0:3])
                plot(model,dist_par=1,axs=axs[mi][3:])

            # Not converged but not failed outright
            if model.info.code > 0:
                Failures[sim_i,mi] = 1
            
            n_lam_updt[sim_i,mi] = model.info.iter
            timings[sim_i,mi] = end_times[mi] - start_times[mi]

            pred_diff_mean = model.overall_preds[0].flatten() - sim_dat["eta_mean"].values
            pred_diff_scale = np.log(mod_fam.links[1].fi(model.overall_preds[1])).flatten() - sim_dat["eta_scale"].values

            eta_mses[sim_i,mi,0] = np.dot(pred_diff_mean,pred_diff_mean)/len(pred_diff_mean)
            eta_mses[sim_i,mi,1] = np.dot(pred_diff_scale,pred_diff_scale)/len(pred_diff_scale)
        
        if should_plot:
            plt.savefig(f"./results/data/sim3/plots/sim_size:{n_sim}_fam:{fam_name}_set:{sim_i}.pdf", format="pdf", bbox_inches='tight')
            plt.close(fig)

        iterator.set_description_str(desc=f"MSE. (mean): {[(float(np.round(m,decimals=4)),float(np.round(sd,decimals=2))) for m,sd in zip(np.mean(eta_mses[:(sim_i+1),:,0],axis=0),np.std(eta_mses[:(sim_i+1),:,0],axis=0))]}", refresh=True)
        ###################################### Save in progress results ######################################
        res = {"eta_mses":eta_mses,
               "n_lam_updt":n_lam_updt,
               "timings":timings,
               "Failures":Failures
               }
        
        with open(f'./results/sim/sim_3_size:{n_sim}_fam:{fam_name}.pickle', 'wb') as file:
            pickle.dump(res,file, protocol=pickle.HIGHEST_PROTOCOL)
    
     iterator.close()