In [76]:
import sys
from pathlib import Path
from shutil import rmtree

import json
from tqdm import tqdm, trange

import torch

In [73]:
sys.path.append("..")
sys.path.append("../muse_rewrite")

In [74]:
from vae import VQGanVAE
from masked_transformer import MaskGitTransformer, MaskGit, TokenCritic
import matplotlib.pyplot as plt
import torchvision.transforms as T

from torch.utils.data import Dataset, DataLoader, random_split

In [75]:
vae = VQGanVAE(
    dim = 128,
    codebook_size = 8192
)

# this checkpoint was trained with image size upto 256
vae.load("../muse_rewrite/results-backup/vae.300000.ema.pt", is_ema=True)

In [13]:
transformer = MaskGitTransformer(
    num_tokens = 8192,       # must be same as codebook size above
    seq_len = 16*16,            # must be equivalent to fmap_size ** 2 in vae
    dim = 2048,                # model dimension
    depth = 24,                # depth
    dim_head = 64,            # attention head dimension
    heads = 8,                # attention heads,
    ff_mult = 2,              # feedforward expansion factor
)

base_maskgit = MaskGit(
    vae = vae,                 # vqgan vae
    transformer = transformer, # transformer
    image_size = 256,          # image size
    cond_drop_prob = 0.25,     # conditional dropout, for classifier free guidance
    self_token_critic = True,
    no_mask_token_prob = 0.25,
).cuda()

Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda


In [15]:
base_maskgit.load("../muse_rewrite/results-maskgit/maskgit.299999.pt")

In [68]:
def yes_or_no(question):
  answer = input(f'{question} (y/n)')
  return answer.lower() in ['yes', 'y']

In [77]:
class ImageTextDataset(Dataset):
  def __init__(
      self,
      folder,
      token_folder,
      annotations_path,
      image_size,
      tokenizer = None,
      exts = ['jpg', 'jpeg', 'png']
  ):
    super().__init__()
    self.image_size = image_size
    self.tokenizer = tokenizer
    self.token_folder = Path(token_folder)

    if yes_or_no("Do you want to clear token folder and recompute tokens? (yes/no)"):
      rmtree(str(self.token_folder))
      self.token_folder.mkdir(parents = True, exist_ok = True)
      self.should_compute_tokens = True
    else:
      self.should_compute_tokens = False
      
    
    self.transform = T.Compose([
      T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
      T.Resize((image_size, image_size)),
      # image might not fit the text after flip and crop
      # T.RandomHorizontalFlip(),
      # T.CenterCrop(image_size),
      T.ToTensor()
    ])

    image_paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

    image_annotations = json.load(open(annotations_path))["annotations"]

    image_annotations_keyed = dict()
    for annotation in image_annotations:
      image_id = annotation["image_id"]
      caption = annotation["caption"]
      if image_id not in image_annotations_keyed:
        image_annotations_keyed[image_id] = []
      image_annotations_keyed[image_id].append(caption)

    self.data = []
    for path in tqdm(image_paths):
      image_id = path.name.split(".")[0]
      encoded_path = f'{token_folder}/{image_id}.pt'
      for text in image_annotations_keyed[int(image_id)]:
        image_data = {
          "path": path,
          "encoded_path": encoded_path,
          "texts": text,
        }
        self.data.append(image_data)

      if self.should_compute_tokens:
          with torch.no_grad():
              _, indices, _ = self.tokenizer.encode(self.transform(Image.open(path)).unsqueeze(0).cuda())
              torch.save(indices[0], encoded_path)

    print(f'Found {len(self.data)} training samples at {folder}')

  def __len__(self):
    return len(self.data)

  def __getitem__(self, index):
    path = self.data[index]["path"]
    texts = self.data[index]["texts"]
    encoded_img = torch.load(self.data[index]["encoded_path"])
    img = Image.open(path)
    return encoded_img, self.transform(img), texts

In [78]:
validation_folder = "../COCO-Captions/val2017"
token_folder = "../"

In [80]:
ds = ImageTextDataset(validation_folder, 256)

TypeError: __init__() missing 2 required positional arguments: 'annotations_path' and 'image_size'