In [None]:
import tensorflow as tf
import numpy as np
import os

# Path to the Waymo dataset (update as needed)
data_path = "/content/drive/MyDrive/SEM 3/248 Intelligent Autonomous Systems/Project/Data"

# Function to parse the TFRecord files
def parse_tfrecord_fn(example):
    # Define the features to extract from the Waymo dataset
    feature_description = {
        'state/id': tf.io.FixedLenFeature([], tf.int64),
        'state/type': tf.io.FixedLenFeature([], tf.int64),
        'state/current_valid': tf.io.FixedLenFeature([128], tf.int64),
        'state/future_valid': tf.io.FixedLenFeature([80], tf.int64),
        'state/current_x': tf.io.FixedLenFeature([128], tf.float32),
        'state/current_y': tf.io.FixedLenFeature([128], tf.float32),
        'state/future_x': tf.io.FixedLenFeature([80], tf.float32),
        'state/future_y': tf.io.FixedLenFeature([80], tf.float32),
    }
    return tf.io.parse_single_example(example, feature_description)

# Load the dataset
def load_tfrecord_files(tfrecord_files):
    dataset = tf.data.TFRecordDataset(tfrecord_files, compression_type='GZIP')
    dataset = dataset.map(parse_tfrecord_fn)
    return dataset

# Function to preprocess the data
def preprocess_data(dataset):
    for record in dataset:
        current_x = record['state/current_x'].numpy()
        current_y = record['state/current_y'].numpy()
        future_x = record['state/future_x'].numpy()
        future_y = record['state/future_y'].numpy()

        # Filter valid points
        valid_current = record['state/current_valid'].numpy().astype(bool)
        valid_future = record['state/future_valid'].numpy().astype(bool)

        trajectory_current = np.stack([current_x[valid_current], current_y[valid_current]], axis=-1)
        trajectory_future = np.stack([future_x[valid_future], future_y[valid_future]], axis=-1)

        print("Current Trajectory:", trajectory_current)
        print("Future Trajectory:", trajectory_future)

# Example usage
if __name__ == "__main__":
    # List all TFRecord files in the dataset directory
    tfrecord_files = [os.path.join(data_path, f) for f in os.listdir(data_path) if f.endswith('.tfrecord')]

    # Load and preprocess the dataset
    dataset = load_tfrecord_files(tfrecord_files)
    preprocess_data(dataset)
