In [None]:
!pip install torchvision
!pip install diffusers



In [6]:
import torch
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10
from torchvision import transforms
from diffusers import AutoencoderKL
import torch.nn.functional as F

import os
os.environ["DIFFUSERS_NO_FLAX"] = "1"

In [7]:
# download CIFAR-10 and save raw images and latent images processed by VAE
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 8
vae_image_size = 256
subset_num = 1000

data_dir = "./cifar10_data"
latent_root = "./cifar10_reg_vae"

os.makedirs(latent_root, exist_ok=True)
os.makedirs(f"{latent_root}/raw_image",exist_ok = True)
os.makedirs(f"{latent_root}/vae_latent",exist_ok = True)

dataset = CIFAR10(root = data_dir, train=True, download=True, transform = transforms.ToTensor())

if subset_num is not None :
  dataset = Subset(dataset, list(range(subset_num)))
loader = DataLoader(dataset,batch_size = batch_size, shuffle=False)


vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
vae.eval()

idx =  0
for imgs, label in loader :
  imgs = imgs.to(device)
  for img in imgs :
    # save raw image(32 x 32)
    torch.save(img.cpu(), f"{latent_root}/raw_image/{idx:06d}.pt")
    # save vae latent
    vae_img = F.interpolate(
        img.unsqueeze(0),
        size=(vae_image_size, vae_image_size),
        mode="bilinear",
        align_corners=False
    )
    #following the paper, it need to be adjusted
    vae_img = vae_img * 2 - 1  # [-1, 1]
    vae_img = vae_img.to(device)

    with torch.no_grad():
      latent = vae.encode(vae_img).latent_dist.sample()
      latent = latent.squeeze(0).cpu()
      torch.save(latent, f"{latent_root}/vae_latent/{idx:06d}.pt")
    idx += 1
    if idx % 100 == 0:
      print(f"{idx} images processed")





100 images processed
200 images processed
300 images processed
400 images processed
500 images processed
600 images processed
700 images processed
800 images processed
900 images processed
1000 images processed


In [8]:
#download it into local
import shutil
import os
from google.colab import files

folder_to_download = "cifar10_reg_vae"
zip_file_name = f"{folder_to_download}.zip"

# Create a zip archive of the folder
shutil.make_archive(folder_to_download, 'zip', folder_to_download)

# Download the zip file
files.download(zip_file_name)

print(f"The folder '{folder_to_download}' has been zipped as '{zip_file_name}' and is ready for download.")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

The folder 'cifar10_reg_vae' has been zipped as 'cifar10_reg_vae.zip' and is ready for download.


In [None]:
!pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-xghpvsk0
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-xghpvsk0
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369490 sha256=46e28609088c83afe2137a3f4572c37555c20ae9c485b02ead3762f6bf500a07
  Stored in directory: /tmp/pip-ephem-wheel-cache-nm8_f26i/wheels/35/3e/df/3d24cbfb3b6a06f17

In [None]:
# download pre-trained weight of clip model
import clip
import torch

model, preprocess = clip.load("ViT-B/32", device="cpu")

100%|███████████████████████████████████████| 338M/338M [00:05<00:00, 65.8MiB/s]


In [None]:
import os

os.listdir("/root/.cache/clip")

['ViT-B-32.pt']

In [None]:
from google.colab import files

files.download("/root/.cache/clip/ViT-B-32.pt")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>