In [None]:
from models import VectorQuantizedVAE
import torch
torch.set_printoptions(threshold=10000)
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
import numpy as np
import os 

from dataset import CustomDataset

from constants import *

torch.set_float32_matmul_precision('medium')

In [None]:
def generate_samples(images, model):
    with torch.no_grad():
        images = images.to(device)
        x_tilde, _, _ = model(images)
    return x_tilde

bs = 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = CustomDataset(transform=transforms.Compose([transforms.ToTensor(), 
                                                      transforms.Grayscale(),
                                                      ]))

data_loader = DataLoader(dataset, shuffle=True, batch_size=bs)

In [None]:
model = VectorQuantizedVAE.load_from_checkpoint('weights/latest.ckpt', num_hiddens=num_hiddens, num_residual_layers=num_residual_layers,                                                             num_residual_hiddens=num_residual_hiddens,
                                                num_embeddings=num_embeddings, embedding_dim=embedding_dim,
                                                commitment_cost=commitment_cost).to(device)
model.eval()

In [None]:
folder_dict = {}
folders = [f for f in os.listdir('all_mechanisms') if os.path.isdir(os.path.join('all_mechanisms', f))]

# Enumerate through the folders and assign numbers starting from 0
for index, folder_name in enumerate(folders):
    folder_dict[folder_name] = index

print(folder_dict)

# Iterate through the DataLoader to obtain and plot original and reconstructed images
for batch_num, batch in enumerate(data_loader):
    # Separate the batch into images and labels (assuming labels are not needed)
    images, description = batch
    # print(description)
    description = np.array([i for i in description])
    # print(description)
    mech_type = description[0].split('/')[0].split('\\')[1]
    # print(mech_type)

    description = description[0].split('/')[1].split(' ')
    description = [x for x in description if x]

    description[-5] = float(folder_dict[mech_type])
    description = np.array([float(x) for x in description]).reshape(1, -1)

    # print(description)

    images = images.to(device)
    
    # Forward pass through the VAE model to obtain reconstructed images
    with torch.no_grad():
        z = model._encoder(images)
        z = model._pre_vq_conv(z)
        _, z = model._vq_vae(z)
        z = z.cpu().detach().numpy()  

    description = description.reshape(description.shape[1])
    z = z.reshape(z.shape[0], z.shape[1], -1)

    np.savez('vq_embeddings/{}.npz'.format(batch_num), arr1=z, arr2=description)


In [None]:
# Fixed images for Tensorboard
fixed_images, _ = next(iter(data_loader))
fixed_grid = make_grid(fixed_images, nrow=8, normalize=True)
save_image(fixed_grid, 'orig_image.png')

# Generate the samples first once
_, reconstruction = model(fixed_images.to(device))
grid = make_grid(reconstruction.cpu(), nrow=8, normalize=True)
save_image(grid, 'recon_images.png')