In [None]:
%load_ext autoreload

%autoreload 2

In [None]:
import os

import pandas as pd

from varying_aspect_ratio_dataset import create_df_from_parquets, assign_to_buckets

max_files = 101

cache = f"laion_aesthetics_{max_files}.parquet"

#path = "/hdd/data/finetune_SD/laion_aesthetics"    
path = "../../../finetune_SD/laion_aesthetics_2/" 

if os.path.exists(cache):
    df = pd.read_parquet(cache)
else:
    #df = create_df_from_parquets(path, min_width=128, min_height=128, max_files=max_files)
    #df = assign_to_buckets(df, bucket_step_size=64, max_width=768, max_height=768, min_bucket_count=64)
    df = create_df_from_parquets(path, min_width=64, min_height=64, max_files=max_files)
    df = assign_to_buckets(df, bucket_step_size=64, max_width=128, max_height=128, min_bucket_count=64)
    df.to_parquet(cache)

In [None]:
import numpy as np

np.abs((df.width / df.height) - (df.bucket_width / df.bucket_height)).mean()
# average ratio diff 0.14

In [None]:
(df.width - df.bucket_width).mean() * (df.height - df.bucket_height).mean()
# on average around 27 pixels lost when cropping

In [None]:
from transformers import CLIPTokenizer

#model_path = "runwayml/stable-diffusion-v1-5"
model_path = "../../../mus2vid/models/stable-diffusion-v1-5"

# Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", device="cpu")

In [None]:
from varying_aspect_ratio_dataset import BucketBatchSampler, BucketDataset, BucketSampler

from torch.utils.data import DataLoader

use_batch_sampler = True
batch_size = 2


if use_batch_sampler:
    bucket_batch_sampler = BucketBatchSampler(df["bucket"], batch_size=batch_size) 
    bucket_dataset = BucketDataset(df, tokenizer)

    dataloader = DataLoader(bucket_dataset, batch_size=1, 
                            batch_sampler=bucket_batch_sampler, 
                            shuffle=False, 
                            num_workers=16, 
                            drop_last=False,
                            pin_memory=True,
                            )
else:
    bucket_batch_sampler = BucketSampler(df["bucket"], batch_size=batch_size) 
    bucket_dataset = BucketDataset(df, tokenizer)

    dataloader = DataLoader(bucket_dataset, batch_size=batch_size, 
                            sampler=bucket_batch_sampler, 
                            shuffle=False, 
                            num_workers=16, 
                            drop_last=False,
                            pin_memory=True,
                            )
    

In [None]:
import random
%timeit bucket_dataset[random.randint(0, 1000)]
# 5ms per img

In [None]:
from tqdm import tqdm

def run(dl, steps, verbose=False):
    iterator = iter(dl)
    for i, batch in tqdm(enumerate(iterator), disable=not verbose):
        img = batch["pixel_values"].cuda(non_blocking=True)
        if i == steps:
            break

In [None]:
df.bucket_width.unique()

In [None]:
out = next(iter(dataloader))

In [None]:
out["pixel_values"].shape

In [None]:
%timeit run(dataloader, 1000)

In [None]:
%timeit run(dataloader, 1000)

In [None]:
# with uncompressing (100 steps)
# 1 worker - 3.32 s, 30.44 it/s
# 2 worker - 2.34 s, 43.66 it/s
# 4 worker - 1.8 s, 58.11 it/s
# 8 worker - 1.3 s, 86.4 it/s
# 16 worker - 1.4 s, 88.4 it/s


# 1000 steps
# 8 - 12.9 - 78.8 it/s
# 16 - 10.5 - 93 it/s
# 16 - 8.7 - 118 it/s


# 1000 steps - img to cuda
# 16 - 9.97, 8.5, 7.4, 7.3, 7.3 - mean=8.1

# 1000 steps - img to cuda - pin memory
# 16 - 8.2, 9.5, 8.0, 9.2, 7.8 - mean=8.5
# .cuda(non_blocking=True) - 7.0, 7.4, 7.8, 7.0, 6,6, 6.6 = 6.8

# best setting for batch size 1: 16 worker, pin_memory=True, .cuda(non_blocking=True)
# with best setting using %timeit
# bs 1: 4.23 s - 238 imgs/s == 4.2 ms
# bs 2: 9.52 s
# bs 4: 15.7 s
# bs 8: 27.9 s - 286 images per second == 3.5 ms

In [None]:
# when loading from .tar files... :

# 0 worker total: 0.14, 91.0, 4.5 s/it
# 1 worker - 85.7 for its, 4.23 s/it
# 4 worker - 354 for completion, 17.7 s/it


# new sampler
# 0 worker - 78.8, 3.9     # bs==2: 149, 7.45 it/s
# 1 worker - 82.4, 4.1 s/it # bs2: 159, 7.9 it/s
# 4 worker - 256 for completion, 12.8 s/it

In [None]:
iterator = iter(bucket_batch_sampler)

In [None]:
next(iterator).shape

In [None]:
img = bucket_dataset[0]

In [None]:
img["pixel_values"].shape

In [None]:
img

In [None]:
iterator = iter(dataloader)

In [None]:
batch = next(iterator)

In [None]:
batch["input_ids"].shape

In [None]:
imgs = batch["pixel_values"]

In [None]:
import torchvision

torchvision.transforms.ToPILImage()(imgs[1])