In [2]:
import torch
import json
import numpy as np
from fcts.train_procedure import train_with_LBFGS
from fcts.lbm_nmar import LBM_NMAR
from fcts.lbfgs import FullBatchLBFGS
from fcts.figures import groupes_politiques, pi_df, text_legend_row
from fcts.utils import reparametrized_expanded_params, init_random_params, save_objects_to_yaml, load_objects_from_yaml
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns 
import yaml

In [None]:
##############" LOADING Arguments" ################
#%env PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0

nq = 3 #enter nb of row classes
nl = 5 #enter nb of col classes
device = 'mps' #put 'cuda' or 'cpu'
device2 = 'mps' #put None or 'cuda'

if not torch.backends.mps.is_available() and device != 'cpu':
    print('Cuda is not available. Algorithm will use cpu')
    device, device2 = torch.device('cpu'), None

In [None]:
##############" LOADING DATASET" ################

votes = np.loadtxt("data_parliament/votes.txt",delimiter=";").astype(int)
deputes = json.load(open('data_parliament/deputes.json', 'r')) 
#deputes: Dataset with: Family name, Name, Political group 
texts = json.load(open('data_parliament/texts.json', 'r'))
#texts: Dataset with: political group demanding, title of demand, date, type (type of vote, type of majority, name of type of vote), 
n1, n2 = votes.shape 
# shape of dataset: 
print("row length (nb of persons): ",n1)
print("col length (nb of laws): ",n2)

In [None]:
##############" Initialization " ################
vector_of_parameters = torch.tensor(init_random_params(n1, n2, nq, nl), requires_grad=True, device=device, dtype=torch.float32)

In [None]:
##############" Model creation "################
model = LBM_NMAR(
    vector_of_parameters,
    votes,
    (n1, n2, nq, nl),
    device=device,
    device2=device2,
)

In [None]:
try:
    success, loglike = train_with_LBFGS(model)
except KeyboardInterrupt:
    print("KeyboardInterrupt detected, stopping training")

In [None]:
# Parameters of the model
(   nu_a,
    rho_a,
    nu_b,
    rho_b,
    nu_p,
    rho_p,
    nu_q,
    rho_q,
    tau_1,
    tau_2,
    mu_un,
    sigma_sq_a,
    sigma_sq_b,
    sigma_sq_p,
    sigma_sq_q,
    alpha_1,
    alpha_2,
    pi,
) = reparametrized_expanded_params(torch.cat((model.variationnal_params, model.model_params)), n1, n2, nq, nl, device)

In [None]:
parameters_dict = {
    'nu_a':nu_a,
    'rho_a':rho_a,
    'nu_b':nu_b,
    'rho_b':rho_b,
    'nu_p': nu_p,
    'rho_p':rho_p,
    'nu_q':nu_q,
    'rho_q':rho_q,
    'tau_1':tau_1,
    'tau_2':tau_2,
    'mu_un':mu_un,
    'sigma_sq_a': sigma_sq_a,
    'sigma_sq_b':sigma_sq_b,
    'sigma_sq_p':sigma_sq_p,
    'sigma_sq_q':sigma_sq_q,
    'alpha_1':alpha_1,
    'alpha_2':alpha_2,
    'pi':pi,
}

save_objects_to_yaml(parameters_dict, 'trained_parameters.yaml')