In [1]:
import pdb

from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image
import requests
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd

In [2]:
from typing import List, Union
from glob import glob
from os import path
from time import time
import torch.nn as nn
from transformers.image_transforms import center_to_corners_format

In [3]:
import wandb
from wandb import AlertLevel

In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
from laion_dataloader import make_train_dataset

In [6]:
device = torch.device('cuda')

In [7]:
torch.cuda.is_available()

True

In [8]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mdiana15kapatsyn[0m ([33mteam__1[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [10]:
data_dir = '/shared_drive/user-files/laion_dataset_200M/laion200m-data'

IMAGE_RESOLUTION = 256

In [11]:
train_dataset = make_train_dataset(data_dir=data_dir,
                                   seed=42, buffer_size=100, resolution=IMAGE_RESOLUTION)

Resolving data files:   0%|          | 0/16254 [00:00<?, ?it/s]

In [2]:
1331200/7868

169.19166243009659

In [12]:
16254*7800

126781200

In [13]:
# you can specify the revision tag if you don't want the timm dependency
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")

model = model.to(device)

# DataLoader: option 1

In [None]:
batch_size = 128

In [None]:
def collate_fn(train_dataset):
        images = [example['jpg'].convert("RGB") for example in train_dataset]
        return processor(images=images, return_tensors="pt", do_resize=False).to(device)

dataloader = DataLoader(train_dataset, batch_size=batch_size,
                        collate_fn=collate_fn)

In [None]:
target_sizes = torch.tensor([[IMAGE_RESOLUTION, IMAGE_RESOLUTION]]*batch_size)

In [None]:
for step, batch in tqdm(enumerate(dataloader)):
    outputs = model(**batch)

    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)

0it [00:00, ?it/s]

KeyboardInterrupt: 

batch_size = 128, ~1.15it/s

In [None]:
# %debug

# DataLoader: option 2

In [None]:
batch_size = 128
num_workers = 6

In [None]:
target_sizes = torch.tensor([[IMAGE_RESOLUTION, IMAGE_RESOLUTION]]*batch_size).to(device)

In [None]:
def collate_fn(train_dataset): 
    images = []
    urls = []
    
    for example in train_dataset:
        images.append(example['jpg'].convert("RGB"))
        urls.append(example['url'])
        
    return processor(images=images, return_tensors="pt", do_resize=False), urls

dataloader = DataLoader(train_dataset, 
                        batch_size=batch_size,
                        num_workers=num_workers,
                        pin_memory=True,
                        collate_fn=collate_fn)

In [None]:
def save_results_to_parquet(results, urls, model, step):
    rows = []
    columns = ['url', 'label', 'score', 'top_left_x', 'top_left_y', 'bottom_right_x', 'bottom_right_y']
    for i, result_per_image in enumerate(results):
        for score, label, box in zip(result_per_image["scores"], result_per_image["labels"], result_per_image["boxes"]):
            url = urls[i]
            box = box.detach().cpu()
            top_left_x, top_left_y, bottom_right_x, bottom_right_y = box[0].item(), box[1].item(), box[2].item(), box[3].item()
            label_name = model.config.id2label[label.item()]
            score = np.round(score.detach().cpu().item(), 2)
    
            row = [url, label_name, score, top_left_x, top_left_y, bottom_right_x, bottom_right_y]
            rows.append(row)
        
    df = pd.DataFrame(rows, columns=columns)
    df.to_parquet(f'/shared_drive/user-files/laion_dataset_200M/laion200m-od-labels/{step}_batch.parquet')

In [None]:
for step, batch in tqdm(enumerate(dataloader)):
    batch_urls = batch[1]
    batch_tensors = batch[0]
    batch_tensors['pixel_values'] = batch_tensors['pixel_values'].to(device)
    batch_tensors['pixel_mask'] = batch_tensors['pixel_mask'].to(device)

    outputs = model(**batch_tensors)

    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)

    save_results_to_parquet(results=results, 
                                 urls=batch_urls, 
                                 model=model, 
                                 step=step)
    # pdb.set_trace()

batch_size = 128, num_workers = 6 -- ~5it/s \
batch_size = 128, num_workers = 12 -- ~5it/s


batch_size = 128, num_workers = 6, pin_memory=True -- ~6it/s. 10GB GPU memory\
batch_size = 128, num_workers = 6, pin_memory=True -- ~4-6it/s. 10GB GPU memory (**with labels saving**)\
batch_size = 256, num_workers = 6, pin_memory=True -- ~3it/s. 20GB GPU memory

128 * 6 (or 256 * 3) = 768 images/sec \
**125 000 000 images/ 45 hours**

**with labels saving**\
128 * 5 = 640 images/sec\
125 000 000 images/ 54 hours - **2.5 days**

# Dataloader option 3

In [14]:
IMAGE_RESOLUTION=256
BATCH_SIZE=128
NUM_WORKERS=12

In [15]:
target_sizes = torch.tensor([[IMAGE_RESOLUTION, IMAGE_RESOLUTION]]*BATCH_SIZE).to(device)

In [16]:
def save_results_to_parquet(results, urls, model, step):
    rows = []
    columns = ['url', 'label', 'score', 'top_left_x', 'top_left_y', 'bottom_right_x', 'bottom_right_y']
    for i, result_per_image in enumerate(results):
        for score, label, box in zip(result_per_image["scores"], result_per_image["labels"], result_per_image["boxes"]):
            url = urls[i]
            box = box.detach().cpu()
            top_left_x, top_left_y, bottom_right_x, bottom_right_y = box[0].item(), box[1].item(), box[2].item(), box[3].item()
            label_name = model.config.id2label[label.item()]
            score = np.round(score.detach().cpu().item(), 2)
    
            row = [url, label_name, score, top_left_x, top_left_y, bottom_right_x, bottom_right_y]
            rows.append(row)
        
    df = pd.DataFrame(rows, columns=columns)
    df.to_parquet(f'/mnt/disks/disk-big2/laion200m-od-labels/{step}_batch.parquet')

In [17]:
def post_process_object_detection(
        outputs, threshold: float = 0.5, target_sizes = None
    ):
        out_logits, out_bbox = outputs.logits, outputs.pred_boxes

        if target_sizes is not None:
            if len(out_logits) != len(target_sizes):
                raise ValueError(
                    "Make sure that you pass in as many target sizes as the batch dimension of the logits"
                )

        # t = time()
        prob = nn.functional.softmax(out_logits, -1)
        scores, labels = prob[..., :-1].max(-1)
        # print(time() - t)

        # t = time()
        # Convert to [x0, y0, x1, y1] format
        boxes = center_to_corners_format(out_bbox)
        # print(time() - t)

        mask = scores > threshold
    
        # mask = mask.cpu()
        # scores = scores.cpu()
        # labels = labels.cpu()
        # boxes = boxes.cpu()

        # t = time()
        # Convert from relative [0, 1] to absolute [0, height] coordinates
        if target_sizes is not None:
            if isinstance(target_sizes, List):
                img_h = torch.Tensor([i[0] for i in target_sizes])
                img_w = torch.Tensor([i[1] for i in target_sizes])
            else:
                img_h, img_w = target_sizes.unbind(1)

            scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
            boxes = boxes * scale_fct[:, None, :]
        # print(time() - t)

        # t = time()
        results = []
    
        # print(len(scores))
        # mask_idx, class_idx = torch.where(mask)
        # class_idx[i, i*128]

        for i in range(len(scores)):
        # for s, l, b in zip(scores, labels, boxes):
            score = scores[i][mask[i]]
            label = labels[i][mask[i]]
            box = boxes[i][mask[i]]
            # print(s.shape)
            results.append({"scores": score, "labels": label, "boxes": box})
        # print(time() - t)
        # print()

        return results

# Run

In [18]:
run = wandb.init(
    # Set the project where this run will be logged
    project="object-detector",
    # Track hyperparameters and run metadata
    config={
    },
)

In [19]:
def collate_fn(train_dataset): 
    images = []
    urls = []
    
    for example in train_dataset:
        images.append(example['jpg'].convert("RGB"))
        urls.append(example['url'])
        
    return processor(images=images, return_tensors="pt", do_resize=False), urls

In [20]:
dataloader = DataLoader(train_dataset, 
                        batch_size=BATCH_SIZE,
                        num_workers=NUM_WORKERS,
                        pin_memory=True,
                        collate_fn=collate_fn)

In [1]:
128*10400

1331200

In [21]:
t = time()

for step, batch in tqdm(enumerate(dataloader)):
    if step > 10400:
        try:
            batch_urls = batch[1]
            batch_tensors = batch[0]
        
            batch_tensors['pixel_values'] = batch_tensors['pixel_values'].to(device)
            batch_tensors['pixel_mask'] = batch_tensors['pixel_mask'].to(device)
        
            outputs = model(**batch_tensors)
        
            results = post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)
        
            save_results_to_parquet(results=results, 
                                            urls=batch_urls, 
                                            model=model, 
                                            step=step)
        #     wandb.log({"n_batch": step})
        except:
            pdb.set_trace()
        #     wandb.alert(title=f"Batch Warning!",
        #                 text=f"Problem with batch {step}",
        #                 level=AlertLevel.WARN)

print(time() - t)
# wandb.alert(title=f"Run finished!",
#             text = f"Objects successfully detected in all {step+1} batches in {np.round((time() - t)/3600, 2)} hours !!! :)",
#                    level=AlertLevel.INFO)
# wandb.finish()

0it [00:00, ?it/s]

Exception in thread Thread-7 (_pin_memory_loop):
Traceback (most recent call last):
  File "/shared_drive/user-files/miniconda3/envs/controlnet-scalinglaws/lib/python3.11/threading.py", line 1038, in _bootstrap_inner
    self.run()
  File "/shared_drive/user-files/miniconda3/envs/controlnet-scalinglaws/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 761, in run_closure
    _threading_Thread_run(self)
  File "/shared_drive/user-files/miniconda3/envs/controlnet-scalinglaws/lib/python3.11/threading.py", line 975, in run
    self._target(*self._args, **self._kwargs)
  File "/shared_drive/user-files/miniconda3/envs/controlnet-scalinglaws/lib/python3.11/site-packages/torch/utils/data/_utils/pin_memory.py", line 53, in _pin_memory_loop
    do_one_step()
  File "/shared_drive/user-files/miniconda3/envs/controlnet-scalinglaws/lib/python3.11/site-packages/torch/utils/data/_utils/pin_memory.py", line 30, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        ^^^^^

KeyboardInterrupt: 

In [39]:
step

171