In [6]:
import numpy as np
import yaml
import os
import torch
import torchvision
import torchvision.transforms as transforms
from models.VQ_VAE import VQVAE

# Load the model

In [None]:
# load the model
config_path = '/home/wenhao/VQ_Selection/VQ_Selection_Real_Data/configs_codes_fixed_ratio/Cifar10_ori_16.yaml'
config = yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader)
weights_path = os.path.join(config['train_configs']['save_dir'], 'model.pth')

vq_params = config['model_configs']['vq_params']
backbone_configs = config['model_configs']['backbone_configs']

# load the model
model = VQVAE(vq_params, backbone_configs)

# load the weights
model.load_state_dict(torch.load(weights_path, map_location='cpu', weights_only=True), strict=True)

## Load the dataset

In [26]:
# load the data
batch_size = 64
num_workers = 24

transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                            ])
train_dataset = torchvision.datasets.CIFAR10(root='/data/zwh', train=True, download=False, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='/data/zwh', train=False, download=False, transform=transform)

train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [int(0.8 * len(train_dataset)), len(train_dataset) - int(0.8 * len(train_dataset))])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



## Get the latents for the dataset

In [None]:
def get_batch_latents(save_path,split='train'):
    min_encodings_list = []
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.to(device)
        latents = model.get_latents(data)
        # For Cifar 10 Encoder will comprese original image 32 * 32 into latents 8 * 8
        # min_encodings is the flattened latents after quantization   the shape is (batch_size * H * W, 1)
        min_encodings = latents.view(data.shape[0], 8, 8, 1)

        # reshape the min_encodings to (batch_size, H, W, 1)
        min_encodings = min_encodings.permute(0, 3, 1, 2)
        min_encodings = min_encodings.cpu().numpy()
        min_encodings_list.append(min_encodings)
    min_encodings = np.concatenate(min_encodings_list, axis=0)
    
    os.makedirs(save_path, exist_ok=True)
    np.save(os.path.join(save_path, f'{split}_latents.npy'), min_encodings)



# get latents for train dataset
model.eval()
model.to(device)
train_save_path = './train_latents' 
test_save_path = './test_latents'
val_save_path = './val_latents'

# get latents for train dataset
get_batch_latents(train_save_path, split='train')
get_batch_latents(val_save_path, split='val')
get_batch_latents(test_save_path, split='test')