In [30]:
import json
from typing import Iterable
from copy import deepcopy

import numpy as np

# from pprint import pprint


def load_dataset(path) -> list:
    """Load a dataset.

    Args:
        path (str | Path): the path for the dataset1

    Returns:
        list: the list of the dataset elements
    """
    with open(path, mode="rb") as json_file:
        return json.load(json_file)


def split_dataset(
    dataset: Iterable,
    split: Iterable = (0.33, 0.33, 0.33),
    random_state: int = 53,
):
    """Split a dataset into three parts.

    Args:
        dataset (Iterable): the dataset to split
        split (Iterable): the split ratios
        random_state (int): the random state

    Returns:
        tuple: the three splits
    """
    np.random.seed(random_state)
    np.random.shuffle(dataset)  # what about equipotent classes?

    # normalize the split
    normalized_split = np.array(split) / np.sum(split)

    splitted_dataset = deepcopy(dataset)
    for index, tool in enumerate(dataset):
        # Compute the split
        subdataset_length = len(tool["dataset"])
        adapted_split = np.array(
            normalized_split * subdataset_length, dtype=int
        ).cumsum()

        # Shuffle the dataset
        subdataset = deepcopy(tool["dataset"])
        np.random.shuffle(subdataset)

        # Split the dataset
        train, test, validation = np.split(subdataset, adapted_split)[:3]
        splitted_dataset[index]["dataset"] = {
            "train": train,
            "test": test,
            "validation": validation,
        }

    return splitted_dataset


dataset = load_dataset("dataset_a.json")
split = split_dataset(dataset)
# return np.split(dataset, [int(split[0] * len(dataset)), int((split[0] + split[1]) * len(dataset))])

In [37]:
split[0]["dataset"]["test"][:5]

array([{'user_request': 'Is the toilet seat down?', 'command': "detect_object('toilet seat')", 'rouge_score': 0.33333332847222225},
       {'user_request': 'Is the dog in the yard?', 'command': "detect_object('dog')", 'rouge_score': 0.499999995138889},
       {'user_request': 'Is the mailbox empty?', 'command': "detect_object('mailbox')", 'rouge_score': 0.36363635900826446},
       {'user_request': 'Is the trash can full?', 'command': "detect_object('trash can')", 'rouge_score': 0.33333332847222225},
       {'user_request': 'Is the garage door closed?', 'command': "detect_object('garage door')", 'rouge_score': 0.33333332847222225}],
      dtype=object)

In [22]:
x = np.arange(8.0)
print(x)
np.split(x, [3, 5, 6, 10])

[0. 1. 2. 3. 4. 5. 6. 7.]


[array([0., 1., 2.]),
 array([3., 4.]),
 array([5.]),
 array([6., 7.]),
 array([], dtype=float64)]