In [24]:
# imports 
import jax 
import jax.numpy as jnp 
import numpy as np 
from datasets import load_from_disk
from datasets import Dataset 
from collections.abc import Iterator 


def build_jax_dataloader(
        dataset:Dataset, 
        batch_size:int,
        prng_key:jax.Array, 
        drop_last:bool=True
) -> Iterator[np.ndarray]:
    """ Generator to yield batches of data for jax """
    
    dataset_len = len(dataset) 
    # we generate shuffled indices for dataset on jax backend
    shuffled_indices = jax.random.permutation(prng_key, jnp.arange(dataset_len))
    # cast back to numpy to avoid device to host memory transfer in python loop 
    shuffled_indices = np.array(shuffled_indices)

    # computing boundary 
    num_batches = dataset_len // batch_size 
    if not drop_last and dataset_len % batch_size != 0: num_batches += 1

    # streaming data using iteration 
    for i in range(num_batches): 
        start_idx = i * batch_size
        end_idx = min((i+1) * batch_size, dataset_len)

        # slicing the index shuffled array 
        batch_indices = shuffled_indices[start_idx:end_idx]

        # Quering arrow dataset with list of indices is highly optimised 
        # It will return a dictionary of list
        batch_data = dataset[batch_indices.tolist()]

        # project python list in continuos C array 
        en_batch = np.array(batch_data['en'], dtype=np.int32) 
        hi_batch = np.array(batch_data['hi'], dtype=np.int32)

        # stack and return to shape (Batch_size, 2, max_seq_len)
        yield np.stack([en_batch, hi_batch], axis=1)

# executing the dataloader 
key = jax.random.PRNGKey(744)
batch_size=64 
processed_dataset = load_from_disk("../processed_data");print(processed_dataset)
dataloader = build_jax_dataloader(processed_dataset['train'], batch_size, key) #type:ignore
batch = next(dataloader)
print(batch.shape)

DatasetDict({
    train: Dataset({
        features: ['en', 'hi'],
        num_rows: 1652981
    })
    validation: Dataset({
        features: ['en', 'hi'],
        num_rows: 520
    })
    test: Dataset({
        features: ['en', 'hi'],
        num_rows: 2507
    })
})
(64, 2, 128)


In [27]:
len(processed_dataset['train'])

1652981

In [29]:
processed_dataset['train'].select([1,2,3])

Dataset({
    features: ['en', 'hi'],
    num_rows: 3
})

In [36]:
processed_dataset['train'][[1,5,9]]['en'].__len__(),processed_dataset['train'][[1,5,9]]['en'][0].__len__()

(3, 128)

In [48]:
a = np.array([[1,2,3],[4,5,6], [0,0,0]])
b = np.array([[7,8,9],[10,11,12], [1,1,1]])
p = np.stack([a, b], axis=0)
p

array([[[ 1,  2,  3],
        [ 4,  5,  6],
        [ 0,  0,  0]],

       [[ 7,  8,  9],
        [10, 11, 12],
        [ 1,  1,  1]]])

In [50]:
a,b = p[0], p[1]
a.shape, b.shape

((3, 3), (3, 3))