In [89]:
import subprocess
from torch.utils.data import Dataset
import json
from typing import Optional

class LeanDataset(Dataset):
    def __init__(self, cwd = "."):
        self.cwd = cwd
        self.process = subprocess.Popen(
            ["lake", "exe", "lean_dataset"],
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,  # Use text mode for input and output.
            cwd=self.cwd
        )
        self.initialMetadata = json.loads(self.read_output())

    def send_input(self, input_data: str):
        """
        Sends input to the child process's stdin.

        :param input_data: The input data to send to the process.
        """
        if self.process.stdin:
            self.process.stdin.write(input_data + "\n")
            self.process.stdin.flush()

    def read_output(self) -> str:
        """
        Reads the next line of output from the child process's stdout.

        :return: A line of output from stdout.
        """
        if self.process.stdout:
            return self.process.stdout.readline()
        return ""

    def read_error(self) -> str:
        """
        Reads the next line of output from the child process's stderr.

        :return: A line of output from stderr.
        """
        if self.process.stderr:
            return self.process.stderr.readline()
        return ""

    def wait_for_completion(self) -> int:
        """
        Waits for the process to terminate and returns the exit code.

        :return: The exit code of the process.
        """
        return self.process.wait()
        
    def __len__(self):
        return self.initialMetadata['len']

    def __getitem__(self, index):
        self.send_input(json.dumps({ "index" : index }))
        return json.loads(self.read_output())

In [94]:
data = LeanDataset("/home/adam/Projects/lean_dataset")

In [95]:
len(data)

10

In [96]:
data[123]

{'res': {'index': 123}}