In [1]:
import pdb

from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image
import requests
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from laion_dataloader import make_train_dataset

In [4]:
device = torch.device('cuda')

In [5]:
torch.cuda.is_available()

True

In [6]:
data_dir = '/shared_drive/user-files/laion_dataset_200M/laion200m-data'

IMAGE_RESOLUTION = 256

In [7]:
train_dataset = make_train_dataset(data_dir=data_dir,
                                   seed=42, buffer_size=100, resolution=IMAGE_RESOLUTION)

Resolving data files:   0%|          | 0/16254 [00:00<?, ?it/s]

In [8]:
# you can specify the revision tag if you don't want the timm dependency
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")

model = model.to(device)

# DataLoader: option 1

In [None]:
batch_size = 128

In [37]:
def collate_fn(train_dataset):
        images = [example['jpg'].convert("RGB") for example in train_dataset]
        return processor(images=images, return_tensors="pt", do_resize=False).to(device)

dataloader = DataLoader(train_dataset, batch_size=batch_size,
                        collate_fn=collate_fn)

In [38]:
target_sizes = torch.tensor([[IMAGE_RESOLUTION, IMAGE_RESOLUTION]]*batch_size)

In [39]:
for step, batch in tqdm(enumerate(dataloader)):
    outputs = model(**batch)

    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)

0it [00:00, ?it/s]

KeyboardInterrupt: 

batch_size = 128, ~1.15it/s

In [12]:
# %debug

# DataLoader: option 2

In [9]:
batch_size = 128
num_workers = 6

In [10]:
target_sizes = torch.tensor([[IMAGE_RESOLUTION, IMAGE_RESOLUTION]]*batch_size).to(device)

In [11]:
def collate_fn(train_dataset): 
    images = []
    urls = []
    
    for example in train_dataset:
        images.append(example['jpg'].convert("RGB"))
        urls.append(example['url'])
        
    return processor(images=images, return_tensors="pt", do_resize=False), urls

dataloader = DataLoader(train_dataset, 
                        batch_size=batch_size,
                        num_workers=num_workers,
                        pin_memory=True,
                        collate_fn=collate_fn)

In [12]:
def save_results_to_parquet(results, urls, model, step):
    rows = []
    columns = ['url', 'label', 'score', 'top_left_x', 'top_left_y', 'bottom_right_x', 'bottom_right_y']
    for i, result_per_image in enumerate(results):
        for score, label, box in zip(result_per_image["scores"], result_per_image["labels"], result_per_image["boxes"]):
            url = urls[i]
            box = box.detach().cpu()
            top_left_x, top_left_y, bottom_right_x, bottom_right_y = box[0].item(), box[1].item(), box[2].item(), box[3].item()
            label_name = model.config.id2label[label.item()]
            score = np.round(score.detach().cpu().item(), 2)
    
            row = [url, label_name, score, top_left_x, top_left_y, bottom_right_x, bottom_right_y]
            rows.append(row)
        
    df = pd.DataFrame(rows, columns=columns)
    df.to_parquet(f'/shared_drive/user-files/laion_dataset_200M/laion200m-od-labels/{step}_batch.parquet')

In [14]:
for step, batch in tqdm(enumerate(dataloader)):
    batch_urls = batch[1]
    batch_tensors = batch[0]
    batch_tensors['pixel_values'] = batch_tensors['pixel_values'].to(device)
    batch_tensors['pixel_mask'] = batch_tensors['pixel_mask'].to(device)

    outputs = model(**batch_tensors)

    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)

    save_results_to_parquet(results=results, 
                                 urls=batch_urls, 
                                 model=model, 
                                 step=step)
    # pdb.set_trace()

batch_size = 128, num_workers = 6 -- ~5it/s \
batch_size = 128, num_workers = 12 -- ~5it/s


batch_size = 128, num_workers = 6, pin_memory=True -- ~6it/s. 10GB GPU memory\
batch_size = 128, num_workers = 6, pin_memory=True -- ~4-6it/s. 10GB GPU memory (**with labels saving**)\
batch_size = 256, num_workers = 6, pin_memory=True -- ~3it/s. 20GB GPU memory

128 * 6 (or 256 * 3) = 768 images/sec \
**125 000 000 images/ 45 hours**

**with labels saving**\
128 * 5 = 640 images/sec\
125 000 000 images/ 54 hours - **2.5 days**