# Dataloader examples

In [None]:
import pickle

from datasets import load_dataset
from plaid.bridges.huggingface_bridge import (
    huggingface_dataset_to_plaid,
    huggingface_description_to_problem_definition,
)
from plaid.containers.sample import Sample

from plaid_bridges.dataloaders import (
    HeterogeneousCollater,
    HomogeneousCollater,
    PlaidDataLoader,
)

In [None]:
hf_dataset = load_dataset("PLAID-datasets/VKI-LS59", split="all_samples[:10]")
sample = Sample.model_validate(pickle.loads(hf_dataset[0]["sample"]))

pb_def = huggingface_description_to_problem_definition(hf_dataset.info.description)
ids = pb_def.get_split("train")[:10]

dataset, _ = huggingface_dataset_to_plaid(hf_dataset, ids=ids, processes_number=5)

In [None]:
print(dataset)
all_feat_ids = dataset[0].get_all_features_identifiers()

scalar_features = [f for f in all_feat_ids if "scalar" in f.values()]
field_features = [f for f in all_feat_ids if "field" in f.values()]

in_feature_identifiers = [scalar_features[0], field_features[0]]
out_feature_identifiers = [field_features[1], scalar_features[1]]

print(in_feature_identifiers)
print(out_feature_identifiers)

In [None]:
loader = PlaidDataLoader(
    dataset,
    collate_fn=HomogeneousCollater(
        in_feature_identifiers=in_feature_identifiers,
        out_feature_identifiers=out_feature_identifiers,
    ),
    batch_size=2,
    shuffle=True,
)


batch = next(iter(loader))
print(
    "fields =",
    batch[1][0],
    " | >>>> tensor:",
    batch[1][0].shape,
)

## Case with heterogeneous samples

In [None]:
hf_dataset = load_dataset("PLAID-datasets/tensile2d", split="all_samples[:10]")
sample = Sample.model_validate(pickle.loads(hf_dataset[0]["sample"]))

pb_def = huggingface_description_to_problem_definition(hf_dataset.info.description)
ids = pb_def.get_split("train_500")[:10]


dataset, _ = huggingface_dataset_to_plaid(hf_dataset, ids=ids, processes_number=5)

In [None]:
print(dataset)
all_feat_ids = dataset[0].get_all_features_identifiers()

scalar_features = [f for f in all_feat_ids if "scalar" in f.values()]
field_features = [f for f in all_feat_ids if "field" in f.values()]

in_feature_identifiers = [scalar_features[0], field_features[0]]
out_feature_identifiers = [field_features[1], scalar_features[1]]

print(in_feature_identifiers)
print(out_feature_identifiers)

In [None]:
loader = PlaidDataLoader(
    dataset,
    collate_fn=HeterogeneousCollater(
        in_feature_identifiers=in_feature_identifiers,
        out_feature_identifiers=out_feature_identifiers,
    ),
    batch_size=2,
    shuffle=True,
)

batch = next(iter(loader))
print("scalars =", batch[0][0])
print(
    "fields =",
    batch[1][0],
    " | >>>> list:",
    type(batch[1][0]),
)