In [None]:



import os
import sys
import time
import numpy as np
import pandas as pd
import seaborn as sns
from pathlib import Path
import matplotlib.pyplot as plt 
from collections import defaultdict 
from types import SimpleNamespace




In [None]:

sys.path.append("../")

from dd_package.data.dyslexia_data import DyslexiaData
from dd_package.data.preprocess import preprocess_data

from dd_package.models.regression_estimators import RegressionEstimators
from dd_package.models.classification_estimators import ClassificationEstimators

from dd_package.common.utils import save_a_dict, load_a_dict, print_the_evaluated_results




In [None]:


configs = {
    "models_path": Path("/home/soroosh/Programmes/DD/Models"),
    "results_path": Path("/home/soroosh/Programmes/DD/Results"),
    "figures_path": Path("/home/soroosh/Programmes/DD/Figures"),
    "params_path": Path("/home/soroosh/Programmes/DD//Params"),
    "n_repeats": 10,
    "n_splits": 5,
}

configs = SimpleNamespace(**configs)

estimator_name = "SV_cls"
data_name = "DD_demo"
to_shuffle = True
learning_method = "classification"

specifier = data_name+"-"+estimator_name+"-"+str(to_shuffle)
configs.specifier = specifier
configs.data_name = data_name
configs.name_wb = data_name+": "+specifier
configs.learning_method = learning_method
# configs.project = "DD_test"
# configs.group = "debug"




In [None]:

dd = DyslexiaData(path="../../datasets/", n_repeats=5)



demos = dd.get_demo_datasets()




In [None]:

demo = dd.concat_classes_demo()  # .reset_index(drop=True)

demo



In [None]:


x_org, y_org = dd.get_onehot_features_targets(
        data=demo,
        c_features= ["Sex", "Grade", ],  # 
        indicators=["SubjectID"],
    )


In [None]:


x , x_df = dd.get_


In [None]:
import jax
import jax.numpy as jnp
from copy import deepcopy

In [None]:

c1 = x[1, :].reshape(1, -1)
c2 = x[52, :].reshape(1, -1)
c3 = x[111, :].reshape(1, -1)
centroids = jnp.concatenate((c1, c2, c3), axis=0)



In [None]:
def compute_euclidean(datapoint, centroid,):
    return jax.tree_leaves(jnp.sum(jnp.power(datapoint - centroid, 2), axis=1))[0]

def is_pos_def(x):
    return jax.tree_leaves(jnp.all(jnp.linalg.eigvals(x) > 0))




In [None]:
data = x

In [None]:
n_nodes = data.shape[0]
n_clusters = centroids.shape[0]
distance_fn = compute_euclidean 

In [None]:

cv = dd.get_stratified_kfold_cv(
    to_shuffle=to_shuffle,
    n_splits=configs.n_splits,
)



In [None]:


y = y_org.Group.values



# y = preprocess_data(x=y, pp='mm')  # only x is standardized



In [None]:

f_iter = True 
n_iter = 0
n_iters = 400
tol_m = 1e-2
tol_g = 1e-3
clusters = jnp.zeros([n_nodes]) + jnp.inf

grads_sums = []
aris_history = []
grads_history = []
hessians_history = []


from sklearn import metrics


step_size = 2e-1
c_iter = 0


while f_iter: 
    
    # cluster assingment 
    for i in range(n_nodes):   
        distances = distance_fn(datapoint=x[i, :], centroid=centroids)
        clusters = clusters.at[i].set(jnp.argmin(distances, axis=0))
        
    previous_clusters = deepcopy(clusters)
    
    # cluster update
    tmp_grads, tmp_hess = [], []
    for k in range(n_clusters):
        
        cluster_data = jnp.mean(x[jnp.where(clusters==k)[0], :], axis=0).reshape(1, -1)
        
        # Gradient of the distance function w.r.t the closest center 
        grads = jax.jacfwd(distance_fn, argnums=(1,))(cluster_data, centroids[k, :])
        grads = jax.tree_leaves(grads)[0]
        
        # Upgate centrods
        updated_centeroid = centroids[k, :] - step_size * grads
        centroids = centroids.at[k].set(updated_centeroid[0])
        
        # Convergence check
        # FONC: First Order Necessary Condition
        l2_norm_grads= jnp.sqrt(jnp.sum(jnp.power(grads, 2)))
        tmp_grads.append(l2_norm_grads)

        # SONC: Second Order Necessary Condition
        # Jaccobian of gradients (hessian matrix) of the distance function w.r.t the closest center 
        hessian_mat = jax.hessian(distance_fn, argnums=(1,))(cluster_data, centroids[k, :])
        hessian_mat = jax.tree_leaves(hessian_mat)[0]
        if jnp.all(is_pos_def(hessian_mat[0])[0]):
            tmp_hess.append(1)
        else:
            tmp_hess.append(0)

    
    ave_l2_norms_grads = jnp.asarray(tmp_grads).mean()
    grads_history.append(ave_l2_norms_grads)
    
    ave_semi_pos_def_check = jnp.asarray(tmp_hess).mean()
    hessians_history.append(ave_semi_pos_def_check)
    
    ari = metrics.adjusted_rand_score(y, clusters)

    print(
        f"n_iter = {n_iter} ari={ari:.3f} f_iter={f_iter} ave_l2_norms_grads = {ave_l2_norms_grads:.3f}" 
    )
    
    
    # FOCN and SOCN    
    if ave_semi_pos_def_check ==1. and ave_l2_norms_grads <= tol_g:
        print("ave_l2_norms_grads:", ave_l2_norms_grads)
        print("An optimum has found! stoped by FONC and SONC")
        print("ARI:",ari)

        if jnp.all(previous_clusters == clusters):
            f_iter = False
            print("Converge by two consequitive cluster recovery results conincidence")            
            print(f"node {i} ARI {ari}")
            break
    
    n_iter += 1
    
    if n_iter >= n_iters:
        f_iter = False
       






In [None]:


data = dd.get_stratified_train_test_splits(
    x=x, y=y,
    labels=y_org.Group.values,
    to_shuffle=to_shuffle,
    n_splits=configs.n_repeats
)



In [None]:


reg_est = ClassificationEstimators(
    x=x, y=y, cv=cv, data=data,
    estimator_name=estimator_name,
    configs=configs,
)




In [None]:


reg_est.instantiate_tuning_estimator_and_parameters()



In [None]:


reg_est.tune_hyper_parameters()





In [None]:


reg_est.instantiate_train_test_estimator()




In [None]:

reg_est.train_test_tuned_estimator()





In [None]:

# reg_est.save_params_results()




In [None]:

reg_est.print_results()




In [None]:



res = load_a_dict(name="DD_demo-L_cls-True_TEST",
                  save_path="/home/soroosh/Programmes/DD/Results/")


print_the_evaluated_results(results=res, learning_method="classification")


In [None]:


for k, v in res.items():
    print("probs:", v["y_pred_prob"])
    

In [None]:
to_exclude_at_risk = False

# dict of dicts, s.t each dict contains pd.df of a class, e.g normal
_ = dd.get_demo_datasets()  # demos and phonological (which is initially part of demo)
demo_phono = dd.concat_classes_demo()

# The optimize way to exclude at-risk class
if to_exclude_at_risk == 1:
    to_exclude_at_risk = True
    demo_phono = demo_phono.loc[demo_phono.Group != 2]

df_data_to_use = demo_phono.loc[:, [
                                       'Group', 'SubjectID', 'Sound_detection', 'Sound_change', 'Reading_speed'
                                   ]]
c_features = None
indicators = ['SubjectID', ]
targets = ["Group", "Reading_speed", ]

In [None]:
df_data_to_use

In [None]:
to_exclude_at_risk = True

# dict of dicts, s.t each dict contains pd.df of a class, e.g normal
_ = dd.get_demo_datasets()  # demos and phonological (which is initially part of demo)
demo_phono = dd.concat_classes_demo()

# The optimize way to exclude at-risk class
if to_exclude_at_risk == 1:
    to_exclude_at_risk = True
    demo_phono = demo_phono.loc[demo_phono.Group != 2]

df_data_to_use = demo_phono.loc[:, [
                                       'Group', 'SubjectID', 'Sound_detection', 'Sound_change', 'Reading_speed'
                                   ]]
c_features = None
indicators = ['SubjectID', ]
targets = ["Group", "Reading_speed", ]

In [None]:
df_data_to_use

In [None]:
set(df_data_to_use.Group)