# Import Package

In [1]:
import os
import os
import argparse
import random
import math
import torch
import os
import cv2
import torch
import numpy as np
import albumentations as A
import torchvision.transforms as transforms
from torchvision import transforms as T
from torch.utils.data import Dataset
import torch.distributed as dist

from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR
from timm.scheduler.cosine_lr import CosineLRScheduler
from mmcv.utils import Config, DictAction
from datetime import datetime

from utils.train_api import save_model
from utils.logger import get_logger

from dataset.custom import CustomDataset
from utils.metrics import eval_depth, cropping_img
from models.VQVAE import VQVAE
import json

# Load Model and Dataset

In [2]:
# image_size = (640, 480)

image_size = 320

val_dataset = CustomDataset(
        data_path=[
            "/home/nil/manipulation/dataset/pick_apple_100_328/rgb/pick_apple_2/train/episode_0/"
        ],
        crop_size=(image_size, image_size),
        single_img=False)

val_data_loader = torch.utils.data.DataLoader(
        val_dataset, sampler=None,
        batch_size=1,
        num_workers=1,
    )

cfg_model = dict(
    image_size=image_size, # TODO
    num_resnet_blocks=2,
    downsample_ratio=32,
    num_tokens=128,
    codebook_dim=512,
    hidden_dim=16,
    use_norm=False,
    channels=1,
    train_objective='regression',
    max_value=10.,
    residul_type='v1',
    loss_type='mse',
)

vae = VQVAE(
    **cfg_model,
    ).cuda()

model_path = "/home/nil/manipulation/groundvla/libs/AiT/checkpoints/vae-final.pt"
ckpt = torch.load(model_path)['weights']
if 'module' in list(ckpt.keys())[0]:
    new_ckpt = {}
    for key in list(ckpt.keys()):
        ## remove module.
        new_key = key[7:]
        new_ckpt[new_key] = ckpt[key]
    vae.load_state_dict(
        new_ckpt,
    )
else:
    vae.load_state_dict(
        ckpt,
    )

# of test images: 118


In [3]:
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import os


In [4]:
vae.eval()
save_path = "/home/nil/manipulation/groundvla/libs/AiT/vae/tmp.json"
if os.path.exists(save_path):
    with open(save_path, "r") as file:
        final_codes = json.load(file)
else:
    final_codes = {}
    
with torch.no_grad():
    counter = 0
    for iter, (image, _, path) in tqdm(enumerate(val_data_loader)):
        
        codes = vae(
                img=image.cuda(),return_indices= True
        )

        codes = codes[0].flatten().detach().cpu().tolist()
        print(codes)
        print("len(codes):", len(codes))
        assert len(codes) == 100
        depth_string = "<DEPTH_START>"
        depth_string += "".join([f"<DEPTH_{num}>" for num in codes ])
        depth_string += "<DEPTH_END>"
        final_codes[path[0]] = depth_string  
    
    with open(save_path, "w") as file:
        json.dump(final_codes, file, indent=4)
print(counter)

8it [00:00, 25.90it/s]

[36, 36, 103, 50, 40, 120, 30, 11, 20, 85, 37, 33, 111, 3, 10, 109, 29, 103, 13, 29, 32, 32, 64, 12, 10, 109, 22, 11, 32, 85, 67, 32, 11, 22, 109, 40, 96, 73, 32, 8, 68, 81, 2, 26, 32, 109, 33, 89, 13, 11, 55, 38, 30, 99, 116, 60, 22, 2, 34, 69, 124, 59, 74, 42, 34, 21, 32, 81, 69, 90, 123, 44, 101, 114, 76, 67, 45, 96, 72, 43, 66, 69, 82, 12, 72, 60, 123, 41, 89, 87, 59, 50, 93, 70, 35, 123, 60, 120, 105, 3]
len(codes): 100
[97, 36, 103, 50, 40, 109, 30, 11, 20, 85, 37, 33, 111, 3, 10, 109, 29, 103, 13, 17, 32, 54, 64, 12, 10, 109, 22, 11, 32, 85, 67, 32, 11, 22, 109, 120, 96, 73, 32, 8, 68, 81, 2, 26, 32, 109, 33, 89, 13, 11, 55, 38, 73, 99, 116, 60, 22, 2, 34, 69, 124, 31, 74, 42, 34, 21, 32, 81, 69, 90, 123, 126, 48, 114, 76, 21, 45, 96, 72, 43, 126, 69, 82, 12, 72, 60, 100, 41, 89, 87, 59, 10, 93, 70, 50, 123, 80, 120, 103, 40]
len(codes): 100
[97, 36, 103, 50, 40, 120, 30, 11, 20, 20, 37, 33, 111, 3, 40, 109, 29, 103, 13, 29, 32, 54, 64, 12, 10, 109, 22, 11, 32, 85, 67, 32, 11, 2

28it [00:00, 65.48it/s]

[60, 50, 35, 10, 40, 109, 30, 11, 20, 85, 37, 25, 101, 3, 40, 109, 29, 3, 13, 17, 80, 29, 48, 58, 10, 109, 87, 11, 32, 85, 13, 96, 1, 1, 109, 120, 33, 73, 32, 8, 72, 68, 73, 90, 54, 109, 33, 89, 13, 11, 63, 10, 20, 99, 116, 111, 22, 2, 34, 69, 34, 94, 47, 97, 97, 21, 32, 81, 69, 90, 93, 101, 124, 36, 123, 67, 47, 96, 72, 43, 126, 15, 102, 12, 110, 60, 100, 41, 89, 87, 59, 10, 93, 70, 50, 59, 80, 120, 105, 40]
len(codes): 100
[60, 50, 35, 10, 40, 109, 30, 11, 20, 85, 37, 25, 101, 3, 40, 109, 29, 3, 13, 17, 80, 29, 48, 58, 10, 109, 87, 11, 32, 85, 13, 96, 1, 1, 109, 120, 33, 73, 32, 8, 72, 68, 73, 90, 54, 109, 33, 89, 13, 11, 63, 10, 20, 99, 116, 111, 22, 2, 34, 69, 34, 94, 47, 97, 97, 21, 32, 81, 69, 90, 93, 101, 124, 36, 123, 67, 47, 96, 72, 43, 126, 15, 102, 12, 110, 60, 100, 41, 89, 87, 59, 10, 93, 70, 50, 59, 80, 120, 105, 40]
len(codes): 100
[60, 50, 98, 10, 40, 109, 30, 11, 20, 85, 78, 25, 48, 10, 40, 109, 29, 103, 13, 17, 44, 2, 103, 98, 10, 109, 22, 11, 32, 85, 8, 96, 40, 1, 109

48it [00:00, 82.94it/s]

[60, 50, 35, 10, 40, 109, 30, 11, 20, 20, 43, 26, 48, 10, 40, 109, 29, 103, 13, 17, 110, 2, 105, 35, 10, 109, 22, 11, 32, 85, 8, 103, 109, 96, 28, 120, 1, 73, 32, 8, 34, 70, 13, 26, 54, 109, 2, 78, 13, 11, 19, 70, 109, 67, 116, 60, 22, 2, 34, 69, 36, 81, 27, 43, 34, 21, 32, 81, 69, 90, 123, 126, 101, 114, 76, 21, 47, 96, 26, 43, 66, 69, 82, 12, 72, 44, 100, 41, 89, 87, 59, 10, 93, 70, 50, 123, 80, 120, 105, 40]
len(codes): 100
[60, 50, 35, 10, 40, 109, 30, 11, 20, 20, 43, 26, 48, 10, 40, 109, 29, 103, 13, 17, 110, 2, 105, 35, 10, 109, 22, 11, 32, 85, 8, 103, 109, 96, 28, 120, 1, 73, 32, 8, 34, 70, 13, 26, 54, 109, 2, 78, 13, 11, 19, 70, 109, 67, 116, 60, 22, 2, 34, 69, 36, 81, 27, 43, 34, 21, 32, 81, 69, 90, 123, 126, 101, 114, 76, 21, 47, 96, 26, 43, 66, 69, 82, 12, 72, 44, 100, 41, 89, 87, 59, 10, 93, 70, 50, 123, 80, 120, 105, 40]
len(codes): 100
[60, 50, 35, 10, 40, 109, 30, 11, 20, 20, 43, 26, 48, 10, 40, 109, 29, 103, 13, 29, 110, 2, 105, 35, 10, 109, 22, 11, 32, 85, 8, 103, 109,

69it [00:00, 91.83it/s]

[58, 97, 103, 50, 40, 109, 30, 11, 20, 85, 70, 33, 111, 3, 40, 109, 29, 103, 13, 17, 13, 32, 64, 12, 10, 109, 22, 11, 32, 85, 32, 42, 29, 22, 109, 40, 33, 73, 32, 8, 100, 100, 33, 64, 112, 28, 33, 89, 13, 11, 122, 95, 99, 67, 55, 60, 22, 2, 34, 69, 26, 44, 100, 42, 34, 21, 32, 81, 69, 90, 31, 81, 97, 26, 76, 67, 47, 96, 26, 43, 90, 68, 90, 12, 72, 60, 100, 41, 89, 87, 89, 10, 93, 70, 50, 123, 80, 120, 105, 40]
len(codes): 100
[114, 114, 101, 103, 40, 109, 30, 11, 20, 85, 111, 85, 101, 3, 40, 120, 29, 103, 13, 17, 13, 22, 3, 98, 40, 109, 22, 11, 32, 85, 67, 114, 17, 10, 109, 40, 1, 73, 32, 8, 41, 37, 33, 64, 112, 28, 33, 78, 13, 11, 41, 122, 16, 99, 116, 60, 22, 2, 34, 69, 26, 36, 102, 87, 34, 21, 32, 81, 69, 90, 31, 37, 34, 26, 76, 67, 47, 96, 26, 43, 90, 69, 82, 12, 72, 60, 100, 41, 89, 87, 59, 10, 93, 70, 35, 31, 80, 120, 105, 40]
len(codes): 100
[114, 114, 101, 103, 40, 109, 30, 11, 20, 85, 111, 85, 101, 3, 40, 120, 29, 103, 13, 17, 13, 22, 3, 98, 40, 109, 22, 11, 32, 85, 67, 114, 1

91it [00:01, 96.50it/s]

[8, 67, 96, 96, 32, 54, 28, 11, 20, 85, 12, 99, 20, 40, 85, 50, 29, 103, 13, 17, 89, 98, 23, 79, 3, 85, 42, 11, 32, 85, 89, 103, 29, 103, 76, 29, 96, 30, 32, 8, 57, 87, 21, 105, 95, 96, 33, 78, 13, 11, 69, 23, 122, 57, 94, 58, 87, 2, 34, 70, 110, 110, 111, 22, 60, 21, 32, 81, 69, 102, 31, 78, 44, 72, 76, 114, 45, 96, 72, 43, 90, 71, 90, 12, 72, 60, 123, 41, 43, 87, 89, 50, 93, 7, 35, 123, 80, 120, 105, 40]
len(codes): 100
[8, 67, 96, 96, 32, 54, 28, 11, 20, 85, 12, 99, 20, 40, 85, 50, 29, 103, 13, 17, 89, 98, 23, 79, 3, 85, 42, 11, 32, 85, 89, 103, 29, 103, 76, 29, 96, 30, 32, 8, 57, 87, 21, 105, 95, 96, 33, 78, 13, 11, 69, 23, 122, 57, 94, 58, 87, 2, 34, 70, 110, 110, 111, 22, 60, 21, 32, 81, 69, 102, 31, 78, 44, 72, 76, 114, 45, 96, 72, 43, 90, 71, 90, 12, 72, 60, 123, 41, 43, 87, 89, 50, 93, 7, 35, 123, 80, 120, 105, 40]
len(codes): 100
[8, 67, 33, 3, 32, 10, 28, 11, 20, 85, 89, 38, 13, 48, 8, 10, 29, 103, 13, 17, 58, 22, 51, 89, 114, 8, 42, 11, 32, 85, 89, 105, 73, 54, 76, 28, 96, 

118it [00:01, 79.58it/s] 

[126, 17, 8, 40, 67, 11, 13, 54, 20, 85, 42, 70, 17, 56, 20, 50, 8, 35, 13, 29, 111, 50, 78, 71, 13, 1, 20, 120, 32, 85, 89, 35, 30, 120, 18, 76, 35, 8, 103, 8, 57, 87, 48, 36, 13, 103, 100, 17, 28, 32, 69, 117, 55, 68, 77, 75, 94, 96, 60, 117, 72, 83, 111, 42, 44, 58, 59, 34, 69, 31, 31, 78, 44, 26, 76, 67, 84, 87, 72, 12, 90, 71, 82, 12, 72, 60, 76, 117, 43, 12, 79, 50, 7, 69, 35, 123, 34, 120, 105, 10]
len(codes): 100
[126, 17, 8, 1, 67, 11, 13, 54, 20, 85, 42, 107, 17, 38, 33, 103, 8, 35, 13, 29, 111, 50, 111, 16, 13, 10, 20, 120, 32, 85, 89, 35, 28, 109, 79, 79, 101, 8, 103, 8, 57, 98, 48, 36, 73, 103, 100, 17, 28, 32, 69, 117, 55, 68, 77, 75, 116, 96, 34, 41, 72, 83, 111, 42, 44, 98, 52, 60, 69, 31, 31, 78, 44, 26, 76, 67, 55, 50, 72, 12, 90, 71, 90, 12, 72, 60, 76, 117, 43, 12, 79, 50, 7, 69, 35, 123, 34, 120, 105, 3]
len(codes): 100
[126, 29, 8, 40, 67, 32, 85, 54, 20, 85, 42, 107, 29, 38, 33, 112, 20, 35, 13, 29, 111, 50, 111, 16, 13, 103, 20, 109, 32, 85, 89, 35, 11, 30, 79, 


