In [1]:
%config Completer.use_jedi = False

import sys
from datetime import datetime, timedelta
from typing import Generator, Iterator, Optional, Tuple

import pandas as pd
import pyarrow as pa

### TorchData

In [2]:
from torchdata.datapipes.iter import FileLister

filelist = list(FileLister("./data", recursive=True))
filelist[:5]

['./data/dt=2023-01-01/a300c22cb3554cec95c68957f6ac326f-0.parquet',
 './data/dt=2023-01-02/a300c22cb3554cec95c68957f6ac326f-0.parquet',
 './data/dt=2023-01-03/a300c22cb3554cec95c68957f6ac326f-0.parquet',
 './data/dt=2023-01-04/a300c22cb3554cec95c68957f6ac326f-0.parquet',
 './data/dt=2023-01-05/a300c22cb3554cec95c68957f6ac326f-0.parquet']

# ParquetDataset

## Basic Usage

In [3]:
from pyarrow.parquet import ParquetDataset

dataset = ParquetDataset("./data", memory_map=True, use_legacy_dataset=False)
df = pd.read_parquet("./data")

file_rows = [frag.count_rows() for frag in dataset.fragments]

print("Pandas shape :", df.shape)
print("Pandas  size :", sys.getsizeof(df))
print("Pyarrow size :", sys.getsizeof(dataset))
print("files        :", dataset.files[:3])
print("fragments    :", dataset.fragments[:3])
print("files rows   :", file_rows)
print("column size  :", len(dataset.schema))

Pandas shape : (50000000, 2)
Pandas  size : 450000734
Pyarrow size : 64
files        : ['./data/dt=2023-01-01/a300c22cb3554cec95c68957f6ac326f-0.parquet', './data/dt=2023-01-02/a300c22cb3554cec95c68957f6ac326f-0.parquet', './data/dt=2023-01-03/a300c22cb3554cec95c68957f6ac326f-0.parquet']
fragments    : [<pyarrow.dataset.ParquetFileFragment path=./data/dt=2023-01-01/a300c22cb3554cec95c68957f6ac326f-0.parquet partition=[dt=2023-01-01]>, <pyarrow.dataset.ParquetFileFragment path=./data/dt=2023-01-02/a300c22cb3554cec95c68957f6ac326f-0.parquet partition=[dt=2023-01-02]>, <pyarrow.dataset.ParquetFileFragment path=./data/dt=2023-01-03/a300c22cb3554cec95c68957f6ac326f-0.parquet partition=[dt=2023-01-03]>]
files rows   : [8640000, 8640000, 8640000, 8640000, 8640000, 6800000]
column size  : 2


## Iteration

In [4]:
for frag in dataset.fragments:
    for batch in frag.to_batches():
        df = batch.to_pandas()
        row = batch.take(pa.array([0]))

        print("frag size   :", sys.getsizeof(frag))
        print("num rows    :", batch.num_rows)
        print("Pandas shape:", df.shape)
        display(row.to_pandas())
        break
    break

frag size   : 72
num rows    : 32768
Pandas shape: (32768, 1)


Unnamed: 0,idx
0,0


## Create Parquet Files

In [5]:
def create_data():
    df = pd.DataFrame({"idx": range(50000000)})
    dt = datetime(2023, 1, 1)
    df["dt"] = df["idx"].apply(
        lambda x: (dt + timedelta(milliseconds=x * 10)).date()
    )
    pa.parquet.write_to_dataset(
        pa.Table.from_pandas(df),
        root_path="data",
        partition_cols=["dt"],
        use_legacy_dataset=False,
    )


# create_data()

## Pytorch Dataset

In [None]:
import random
import tracemalloc
from bisect import bisect_right

from pyarrow.dataset import ParquetFileFragment
from pyarrow.lib import RecordBatch
from torch.utils.data import Dataset
import gc

class PyArrowDataset(Dataset):
    """
    Restriction
     - Don't shuffle in Dataloader. this is for efficiency to precess large dataset.
       If you need to shuffle, do it before this custom dataset. (like in SparkSQL)
       But the algorithm supports random access.
    """

    def __init__(self, source: str, seed: int = 123):
        self.source = source
        self.seed = seed

        # Pyarrow
        self.dataset = ParquetDataset(source, use_legacy_dataset=False)
        self.fragments: List[ParquetFileFragment] = self.dataset.fragments
        self._batches: Iterator[RecordBatch] = None
        self._batch: Optional[RecordBatch] = None
        self._df: pd.DataFrame = None

        # Indexing meta information to make search faster
        self._cumulative_n_rows: List[int] = []
        self._batch_idx: int = 0

        # Index
        self._fragment_idx = 0

        # Initialization
        self._init()

    def _init(self):
        random.seed(self.seed)
        # random.shuffle(self.fragments)

        self._cumulative_n_rows = [frag.count_rows() for frag in self.fragments]
        for i in range(1, len(self._cumulative_n_rows)):
            self._cumulative_n_rows[i] += self._cumulative_n_rows[i - 1]

    def _get_next(self, idx: int) -> Tuple[int, int]:
        print('_get_next 01', idx)
        def get_prev_cum_frag_size(_fragment_idx):
            if _fragment_idx >= 1:
                return self._cumulative_n_rows[_fragment_idx - 1]
            return 0

        # Calculate fragment idx
        fragment_idx = self._fragment_idx
        fragment_changed = False
        _prev_size = get_prev_cum_frag_size(fragment_idx)
        _cur_size = self._cumulative_n_rows[self._fragment_idx]
        if (idx < _prev_size) or (idx >= _cur_size):
            fragment_idx = bisect_right(self._cumulative_n_rows, idx)
            assert fragment_idx < len(self.fragments)
            # fragment_idx %= len(self.fragments)
            fragment_changed = self._fragment_idx != fragment_idx
            self._fragment_idx = fragment_idx
            self._batch_idx = 0
            
            if self._batches:
                self._batches.clear()
                
            del self._batches
            del self._batch
            del self._df
            self._batches = None
            self._batch = None
            self._df = None
        
        print('_get_next 02', idx)
        # Calculate batch idx
        _prev_size = get_prev_cum_frag_size(fragment_idx)
        batch_idx = idx - _prev_size
        batch_changed = batch_idx < self._batch_idx

        # Calculate batches of the fragment
        if self._batches is None or fragment_changed or batch_changed:
            if self._batches:
                self._batches.clear()
            
            self.batches = self.fragments[fragment_idx].to_batches()
            self._batch = None

        if self._batch is None:
            self._batch = next(self.batches)
            del self._df
            self._df = self._batch.to_pandas()
            self._batch_idx = 0
        
        print('_get_next 03', idx)
        need_to_load_data = False
        while True:
            print(
                "ITER:",
                f"{self._batch_idx} <= {batch_idx} < {self._batch_idx + self._batch.num_rows} | {sys.getsizeof(self._batch)}",
            )
            if (
                self._batch_idx
                <= batch_idx
                < self._batch_idx + self._batch.num_rows
            ):
                if need_to_load_data:
                    self._df = self._batch.to_pandas()
                break

            need_to_load_data = True
            self._batch_idx += self._batch.num_rows
            self._batch = next(self.batches)
        
        print('_get_next 04', idx)
        return fragment_idx, batch_idx - self._batch_idx
    
    def __del__(self):
        print('Deleted')
        if self.dataset:
            self.dataset.clear()
        
        if self.fragments:
            self.fragments.clearn
        del self.dataset
        del self.fragments
        del self._batches
        del self._batch
        del self._df


    def __len__(self):
        return self._cumulative_n_rows[-1]

    def __getitem__(self, idx):
        print('__getitem__', idx)
        fragment_idx, batch_idx = self._get_next(idx)

        row = self._df.iloc[batch_idx][["idx"]]
        row = row.fillna(0)
        row["fragment_idx"] = fragment_idx
        row["batch_idx"] = batch_idx
        return row, idx
    
    


tracemalloc.start()
dataset = PyArrowDataset("./data")
print(dataset[50000][0].idx)
print(dataset[0][0].idx)
print(dataset[500000][0].idx)

print('여기까지')
del dataset
print(tracemalloc.get_traced_memory())
print(gc.get_count())

__getitem__ 50000
_get_next 01 50000
_get_next 02 50000
_get_next 03 50000
ITER: 0 <= 50000 < 32768 | 266336
ITER: 32768 <= 50000 < 65536 | 266336
_get_next 04 50000
50000
__getitem__ 0
_get_next 01 0
_get_next 02 0
_get_next 03 0
ITER: 0 <= 0 < 32768 | 266336
_get_next 04 0
0
__getitem__ 500000
_get_next 01 500000
_get_next 02 500000
_get_next 03 500000
ITER: 0 <= 500000 < 32768 | 266336
ITER: 32768 <= 500000 < 65536 | 266336
ITER: 65536 <= 500000 < 98304 | 266336
ITER: 98304 <= 500000 < 131072 | 266336
ITER: 131072 <= 500000 < 163840 | 266336
ITER: 163840 <= 500000 < 196608 | 266336
ITER: 196608 <= 500000 < 229376 | 266336
ITER: 229376 <= 500000 < 262144 | 266336
ITER: 262144 <= 500000 < 294912 | 266336
ITER: 294912 <= 500000 < 327680 | 266336
ITER: 327680 <= 500000 < 360448 | 266336
ITER: 360448 <= 500000 < 393216 | 266336
ITER: 393216 <= 500000 < 425984 | 266336
ITER: 425984 <= 500000 < 458752 | 266336
ITER: 458752 <= 500000 < 491520 | 266336
ITER: 491520 <= 500000 < 524288 | 26633

Exception ignored in: <function PyArrowDataset.__del__ at 0x7f57c9cf9560>
Traceback (most recent call last):
  File "/tmp/ipykernel_9179/2190492011.py", line 123, in __del__
AttributeError: '_ParquetDatasetV2' object has no attribute 'clear'


In [None]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=64, shuffle=True)
data, labels = next(iter(loader))
a = data[:, 0] - 1
b = labels % 1000

a == b






# ParquetFile

## Row 갯수 

In [None]:
from pyarrow.parquet import ParquetFile

parquet_file = ParquetFile("./data/dt=20230101/userdata.parquet")

print("parquet_file size: ", sys.getsizeof(parquet_file))
parquet_file.metadata

In [25]:
from pyarrow.parquet import ParquetDataset

dataset = ParquetDataset("./data")

print("dataset size :", sys.getsizeof(dataset))

dataset.files

dataset size : 64


['./data/dt=20230101/userdata.parquet']