In [None]:
# !git clone https://github.com/Kemsekov/kemsekov_torch
# !cd kemsekov_torch && git pull

In [None]:
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from kemsekov_torch.train import split_dataset

# folder with images that will be used to train vqvae
images_folder = 'train_images_folder'


TRAIN_IMAGE_SIZE=(128,128)

# simple image augmentations
interpolation = T.InterpolationMode.NEAREST
tr = T.Compose([
    T.ToTensor(),
    T.Lambda(lambda x: x[:3]),
    T.Resize(int(1.12*TRAIN_IMAGE_SIZE[0])),
    T.RandomCrop(TRAIN_IMAGE_SIZE),
    T.ColorJitter(0.5,0.5,0.5),
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
])
dataset = ImageFolder(images_folder,transform=tr)

# split dataset into train and eval
train_dataset,test_dataset,train_loader, test_loader = split_dataset(
    dataset,
    test_size=0.05,
    num_workers=16,
    batch_size=16
)
len(train_dataset),len(test_dataset)

In [None]:
import matplotlib.pyplot as plt
from random import randint

# Set up a 4x4 grid for displaying images
fig, axes = plt.subplots(4, 4, figsize=(5, 5))

for i in range(4):
    for j in range(4):
        index = randint(0, len(dataset) - 1)       # Random index from dataset
        sample = dataset[index]                    # Select a random sample
        image, label = sample[0], sample[1]        # Separate image and label
        ax = axes[i, j]                            # Select subplot position
        # Display image on the selected subplot
        ax.imshow(T.ToPILImage()(image))
        ax.axis("off")                             # Hide axes for clean view

plt.tight_layout()
plt.show()

In [None]:
import torch
from sklearn.metrics import r2_score

def vqvae2_loss(x,x_rec,z,z_q,beta=0.25):
    """
    Computes loss for vqvae2 results.
    
    x: original x
    x_rec: reconstruction
    z: list of embeddings from encoder
    z_q: list of quantized embeddings
    beta: how fast to update encoder outputs z relative to reconstruction loss term
    """
    
    loss_ = lambda x,y : ((x-y)**2).mean()
    
    # general mse reconstruction loss
    # TODO: add perceptual loss
    rec_loss = loss_(x,x_rec)
    
    commitment_loss = 0
    # commitment loss
    for z_,z_q_ in zip(z,z_q):
        commitment_loss += loss_(z_,z_q_.detach())/len(z)
    commitment_loss/=len(z)
    
    return rec_loss+beta*commitment_loss

# r2 score that will be used in metrics
def r2(x,y):
    if isinstance(x,list):
        return sum([r2(a,b) for a,b in zip(x,y)])/len(x)
    return r2_score(x.detach().cpu().flatten(),y.detach().cpu().flatten())

In [None]:
from kemsekov_torch.vqvae.vqvae2 import VQVAE2Scale3
codebook_size=[512,512,512]

# 3 - levels hierarchical vqvae
vqvae = VQVAE2Scale3(
    in_channels=3,
    latent_dim=16,
    num_residual_layers=5,
    embedding_dim=96,
    codebook_size=[512,512,512],
    compression_ratio=4,
    dimensions=2, # we work with 2-dim images
).eval()

# see output shapes
s = dataset[0][0][None,:]
[[b.shape for b in v] if isinstance(v,list) else v.shape for v in vqvae(s)]

In [None]:
from kemsekov_torch.train import train
import os
import warnings
warnings.filterwarnings("ignore") # to ignore pil rbg channels warnings

def vqvae2_compute_loss_and_metric(model,batch):
    model.requires_grad_(True)
    x = batch[0]
    
    # feed input image to model
    # get reconstruction,latents and quantized latents
    x_rec,z,zq,_ = model(x)
    x_rec=x_rec.sigmoid()

    loss = vqvae2_loss(x,x_rec,z,zq)

    # get information about codebooks usage
    usage_bottom = model.quantizer_bottom.get_codebook_usage()
    usage_mid = model.quantizer_mid.get_codebook_usage()
    usage_top = model.quantizer_top.get_codebook_usage()

    return loss,{
        "rec_r2":r2(x,x_rec),
        "z_bottom_r2":r2(z[0],zq[0]),
        "z_mid_r2":r2(z[1],zq[1]),
        "z_top_r2":r2(z[2],zq[2]),
        "usage_bottom":usage_bottom,
        "usage_mid":usage_mid,
        'usage_top':usage_top,
    }

optim = torch.optim.AdamW(vqvae.parameters())
scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optim,T_max=len(train_loader))
# where to save training checkpoints
vqvae2_path = 'runs/vqvae2-16-4x'

vqvae = train(
    model=vqvae,
    train_loader=train_loader,
    test_loader=test_loader,
    compute_loss_and_metric=vqvae2_compute_loss_and_metric,
    optimizer=optim,
    scheduler=scheduler,
    num_epochs=1000,
    save_on_metric_improve=['rec_r2'],
    save_results_dir=vqvae2_path,
    load_checkpoint_dir=os.path.join(vqvae2_path,'last'),
    accelerate_args={
        'mixed_precision':'bf16',
        'dynamo_backend':'inductor'
    },
    gradient_clipping_max_norm=1,
)