In [1]:
import os
os.environ["JAX_PLATFORMS"] = "cpu"



In [4]:
import cv2
import jax
import numpy as np
import jax.numpy as jnp
import grain.python as pygrain
from typing import Dict, Any, Callable, List, Optional
import struct as st

In [12]:
def unpack_dict_of_byte_arrays(packed_data):
    """Unpacks a dictionary of byte arrays from a packed binary format."""
    unpacked_dict = {}
    offset = 0
    while offset < len(packed_data):
        # Unpack the key length
        key_length = st.unpack_from('I', packed_data, offset)[0]
        offset += st.calcsize('I')
        # Unpack the key bytes and convert to string
        key = packed_data[offset:offset+key_length].decode('utf-8')
        offset += key_length
        # Unpack the byte array length
        byte_array_length = st.unpack_from('I', packed_data, offset)[0]
        offset += st.calcsize('I')
        # Unpack the byte array
        byte_array = packed_data[offset:offset+byte_array_length]
        offset += byte_array_length
        unpacked_dict[key] = byte_array
    return unpacked_dict

def image_augmenter(image, image_scale, method=cv2.INTER_AREA):
    """Basic image augmentation: convert color and resize."""
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (image_scale, image_scale),
                       interpolation=method)
    return image

def get_source(path: str = "/home/mrwhite0racle/gcs_mount/datasets/laion12m+mscoco") -> Any:
    """
    Get the GCS data source. 
    """
    records_path = path
    records = [os.path.join(records_path, i) for i in os.listdir(
        records_path) if 'array_record' in i]
    return pygrain.ArrayRecordDataSource(records)

class GCSTransform(pygrain.MapTransform):
    """
    Transform for GCS data source.
    """
    def __init__(self, image_scale: int = 256, method=cv2.INTER_AREA):
        super().__init__()
        self.image_scale = image_scale
        self.method = method

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        element = unpack_dict_of_byte_arrays(data)
        image = np.asarray(bytearray(element['jpg']), dtype="uint8")
        image = cv2.imdecode(image, cv2.IMREAD_UNCHANGED)
        image = image_augmenter(image, self.image_scale, self.method)
        text = element['text'].decode('utf-8')
        return {
            'image': image,
            'text': text
        }

In [14]:
records_source.paths

['/home/mrwhite0racle/gcs_mount/datasets/laion12m+mscoco/00000.array_record',
 '/home/mrwhite0racle/gcs_mount/datasets/laion12m+mscoco/00001.array_record',
 '/home/mrwhite0racle/gcs_mount/datasets/laion12m+mscoco/00002.array_record',
 '/home/mrwhite0racle/gcs_mount/datasets/laion12m+mscoco/00003.array_record',
 '/home/mrwhite0racle/gcs_mount/datasets/laion12m+mscoco/00004.array_record',
 '/home/mrwhite0racle/gcs_mount/datasets/laion12m+mscoco/00005.array_record',
 '/home/mrwhite0racle/gcs_mount/datasets/laion12m+mscoco/00006.array_record',
 '/home/mrwhite0racle/gcs_mount/datasets/laion12m+mscoco/00007.array_record',
 '/home/mrwhite0racle/gcs_mount/datasets/laion12m+mscoco/00008.array_record',
 '/home/mrwhite0racle/gcs_mount/datasets/laion12m+mscoco/00009.array_record',
 '/home/mrwhite0racle/gcs_mount/datasets/laion12m+mscoco/00010.array_record',
 '/home/mrwhite0racle/gcs_mount/datasets/laion12m+mscoco/00011.array_record',
 '/home/mrwhite0racle/gcs_mount/datasets/laion12m+mscoco/00012.a

In [13]:
len(records_source)

10766521

In [None]:
records_source = get_source()
transforms = [
    GCSTransform(),
]

sampler = pygrain.IndexSampler(
    num_records=len(records_source),

loader = pygrain.DataLoader(
    records_source,
    