In [1]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional
from math import ceil

from sqlitedict import SqliteDict

from mpgm.mpgm.sample_generation.gibbs_samplers import *
from mpgm.mpgm.sample_generation.graph_generators import *
from mpgm.mpgm.sample_generation.weight_assigners import *
from mpgm.mpgm.generating_samples import SampleParamsSave, SampleParamsWrapper

from mpgm.mpgm.models.TPGM import TPGM
from mpgm.mpgm.models.Model import Model

from mpgm.mpgm.model_fitters.prox_grad_fitters import Prox_Grad_Fitter

from mpgm.mpgm.fitting_models import FitParamsSave, FitParamsWrapper

from mpgm.mpgm.evaluation_functions import *



In [2]:
# fit_file_name = "fit_models.sqlite"
# samples_file_name = "samples.sqlite"
# SPS = SampleParamsWrapper.load_samples("PleaseWork", samples_file_name)
# FPS = FitParamsWrapper.load_fit("", fit_file_name)

In [3]:
def get_sample_name(experiment_name:str, batch_nr:int, sample_nr:int) -> str:
    sample_name = experiment_name + "_batch_" + str(batch_nr) + "_sample_" + str(sample_nr)
    return sample_name

def get_fit_sample_name(experiment_name:str, batch_nr:int, sample_nr:int) -> str:
    return "fit_" + get_sample_name(experiment_name, batch_nr, sample_nr)

def generate_batch_samples_vary_seed(SPW:SampleParamsWrapper, experiment_name:str, batch_nr:int, samples_per_batch:int,
                                     random_seeds:List[int], samples_file_name:str):
    assert samples_per_batch == len(random_seeds), "Must have the same number of random seeds as the number of samples you " \
                                            "want to generate"
    for sample_nr in range(samples_per_batch):
        SPW.random_seed = random_seeds[sample_nr]
        sample_name = get_sample_name(experiment_name, batch_nr, sample_nr)
        SPW.generate_samples_and_save(sample_name, samples_file_name)
        print("Sampling " + sample_name + " finished.")

def vary_nr_samples_and_generate_samples(SPW:SampleParamsWrapper, experiment_name:str, samples_per_batch:int,
                                         samples_file_name:str, samples_numbers:List[int]):
    nr_batches = len(samples_numbers)
    random_seeds = range(samples_per_batch)
    for batch_nr, nr_samples in enumerate(samples_numbers):
        SPW.nr_samples = nr_samples
        generate_batch_samples_vary_seed(SPW, experiment_name, batch_nr, samples_per_batch, random_seeds,
                                         samples_file_name)

In [None]:
def fit_all_batches_all_samples(FPW:FitParamsWrapper, fit_file_name:str, experiment_name:str, samples_per_batch:int,
                                nr_batches:int, theta_init:np.ndarray):
    for batch_nr in range(nr_batches):
        for sample_nr in range(samples_per_batch):
            sample_name = get_sample_name(experiment_name, batch_nr, sample_nr)
            fit_id = get_fit_sample_name(experiment_name, batch_nr, sample_nr)
            FPW.fit_model_and_save(fit_id, fit_file_name, samples_id=sample_name, theta_init=theta_init)

In [4]:
SGW = SampleParamsWrapper(nr_variables=20, nr_samples=10, random_seed=0, sample_init=np.zeros((20, )))

SGW.graph_generator = LatticeGraphGenerator(sparsity_level=0)
SGW.weight_assigner = Bimodal_Distr_Weight_Assigner(neg_mean=-0.1, threshold=1, std=0)
SGW.model = TPGM(R=10)
SGW.sampler = TPGMGibbsSampler(burn_in = 200,
                               thinning_nr = 50)

# n ** 2 log(p) is roughly 27.
experiment_name = "lattice_same_neg_weight_vary_nr_samples"
samples_per_batch = 5
samples_file_name = "samples.sqlite"
samples_numbers = [10, 30, 100, 300, 600]
vary_nr_samples_and_generate_samples(SGW, experiment_name, samples_per_batch, samples_file_name, samples_numbers)


100%|██████████| 651/651 [00:01<00:00, 627.12it/s]
100%|██████████| 651/651 [00:00<00:00, 663.12it/s]
100%|██████████| 651/651 [00:00<00:00, 664.36it/s]
100%|██████████| 651/651 [00:00<00:00, 659.46it/s]
100%|██████████| 651/651 [00:01<00:00, 645.73it/s]
100%|██████████| 1651/1651 [00:02<00:00, 641.78it/s]
100%|██████████| 1651/1651 [00:02<00:00, 649.76it/s]
100%|██████████| 1651/1651 [00:02<00:00, 654.46it/s]
100%|██████████| 1651/1651 [00:02<00:00, 677.04it/s]
100%|██████████| 1651/1651 [00:02<00:00, 679.53it/s]
100%|██████████| 5151/5151 [00:07<00:00, 669.44it/s]
100%|██████████| 5151/5151 [00:07<00:00, 672.74it/s]
100%|██████████| 5151/5151 [00:07<00:00, 670.17it/s]
100%|██████████| 5151/5151 [00:07<00:00, 665.86it/s]
100%|██████████| 5151/5151 [00:08<00:00, 642.37it/s]
100%|██████████| 15151/15151 [00:22<00:00, 671.61it/s]
100%|██████████| 15151/15151 [00:22<00:00, 667.33it/s]
100%|██████████| 15151/15151 [00:22<00:00, 668.41it/s]
100%|██████████| 15151/15151 [00:22<00:00, 662.55i

Sampling lattice_same_neg_weight_vary_nr_samples_batch_0_sample_0 finished.
Sampling lattice_same_neg_weight_vary_nr_samples_batch_0_sample_1 finished.
Sampling lattice_same_neg_weight_vary_nr_samples_batch_0_sample_2 finished.
Sampling lattice_same_neg_weight_vary_nr_samples_batch_0_sample_3 finished.
Sampling lattice_same_neg_weight_vary_nr_samples_batch_0_sample_4 finished.
Sampling lattice_same_neg_weight_vary_nr_samples_batch_1_sample_0 finished.
Sampling lattice_same_neg_weight_vary_nr_samples_batch_1_sample_1 finished.
Sampling lattice_same_neg_weight_vary_nr_samples_batch_1_sample_2 finished.
Sampling lattice_same_neg_weight_vary_nr_samples_batch_1_sample_3 finished.
Sampling lattice_same_neg_weight_vary_nr_samples_batch_1_sample_4 finished.
Sampling lattice_same_neg_weight_vary_nr_samples_batch_2_sample_0 finished.
Sampling lattice_same_neg_weight_vary_nr_samples_batch_2_sample_1 finished.
Sampling lattice_same_neg_weight_vary_nr_samples_batch_2_sample_2 finished.
Sampling lat

In [None]:
FPW = FitParamsWrapper(random_seed=2,
                       samples_file_name=samples_file_name)

FPW.model = TPGM(R=10)
FPW.fitter = Prox_Grad_Fitter(alpha=0.3, early_stop_criterion='likelihood')
FPW.fit_model_and_save(fit_id=fit_id, fit_file_name=fit_file_name, parallelize=False)

FPS = FitParamsWrapper.load_fit(fit_id, fit_file_name)