# Torch common examples

In [1]:
import pickle

from datasets import load_dataset
from plaid.bridges.huggingface_bridge import (
    huggingface_dataset_to_plaid,
    huggingface_description_to_problem_definition,
)
from plaid.containers.sample import Sample

from plaid_bridges.common.base import BaseRegressionDataset
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
hf_dataset = load_dataset("PLAID-datasets/VKI-LS59", split="all_samples[:10]")
sample = Sample.model_validate(pickle.loads(hf_dataset[0]["sample"]))

pb_def = huggingface_description_to_problem_definition(hf_dataset.info.description)
ids = pb_def.get_split("train")[:10]

dataset, _ = huggingface_dataset_to_plaid(hf_dataset, ids=ids, processes_number=5)

Converting huggingface dataset to plaid dataset...


100%|██████████| 10/10 [00:00<00:00, 73.15it/s]


In [3]:
print(dataset)
all_feat_ids = dataset[0].get_all_features_identifiers()

scalar_features = [f for f in all_feat_ids if "scalar" in f.values()]
field_features = [f for f in all_feat_ids if "field" in f.values()]

in_feature_identifiers = [scalar_features[0], field_features[0]]
out_feature_identifiers = [field_features[1], scalar_features[1]]

print(in_feature_identifiers)
print(out_feature_identifiers)

Dataset(10 samples, 8 scalars, 0 time_series, 8 fields)
[{'type': 'scalar', 'name': np.str_('Pr')}, {'type': 'field', 'name': 'sdf', 'base_name': 'Base_2_2', 'zone_name': 'Zone', 'location': 'Vertex', 'time': np.float64(0.0)}]
[{'type': 'field', 'name': 'rou', 'base_name': 'Base_2_2', 'zone_name': 'Zone', 'location': 'Vertex', 'time': np.float64(0.0)}, {'type': 'scalar', 'name': np.str_('Q')}]


In [4]:

reg_dataset = BaseRegressionDataset(
    dataset = dataset,
    in_feature_identifiers = in_feature_identifiers,
    out_feature_identifiers = out_feature_identifiers,
)

reg_dataset.show_details()

RegressionDataset (10 sample, 2 input features, 2) output features)
Input features : ['Pr (scalar)', 'sdf (field)']
Output features: ['rou (field)', 'Q (scalar)']


In [5]:
print(reg_dataset[1][0][1])
print(reg_dataset.in_features[1][1])

[0.50928488 0.50928525 0.50928687 ... 1.1507115  1.15067122 1.15063128]
[0.50928488 0.50928525 0.50928687 ... 1.1507115  1.15067122 1.15063128]
