In [4]:
import pandas as pd
import numpy as np
import pymc3 as pm

# Leer el archivo CSV con pandas y cargar datos

data = pd.read_csv('Datos_Bart_C.csv', delimiter=';')

print(data)

# Definición de variables predictoras y respuesta

x = data[['total_credits_1', 'total_credits_2', 'total_courses_1', 'course_approved_1', 't_gpa_1']].values
y = data['Permanencia_term_3'].values

def bart_model(x,y, num_trees=50, tree_depth=3, alpha=0.95):

    with pm.Model() as bart_model:

    # Priors para los hiperparámetros
        sigma = pm.HalfCauchy('sigma', beta=10, testval=1.0)
        mu = pm.Normal('mu', mu=0, sigma=10, shape=num_trees)
        tau = pm.Gamma('tau', alpha=alpha, beta=alpha, shape=num_trees)
        p = pm.Beta('p', alpha=alpha, beta=alpha, shape=num_trees)

    # Árboles de regresión
    trees = []
    for i in range(num_trees):
        tree = pm.glm.forest.BARTRegressionTree('tree_{}'.format(i), x, y, mu[i], tau[i], p[i], alpha=tree_depth)
        trees.append(tree)

    # Modelo BART
    bart = pm.Deterministic('bart', sum(trees))
    
    # Verosimilitud
    y_obs = pm.Normal('y_obs', mu=bart, sigma=sigma, observed=y)

    return bart_model
    

      id                      student_id  year  total_credits_1  \
0      1  05cff88fd871c9e19bce4d8df60231     0               29   
1      2  080b659e7f6e033017ad4bda83b361     0               29   
2      3  0dee01f418108ad66efc4a0e0ec69d     0               29   
3      4  0fe5ab4f3df38493e8cc71afe37e45     0               29   
4      5  105b0c6e99b923e322c5f2a8b8cb72     0               29   
..   ...                             ...   ...              ...   
281  282  2f4c1ec05fd0ee144b9dc736aad229     1               28   
282  283  61b2c03d79b0f48730bd4beddedcd3     1               26   
283  284  a00331cf6a64c2d04fd928bf0628bc     1               28   
284  285  a97dc96828578f10ff08084d7eb19c     1               28   
285  286  fdac9c59f0f6b4e24f5414644745be     1               28   

     total_credits_2  total_courses_1  total_courses_2  course_approved_1  \
0                 35                5                7                  3   
1                 31                5    

In [5]:
print(bart_model)

<function bart_model at 0x7fc670313eb0>
