In [7]:
import sys
sys.path.append("metrics")
sys.path.append("external")
sys.path.append("external/vqgan-taming")
from masked_model import MaskedModel, SequenceModelWrapper
from lowres_trainer import LowResTrainer

from vqgan import VQModel

from masked_transformer import Transformer

from muse_maskgit_pytorch import TransformerBlocks

import torch.nn as nn

import torch

import math

from masked_attn_mamba import MambaAttn
from vqganconfig import vqgan_config

from muse_maskgit_pytorch import MaskGit, MaskGitTransformer

from custom_datasets import ImageTextNameDataset

from torch.utils.data import Dataset, DataLoader, random_split

from tqdm import tqdm

from mambaCText import MambaCText
import torchvision

In [8]:

vae = VQModel(**vqgan_config)
vae.init_from_ckpt(path = "models/pretrained-vqgan.ckpt")

_ = vae.cuda()

Loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
Restored from  models/pretrained-vqgan.ckpt


In [9]:
transformer = MaskGitTransformer(
    actual_model = MambaAttn(
      token_size = 1024,
      depth = 12,
      d_state = 16,
    ),
    num_tokens = 16384,       # must be same as codebook size above
    seq_len = 16*16,            # must be equivalent to fmap_size ** 2 in vae
    dim = 1024,                # model dimension
)

base_maskgit = MaskGit(
    transformer = transformer, # transformer
    image_size = 256,          # image size
    cond_drop_prob = 0.25,     # conditional dropout, for classifier free guidance
    self_token_critic = True,
    no_mask_token_prob = 0.25,
).cuda()

base_maskgit.load("../results/mamba-attn-results/maskgit.1499999.pt")

In [10]:
valid_ds = ImageTextNameDataset(folder = "../dataset/cc3m-valid-original", image_size = 256)
batch_size = 8
valid_dl = DataLoader(valid_ds, batch_size = batch_size)

In [11]:
len(valid_ds)

13443

In [13]:
with torch.no_grad():
    for _, texts, names in tqdm(valid_dl):
        batch_size = len(texts)
        images = vae.decode_from_ids(base_maskgit.generate(texts = list(texts), fmap_size = 16, cond_scale = 4), batch_size = batch_size, fmap_size = 16)
        for i in range(batch_size):
            torchvision.utils.save_image(images[i], open("../valid-generated-mamba-attn/" + names[i] + ".jpg", "w"))

100%|██████████████████████| 1681/1681 [1:10:39<00:00,  2.52s/it]


In [6]:
# make a folder of resized 256x256 images

In [9]:
for resized_images, texts, names in tqdm(valid_dl):
    batch_size = len(texts)
    for i in range(batch_size):
        torchvision.utils.save_image(resized_images[i], open("../dataset/cc3m-valid-256x256-images/" + names[i] + ".jpg", "w"))

100%|█████████████████████████████████████████| 841/841 [01:52<00:00,  7.46it/s]


In [10]:
# make a folder of texts

In [12]:
for resized_images, texts, names in tqdm(valid_dl):
    batch_size = len(texts)
    for i in range(batch_size):
        path = "../dataset/cc3m-valid-256x256-texts/" + names[i] + ".txt"
        with open(path, "w") as f:
            f.write(texts[i])

100%|█████████████████████████████████████████| 841/841 [02:15<00:00,  6.22it/s]
