# This notebook explores the KITTI 2D object detection dataset. #

In [None]:
%load_ext autoreload
%autoreload 2

### Imports  ###

In [None]:
#import standard packages
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
#from IPython.display import Image, display

In [None]:
# Add the parent directory to Python path so we can import from datasets
ROOT = os.path.dirname(os.getcwd())
sys.path.append(ROOT)
print("Project root: ",ROOT)

In [None]:
#import self packages
from datasets.kitti.object2d.parser import Parser

### Display few examples of the Dataset ###

In [None]:
#add dataset image path
dataset_image_path = Path("../datasets/kitti/object2d/training/image_2/")
dataset_label_path = Path("../datasets/kitti/object2d/training/label_2/")

In [None]:
#display images from the dataset
images = ["000000.png","000010.png","000020.png","000030.png","000040.png","000050.png"] 
images = [dataset_image_path /path for path in images]

fig, axes = plt.subplots(3, 2, figsize=(15, 10))
axes = axes.flatten()

for idx, img_path in enumerate(images):
    img = np.array(Image.open(img_path))
    axes[idx].imshow(img)
    axes[idx].set_title(img_path.name)
    axes[idx].axis('off')
plt.tight_layout()
plt.show()


In [None]:
#display parsed labels
label_1_file = "../datasets/kitti/object2d/training/label_2/000000.txt"
label_2_file = "../datasets/kitti/object2d/training/label_2/000010.txt"
label_1_parser = Parser(label_1_file)
label_2_parser = Parser(label_2_file)
label_1_list = label_1_parser.parse_results
label_2_list = label_2_parser.parse_results

print("label_1 after parsing:\n")
for item in label_1_list:
    print(item)
print("\nlabel_2 after parsing:\n")
for item in label_2_list:
    print(item)

### Display bounding boxes ###

In [None]:
from visualizations.draw_boxes import Box,Box_kitti_obj2d

#get images
img1 = images[0]
img2 = images[1]

#get labels
label1 = label_1_list[0]
label2 = label_2_list[0]

In [None]:

#generate box images
box1_image = Box_kitti_obj2d(img1,label1)
box2_image = Box_kitti_obj2d(img2,label2)


In [None]:
# Display
fig, axes = plt.subplots(1, 2, figsize=(30, 20))
axes = axes.flatten()

box_images = [box1_image,box2_image]
for idx, img in enumerate(box_images):
    axes[idx].imshow(img.draw())
    axes[idx].set_title(img.type+" (single label)", fontsize = 30)
    axes[idx].axis('off')

plt.tight_layout()
plt.show()

### Display DontCare class ###

In [None]:
#display DontCare type
dontcare_labels = label_2_list[-4:]

#generate box images
box_images = [Box_kitti_obj2d(img2,label) for label in dontcare_labels]

# Display
fig, axes = plt.subplots(2, 2, figsize=(30, 20))
axes = axes.flatten()

for idx, img in enumerate(box_images):
    axes[idx].imshow(img.draw())
    axes[idx].set_title(img.type, fontsize = 30)
    axes[idx].axis('off')

plt.tight_layout()
plt.show()

### Check dataset attributes


In [None]:
#Create a global list of conclusions
conclusions = []

In [None]:
#Get dataset
data_path_img = Path(os.path.join(ROOT,"datasets","kitti","object2d","training","image_2"))
data_path_lab = Path(os.path.join(ROOT,"datasets","kitti","object2d","training","label_2"))

image_paths = [img for img in data_path_img.iterdir() if img.is_file()]
label_paths = [label for label in data_path_lab.iterdir() if label.is_file()]

In [None]:
#Check identical resolution to all images
from collections import Counter
shapes = Counter()

for img_path in image_paths:
    with Image.open(img_path) as im:
        shapes[(im.height, im.width, len(im.getbands()))] += 1

if len(shapes)==1:
    conc = "All images are with the same shape"
else:
    conc = "There are images of different shapes"
conclusions.append(conc)
print(conc)

In [None]:
#Display resolution distribution

resolutions = [str(k) for k in shapes.keys()]
counts = list(shapes.values())

plt.figure(figsize=(8, 4))
plt.bar(resolutions, counts)
plt.xlabel("Image resolution (H, W, C)")
plt.ylabel("Count")
plt.title("Image resolution distribution")
plt.xticks(rotation=30)
plt.tight_layout()
plt.show()


In [None]:
#Check all images\labels of the same type

img_types = Counter()
for img_p in image_paths:
    type = img_p.suffix
    img_types[type]+=1

label_types = Counter()
for lab_p in label_paths:
    type = lab_p.suffix
    label_types[type]+=1

if len(img_types)>1:
    conc_img = f"There are iamges of different types: {img_types} " 
else:
    conc_img = f"There is a single image type: {img_types}"
conclusions.append(f"{conc_img}")

if len(label_types)>1:
    conc_lbl = f"There are label files of different types: {label_types} " 
else:
    conc_lbl = f"There is a single label files type: {label_types}"

conclusions.append(f"{conc_lbl}")

print(f"image types: {img_types}")
print(f"label types: {label_types}")


In [None]:
#Check label-image couples matching
mismatch_counter = 0
for i in range(len(image_paths)):
    img_name = image_paths[i].stem
    label_name = label_paths[i].stem
    if img_name!=label_name:
        print(f"image name: {img_name} | label name: {label_name}")
        mismatch_counter=1
if not mismatch_counter:
    conc = "All image-label couples were found with no errors."
else:
    conc = "There are image-label couples error (mismatch or order)"
    
print(conc)
conclusions.append(conc)

In [None]:
#Check all bounding boxes are within range
bbox_dimension_flag = 0

for i in range(len(label_paths)):
    label_list = Parser(label_paths[i]).label_list
    with Image.open(img_path) as im:
        img_h,img_w = im.height, im.width
    for label in label_list:
        lx,ly,rx,ry = label.bbox.lx, label.bbox.ly, label.bbox.rx, label.bbox.ry
        if lx > img_w or rx > img_w or ly > img_h or ry > img_h:
            conc = "There are bounding boxes in labels that exeeds image dimension."
            print("The following bbox dimension exceeds image dimension")
            print(f"image dim: {img_w,img_h} | {label.bbox.start_point} -> {label.bbox.end_point} ")
            bbox_dimension_flag = 1
            
if bbox_dimension_flag == 0:
    conc = "All bounding boxes are within image range"
conclusions.append(conc)
print(conc)    

In [None]:
#Check classes distribution - numerical and bar plot
class_counter = Counter()

for i in range(len(label_paths)):
    label_list = Parser(label_paths[i]).label_list
    for label in label_list:
        class_counter[label.type] +=1



In [None]:
print(class_counter)

class_resolutions = [str(k) for k in class_counter.keys()]
class_counts = list(class_counter.values())

plt.figure(figsize=(8, 4))
plt.bar(class_resolutions, class_counts)
plt.xlabel("Class type")
plt.ylabel("Count")
plt.title("Classes distribution")
plt.xticks(rotation=30)
plt.tight_layout()
plt.show()

In [None]:
conc = f"There are {len(class_counter)} classes on the dataset with max class {class_counter.most_common(1)} and min class {class_counter.most_common()[-1]}"
conclusions.append(conc)

### Check bounding box statistics

In [None]:
#Area distribution
#Aspect ratio distribution
hw_dict = {}
for i in range(len(label_paths)):
    label_list = Parser(label_paths[i]).label_list
    for label in label_list:
        if label.type not in hw_dict.keys():
            hw_dict[label.type] = []
        hw_dict[label.type] += [(round(label.bbox.ry - label.bbox.ly,2), round(label.bbox.rx - label.bbox.lx,2))] #heigh width tuple
    
#convert to HxW array of N rows
for key in hw_dict.keys():
    box_arr = np.asarray(hw_dict[key], dtype=np.float32)
    hw_dict[key] = box_arr

In [None]:
print(hw_dict.keys())

In [None]:
# Display
from visualizations.bbox_statistics import plot_bbox_hw_distribution

fig, axes = plt.subplots(3, 3,figsize=(20, 20),sharex=True,sharey=True)

axes = axes.flatten()

for idx, (cls_name, hw) in enumerate(hw_dict.items()):
    plot_bbox_hw_distribution(
        hw,
        method='hexbin',
        log_scale=True,
        ax=axes[idx],
        title=f"{cls_name} (n={len(hw)})"
    )

# Remove unused axes (safety)
# for j in range(len(hw_dict), len(axes)):
#     fig.delaxes(axes[j])

fig.suptitle(
    "Bounding Box Height Ã— Width Distribution per Class",
    fontsize=16
)

fig.supxlabel("Width (normalized)")
fig.supylabel("Height (normalized)")

plt.tight_layout()
plt.show()

In [None]:
conc = "Consider bounding box area distribution as a factor with small effect to be considered."
conclusions.append(conc)

### Occlusions

In [None]:
#Occlusions
occlusion_levels = set()
occlusion_samples = {}
occlusion_num = 4

#loop over all images to find occlueded labels
for i in range(len(label_paths)):
    label_list = Parser(label_paths[i]).label_list
    for label in label_list:
        occlusion_levels.add(label.occluded)
        if label.occluded not in occlusion_samples.keys():
            occlusion_samples[label.occluded] = []
        elif len(occlusion_samples[label.occluded])<occlusion_num: 
            #val = [label_paths[i].stem, i, label.type, label.bbox]
            val = [label_paths[i].stem ,label.label_dict]
            occlusion_samples[label.occluded].append(val)
            break
            
print("Occlusion levels: ",occlusion_levels)                
print("Occlusion samples:\n",occlusion_samples)

In [None]:
#get full image path
for lvl in occlusion_levels:
    for sample in occlusion_samples[lvl]:
        sample_name = sample[0]+".png"
        sample[0] = dataset_image_path / Path(sample_name)

print(occlusion_samples)

In [None]:
#display Occluded images
#generate box images
print(dir(Box_kitti_obj2d))
for lvl in occlusion_levels:
    box_images = [Box_kitti_obj2d(sample[0], sample[1]) for sample in occlusion_samples[lvl]]

    # Display
    rows = int(len(occlusion_samples[lvl])/2)
    fig, axes = plt.subplots(rows, 2, figsize=(30, 20))
    axes = axes.flatten()
    fig.suptitle(f"Occlusion level: {lvl}", fontsize=40)
    for idx, img in enumerate(box_images):
        axes[idx].imshow(img.draw())
        axes[idx].set_title(img.type, fontsize = 30)
        axes[idx].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
#check if occlusion_level = -1 is only for DontCare label type
dontcare_only = True

for i in range(len(label_paths)):
    label_list = Parser(label_paths[i]).label_list
    for label in label_list:
        if label.occluded == -1 and label.type != "DontCare":
           dontcare_only = False 
           conc = "all classes can have an occluded level = -1"

if dontcare_only ==True:
    conc = "Only DontCare class can have an occluded level = -1"
    print(conc)
conclusions.append(conc)


### Truncation investigation

In [None]:
#get all truncated values
truncation_counter = Counter()

for i in range(len(label_paths)):
    label_list = Parser(label_paths[i]).label_list
    for label in label_list:
        truncation_val = label.truncated
        truncation_counter[truncation_val]+=1

most_common_truncation = truncation_counter.most_common(4)
conc = f"The most common truncations: {most_common_truncation}"
conclusions.append(conc)
print(conc)

trunc_sorted = sorted(truncation_counter.items())
truncation = dict(trunc_sorted)
print(truncation)
#check corelation between truncation and occlusions

In [None]:
#check if truncation = -1 is only for dontcare
dontcare_only_truncated = True
for i in range(len(label_paths)):
    label_list = Parser(label_paths[i]).label_list
    for label in label_list:
        if label.truncated == -1.0 and label.type.lower() != "dontcare":
            dontcare_only_truncated= False
            print(label.type)

if dontcare_only_truncated:
    conc = "Only dontcare class gave truncated = -1"
    conclusions.append(conc)
    print(conc)
    most_common_truncation = truncation_counter.most_common(5)

In [None]:
#Display truncation examples
common_truncation_levels = [val for (val,count) in most_common_truncation if val!=-1]
used = []
unique_samples = 0
truncation_samples = []

for i in range(len(label_paths)):
    label_list = Parser(label_paths[i]).label_list
    for label in label_list:
        if label.truncated in common_truncation_levels and label.truncated not in used:     
            truncation_samples.append([label_paths[i],label])
            used.append(label.truncated)
        elif unique_samples<4:
            if label.truncated >0.2 and label.truncated<0.8:
                truncation_samples.append([label_paths[i],label])
                unique_samples+=1

In [None]:
truncation_pairs = []
for label_path,label in truncation_samples:
    label_name = label_path.stem
    sample_name = label_name+".png"
    img_path = dataset_image_path / Path(sample_name)
    truncation_pairs.append([img_path,label])


box_truncated_images = [Box_kitti_obj2d(img, label.label_dict) for [img,label] in truncation_pairs]

# Display
rows = int(len(box_truncated_images)/2)
fig, axes = plt.subplots(rows, 2, figsize=(30, 20))
axes = axes.flatten()
fig.suptitle(f"Truncation samples", fontsize=40)
for idx, img in enumerate(box_truncated_images):
    axes[idx].imshow(img.draw())
    axes[idx].set_title(f"{img.type}, truncation = {truncation_samples[idx][1].truncated}", fontsize = 30)
    axes[idx].axis('off')

plt.tight_layout()
plt.show()


In [None]:
print("Conclusions:")
for conc in conclusions:
    print(conc)