In [3]:
import webdataset as wds
# import jax
# import jax.numpy as jnp
import augmax
import matplotlib.pyplot as plt

# import grain.python as pygrain
from typing import Any, Dict, List, Tuple
# import numpy as np
from functools import partial
import tqdm 

import fsspec
import json

import os
from transformers import AutoTokenizer, FlaxCLIPTextModel, CLIPTextModel

from datasets import load_dataset, concatenate_datasets
from datasets.utils.file_utils import get_datasets_user_agent
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import io
import urllib

import PIL.Image
import cv2

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
USER_AGENT = get_datasets_user_agent()


def fetch_single_image(image_url, timeout=None, retries=0):
    for _ in range(retries + 1):
        try:
            request = urllib.request.Request(
                image_url,
                data=None,
                headers={"user-agent": USER_AGENT},
            )
            with urllib.request.urlopen(request, timeout=timeout) as req:
                image = PIL.Image.open(io.BytesIO(req.read()))
            break
        except Exception:
            image = None
    return image

denormalizeImage = lambda x: (x + 1.0) * 127.5

def plotImages(imgs, fig_size=(8, 8), dpi=100):
    fig = plt.figure(figsize=fig_size, dpi=dpi)
    imglen = imgs.shape[0]
    for i in range(imglen):
        plt.subplot(fig_size[0], fig_size[1], i + 1)
        plt.imshow(jnp.astype(denormalizeImage(imgs[i, :, :, :]), jnp.uint8))
        plt.axis("off")
    plt.show()


# Filtering pipeline for various datasets

In [5]:
def dataMapper(map: Dict[str, Any]):
    def _map(sample) -> Dict[str, Any]:
        return {
            "url": sample[map["url"]],
            "caption": sample[map["caption"]],
        }
    return _map

def imageFetcher():
    def fetch_images(batch, num_threads, timeout=None, retries=0):
        fetch_single_image_with_args = partial(fetch_single_image, timeout=timeout, retries=retries)
        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            batch["image"] = list(executor.map(fetch_single_image_with_args, batch["url"]))
        return batch
    return fetch_images

def mapDataset(dataset, args, mapper=dataMapper, workers=16, batch_size=10000, should_remove_columns=True, fn_kwargs={}):
    if should_remove_columns:
        remove_columns = dataset.column_names
    else:
        remove_columns = None
    return dataset.map(mapper(*args), batched=True, batch_size=batch_size, remove_columns=remove_columns, num_proc=workers, fn_kwargs=fn_kwargs) 

In [5]:
laion12m6 = load_dataset("dclure/laion-aesthetics-12m-umap")
laion12m6_fused = laion12m6['train']
laionMap = {
    "url": "URL",
    "caption": "TEXT",
}
laion12m6_fused = mapDataset(laion12m6_fused, (laionMap, ))

Downloading readme: 100%|██████████| 4.01k/4.01k [00:00<00:00, 12.7MB/s]
Downloading data: 100%|██████████| 2.44G/2.44G [01:20<00:00, 30.2MB/s]
Generating train split: 100%|██████████| 12096809/12096809 [00:11<00:00, 1093969.94 examples/s]
Map (num_proc=16): 100%|██████████| 12096809/12096809 [00:49<00:00, 246633.91 examples/s] 


In [6]:
mscoco = load_dataset("ChristophSchuhmann/MS_COCO_2017_URL_TEXT")
mscoco_fused = mscoco['train']
mscocoMap = laionMap
mscoco_fused = mapDataset(mscoco_fused, (mscocoMap, ))

Downloading data: 100%|██████████| 18.3M/18.3M [00:01<00:00, 12.3MB/s]
Generating train split: 100%|██████████| 591753/591753 [00:00<00:00, 2517731.17 examples/s]
Map (num_proc=16): 100%|██████████| 591753/591753 [00:02<00:00, 262603.89 examples/s] 


In [7]:
fused_data = concatenate_datasets([laion12m6_fused, mscoco_fused, mscoco_fused])

In [8]:
len(fused_data)

12688562

In [None]:
for i in range(0, 20):
    sample = fused_data[i]
    img = fetch_single_image(sample['url'])
    if img is None:
        print("Image is None")
        continue
    text = sample['caption']
    plt.imshow(img)
    plt.title(text)
    # print(f"Aesthetic score: {sample['aesthetic_score_laion_v2']}")
    plt.show()

In [9]:
fused_data.to_parquet("laion-aesthetics-12m+mscoco-2017.parquet")

Creating parquet from Arrow format: 100%|██████████| 12689/12689 [00:20<00:00, 622.45ba/s] 


2454709042

In [118]:
imaged_data = mapDataset(mscoco_fused, (), mapper=imageFetcher, batch_size=5000, workers=64, should_remove_columns=False, fn_kwargs={"num_threads": 64})

  self.pid = os.fork()


Map (num_proc=64):  19%|█▊        | 110000/591753 [01:47<03:12, 2506.78 examples/s]

# COYO-700M Processing

In [5]:
coyo700 = load_dataset("kakaobrain/coyo-700m")#, num_proc=32)

Downloading readme: 100%|██████████| 14.8k/14.8k [00:00<00:00, 25.9MB/s]
Downloading data:   0%|          | 0/128 [00:03<?, ?files/s]


In [9]:
baseFilterMap = {
    # "word_count": {"min": 0, "max": 100},
    "clip_similarity_vitl14": {"min": 0.27, "max": 1000},
    "aesthetic_score_laion_v2": {"min": 5.1, "max": 100},
    "watermark_score": {"min": 0, "max": 0.4},
}

heavyFilterMap = {
    # "word_count": {"min": 0, "max": 100},
    "clip_similarity_vitl14": {"min": 0.25, "max": 100},
    "aesthetic_score_laion_v2": {"min": 6.05, "max": 100},
    "watermark_score": {"min": 0, "max": 0.8},
}

def coyoFilter(filterMap):
    def _filter(sample):
        for key, value in filterMap.items():
            if sample[key] < value["min"] or sample[key] > value["max"]:
                return False
        return True
    return _filter
    

In [10]:
# goodCoyo700 = coyo700.filter(coyoFilter(baseFilterMap), num_proc=64)
aestheticCoyo700 = coyo700.filter(coyoFilter(heavyFilterMap), num_proc=64)

Filter (num_proc=64): 100%|██████████| 746972269/746972269 [03:08<00:00, 3953092.18 examples/s]


In [12]:
len(aestheticCoyo700['train'])

1602504

In [None]:
for i in range(0, 100):
    sample = coyo700Filtered['train'][i]
    img = fetch_single_image(sample['url'])
    if img is None:
        print("Image is None")
        continue
    text = sample['text']
    plt.imshow(img)
    plt.title(text)
    print(f"Aesthetic score: {sample['aesthetic_score_laion_v2']}")
    plt.show()

In [14]:
final_data = mapDataset(aestheticCoyo700['train'], ({
    "url":"url",
    "caption":"text"
    }, ))

Map (num_proc=16): 100%|██████████| 1602504/1602504 [00:27<00:00, 59103.40 examples/s] 


In [16]:
final_data.to_parquet("aestheticCoyo_0.25clip_6aesthetic.parquet")

Creating parquet from Arrow format:   9%|▉         | 152/1603 [00:00<00:01, 741.38ba/s]

Creating parquet from Arrow format: 100%|██████████| 1603/1603 [00:02<00:00, 748.05ba/s]


330153346

In [17]:
final_data[0]

{'url': 'https://img3.goodfon.com/wallpaper/big/5/85/art-krenz-fallen-angel-angel.jpg',
 'caption': 'Picture girl, fiction, fire, wings, angel, red eyes, white hair, art, Angel Fall, Krenz'}

# Streaming Dataset Pipeline

In [None]:
def dataMapper(map: Dict[str, Any]):
    def _map(sample) -> Dict[str, Any]:
        url = sample[map["url"]]
        image = fetch_single_image(url, timeout=10, retries=2)
        return {
            "url": url,
            "caption": sample[map["caption"]],
            "image": image,
        }
    return _map

In [24]:
laion12m6 = load_dataset("dclure/laion-aesthetics-12m-umap", streaming=True)
laionMap = {
    "url": "URL",
    "caption": "TEXT",
}

In [26]:
data = laion12m6['train'].map(dataMapper(laionMap))

In [28]:
dataset = iter(data)
batch = next(dataset)

In [33]:
for i in tqdm.tqdm(range(100)):
    batch = next(dataset)
    

 44%|████▍     | 44/100 [01:52<02:22,  2.55s/it]


KeyboardInterrupt: 

In [13]:
laionMap = {
    "url": "URL",
    "caption": "TEXT",
}
laion12m6Fused = laion12m6['train'].map(dataMapper(laionMap))
imaged_data = mapDataset(laion12m6Fused, (), mapper=imageFetcher, batch_size=5000, should_remove_columns=False, fn_kwargs={"num_threads": 64})

In [15]:
for batch in imaged_data:
    print(len(batch))
    break

In [1]:
laion12m6.m

NameError: name 'laion12m6' is not defined