We precompute the image and text tokens before training the sequence model. The image tokens will also be flattened.

In [1]:
import sys
from pathlib import Path
sys.path.append("../external")
sys.path.append("..")

from tqdm import tqdm

In [2]:
from torch.utils.data import Dataset, DataLoader
from einops import rearrange
import torch

In [3]:
from muse_maskgit_pytorch.t5 import t5_encode_text, DEFAULT_T5_NAME

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def encode_text(texts):
  return t5_encode_text(DEFAULT_T5_NAME)

In [5]:
from vae import VQGanVAE

In [6]:
from datasets import ImageTextNameDataset

In [7]:
def precompute(dataset, vae, t5_encode_fn, save_to, batch_size):
  dataloader = DataLoader(dataset, batch_size = batch_size)

  for images, texts, file_names in tqdm(dataloader):
    with torch.no_grad():
      _, indices, _ = vae.encode(images)
      text_embeds = t5_encode_fn(texts)
    indices = rearrange(indices, "b d d -> b (d d)")

    for i in range(batch_size):
      torch.save([indices[i], text_embeds[i]], f'{save_to}/{file_names[i]}.pt')

In [8]:
vae = VQGanVAE(dim = 128, codebook_size = 8192)

In [9]:
vae.load("../models/vae.199999-256x256.ema.pt")

In [10]:
dataset = ImageTextNameDataset("../../cc3m/", image_size = 256)

KeyboardInterrupt: 

In [None]:
precompute(dataset, vae, encode_text, "cc3m-precomputed", batch_size = 32)