In [None]:
import os
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl

from PIL import Image
import pickle
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.models.self_supervised.simclr.transforms import SimCLRTrainDataTransform, SimCLREvalDataTransform
import pytorch_lightning as pl

import tensorflow as tf
tf.executing_eagerly()

from waymo_open_dataset.utils import frame_utils
from waymo_open_dataset import dataset_pb2 as open_dataset

In [None]:
def get_encoder(model):
    WEIGHTS = {
        'simclr' : 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt'
    }
    weight_path = WEIGHTS[model]
    if model == 'simclr':
        simclr = SimCLR.load_from_checkpoint(weight_path)
        simclr.freeze()
        encoder = simclr.encoder 
        encoder.eval()
    
    return encoder

In [None]:
class WaymoAD(Dataset):
    def __init__(self, tfrecords, transform=None):
        """
        Args:
            pklfile (string) : pkl dump from tfrecords
            transform (callable, optional) : Optional transform per sample
        """
        frames = os.path.tf.data.TFRecordDataset(tfrecords, compression_type='')
        self.transform = transform 

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if self.transform:
            sample = self.transform(sample)

        return sample

In [None]:
BATCH_SIZE = 32
NUM_WORKERS = 8

dataset = WaymoAD(pklfile, root_dir, transform=simclr_encoder)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

In [None]:
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

def create_img_dataset(tfrecords=None):
    if tfrecords is None:
        tfrecords = [
            '/data/krishna/waymo/segment-10017090168044687777_6380_000_6400_000_with_camera_labels.tfrecord',
            '/data/krishna/waymo/segment-10023947602400723454_1120_000_1140_000_with_camera_labels.tfrecord'
        ]
    tfdataset = tf.data.TFRecordDataset(tfrecords, compression_type='')

    count = 0
    for data in tfdataset:
        frame = open_dataset.Frame()
        frame.ParseFromString(bytearray(data.numpy()))
        count += 1
    print("PROCESSED %s SAMPLES ACROSS TFRECORDS" % count)
    

In [None]:
plt.figure(figsize=(25, 20))

def image_show(data, name, layout, cmap=None):
  """Show an image."""
  plt.subplot(*layout)
  plt.imshow(tf.image.decode_jpeg(data), cmap=cmap)
  plt.title(name)
  plt.grid(False)
  plt.axis('off')

for index, image in enumerate(frame.images):
  image_show(image.image, open_dataset.CameraName.Name.Name(image.name),
             [3, 3, index+1])

In [None]:
class WaymoDataset(Dataset):
    def __init__(self, datafile, transform=None):
        self.datafile = datafile
        with open(self.datafile, 'rb') as f:
            datadict = pickle.load(f)
        self.data = [image['center_camera_feed'] for image in datadict]
        self.transform = transform

    def __getitem__(self, index):
        x = self.data[index].astype(np.float32).transpose(2, 0, 1)
        if self.transform:
            x = Image.fromarray(self.data[index].astype(np.uint8))
            x = self.transform(x)
        return x

    def __len__(self):
        return len(self.data)

In [None]:
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt'
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
simclr_resnet50 = simclr.encoder
simclr_resnet50.eval()



In [None]:
allfeats = None
for batch in dataloader:
    print(batch.shape)
    feats = simclr_resnet50(batch)[0].detach().numpy()
    if allfeats is None:
        allfeats = feats
    allfeats = np.append(allfeats, feats, axis=0)
