In [162]:
stan_simple = """
data {
    int n;
    vector[n] y;
    vector<lower=0> sigma;
}
parameters {
    vector[n] theta;
    real mu;
    real<lower=0> tau;
}
model {
    y ~ normal(theta, sigma);
    theta ~ normal(mu, tau);
}
"""

In [247]:
def get_parameter_names_and_types(stan_code):
    # find the names of parameters in a stan program
    param_block = re.search(r"parameters\s+{([\S\s]*?)\}", stan_code).groups()[0]
    return {
        line.split()[-1].strip(";"): line.split()[0].strip()
        for line in param_block.splitlines()[1:]
    }

def get_parameter_dists(parameter_names, stan_code):
    # find the distribution of parameters in a stan program
    model_block = re.search(r"model\s+{([\S\s]*?)\}", stan_code).groups()[0]
    param_to_dist = {}
    for line in model_block.splitlines()[1:]:
        param, dist = line.strip().split("~")
        param = param.strip()
        dist = dist.strip(";").strip()
        if param in parameter_names:
            param_to_dist[param] = dist
    return param_to_dist

def is_normal(dist):
    return "normal" in dist

def get_mu_sigma(dist):
    mu, sigma = re.search(r"normal\((.+)\,\s+(.+)\)", dist).groups()
    return mu, sigma

def create_transformed_parameters_block(name_to_type, param_to_dist):
    # create a new transformed parameters block
    transformed_params_lines = []
    for param_name, dist in param_to_dist.items():
        mu, sigma = get_mu_sigma(dist)
        transformed_param_line = ("  " 
                                  + name_to_type[param_name] 
                                  + " " + param_name 
                                  + " = " 
                                  + mu 
                                  + " + " 
                                  + param_name 
                                  + "_std * " 
                                  + sigma 
                                  + ";\n")
        transformed_params_lines.append(transformed_param_line)
    return "transformed parameters {\n" + "\n".join(transformed_params_lines) + "}"
    
    
def modify_parameters_block(param_names, stan_code):
    param_block = re.search(r"(parameters\s+{[\S\s]*?\})", stan_simple).groups()[0]
    for param in param_names:
        param_block = param_block.replace(param, param + "_std")
    return param_block


def modify_model_block(param_names, stan_code):
    new_lines = []
    for line in param_block.splitlines():
        for param in param_names:
            if param in line.split("~")[0]:
                non_cent_line = "  " + param + "_std ~ normal(0, 1);"
                new_lines.append(non_cent_line)
            else:
                new_lines.append(line)
    return "model {" + "\n".join(new_lines) + "\n}"
    
    
def make_non_centered(stan_code):
    name_to_type = get_parameter_names_and_types(stan_code)
    param_to_dist = get_parameter_dists(name_to_type.keys(), stan_code)
    norm_params = {param_name: dist for param_name, dist in param_to_dist.items() if is_normal(dist)}
    std_normal_params = [param_name + "_std" for param_name in norm_params.keys()]
    
    transformed_params_block = create_transformed_parameters_block(name_to_type, norm_params)
    param_block = modify_parameters_block(norm_params.keys(), stan_code)
    model_block = modify_model_block(norm_params.keys(), stan_code)
    
    # todo: compose the blocks back together
    
    return transformed_params_block, param_block, model_block

In [248]:
t, p, m = make_non_centered(stan_simple)