In [None]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from taming.models.vqgan import VQModel
from omegaconf import OmegaConf

transform = transforms.Compose(
    [
        transforms.RandomResizedCrop(256, scale=(0.2, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

dataset = datasets.ImageFolder("data/train", transform=transform)

data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=8,
    num_workers=1,
    pin_memory=True,
    drop_last=True,
)

sample, _ = next(iter(data_loader))
sample = sample.cuda()

torch.Size([8, 3, 256, 256])


In [11]:
config = OmegaConf.load("config/vqgan.yaml").model
vqgan = VQModel(
    ddconfig=config.params.ddconfig,
    n_embed=config.params.n_embed,
    embed_dim=config.params.embed_dim,
    ckpt_path="pretrained_models/vae/vqgan_jax_strongaug.pt",
)
for param in vqgan.parameters():
    param.requires_grad = False

vqgan.eval()
vqgan.cuda()
with torch.no_grad():
    z_q, _, token_tuple = vqgan.encode(sample)

_, _, token_indices = token_tuple
B, C, H, W = z_q.shape
z_q = z_q.reshape(B, C, -1).permute(0, 2, 1)
token_indices = token_indices.reshape(B, -1)
gt_indices = token_indices.clone().detach().long()

Working with z of shape (1, 256, 16, 16) = 65536 dimensions.
Strict load
Restored from pretrained_models/vae/vqgan_jax_strongaug.pt


In [None]:
print(sample.shape)         # [B, C, H, W]
print(z_q.shape)            # [B, N, D]
print(token_indices.shape)  # [B, N]

torch.Size([8, 3, 256, 256])
torch.Size([8, 256, 256])
torch.Size([8, 256])


In [13]:
print(z_q.std())

tensor(1.0460, device='cuda:0')
