In [3]:
import ray
from zod import ZodFrames
import zod.constants as constants
import numpy as np

In [4]:
def id_to_car_points(zod_frames, frame_id):
    frame = zod_frames[frame_id]
    # extract oxts
    oxts = frame.oxts
    # get timestamp
    key_timestamp = frame.info.keyframe_time.timestamp()
    # get posses associated with frame timestamp
    current_pose = oxts.get_poses(key_timestamp)
    # transform the points to the car coordinate system
    transformed_poses = np.linalg.pinv(current_pose) @ oxts.poses
    points = transformed_poses[:, :3, -1]
    points = points[points[:, 0] > 0]
    return points

def euclidean_distance(coords):
    """
    Calculate the Euclidean distance between successive rows of a given array of coordinates.
    """
    diffs = np.diff(coords, axis=0)
    dists = np.sqrt(np.sum(diffs**2, axis=1))
    return dists

def get_points_at_distance(points, target_distances):
    dists = euclidean_distance(points)
    dists = np.insert(dists, 0, 0) # so that there is a dist for all points in points.
    accumulated_distances = np.cumsum(dists)
    
    interpolated_points = np.empty((len(target_distances), points.shape[1]))
    
    if max(target_distances) > accumulated_distances[-1]:
        raise ValueError("Target distance is larger than the accumulated distance")
    
    index = 0
    inter_idx = 0
    for target_distance in target_distances:
        # Increment index until we have passed the target distance
        while accumulated_distances[index] < target_distance:
            index += 1
        # If we reach this state, then index - 1 is the closest index before going over.
        # Check if the target distance is exactly at a point in the list
        if accumulated_distances[index - 1] == target_distance:
            interpolated_points[inter_idx] = points[index - 1]
            inter_idx += 1
        else:
            # Interpolate between the two nearest points
            p1 = points[index - 1]
            p2 = points[index]
            d1 = accumulated_distances[index - 1]
            d2 = accumulated_distances[index]
            t = (target_distance - d1) / (d2 - d1)
            interpolated_points[inter_idx] = p1 + t * (p2 - p1)
            inter_idx += 1
    return interpolated_points

def validate_and_categorise(zod_frames, id_set):
    # bad data
    too_short_ids = []
    corrupt_ids_oxts = []
    corrupt_ids_image = []

    # categorised data
    turns_right = []
    turns_left = []
    straights = []

    quarter = len(id_set)//4

    for frame_id in id_set:
        # Data validation
        # Test OXTS data and fetch it
        corrupt = False
        try:
            car_points = id_to_car_points(zod_frames, frame_id)
        except AssertionError:
            corrupt_ids_oxts.append(frame_id)
            corrupt = True
        # Test if image is corrupt
        try:
            zod_frames[frame_id].get_image()
        except Exception as e:
            corrupt_ids_image.append(frame_id)
            corrupt = True
        if corrupt:
            continue
        # Test OXTS distance >= 165m
        dists = euclidean_distance(car_points)
        if sum(dists) < 165:
            too_short_ids.append(frame_id)
            continue
        
        # Categorise OXTS [turn-left, turn-right, straight]
        points = id_to_car_points(zod_frames, frame_id)
        target_distances = [5, 10, 15, 20, 25, 30, 35, 40, 50, 60, 70, 80, 95, 110, 125, 145, 165]
        interpolated_points = get_points_at_distance(points, target_distances)
        
        threshold_in_m = 5
        points_required = 5
        points_to_left = interpolated_points[:,1][interpolated_points[:,1] >= threshold_in_m]
        points_to_right = interpolated_points[:,1][interpolated_points[:,1] <= -threshold_in_m]
        if len(points_to_right) >= points_required:
            turns_right.append(frame_id)
        elif len(points_to_left) >= points_required:
            turns_left.append(frame_id)
        else:
            straights.append(frame_id)
        
    return dict(
        too_short_ids=too_short_ids, 
        corrupt_ids_oxts=corrupt_ids_oxts, 
        corrupt_ids_image=corrupt_ids_image,
        turns_right=turns_right,
        turns_left=turns_left,
        straights=straights,)

In [5]:
zod_frames = ZodFrames(dataset_root="/mnt/ZOD", version='full')
training_frames_all = list(zod_frames.get_split(constants.TRAIN))
val_frames_all = list(zod_frames.get_split(constants.VAL))

# let's split this into chunks
n_sets = 8 # number of available cpus
# for train ids
set_length = len(training_frames_all)//n_sets
train_chunks = [training_frames_all[i:i + set_length] for i in range(0, len(training_frames_all), set_length)]
# for val ids
set_length = len(val_frames_all)//n_sets
val_chunks = [val_frames_all[i:i + set_length] for i in range(0, len(val_frames_all), set_length)]

# path = platform.data.config["path"]
# version = platform.data.config["version"]

# zod_frame = ZodFrames(path, version)

Loading infos: 0it [00:00, ?it/s]

In [6]:
ray.init(num_cpus=8)
remote_function = ray.remote(validate_and_categorise)
train_results = ray.get([remote_function.remote(zod_frames, chunk) for chunk in train_chunks])
ray.shutdown()

ray.init(num_cpus=8)
remote_function = ray.remote(validate_and_categorise)
val_results = ray.get([remote_function.remote(zod_frames, chunk) for chunk in val_chunks])
ray.shutdown()

2023-06-27 07:23:49,104	INFO worker.py:1627 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
2023-06-27 07:32:50,676	INFO worker.py:1627 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


In [7]:
from collections import defaultdict
from pathlib import Path

# summarise results
def summarise(summary_dict, results):
    for result in results:
        for key, value in result.items():
            summary_dict[key] += value
    return summary_dict

# write to files
def write_summary_to_file(summary_dict, subset):
    for key, value in summary_dict.items():
        filename = f"{subset}_{key}"
        path = Path(dir_path, filename).with_suffix(".txt")
        with open(path, "w") as file:
            file.write("\n".join(value))
            
dir_path = "../balanced_data"

summary_dict_train = defaultdict(list)
summary_dict_val = defaultdict(list)

write_summary_to_file(summarise(summary_dict_train, train_results), "train")
write_summary_to_file(summarise(summary_dict_val, val_results), "val")

In [8]:
import os
import random
from pathlib import Path

def filename_to_arr(filename):
    with open(Path(dir_path, filename), "r") as file:
        return file.read().splitlines()

dir_path = "../balanced_data/"
filenames = [x for x in os.listdir(dir_path) if x.split("_")[0] in ["train", "val"]]

for filename in filenames:
    content = filename_to_arr(filename)
    print(f"{len(content)} samples: {filename}")

5469 samples: val_straights.txt
13098 samples: train_turns_right.txt
0 samples: train_corrupt_ids_oxts.txt
1944 samples: val_too_short_ids.txt
0 samples: val_corrupt_ids_oxts.txt
0 samples: val_corrupt_ids_image.txt
47426 samples: train_straights.txt
0 samples: train_corrupt_ids_image.txt
1410 samples: val_turns_right.txt
11735 samples: train_turns_left.txt
1200 samples: val_turns_left.txt
17713 samples: train_too_short_ids.txt


In [9]:
categories = [x for x in filenames if "turns" in x or "straights" in x]
val = [x for x in categories if "val" in x]
train = [x for x in categories if "train" in x]
val, train

(['val_straights.txt', 'val_turns_right.txt', 'val_turns_left.txt'],
 ['train_turns_right.txt', 'train_straights.txt', 'train_turns_left.txt'])

In [None]:
def create_balanced_set(filenames):
    subset_str = filenames[0].split("_")[0]
    id_lists = []
    
    # get the ids from the files
    for filename in filenames:
        id_lists.append(filename_to_arr(filename))
    
    # get the lengths of the lists
    lens = [len(x) for x in id_lists]
    # get the size of the smallest set
    smallest_set_size = min(lens)
    
    # build the balanced dataset where each subset has the same size
    balanced_dataset = []
    for id_list in id_lists:
        balanced_dataset += (random.sample(id_list, smallest_set_size))

    # shuffle the data as they are appended category-wise.
    random.shuffle(balanced_dataset)

    # write to file
    if subset_str == "train":
        with open(Path(dir_path, f"balanced_{subset_str}_ids.txt"), "w") as f:
            f.write("\n".join(balanced_dataset))
    else:
        with open(Path(dir_path, f"balanced_{subset_str}_ids.txt"), "w") as f:
            f.write("\n".join(balanced_dataset))

create_balanced_set(val)
create_balanced_set(train)