In [307]:
%config Completer.use_jedi = False

import sys
from typing import Generator, Iterator, Optional, Tuple

import pandas as pd
import pyarrow as pa

### TorchData

In [226]:
from torchdata.datapipes.iter import FileLister

filelist = list(FileLister("./data", recursive=True))
filelist

['./data/dt=20230101/userdata.parquet']

# ParquetDataset

## Basic Usage

In [229]:
from pyarrow.parquet import ParquetDataset

dataset = ParquetDataset("./data", memory_map=True, use_legacy_dataset=False)
df = pd.read_parquet("./data")

file_rows = [frag.count_rows() for frag in dataset.fragments]

print("Pandas shape :", df.shape)
print("Pandas  size :", sys.getsizeof(df))
print("Pyarrow size :", sys.getsizeof(dataset))
print("files        :", dataset.files)
print("fragments    :", dataset.fragments)
print("files rows   :", file_rows)
print("column size  :", len(dataset.schema))

Pandas shape : (1000, 14)
Pandas  size : 698858
Pyarrow size : 64
files        : ['./data/dt=20230101/userdata.parquet']
fragments    : [<pyarrow.dataset.ParquetFileFragment path=./data/dt=20230101/userdata.parquet partition=[dt=20230101]>]
files rows   : [1000]
column size  : 14


## Iteration

In [326]:
for frag in dataset.fragments:
    for batch in frag.to_batches():
        df = batch.to_pandas()
        row = batch.take(pa.array([0]))

        print("frag size   :", sys.getsizeof(frag))
        print("num rows    :", batch.num_rows)
        print("Pandas shape:", df.shape)
        print(row["gender"])


frag size   : 72
num rows    : 1000
Pandas shape: (1000, 13)
[
  "Female"
]
frag size   : 72
num rows    : 1000
Pandas shape: (1000, 13)
[
  "Female"
]


In [446]:
r = row.to_pandas()
r.drop(['id', 'registration_dttm'], axis=1, inplace=True)
r.fillna({'salary': 0}, inplace=True)
r['salary'].values

array([49756.53])

In [474]:
r

Unnamed: 0,first_name,last_name,email,gender,ip_address,cc,country,birthdate,salary,title,comments
0,Amanda,Jordan,ajordan0@com.com,Female,1.197.201.2,6759521864920116,Indonesia,3/8/1971,49756.53,Internal Auditor,100.0


## Pytorch Dataset

In [517]:
import random
from bisect import bisect_right

from pyarrow.dataset import ParquetFileFragment
from pyarrow.lib import RecordBatch
from torch.utils.data import Dataset


class CustomPyarrowDataset(Dataset):
    """
    Restriction
     - Don't shuffle in Dataloader. this is for efficiency to precess large dataset.
       If you need to shuffle, do it before this custom dataset. (like in SparkSQL)
       But the algorithm supports random access.
    """

    def __init__(self, source: str, seed: int = 123):
        self.source = source
        self.seed = seed

        # Pyarrow
        self.dataset = ParquetDataset(source, use_legacy_dataset=False)
        self.fragments: List[ParquetFileFragment] = self.dataset.fragments
        self._batches: Iterator[RecordBatch] = None
        self._batch: Optional[RecordBatch] = None
        self._df: pd.DataFrame = None

        # Indexing meta information to make search faster
        self._cumulative_n_rows: List[int] = []
        self._cumulative_batch_n: int = 0

        # Index
        self._fragment_idx = 0

        # Initialization
        self._init()

    def _init(self):
        random.seed(self.seed)
        random.shuffle(self.fragments)

        self._cumulative_n_rows = [frag.count_rows() for frag in self.fragments]
        for i in range(1, len(self._cumulative_n_rows)):
            self._cumulative_n_rows[i] += self._cumulative_n_rows[i - 1]

    def _get_next(self, idx: int) -> Tuple[int, int]:
        def get_prev_cum_frag_size(_fragment_idx):
            if _fragment_idx >= 1:
                return self._cumulative_n_rows[_fragment_idx - 1]
            return 0

        # Calculate fragment idx
        fragment_idx = self._fragment_idx
        fragment_changed = False
        _prev_size = get_prev_cum_frag_size(fragment_idx)
        _cur_size = self._cumulative_n_rows[self._fragment_idx]
        if (idx < _prev_size) or (idx >= _cur_size):
            fragment_idx = bisect_right(self._cumulative_n_rows, idx)
            assert fragment_idx < len(self.fragments)
            # fragment_idx %= len(self.fragments)
            fragment_changed = True if self._fragment_idx != fragment_idx else False
            self._fragment_idx = fragment_idx
            self._cumulative_batch_n = 0
            self._batches = None
            self._batch = None
            self._df = None
            
        # Calculate batch idx
        _prev_size = get_prev_cum_frag_size(fragment_idx)
        
        # Calculate batches of the fragment
        if self._batches is None or fragment_changed:
            self.batches = self.fragments[fragment_idx].to_batches()
        
        if self._batch is None:
            self._batch = next(self.batches)
            self._df = self._batch.to_pandas()
            self._cumulative_batch_n = 0
            
        while True:
            batch_idx = idx - _prev_size - self._cumulative_batch_n
            if batch_idx <= self._batch.num_rows:
                break
            
            self._cumulative_n_rows += self.batch.num_rows
            self._batch = next(self.batches)
            self._df = self._batch.to_pandas()
        return fragment_idx, batch_idx
        

    def __len__(self):
        return self._cumulative_n_rows[-1]

    def __getitem__(self, idx):
        fragment_idx, batch_idx = self._get_next(idx)
        row = self._df.iloc[batch_idx][['id', 'salary']]
        row = row.fillna(0)
        row['fragment_idx'] = fragment_idx
        row['batch_idx'] = batch_idx
        return row.values, idx


dataset = CustomPyarrowDataset("./data")
dataset[0]

(array([1.000000e+00, 4.975653e+04, 0.000000e+00, 0.000000e+00]), 0)

In [518]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=64, shuffle=True)
data, labels = next(iter(loader))
a = data[:, 0] -1
b = labels % 1000

a == b


tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True])

# ParquetFile

## Row 갯수 

In [480]:
from pyarrow.parquet import ParquetFile

parquet_file = ParquetFile("./data/dt=20230101/userdata.parquet")

print("parquet_file size: ", sys.getsizeof(parquet_file))
parquet_file.metadata

parquet_file size:  64


<pyarrow._parquet.FileMetaData object at 0x7f47ac166470>
  created_by: parquet-mr version 1.8.1 (build 4aba4dae7bb0d4edbcf7923ae1339f28fd3f7fcf)
  num_columns: 13
  num_rows: 1000
  num_row_groups: 1
  format_version: 1.0
  serialized_size: 1125

In [25]:
from pyarrow.parquet import ParquetDataset

dataset = ParquetDataset("./data")

print("dataset size :", sys.getsizeof(dataset))

dataset.files

dataset size : 64


['./data/dt=20230101/userdata.parquet']