This notebooks help to prepare dataset for training flow matching model.

We convert images to VAE latents, compute CLIP embeddings and saves them locally for training.

In [3]:
#folder with images dataset
images_folder='/home/vlad/Documents/TreesDataset'

# where to save VAE latents
save_latents_dir = './latents_tree'

# image size for VAE
image_size=256

# generate also horizontal flip version for each image
add_flip = True

# apply grayscale to saved latents
vae_latents_grayscale = False

In [4]:
import os
import torchvision.transforms as T
from kemsekov_torch.utils import PadToMultiple
import tqdm
from vae import VaeEncoder
from clip_emb import *

device='cuda'
clip = CLIPEmbedder(device=device)
vae = VaeEncoder(device=device)

tr = T.Compose([
    T.ToTensor(),
    T.Resize(image_size,T.InterpolationMode.NEAREST),
    T.RandomHorizontalFlip(),
])
flip = T.RandomHorizontalFlip()

grayscale = T.Grayscale() if vae_latents_grayscale else lambda x:x 

pad = PadToMultiple(64)

image_formats=['.png','.jpg','.jpeg','.webp']
def save(clip, device, pad, save_folder, i, f, im):
    latents = vae.encode(pad(im)).half()
    emb = clip.image_to_embedding(f).half()
    torch.save(
            {
                "im_path":f,
                'vae_latents':latents,
                'clip_emb':emb,
            },
            os.path.join(save_folder,str(i)+".pt")
        )

for dir_path,subf,files in os.walk(images_folder):
    files = [os.path.join(dir_path,f) for f in files if [f.endswith(v) for v in image_formats]]
    class_ = os.path.split(dir_path)[-1]
    if len(files)==0: continue
    save_folder = os.path.join(save_latents_dir,class_)
    os.makedirs(save_folder,exist_ok=True)
    files_r = files
    data = tqdm.tqdm(enumerate(files_r),total=len(files_r),desc=class_)
    for i,f in data:
        im=tr(Image.open(f).convert("RGB"))
        im = grayscale(im)
        save(clip, device, pad, save_folder, i, f, im)
        if add_flip:
            save(clip, device, pad, save_folder, str(i)+"_flip", f, flip(im))

  from .autonotebook import tqdm as notebook_tqdm
Rotten_birch: 100%|██████████| 10/10 [00:01<00:00,  7.91it/s]
Spruce: 100%|██████████| 38/38 [00:03<00:00, 11.95it/s]
Mystical: 100%|██████████| 40/40 [00:05<00:00,  7.62it/s]
Ash: 100%|██████████| 8/8 [00:00<00:00, 12.60it/s]
unsorted: 100%|██████████| 129/129 [00:12<00:00, 10.13it/s]
Dead_birch: 100%|██████████| 29/29 [00:02<00:00, 11.27it/s]
Birch: 100%|██████████| 114/114 [00:10<00:00, 11.05it/s]
Dead_pine: 100%|██████████| 5/5 [00:00<00:00,  8.82it/s]
Willow: 100%|██████████| 2/2 [00:00<00:00, 18.06it/s]
Normal: 100%|██████████| 73/73 [00:07<00:00,  9.44it/s]
Aspen: 100%|██████████| 37/37 [00:04<00:00,  8.78it/s]
pine: 100%|██████████| 80/80 [00:07<00:00, 11.13it/s]
Oak: 100%|██████████| 9/9 [00:01<00:00,  4.96it/s]
Rotten_spruce: 100%|██████████| 7/7 [00:00<00:00,  7.90it/s]
Dead_spruce: 100%|██████████| 16/16 [00:02<00:00,  7.86it/s]
Dead_aspen: 100%|██████████| 29/29 [00:03<00:00,  8.56it/s]
Fir: 100%|██████████| 47/47 [00:05<00