In [6]:
import sbibm
import jax 


from datasets import Dataset


from jax import numpy as jnp

import json
from huggingface_hub import upload_file

import numpy as np

import tqdm

In [11]:
repo_name = "aurelio-amerio/SBI-benchmarks"

In [3]:
from scoresbibm.tasks.sbibm_tasks import (
    LinearGaussian,
    BernoulliGLM,
    BernoulliGLMRaw,
    MixtureGaussian,
    TwoMoons,
    SLCP)
    
from scoresbibm.tasks.unstructured_tasks import (
    LotkaVolterraTask,
    SIRTask)


In [4]:
base_task_cls = [LinearGaussian,
              BernoulliGLM,
              BernoulliGLMRaw,
              MixtureGaussian,
              TwoMoons,
              SLCP]

advanced_task_cls = [LotkaVolterraTask,
                  SIRTask]

In [12]:
metadata = {}

pbar = tqdm.tqdm(total=len(base_task_cls) + len(advanced_task_cls))
for task_cls in base_task_cls:
    task = task_cls()
    task_name = task.name
    dim_data = task.get_x_dim()
    dim_theta = task.get_theta_dim()

    metadata[task_name] = {"dim_data": dim_data, "dim_theta": dim_theta, "metadata": None}
    pbar.update(1)

# for task_cls in advanced_task_cls:
#     task = task_cls()
#     task_name = task.name
#     dim_data = task.get_x_dim()
#     dim_theta = task.get_theta_dim()

#     samples = task.get_data(10, jax.random.PRNGKey(0))
#     metadata_ = np.array(samples["metadata"]).tolist()

#     metadata[task_name] = {"dim_data": dim_data, "dim_theta": dim_theta, "metadata": metadata_}
#     pbar.update(1)

file_path = "metadata.json"
with open(file_path, 'w') as f:
    json.dump(metadata, f, indent=4)

  0%|          | 0/8 [00:00<?, ?it/s]

In [1]:
def get_task_data_base(task_cls, num_samples):
    task = task_cls()
    data = task.get_data(num_samples)
    reference_posteriors = []
    true_parameters = []
    observations = []
    for i in range(1,11):
        reference_posteriors.append(task.get_reference_posterior_samples(i))
        true_parameters.append(task.get_true_parameters(i))
        observations.append(task.get_observation(i))

    return data, reference_posteriors, true_parameters, observations

# def get_task_data_advanced(task_cls, num_samples):
#     task = task_cls()

#     data = task.get_data(num_samples,jax.random.PRNGKey(42))
    
#     # posterior_sampler = task.get_reference_sampler()
#     observation_generator = task.get_observation_generator("posterior") 

#     itr = observation_generator(jax.random.PRNGKey(42))

#     for i in range(1,11):
#         condition_mask, x_o, theta_o, meta_data, node_ids = next(itr)
#         observations.append(task.get_observation(i))

#     reference_posteriors = []
#     true_parameters = []
#     observations = []
#     for i in range(1,11):
#         # reference_posteriors.append(task.get_reference_posterior_samples(i))
#         true_parameters.append(task.get_true_parameters(i))
#         observations.append(task.get_observation(i))

#     return data, reference_posteriors, true_parameters, observations

In [13]:
upload_file(
    path_or_fileobj=file_path,
    path_in_repo="metadata.json",  # The name of the file in the repo
    repo_id=repo_name,
    repo_type="dataset",
)

CommitInfo(commit_url='https://huggingface.co/datasets/aurelio-amerio/SBI-benchmarks/commit/19cbce6ebefb0ab96d743011465196869722c5d1', commit_message='Upload metadata.json with huggingface_hub', commit_description='', oid='19cbce6ebefb0ab96d743011465196869722c5d1', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/aurelio-amerio/SBI-benchmarks', endpoint='https://huggingface.co', repo_type='dataset', repo_id='aurelio-amerio/SBI-benchmarks'), pr_revision=None, pr_num=None)

In [23]:
def upload_dataset(task_cls, repo_name: str):

    task = task_cls()
    task_name = task.name
    max_samples = int(1e6)
    num_samples = max_samples + 1000

    data_dict, reference_posteriors, true_parameters, observations = get_task_data_base(task_cls, num_samples)
    
    dtype = np.float32

    xs = data_dict["x"][: max_samples]
    xs = np.array(xs).astype(dtype)
    thetas = data_dict["theta"][: max_samples]
    thetas = np.array(thetas).astype(dtype)

    xs_val = data_dict["x"][max_samples :]
    xs_val = np.array(xs_val).astype(dtype)
    thetas_val = data_dict["theta"][max_samples :]
    thetas_val = np.array(thetas_val).astype(dtype)

    observations = np.array(observations).astype(dtype)

    reference_samples = np.array(reference_posteriors)
    reference_samples = reference_samples.astype(dtype)

    true_parameters = np.array(true_parameters).astype(dtype)

    # dim_data = data_dict["dim_data"]
    # dim_theta = data_dict["dim_theta"]
    # dim_joint = dim_data + dim_theta
    # num_observations = data_dict["num_observations"]

    dataset_train = Dataset.from_dict({"xs": xs, "thetas": thetas})
    dataset_val = Dataset.from_dict({"xs": xs_val, "thetas": thetas_val})
    dataset_reference_posterior = Dataset.from_dict(
        {"reference_samples": reference_samples, "observations": observations, "true_parameters": true_parameters}
    )

    dataset_train.push_to_hub(repo_name, config_name=task_name, split="train", private=False)
    dataset_val.push_to_hub(repo_name, config_name=task_name, split="validation", private=False)
    dataset_reference_posterior.push_to_hub(repo_name, config_name=f"{task_name}_posterior", split="reference_posterior", private=False)

    return #dataset_train, dataset_val, dataset_reference_posterior

# upload data

In [24]:
for task_cls in base_task_cls:
    upload_dataset(task_cls, repo_name)   


[A
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00,  3.13ba/s]
Processing Files (1 / 1): 100%|██████████| 81.3MB / 81.3MB, 6.43MB/s  
New Data Upload: 100%|██████████| 81.3MB / 81.3MB, 6.43MB/s  
Uploading the dataset shards: 100%|██████████| 1/1 [00:17<00:00, 17.74s/ shards]

Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 459.20ba/s]
Processing Files (1 / 1): 100%|██████████|  117kB /  117kB,  195kB/s  
New Data Upload: 100%|██████████|  117kB /  117kB,  195kB/s  
Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.31s/ shards]

Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 30.47ba/s]
Processing Files (1 / 1): 100%|██████████| 4.70MB / 4.70MB,  0.00B/s  
New Data Upload: |          |  0.00B /  0.00B,  0.00B/s  
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  1.03 shards/s]

[A
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00,  2.88ba/s]
Processing Files (1 /

# gw dataset

In [None]:
import torch

In [3]:
gw_dir = "/home/aure/Documents/dataset(1)/dataset"

In [4]:
thetas = torch.load(f"{gw_dir}/thetas_0.pt")
theta1 = torch.load(f"{gw_dir}/thetas_1.pt")
theta2 = torch.load(f"{gw_dir}/thetas_2.pt")
theta3 = torch.load(f"{gw_dir}/thetas_3.pt")
theta4 = torch.load(f"{gw_dir}/thetas_4.pt")
theta5 = torch.load(f"{gw_dir}/thetas_5.pt")
theta6 = torch.load(f"{gw_dir}/thetas_6.pt")
theta7 = torch.load(f"{gw_dir}/thetas_7.pt")
theta8 = torch.load(f"{gw_dir}/thetas_8.pt")
theta9 = torch.load(f"{gw_dir}/thetas_9.pt")



xs_raw = torch.load(f"{gw_dir}/xs_0.pt")
xs_raw1 = torch.load(f"{gw_dir}/xs_1.pt")
xs_raw2 = torch.load(f"{gw_dir}/xs_2.pt")
xs_raw3 = torch.load(f"{gw_dir}/xs_3.pt")
xs_raw4 = torch.load(f"{gw_dir}/xs_4.pt")
xs_raw5 = torch.load(f"{gw_dir}/xs_5.pt")
xs_raw6 = torch.load(f"{gw_dir}/xs_6.pt")
xs_raw7 = torch.load(f"{gw_dir}/xs_7.pt")
xs_raw8 = torch.load(f"{gw_dir}/xs_8.pt")
xs_raw9 = torch.load(f"{gw_dir}/xs_9.pt")

In [None]:
thetas = torch.cat([thetas, theta1, theta2, theta3, theta4, theta5, theta6, theta7, theta8, theta9], dim=0)
xs_raw = torch.cat([xs_raw, xs_raw1, xs_raw2, xs_raw3, xs_raw4, xs_raw5, xs_raw6, xs_raw7, xs_raw8, xs_raw9], dim=0)


thetas = jnp.array(thetas.numpy())
xs_raw = jnp.array(xs_raw.numpy())

thetas = jax.device_put(thetas, jax.devices("cpu")[0])
xs_raw = jax.device_put(xs_raw, jax.devices("cpu")[0])



In [None]:
dataset_train = Dataset.from_dict({"xs": xs_raw, "thetas": thetas})