In [1]:
import torch
from model import VQModel

torch.set_float32_matmul_precision('high')

model = VQModel()
sd = torch.load('../weights/vqgan.ckpt', map_location='cpu')['state_dict']
model.load_state_dict(sd, strict=False)
model.bfloat16().cuda()

from model.vqgan.image_tokenizer import ImageTokenizer
tokenizer = ImageTokenizer(model)
torch.cuda.empty_cache()

Working with z of shape (1, 256, 32, 32) = 262144 dimensions.


In [2]:
import datasets
from torch.utils.data import DataLoader

dataset = datasets.load_dataset('/home/andrew264/datasets/imagenet-1k', streaming=True, trust_remote_code=True,
                                    split='train')

def transform(examples):
    return {'image': [tokenizer.img_tokens_from_pil(examples['image'])]}

dataset = dataset.map(transform, batched=True, batch_size=4, remove_columns=['label'])
dataloader = DataLoader(dataset, batch_size=25, num_workers=0)

In [3]:
import os
import numpy as np


def write_datafile(f_name: str, toks: np.ndarray):
    """
    Saves token data as a .bin file
    """
    print(f"writing {len(toks):,} tokens to {f_name}")
    with open(os.path.join('../data/imagenet/', f_name), 'wb') as f:
        f.write(toks.tobytes())

In [None]:
import tqdm

images_per_shard = int(1e5)
per_image_size = 1024

all_np_tokens = np.empty((per_image_size*images_per_shard), dtype=np.uint16)
index = 0

for batch in tqdm.tqdm(dataloader):
    img_tokens = batch['image'].view(-1)
    img_tokens = img_tokens.cpu().numpy().astype(np.uint16)
    batch_size = len(img_tokens)

    if index + batch_size > len(all_np_tokens):
        write_datafile(f"tokens_shard_{index // (per_image_size*images_per_shard)}.bin", all_np_tokens[:index])
        all_np_tokens = np.empty((per_image_size*images_per_shard), dtype=np.uint16)
        index = 0
    
    all_np_tokens[index:index+batch_size] = img_tokens
    index += batch_size

if index > 0:
    write_datafile(f"tokens_shard_{index // (per_image_size*images_per_shard)}.bin", all_np_tokens[:index])

print("Token writing process completed.")

0it [00:00, ?it/s]

In [None]:
with open('../data/tokens_shard_1.bin', 'rb') as f:
    data = f.read()
    tokens = np.frombuffer(data, dtype=np.uint16).reshape(-1, 1024)
    print(tokens.shape)

In [None]:
torch.tensor(tokens[-100])

In [None]:
tokenizer.pil_from_img_toks(torch.tensor(tokens[-100]).unsqueeze(0).int().cuda())[0]