In [1]:
import webdataset as wds
import jax
import jax.numpy as jnp
import augmax
import matplotlib.pyplot as plt

import grain.python as pygrain
from typing import Any, Dict, List, Tuple
import numpy as np
from functools import partial
import tqdm 

import fsspec
import json

import os
from transformers import AutoTokenizer, FlaxCLIPTextModel, CLIPTextModel

from datasets import load_dataset, concatenate_datasets, Dataset
from datasets.utils.file_utils import get_datasets_user_agent
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import io
import urllib

import PIL.Image
import cv2

In [2]:
USER_AGENT = get_datasets_user_agent()


def fetch_single_image(image_url, timeout=None, retries=0):
    for _ in range(retries + 1):
        try:
            request = urllib.request.Request(
                image_url,
                data=None,
                headers={"user-agent": USER_AGENT},
            )
            with urllib.request.urlopen(request, timeout=timeout) as req:
                image = PIL.Image.open(io.BytesIO(req.read()))
            break
        except Exception:
            image = None
    return image

denormalizeImage = lambda x: (x + 1.0) * 127.5

def plotImages(imgs, fig_size=(8, 8), dpi=100):
    fig = plt.figure(figsize=fig_size, dpi=dpi)
    imglen = imgs.shape[0]
    for i in range(imglen):
        plt.subplot(fig_size[0], fig_size[1], i + 1)
        plt.imshow(jnp.astype(denormalizeImage(imgs[i, :, :, :]), jnp.uint8))
        plt.axis("off")
    plt.show()


# Filtering pipeline for various datasets

In [3]:
def dataMapper(map: Dict[str, Any]):
    def _map(sample) -> Dict[str, Any]:
        return {
            "url": sample[map["url"]],
            "caption": sample[map["caption"]],
        }
    return _map

def imageFetcher():
    def fetch_images(batch, num_threads, timeout=None, retries=0):
        fetch_single_image_with_args = partial(fetch_single_image, timeout=timeout, retries=retries)
        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            batch["image"] = list(executor.map(fetch_single_image_with_args, batch["url"]))
        return batch
    return fetch_images

def mapDataset(dataset, args, mapper=dataMapper, workers=16, batch_size=10000, should_remove_columns=True, fn_kwargs={}):
    if should_remove_columns:
        remove_columns = dataset.column_names
    else:
        remove_columns = None
    return dataset.map(mapper(*args), batched=True, batch_size=batch_size, remove_columns=remove_columns, num_proc=workers, fn_kwargs=fn_kwargs) 

In [None]:
laion12m6 = load_dataset("dclure/laion-aesthetics-12m-umap")
laion12m6_fused = laion12m6['train']
laionMap = {
    "url": "URL",
    "caption": "TEXT",
}
laion12m6_fused = mapDataset(laion12m6_fused, (laionMap, ))

In [4]:
mscoco = load_dataset("ChristophSchuhmann/MS_COCO_2017_URL_TEXT", split="all")
mscoco_fused = mscoco
mscocoMap = {
    "url": "URL",
    "caption": "TEXT",
}
mscoco_fused = mapDataset(mscoco_fused, (mscocoMap, ))

In [9]:
mscoco

Dataset({
    features: ['URL', 'TEXT'],
    num_rows: 591753
})

In [12]:
fused_data = concatenate_datasets([mscoco_fused, mscoco_fused, mscoco_fused, laion12m6_fused, mscoco_fused, mscoco_fused])

In [13]:
len(fused_data)

15055574

In [None]:
for i in range(0, 20):
    sample = fused_data[i]
    img = fetch_single_image(sample['url'])
    if img is None:
        print("Image is None")
        continue
    text = sample['caption']
    plt.imshow(img)
    plt.title(text)
    # print(f"Aesthetic score: {sample['aesthetic_score_laion_v2']}")
    plt.show()

In [14]:
fused_data.to_parquet("laion-aesthetics-12m+mscoco-2017.parquet")

Creating parquet from Arrow format:   2%|▏         | 370/15056 [00:00<00:07, 1845.71ba/s]

Creating parquet from Arrow format: 100%|██████████| 15056/15056 [00:18<00:00, 826.17ba/s] 


2730439590

In [118]:
imaged_data = mapDataset(mscoco_fused, (), mapper=imageFetcher, batch_size=5000, workers=64, should_remove_columns=False, fn_kwargs={"num_threads": 64})

  self.pid = os.fork()


Map (num_proc=64):  19%|█▊        | 110000/591753 [01:47<03:12, 2506.78 examples/s]

# COYO-700M Processing

In [4]:
coyo700 = load_dataset("kakaobrain/coyo-700m", num_proc=64)

Resolving data files:   0%|          | 0/128 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/352 [00:00<?, ?it/s]

In [7]:
baseFilterMap = {
    # "word_count": {"min": 0, "max": 100},
    "clip_similarity_vitl14": {"min": 0.27, "max": 1000},
    "aesthetic_score_laion_v2": {"min": 5.1, "max": 100},
    "watermark_score": {"min": 0, "max": 0.4},
}

heavyFilterMap = {
    # "word_count": {"min": 0, "max": 100},
    "clip_similarity_vitl14": {"min": 0.26, "max": 100},
    "aesthetic_score_laion_v2": {"min": 5.0, "max": 100},
    "watermark_score": {"min": 0, "max": 0.8},
    "width": {"min":512, "max":99999},
    "height": {"min":512, "max":99999},
}

def coyoFilter(filterMap):
    def _filter(sample):
        for key, value in filterMap.items():
            if sample[key] < value["min"] or sample[key] > value["max"]:
                return False
        return True
    return _filter
    

In [8]:
# goodCoyo700 = coyo700.filter(coyoFilter(baseFilterMap), num_proc=64)
aestheticCoyo700 = coyo700.filter(coyoFilter(heavyFilterMap), num_proc=64)

Filter (num_proc=64): 100%|██████████| 746972269/746972269 [03:33<00:00, 3505937.43 examples/s]


In [17]:
len(aestheticCoyo700['train'])

16586129

In [None]:
for i in range(0, 100):
    sample = coyo700Filtered['train'][i]
    img = fetch_single_image(sample['url'])
    if img is None:
        print("Image is None")
        continue
    text = sample['text']
    plt.imshow(img)
    plt.title(text)
    print(f"Aesthetic score: {sample['aesthetic_score_laion_v2']}")
    plt.show()

In [19]:
final_data = mapDataset(aestheticCoyo700['train'], ({
    "url":"url",
    "caption":"text"
    }, ))

Map (num_proc=16): 100%|██████████| 16586129/16586129 [01:15<00:00, 218835.67 examples/s]


In [20]:
final_data.to_parquet("aestheticCoyo_0.26_clip_5.5aesthetic_256plus.parquet")

Creating parquet from Arrow format: 100%|██████████| 16587/16587 [00:22<00:00, 729.06ba/s]


3343770169

In [17]:
final_data[0]

{'url': 'https://img3.goodfon.com/wallpaper/big/5/85/art-krenz-fallen-angel-angel.jpg',
 'caption': 'Picture girl, fiction, fire, wings, angel, red eyes, white hair, art, Angel Fall, Krenz'}

# Data Loading Experiments

In [8]:
laion12m6 = load_dataset("dclure/laion-aesthetics-12m-umap", split="all")
# laion12m6_fused = laion12m6
# laionMap = {
#     "url": "URL",
#     "caption": "TEXT",
# }
# laion12m6_fused = mapDataset(laion12m6_fused, (laionMap, ))

Downloading readme: 100%|██████████| 4.01k/4.01k [00:00<00:00, 22.5MB/s]
Downloading data: 100%|██████████| 2.44G/2.44G [00:44<00:00, 54.9MB/s]
Generating train split: 100%|██████████| 12096809/12096809 [00:09<00:00, 1313947.82 examples/s]


In [10]:
import multiprocessing
import threading
from multiprocessing import Queue
# from arrayqueues.shared_arrays import ArrayQueue
# from faster_fifo import Queue
import time
import albumentations as A
import queue

USER_AGENT = get_datasets_user_agent()

data_queue = Queue(16*2000)
error_queue = Queue(16*2000)


def fetch_single_image(image_url, timeout=None, retries=0):
    for _ in range(retries + 1):
        try:
            request = urllib.request.Request(
                image_url,
                data=None,
                headers={"user-agent": USER_AGENT},
            )
            with urllib.request.urlopen(request, timeout=timeout) as req:
                image = PIL.Image.open(io.BytesIO(req.read()))
            break
        except Exception:
            image = None
    return image

def map_sample(
    url, caption, 
    image_shape=(256, 256),
    upscale_interpolation=cv2.INTER_LANCZOS4,
    downscale_interpolation=cv2.INTER_AREA,
):
    try:
        image = fetch_single_image(url, timeout=15, retries=3)  # Assuming fetch_single_image is defined elsewhere
        if image is None:
            return
        
        image = np.array(image)
        original_height, original_width = image.shape[:2]
        # check if the image is too small
        if min(original_height, original_width) < min(image_shape):
            return
        # check if wrong aspect ratio
        if max(original_height, original_width) / min(original_height, original_width) > 2:
            return
        # check if the variance is too low
        if np.std(image) < 1e-4:
            return
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        downscale = max(original_width, original_height) > max(image_shape)
        interpolation = downscale_interpolation if downscale else upscale_interpolation
        image = A.longest_max_size(image, max(image_shape), interpolation=interpolation)
        image = A.pad(
            image,
            min_height=image_shape[0],
            min_width=image_shape[1],
            border_mode=cv2.BORDER_CONSTANT,
            value=[255, 255, 255],
        )
        data_queue.put({
            "url": url,
            "caption": caption,
            "image": image
        })
    except Exception as e:
        error_queue.put({
            "url": url,
            "caption": caption,
            "error": str(e)
        })
        
def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=None, retries=0):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        executor.map(map_sample, batch["url"], batch['caption'], image_shape=image_shape, timeout=timeout, retries=retries)
    
def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):
    map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)
    shard_len = len(dataset) // num_workers
    print(f"Local Shard lengths: {shard_len}")
    with multiprocessing.Pool(num_workers) as pool:
        iteration = 0
        while True:
            # Repeat forever
            dataset = dataset.shuffle(seed=iteration)
            shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]
            pool.map(map_batch_fn, shards)
            iteration += 1
            
class ImageBatchIterator:
    def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), num_workers: int = 8, num_threads=256):
        self.dataset = dataset
        self.num_workers = num_workers
        self.batch_size = batch_size
        loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)
        self.thread = threading.Thread(target=loader, args=(dataset))
        self.thread.start()
        
    def __iter__(self):
        return self
    
    def __next__(self):
        def fetcher(_):
            return data_queue.get()
        with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
            batch = list(executor.map(fetcher, range(self.batch_size)))
        return batch
    
    def __del__(self):
        self.thread.join()
        
    def __len__(self):
        return len(self.dataset) // self.batch_size
    
def default_collate(batch):
    urls = [sample["url"] for sample in batch]
    captions = [sample["caption"] for sample in batch]
    images = np.stack([sample["image"] for sample in batch], axis=0)
    return {
        "url": urls,
        "caption": captions,
        "image": images,
    }
    
def dataMapper(map: Dict[str, Any]):
    def _map(sample) -> Dict[str, Any]:
        return {
            "url": sample[map["url"]],
            "caption": sample[map["caption"]],
        }
    return _map

class OnlineStreamingDataLoader():
    def __init__(
        self, 
        dataset, 
        batch_size=64, 
        num_workers=16, 
        num_threads=512,
        default_split="all",
        pre_map_maker=dataMapper, 
        pre_map_def={
            "url": "URL",
            "caption": "TEXT",
        },
        global_process_count=1,
        global_process_index=0,
        prefetch=1000,
        collate_fn=default_collate,
    ):
        if isinstance(dataset, str):
            dataset_path = dataset
            print("Loading dataset from path")
            dataset = load_dataset(dataset_path, split=default_split)
        elif isinstance(dataset, list):
            if isinstance(dataset[0], str):
                print("Loading multiple datasets from paths")
                dataset = [load_dataset(dataset_path, split=default_split) for dataset_path in dataset]
            else:
                print("Concatenating multiple datasets")
                dataset = concatenate_datasets(dataset)
        dataset = dataset.map(pre_map_maker(pre_map_def))
        self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)
        print(f"Dataset length: {len(dataset)}")
        self.iterator = ImageBatchIterator(self.dataset, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)
        self.collate_fn = collate_fn
        
        # Launch a thread to load batches in the background
        self.batch_queue = queue.Queue(prefetch)
        
        def batch_loader():
            for batch in self.iterator:
                self.batch_queue.put(batch)
        
        self.loader_thread = threading.Thread(target=batch_loader)
        self.loader_thread.start()
        
    def __iter__(self):
        return self
    
    def __next__(self):
        return self.collate_fn(self.batch_queue.get())
        # return self.collate_fn(next(self.iterator))
        
    def __len__(self):
        return len(self.dataset) // self.batch_size
    

In [None]:
from flaxdiff.data.online_loader import OnlineStreamingDataLoader

In [11]:
dataloader = OnlineStreamingDataLoader("ChristophSchuhmann/MS_COCO_2017_URL_TEXT", batch_size=16, num_workers=16, default_split="train")

Loading dataset from path


Dataset length: 591753


In [8]:
dataloader.batch_queue.qsize()

0

In [9]:
data_queue.qsize()

NameError: name 'data_queue' is not defined

In [20]:
error_queue.qsize()

0

In [12]:
for i in tqdm.tqdm(range(0, 2000)):
    batch = next(dataloader)

  0%|          | 0/2000 [00:00<?, ?it/s]

Exception in thread Thread-11:
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/home/mrwhite0racle/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "/usr/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
TypeError: parallel_image_loader() got multiple values for argument 'num_threads'


In [12]:
def parallel_loading(dataset):
    dataset.map(map_batch_fn, num_proc=64, batched=True, batch_size=64, fn_kwargs={"num_threads": 64})
    
thread = threading.Thread(target=parallel_loading, args=(mscoco_fused,))
thread.start()



Map (num_proc=64):  36%|███▌      | 214080/591753 [01:54<32:53, 191.37 examples/s] 

14974

In [13]:
import torch
from torch.utils.data import Dataset, DataLoader
from concurrent.futures import ThreadPoolExecutor
import aiohttp
from io import BytesIO
import asyncio
from PIL import Image


In [14]:
class URLDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    async def fetch_image(self, url):
        async with aiohttp.ClientSession() as session:
            async with session.get(url) as response:
                image_data = await response.read()
                image = Image.open(BytesIO(image_data))
                return image
    
    def __getitem__(self, index):
        data = self.data[index]
        url, caption = data['url'], data['caption']
        loop = asyncio.get_event_loop()
        image = loop.run_until_complete(self.fetch_image(url))
        # Preprocess image and return along with the caption
        image = image.resize((256, 256))  # Example resize
        return image, caption
    
    def __len__(self):
        return len(self.data)

# Example usage
dataset = URLDataset(mscoco_fused)
data_loader = DataLoader(dataset, batch_size=256, num_workers=8, prefetch_factor=2)

In [15]:
for i in tqdm.tqdm(data_loader):
    pass

  0%|          | 0/2312 [00:15<?, ?it/s]


KeyboardInterrupt: 

In [11]:
class CustomDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        url = self.dataset[idx]['url']
        caption = self.dataset[idx]['caption']
        image = fetch_single_image(url)  # Assuming fetch_single_image is defined elsewhere
        return {
            "url": url,
            "caption": caption,
            "image": image
        }

def collate_fn(batch):
    # Custom collation logic if needed
    print(batch)
    # urls = [item["url"] for item in batch]
    # fetch_single_image_with_args = partial(fetch_single_image, timeout=10, retries=3)
    # with ThreadPoolExecutor(max_workers=len(batch)) as executor:
    #     images = list(executor.map(fetch_single_image_with_args, urls))
    
    # return {
    #     "url": urls,
    #     "caption": [item["caption"] for item in batch],
    #     "image": images
    # }
    
# Assuming mscoco_fused is your dataset
dataset = CustomDataset(mscoco_fused)
data_loader = DataLoader(dataset, batch_size=512, num_workers=8, collate_fn=collate_fn, prefetch_factor=100)

In [12]:
for i in tqdm.tqdm(data_loader):
    # print(i)
    # break
    pass

  0%|          | 0/1156 [00:30<?, ?it/s]


KeyboardInterrupt: 

[{'url': 'http://images.cocodataset.org/train2017/000000461898.jpg', 'caption': 'A close up of a smiling woman with glasses, seen only from the waist up shows that she is manipulating something like a ribbon of diaphanous   fabric,  ', 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x636 at 0x7FA2443BD750>}, {'url': 'http://images.cocodataset.org/train2017/000000461898.jpg', 'caption': 'A smiling woman wearing glasses uses a Wii controller.', 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x636 at 0x7FA2443BE1D0>}, {'url': 'http://images.cocodataset.org/train2017/000000461898.jpg', 'caption': 'a woman is swinging around a video game controller', 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x636 at 0x7FA2757A5930>}, {'url': 'http://images.cocodataset.org/train2017/000000461898.jpg', 'caption': 'A smiling woman in motion, holding wii controls. ', 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x636 at 0x7FA275

In [40]:
queue.qsize()

1686

In [1]:
!pip install arrayqueues

Defaulting to user installation because normal site-packages is not writeable
Collecting arrayqueues
  Downloading arrayqueues-1.4.1-py3-none-any.whl.metadata (3.4 kB)
Downloading arrayqueues-1.4.1-py3-none-any.whl (6.4 kB)
[0mInstalling collected packages: arrayqueues
Successfully installed arrayqueues-1.4.1


In [None]:
with multiprocessing.Manager() as manager:
img_queue = manager.Queue()
process = multiprocessing.Process(target=parallel_image_loader, args=(mscoco_fused, img_queue, 8))
process.start()
process.join()

Map (num_proc=8):   0%|          | 0/591753 [00:00<?, ? examples/s]

Map (num_proc=8):   1%|          | 5504/591753 [00:08<10:17, 948.91 examples/s]

KeyboardInterrupt: 

Map (num_proc=8):   6%|▋         | 37568/591753 [00:36<07:50, 1176.97 examples/s]

: 

In [6]:
import multiprocessing
from multiprocessing import shared_memory
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from datasets import Dataset
import threading

def create_shared_array(shape, dtype):
    """Create a shared numpy array."""
    nbytes = np.prod(shape) * np.dtype(dtype).itemsize
    shm = shared_memory.SharedMemory(create=True, size=nbytes)
    array = np.ndarray(shape, dtype=dtype, buffer=shm.buf)
    return shm, array

def map_fn(url, caption, shared_array, shared_index, lock, shape, dtype):
    image = fetch_single_image(url)  # Assuming fetch_single_image is defined elsewhere
    with lock:
        index = shared_index.value
        shared_array[index] = np.frombuffer(image, dtype=dtype).reshape(shape)  # Store image in shared memory
        shared_index.value += 1  # Move to the next index
        # Save additional info (url, caption) if necessary

def map_batch_fn(batch, shared_array, shared_index, lock, shape, dtype, num_threads=64):
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        executor.map(
            map_fn, 
            batch["url"], 
            batch['caption'], 
            [shared_array] * len(batch["url"]), 
            [shared_index] * len(batch["url"]), 
            [lock] * len(batch["url"]), 
            [shape] * len(batch["url"]), 
            [dtype] * len(batch["url"])
        )

def parallel_image_loader(dataset: Dataset, shared_array, shared_index, lock, shape, dtype, num_workers: int = 8):
    batch_len = len(dataset) // num_workers
    batches = [dataset[i * batch_len:(i + 1) * batch_len] for i in range(num_workers)]
    with multiprocessing.Pool(num_workers) as pool:
        pool.starmap(
            map_batch_fn, 
            [(batch, shared_array, shared_index, lock, shape, dtype) for batch in batches]
        )

class ImageBatchIterator:
    def __init__(self, dataset: Dataset, num_workers: int = 8, batch_size: int = 64, image_shape=(224, 224, 3), dtype=np.uint8):
        self.dataset = dataset
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.image_shape = image_shape
        self.dtype = dtype
        
        # Create shared memory array
        self.shm, self.shared_array = create_shared_array((len(dataset),) + image_shape, dtype)
        self.shared_index = multiprocessing.Value('i', 0)  # Shared index counter
        self.lock = multiprocessing.Lock()  # Lock for safe indexing
        
        self.thread = threading.Thread(target=parallel_image_loader, args=(
            dataset, self.shared_array, self.shared_index, self.lock, image_shape, dtype, num_workers))
        self.thread.start()
        
    def __iter__(self):
        return self
    
    def __next__(self):
        if self.shared_index.value < self.batch_size:
            raise StopIteration
        
        batch_start = max(0, self.shared_index.value - self.batch_size)
        batch_end = self.shared_index.value
        batch = self.shared_array[batch_start:batch_end]
        return batch
    
    def __del__(self):
        self.thread.join()
        self.shm.close()
        self.shm.unlink()  # Free shared memory when done
        
    def __len__(self):
        return len(self.dataset) // self.batch_size

# Example usage:
dataset = ImageBatchIterator(mscoco_fused, num_workers=16, batch_size=64, image_shape=(224, 224, 3))
for i in tqdm.tqdm(range(0, 100)):
    batch = next(dataset)

  0%|          | 0/100 [00:00<?, ?it/s]


StopIteration: 

In [None]:
for i in tqdm.tqdm(range(0, 100)):
    batch = next(dataset)