In [None]:
import pickle
import numpy as np
import onnxruntime as ort
from pathlib import Path
import os
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import torch
import PIL
TOKENIZER_PATH = 'quant3.onnx'
INPUT_SHAPE = (256, 160)
WHEEL_SPEEDS_RANGE = [-150, 150]
WHEEL_SPEEDS_VOCAB_SIZE = 512
WHEEL_SPEED_BINS = np.linspace(WHEEL_SPEEDS_RANGE[0], WHEEL_SPEEDS_RANGE[1], WHEEL_SPEEDS_VOCAB_SIZE)

options = ort.SessionOptions()
tokenizer_session = ort.InferenceSession(TOKENIZER_PATH, options, ['CUDAExecutionProvider'])
pickle_dir = './pickle_data'
data_dirs = os.listdir(pickle_dir)


In [None]:
obs_files = []
for dir in data_dirs:
    t = len(os.listdir(os.path.join(pickle_dir, dir)))
    for it in range(0, t):
        obs_files.append(os.path.join(pickle_dir, dir, f'iter_{it}.pkl'))

In [54]:
def preprocess(img, target_image_size=256, map_dalle=True):
    s = min(img.size)
    
    if s < target_image_size:
        raise ValueError(f'min dim for image {s} < {target_image_size}')
        
    r = target_image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    s = (160, 256)
    img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
    #img = TF.center_crop(img, output_size=2 * [target_image_size])
    img = torch.unsqueeze(T.ToTensor()(img), 0)
    #if map_dalle: 
     # img = map_pixels(img)
    return img

In [82]:
def preprocess_vqgan(x):
      x = 2.*x - 1.
      return x

def custom_to_pil(x):
  x = x.detach().cpu()
  x = torch.clamp(x, -1., 1.)
  x = (x + 1.)/2.
  x = x.permute(1,2,0).numpy()
  x = (255*x).astype(np.uint8)
  x = Image.fromarray(x)
  if not x.mode == "RGB":
    x = x.convert("RGB")
  return x

def reconstruct_with_vqgan(img):
  # could also use model(x) for reconstruction but use explicit encoding and decoding here

  img_tokens = tokenizer_session.run(None, {'input': np.array(img)})[2]
  #print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}")
  #xrec = model.decode(z)
  return img_tokens.reshape(1,-1)

In [83]:
def tokenize_wheel_speed(speed):
    speed = np.clip(speed, WHEEL_SPEEDS_RANGE[0], WHEEL_SPEEDS_RANGE[1])
    return np.digitize(speed, WHEEL_SPEED_BINS, right=True)
def tokenize_frame(image):
    image = Image.fromarray(image)
    x_vqgan = preprocess(image, target_image_size=256, map_dalle=False)
    img_tokens = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan)) 
    return img_tokens
def tokenize_actions(action):
    #this is correct don't spend time here lol
    ws = action[0]
    ad = action[1]
    ws = ws + 1
    ad = ad + 1
    action_index = 3*ad + ws
    return np.array(action_index)
    

In [96]:
i = 0
token_data = np.zeros((1, 163))
for dfile in obs_files:
    if(i == 10):
        #only 10 for checking now
        break
    with open(dfile, 'rb') as f:
        data = pickle.load(f)
        img_tokens = tokenize_frame(data['obs']['cameras']['driver'])
        wheel_speeds = data['obs']['carState']['wheelSpeeds']
        wheel_speeds = np.array([wheel_speeds["fl"], wheel_speeds["fr"]])
        wheel_speeds_tokens = np.expand_dims(tokenize_wheel_speed(wheel_speeds), 0)
        action_token = tokenize_actions(data['action']).reshape(-1,1)
        #print(action_token)
        #print(img_tokens)
        #print(wheel_speeds_tokens)
        data_row = np.concatenate([action_token ,img_tokens, wheel_speeds_tokens], axis = 1)
        token_data = np.concatenate([token_data, data_row], axis = 0)
    i = i+1
token_data = token_data[1:]
print(token_data)

[[  4. 873. 595. ... 596. 256. 256.]
 [  4. 873. 595. ... 596. 256. 256.]
 [  4. 873. 595. ... 596. 256. 256.]
 ...
 [  4. 873. 595. ... 596. 256. 256.]
 [  4. 873. 595. ... 596. 256. 256.]
 [  4. 873. 595. ... 596. 256. 256.]]
