In [1]:
import webdataset as wds
import jax
import jax.numpy as jnp
import augmax
import matplotlib.pyplot as plt

import grain.python as pygrain
from typing import Any, Dict, List, Tuple
import numpy as np
from functools import partial
import tqdm 

import fsspec
import json

import os
from transformers import AutoTokenizer, FlaxCLIPTextModel, CLIPTextModel

from datasets import load_dataset, concatenate_datasets, Dataset, load_from_disk
from datasets.utils.file_utils import get_datasets_user_agent
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import io
import urllib

import PIL.Image
import cv2

There was a problem when trying to write in your cache folder (/home/mrwhite0racle/.cache/huggingface/hub). You should set the environment variable TRANSFORMERS_CACHE to a writable directory.


In [2]:
USER_AGENT = get_datasets_user_agent()


def fetch_single_image(image_url, timeout=None, retries=0):
    for _ in range(retries + 1):
        try:
            request = urllib.request.Request(
                image_url,
                data=None,
                headers={"user-agent": USER_AGENT},
            )
            with urllib.request.urlopen(request, timeout=timeout) as req:
                image = PIL.Image.open(io.BytesIO(req.read()))
            break
        except Exception:
            image = None
    return image

denormalizeImage = lambda x: (x + 1.0) * 127.5

def plotImages(imgs, fig_size=(8, 8), dpi=100):
    fig = plt.figure(figsize=fig_size, dpi=dpi)
    imglen = imgs.shape[0]
    for i in range(imglen):
        plt.subplot(fig_size[0], fig_size[1], i + 1)
        plt.imshow(jnp.astype(denormalizeImage(imgs[i, :, :, :]), jnp.uint8))
        plt.axis("off")
    plt.show()


# Filtering pipeline for various datasets

In [3]:
def dataMapper(map: Dict[str, Any]):
    def _map(sample) -> Dict[str, Any]:
        return {
            "url": sample[map["url"]],
            "caption": sample[map["caption"]],
        }
    return _map

def datasetFilter(filterMap):
    def _filter(sample):
        for key, value in filterMap.items():
            try:
                if sample[key] < value["min"] or sample[key] > value["max"]:
                    return False
            except:
                return False
        return True
    return _filter
    

def imageFetcher():
    def fetch_images(batch, num_threads, timeout=None, retries=0):
        fetch_single_image_with_args = partial(fetch_single_image, timeout=timeout, retries=retries)
        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            batch["image"] = list(executor.map(fetch_single_image_with_args, batch["url"]))
        return batch
    return fetch_images

def mapDataset(dataset, args, mapper=dataMapper, workers=16, batch_size=10000, should_remove_columns=True, fn_kwargs={}):
    if should_remove_columns:
        remove_columns = dataset.column_names
    else:
        remove_columns = None
    return dataset.map(mapper(*args), batched=True, batch_size=batch_size, remove_columns=remove_columns, num_proc=workers, fn_kwargs=fn_kwargs) 

In [4]:
laion12m6 = load_dataset("dclure/laion-aesthetics-12m-umap")
laion12m6_fused = laion12m6['train']
laionMap = {
    "url": "URL",
    "caption": "TEXT",
}
laion12m6_fused = mapDataset(laion12m6_fused, (laionMap, ))

Map (num_proc=16):   0%|          | 0/12096809 [00:00<?, ? examples/s]

In [5]:
mscoco = load_dataset("ChristophSchuhmann/MS_COCO_2017_URL_TEXT", split="all")
mscoco_fused = mscoco
mscocoMap = {
    "url": "URL",
    "caption": "TEXT",
}
mscoco_fused = mapDataset(mscoco_fused, (mscocoMap, ))

Map (num_proc=16):   0%|          | 0/591753 [00:00<?, ? examples/s]

In [6]:
fused_data = concatenate_datasets([mscoco_fused, mscoco_fused, mscoco_fused, laion12m6_fused, mscoco_fused, mscoco_fused])

In [7]:
len(fused_data)

15055574

In [8]:
fused_data = fused_data.shuffle(seed=42)

In [None]:
for i in range(0, 20):
    sample = fused_data[i]
    img = fetch_single_image(sample['url'])
    if img is None:
        print("Image is None")
        continue
    text = sample['caption']
    plt.imshow(img)
    plt.title(text)
    # print(f"Aesthetic score: {sample['aesthetic_score_laion_v2']}")
    plt.show()

In [9]:
fused_data.save_to_disk("gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017")

Saving the dataset (0/6 shards):   0%|          | 0/15055574 [00:00<?, ? examples/s]

In [12]:
test = load_from_disk("gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017")

In [13]:
test.shuffle()

Dataset({
    features: ['url', 'caption'],
    num_rows: 15055574
})

In [None]:
imaged_data = mapDataset(mscoco_fused, (), mapper=imageFetcher, batch_size=5000, workers=64, should_remove_columns=False, fn_kwargs={"num_threads": 64})

# COYO-700M Processing

In [14]:
coyo700 = load_dataset("kakaobrain/coyo-700m", num_proc=64)

Resolving data files:   0%|          | 0/128 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/352 [00:00<?, ?it/s]

In [22]:
baseFilterMap = {
    # "word_count": {"min": 0, "max": 100},
    "clip_similarity_vitl14": {"min": 0.27, "max": 1000},
    "aesthetic_score_laion_v2": {"min": 5.1, "max": 100},
    "watermark_score": {"min": 0, "max": 0.4},
}

heavyFilterMap = {
    # "word_count": {"min": 0, "max": 100},
    "clip_similarity_vitl14": {"min": 0.26, "max": 100},
    "aesthetic_score_laion_v2": {"min": 5.4, "max": 100},
    "watermark_score": {"min": 0, "max": 0.8},
    "width": {"min":256, "max":99999},
    "height": {"min":256, "max":99999},
}

def coyoFilter(filterMap):
    def _filter(sample):
        for key, value in filterMap.items():
            if sample[key] < value["min"] or sample[key] > value["max"]:
                return False
        return True
    return _filter
    

In [23]:
# goodCoyo700 = coyo700.filter(coyoFilter(baseFilterMap), num_proc=64)
aestheticCoyo700 = coyo700.filter(coyoFilter(heavyFilterMap), num_proc=120)

Filter (num_proc=64):   0%|          | 0/746972269 [00:00<?, ? examples/s]

In [26]:
len(aestheticCoyo700['train'])

24638115

In [None]:
for i in range(0, 10):
    sample = aestheticCoyo700['train'][i]
    img = fetch_single_image(sample['url'])
    if img is None:
        print("Image is None")
        continue
    text = sample['text']
    plt.imshow(img)
    plt.title(text)
    print(f"Aesthetic score: {sample['aesthetic_score_laion_v2']}")
    plt.show()

In [31]:
final_data = mapDataset(aestheticCoyo700['train'], ({
    "url":"url",
    "caption":"text"
    },),  batch_size=1000000, workers=None)

Map:   0%|          | 0/24638115 [00:00<?, ? examples/s]

In [None]:
final_data.save_to_disk("gs://flaxdiff-datasets-regional/datasets/coyo700m-aesthetic-5.4_25M")

In [None]:
final_data[0]

# Laion Datasets

In [3]:
laion_aesthetic = load_dataset("laion/laion2B-en-aesthetic", split="train")

Resolving data files:   0%|          | 0/128 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/22 [00:00<?, ?it/s]

In [4]:
laion_aesthetic[:5]

{'URL': ['https://images.pexels.com/photos/1464610/pexels-photo-1464610.jpeg?auto=compress&amp;cs=tinysrgb&amp;h=350',
  'https://dspncdn.com/a1/media/236x/ec/51/59/ec515909d46c49b5f9a1c98db3a50c83.jpg',
  'http://images.singletracks.com/blog/wp-content/uploads/2014/10/empire_link-enhanced92719.jpg',
  'https://us.123rf.com/450wm/yupiramos/yupiramos1909/yupiramos190942486/129795457-recently-married-couple-characters-vector-illustration-design.jpg?ver=6',
  'https://us.123rf.com/450wm/capacitorphoto/capacitorphoto1509/capacitorphoto150900170/45946146-gegrillter-lachs-und-tomaten-zitrone-rosmarin-auf-dem-h%C3%B6lzernen-hintergrund-.jpg?ver=6'],
 'TEXT': ['Cafe Latte in Round Red Cup and Saucer',
  'Stunning Adventure Photography by Stevin Tuchiwsky',
  'Trail: Empire Link, Park City, Utah. Rider: The man himself, Chips Chippendale of Singletrack Magazine. Photo: Jeff.',
  'recently married couple characters vector illustration design',
  'Grilled salmon and tomato, lemon, rosemary on the

In [4]:
laion_400m = load_dataset("laion/laion400m", split="train")

Resolving data files:   0%|          | 0/128 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/128 [00:00<?, ?it/s]

In [5]:
laion_400m[:5]

{'url': ['http://t0.gstatic.com/images?q=tbn:ANd9GcTnX7EwHrzccCd3Ki1mmjgocoPMPB_aGKw4g9PrghYZX1ojZiuS',
  'https://careers.cfainstitute.org/getasset/4794ad7b-a4a8-4fc7-b135-a92a187b3d86/',
  'http://img.beckett.com/images/items/custom/marketplace/66045141/migrated.jpg',
  'https://ae01.alicdn.com/kf/HTB1LWfYsr1YBuNjSszhq6AUsFXaW/high-waist-sleeveless-mini-soft-jeans-dress-frilled-women-ruffles-casual-summer-sundress-short-denim-beach.jpg_3-74x74.jpg',
  'https://images.wolfgangsvault.com/images/catalog/thumb/JRM09062-UV.jpg'],
 'NSFW': ['UNLIKELY', 'UNLIKELY', 'UNLIKELY', 'UNLIKELY', 'UNLIKELY'],
 'similarity': [0.30712828040122986,
  0.35018008947372437,
  0.3508765399456024,
  0.359369695186615,
  0.3070274293422699],
 'LICENSE': ['?', '?', '?', '?', '?'],
 'caption': ['bedroom minimalist home interior storage for kids bedroom design',
  'InterOcean Capital Group, LLC logo',
  '2001 Absolute Memorabilia #190 Jay Gibbons RPM RC',
  'high waist sleeveless mini soft jeans dress frilled 

In [10]:
heavyFilterMap = {
    "WIDTH": {"min":256, "max":99999},
    "HEIGHT": {"min":256, "max":99999},
    "similarity": {"min": 0.27, "max": 1000},
    "pwatermark": {"min": 0, "max": 0.6},
    "aesthetic": {"min": 4.2, "max": 100},
}

laion_aesthetic = laion_aesthetic.filter(datasetFilter(heavyFilterMap), num_proc=32)

Filter (num_proc=32):   0%|          | 0/51869119 [00:00<?, ? examples/s]

In [6]:
heavyFilterMap = {
    "original_width": {"min":256, "max":99999},
    "original_height": {"min":256, "max":99999},
    "similarity": {"min": 0.27, "max": 1000},
}

laion_400m = laion_400m.filter(datasetFilter(heavyFilterMap), num_proc=32)

Filter (num_proc=32):   0%|          | 0/361020613 [00:00<?, ? examples/s]

In [None]:
laion_aesthetic.save_to_disk("./datasets/laion2B-en-aesthetic-4.2_37M")

In [7]:
len(laion_400m)

185733069

In [9]:
laion_400m.save_to_disk("./datasets/laiion400m-185M")

Saving the dataset (0/80 shards):   0%|          | 0/185733069 [00:00<?, ? examples/s]

# CC12M and CC3M

In [7]:
# cc12m = load_dataset("google-research-datasets/conceptual_12m", split="all")
cc12mMap = {
    "url": "image_url",
    "caption": "caption",
}
cc12m = mapDataset(cc12m, (cc12mMap, ), batch_size=1000000, workers=None)

Map:   0%|          | 0/12423374 [00:00<?, ? examples/s]

In [8]:
cc12m.save_to_disk("gs://flaxdiff-datasets-regional/datasets/cc12m")

Saving the dataset (0/6 shards):   0%|          | 0/12423374 [00:00<?, ? examples/s]

In [9]:
cc3m = load_dataset("google-research-datasets/conceptual_captions", split="all")
cc3mMap = {
    "url": "image_url",
    "caption": "caption",
}
cc3m = mapDataset(cc3m, (cc3mMap, ), batch_size=1000000, workers=None)

Downloading readme:   0%|          | 0.00/14.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/187M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/187M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.77M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3318333 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/15840 [00:00<?, ? examples/s]

Map:   0%|          | 0/3334173 [00:00<?, ? examples/s]

In [10]:
cc3m.save_to_disk("gs://flaxdiff-datasets-regional/datasets/cc3m")

Saving the dataset (0/2 shards):   0%|          | 0/3334173 [00:00<?, ? examples/s]

# Leonardo Synthetic Dataset

In [12]:
# playground = load_dataset("bigdata-pw/playground-liked", split="all")
playgroundMap = {
    "url": "url",
    "caption": "prompt",
}
final_data = mapDataset(playground, (playgroundMap,),  batch_size=1000000, workers=None)

final_data.save_to_disk("gs://flaxdiff-datasets-regional/datasets/playground-liked")

Map:   0%|          | 0/14381152 [00:00<?, ? examples/s]

Saving the dataset (0/12 shards):   0%|          | 0/14381152 [00:00<?, ? examples/s]

In [5]:
leonardo = load_dataset("bigdata-pw/leonardo", split="all", num_proc=64)
leonardoMap = {
    "url": "image_url",
    "caption": "caption",
}


Resolving data files:   0%|          | 0/958 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/957610978 [00:00<?, ? examples/s]

DatasetGenerationError: An error occurred while generating the dataset

In [18]:
heavyFilterMap = {
    "like_count": {"min": 1, "max": 100000000},
}

def leonardoFilter(filterMap):
    def _filter(sample):
        # if len(sample['negative_prompt']) != 0:
        #     return False
        for key, value in filterMap.items():
            if sample[key] < value["min"] or sample[key] > value["max"]:
                return False
        return True
    return _filter
    

In [None]:
leonardoLiked = leonardo.filter(leonardoFilter(heavyFilterMap), num_proc=120)

In [None]:
final_data = mapDataset(leonardoLiked, ({
    "url":"url",
    "caption":"text"
    },),  batch_size=1000000, workers=None)

final_data.save_to_disk("gs://flaxdiff-datasets-regional/datasets/leonardo-liked")

In [19]:
leonardo = load_dataset("bigdata-pw/leonardo", split='train', streaming=True)
leonardo_100m = leonardo.shuffle().take(600_000_000)

filtered_leonardo_iterator = leonardo_100m.filter(leonardoFilter(heavyFilterMap))
filtered_leonardo = []
# for sample in tqdm.tqdm(filtered_leonardo_iterator, total=100_000_000):
#     filtered_leonardo.append(sample)
from torch.utils.data import DataLoader

def collate_fn(batch):
    # urls = [item['url'] for item in batch]
    # captions = [item['prompt'] for item in batch]
    # return {"url": urls, "caption": captions}
    return [{"url": item['url'], "caption": item['prompt']} for item in batch]

loader = DataLoader(filtered_leonardo_iterator, batch_size=100000, num_workers=64, persistent_workers=True, collate_fn=collate_fn)

for batch in tqdm.tqdm(loader):
    filtered_leonardo.extend(batch)

Resolving data files:   0%|          | 0/958 [00:00<?, ?it/s]

In [None]:
len(filtered_leonardo)

605231

In [15]:
data = Dataset.from_list(filtered_leonardo)

In [17]:
data.save_to_disk("gs://flaxdiff-datasets-regional/datasets/leonardo-liked-600k")

Saving the dataset (0/1 shards):   0%|          | 0/605231 [00:00<?, ? examples/s]

# Verifications

# Data Loading Experiments

In [5]:
import albumentations as A
from flaxdiff.data.online_loader import OnlineStreamingDataLoader, dataMapper, \
        default_collate, load_dataset, concatenate_datasets, \
        ImageBatchIterator, default_image_processor, load_from_disk
import cv2

import threading
import queue

def default_image_processor(
    image, image_shape, 
    min_image_shape=(128, 128),
    upscale_interpolation=cv2.INTER_CUBIC,
    downscale_interpolation=cv2.INTER_AREA,
):
    try:
        image = np.array(image)
        if len(image.shape) != 3 or image.shape[2] != 3:
            return None, 0, 0
        original_height, original_width = image.shape[:2]
        # check if the image is too small
        if min(original_height, original_width) < min(min_image_shape):
            return None, original_height, original_width
        # check if wrong aspect ratio
        if max(original_height, original_width) / min(original_height, original_width) > 2.4:
            return None, original_height, original_width
        # check if the variance is too low
        if np.std(image) < 1e-5:
            return None, original_height, original_width
        # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        downscale = max(original_width, original_height) > max(image_shape)
        interpolation = downscale_interpolation if downscale else upscale_interpolation

        image = A.longest_max_size(image, max(
            image_shape), interpolation=interpolation)
        image = A.pad(
            image,
            min_height=image_shape[0],
            min_width=image_shape[1],
            border_mode=cv2.BORDER_CONSTANT,
            value=[255, 255, 255],
        )
        return image, original_height, original_width
    except Exception as e:
        # print("Error processing image", e, image_shape, interpolation)
        # traceback.print_exc()
        return None, 0, 0

def default_feature_extractor(sample):
    url = None
    if "url" in sample:
        url = sample["url"]
    elif "URL" in sample:
        url = sample["URL"]
    elif "image_url" in sample:
        url = sample["image_url"]
    else:
        print("No url found in sample, skipping", sample.keys())
    
    caption = None
    if "caption" in sample:
        caption = sample["caption"]
    elif "CAPTION" in sample:
        caption = sample["CAPTION"]
    elif "txt" in sample:
        caption = sample["txt"]
    elif "TEXT" in sample:
        caption = sample["TEXT"]
    elif "text" in sample:
        caption = sample["text"]
    else:
        print("No caption found in sample, skipping", sample.keys())
        
    print("url", url, "caption", caption)
        
    return {
        "url": url,
        "caption": caption,
    }
    

class OnlineStreamingDataLoader():
    def __init__(
        self,
        dataset,
        batch_size=64,
        image_shape=(256, 256),
        min_image_shape=(128, 128),
        num_workers=16,
        num_threads=512,
        default_split="all",
        pre_map_maker=dataMapper,
        pre_map_def={
            "url": "URL",
            "caption": "TEXT",
        },
        global_process_count=1,
        global_process_index=0,
        prefetch=1000,
        collate_fn=default_collate,
        timeout=15,
        retries=3,
        image_processor=default_image_processor,
        upscale_interpolation=cv2.INTER_CUBIC,
        downscale_interpolation=cv2.INTER_AREA,
        feature_extractor=default_feature_extractor,
    ):
        if isinstance(dataset, str):
            dataset_path = dataset
            print("Loading dataset from path")
            if "gs://" in dataset:
                dataset = load_from_disk(dataset_path)
            else:
                dataset = load_dataset(dataset_path, split=default_split)
        elif isinstance(dataset, list):
            if isinstance(dataset[0], str):
                print("Loading multiple datasets from paths")
                dataset = [load_from_disk(dataset_path) if "gs://" in dataset_path else load_dataset(
                    dataset_path, split=default_split) for dataset_path in dataset]
            print("Concatenating multiple datasets")
            dataset = concatenate_datasets(dataset)
            dataset = dataset.shuffle(seed=0)
        # dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
        self.dataset = dataset.shard(
            num_shards=global_process_count, index=global_process_index)
        print(f"Dataset length: {len(dataset)}")
        self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
                                           min_image_shape=min_image_shape,
                                           num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
                                            timeout=timeout, retries=retries, image_processor=image_processor,
                                             upscale_interpolation=upscale_interpolation,
                                             downscale_interpolation=downscale_interpolation,
                                             feature_extractor=feature_extractor)
        self.batch_size = batch_size

        # Launch a thread to load batches in the background
        self.batch_queue = queue.Queue(prefetch)

        def batch_loader():
            for batch in self.iterator:
                try:
                    print("Putting batch in queue")
                    self.batch_queue.put(collate_fn(batch))
                except Exception as e:
                    print("Error collating batch", e)

        self.loader_thread = threading.Thread(target=batch_loader)
        self.loader_thread.start()

    def __iter__(self):
        return self

    def __next__(self):
        return self.batch_queue.get()
        # return self.collate_fn(next(self.iterator))

    def __len__(self):
        return len(self.dataset)

In [6]:
dataloader = OnlineStreamingDataLoader(
            "laion/laion2B-en-aesthetic", 
            batch_size=1,
            num_workers=2,
            num_threads=16,
            image_shape=(128,128),
            global_process_count=jax.process_count(),
            global_process_index=jax.process_index(),
            prefetch=1,
            default_split="train",
            feature_extractor=default_feature_extractor,
        )

Loading dataset from path


Resolving data files:   0%|          | 0/128 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/22 [00:00<?, ?it/s]

Dataset length: 51869119
Local Shard lengths: 25934559


  self.pid = os.fork()


In [7]:
dataloader.__next__()

In [None]:
dataloader = OnlineStreamingDataLoader("ChristophSchuhmann/MS_COCO_2017_URL_TEXT", batch_size=16, num_workers=16, default_split="train")

In [None]:
dataloader.batch_queue.qsize()

In [None]:
data_queue.qsize()

In [None]:
error_queue.qsize()

In [None]:
for i in tqdm.tqdm(range(0, 2000)):
    batch = next(dataloader)

In [None]:
def parallel_loading(dataset):
    dataset.map(map_batch_fn, num_proc=64, batched=True, batch_size=64, fn_kwargs={"num_threads": 64})
    
thread = threading.Thread(target=parallel_loading, args=(mscoco_fused,))
thread.start()

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from concurrent.futures import ThreadPoolExecutor
import aiohttp
from io import BytesIO
import asyncio
from PIL import Image


In [None]:
class URLDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    async def fetch_image(self, url):
        async with aiohttp.ClientSession() as session:
            async with session.get(url) as response:
                image_data = await response.read()
                image = Image.open(BytesIO(image_data))
                return image
    
    def __getitem__(self, index):
        data = self.data[index]
        url, caption = data['url'], data['caption']
        loop = asyncio.get_event_loop()
        image = loop.run_until_complete(self.fetch_image(url))
        # Preprocess image and return along with the caption
        image = image.resize((256, 256))  # Example resize
        return image, caption
    
    def __len__(self):
        return len(self.data)

# Example usage
dataset = URLDataset(mscoco_fused)
data_loader = DataLoader(dataset, batch_size=256, num_workers=8, prefetch_factor=2)

In [None]:
for i in tqdm.tqdm(data_loader):
    pass

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        url = self.dataset[idx]['url']
        caption = self.dataset[idx]['caption']
        image = fetch_single_image(url)  # Assuming fetch_single_image is defined elsewhere
        return {
            "url": url,
            "caption": caption,
            "image": image
        }

def collate_fn(batch):
    # Custom collation logic if needed
    print(batch)
    # urls = [item["url"] for item in batch]
    # fetch_single_image_with_args = partial(fetch_single_image, timeout=10, retries=3)
    # with ThreadPoolExecutor(max_workers=len(batch)) as executor:
    #     images = list(executor.map(fetch_single_image_with_args, urls))
    
    # return {
    #     "url": urls,
    #     "caption": [item["caption"] for item in batch],
    #     "image": images
    # }
    
# Assuming mscoco_fused is your dataset
dataset = CustomDataset(mscoco_fused)
data_loader = DataLoader(dataset, batch_size=512, num_workers=8, collate_fn=collate_fn, prefetch_factor=100)

In [None]:
for i in tqdm.tqdm(data_loader):
    # print(i)
    # break
    pass

In [None]:
queue.qsize()

In [None]:
!pip install arrayqueues

In [None]:
with multiprocessing.Manager() as manager:
img_queue = manager.Queue()
process = multiprocessing.Process(target=parallel_image_loader, args=(mscoco_fused, img_queue, 8))
process.start()
process.join()

In [None]:
import multiprocessing
from multiprocessing import shared_memory
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from datasets import Dataset
import threading

def create_shared_array(shape, dtype):
    """Create a shared numpy array."""
    nbytes = np.prod(shape) * np.dtype(dtype).itemsize
    shm = shared_memory.SharedMemory(create=True, size=nbytes)
    array = np.ndarray(shape, dtype=dtype, buffer=shm.buf)
    return shm, array

def map_fn(url, caption, shared_array, shared_index, lock, shape, dtype):
    image = fetch_single_image(url)  # Assuming fetch_single_image is defined elsewhere
    with lock:
        index = shared_index.value
        shared_array[index] = np.frombuffer(image, dtype=dtype).reshape(shape)  # Store image in shared memory
        shared_index.value += 1  # Move to the next index
        # Save additional info (url, caption) if necessary

def map_batch_fn(batch, shared_array, shared_index, lock, shape, dtype, num_threads=64):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        executor.map(
            map_fn, 
            batch["url"], 
            batch['caption'], 
            [shared_array] * len(batch["url"]), 
            [shared_index] * len(batch["url"]), 
            [lock] * len(batch["url"]), 
            [shape] * len(batch["url"]), 
            [dtype] * len(batch["url"])
        )

def parallel_image_loader(dataset: Dataset, shared_array, shared_index, lock, shape, dtype, num_workers: int = 8):
    batch_len = len(dataset) // num_workers
    batches = [dataset[i * batch_len:(i + 1) * batch_len] for i in range(num_workers)]
    with multiprocessing.Pool(num_workers) as pool:
        pool.starmap(
            map_batch_fn, 
            [(batch, shared_array, shared_index, lock, shape, dtype) for batch in batches]
        )

class ImageBatchIterator:
    def __init__(self, dataset: Dataset, num_workers: int = 8, batch_size: int = 64, image_shape=(224, 224, 3), dtype=np.uint8):
        self.dataset = dataset
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.image_shape = image_shape
        self.dtype = dtype
        
        # Create shared memory array
        self.shm, self.shared_array = create_shared_array((len(dataset),) + image_shape, dtype)
        self.shared_index = multiprocessing.Value('i', 0)  # Shared index counter
        self.lock = multiprocessing.Lock()  # Lock for safe indexing
        
        self.thread = threading.Thread(target=parallel_image_loader, args=(
            dataset, self.shared_array, self.shared_index, self.lock, image_shape, dtype, num_workers))
        self.thread.start()
        
    def __iter__(self):
        return self
    
    def __next__(self):
        if self.shared_index.value < self.batch_size:
            raise StopIteration
        
        batch_start = max(0, self.shared_index.value - self.batch_size)
        batch_end = self.shared_index.value
        batch = self.shared_array[batch_start:batch_end]
        return batch
    
    def __del__(self):
        self.thread.join()
        self.shm.close()
        self.shm.unlink()  # Free shared memory when done
        
    def __len__(self):
        return len(self.dataset) // self.batch_size

# Example usage:
dataset = ImageBatchIterator(mscoco_fused, num_workers=16, batch_size=64, image_shape=(224, 224, 3))
for i in tqdm.tqdm(range(0, 100)):
    batch = next(dataset)

In [None]:
for i in tqdm.tqdm(range(0, 100)):
    batch = next(dataset)