In [None]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import Kitti
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split


import random
%matplotlib inline

In [None]:
seed_number = 4
torch.manual_seed(seed_number)
np.random.seed(seed_number)
random.seed(seed_number)

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Download and load the KITTI dataset
kitti_real_dataset = Kitti(root='data/data_real/',  transform=ToTensor(), download = False)
kitti_syn_dataset = Kitti(root='data/data_syn_DALLE3/',  transform=ToTensor(), download = False)

# Get the k-th image in the list () and its ground truth labels
image_index = 1
# image_index = 1 # With train errors

# Note that label_real and label_syn are the same
image_real, label_real = kitti_real_dataset[image_index]
image_syn, label_syn  = kitti_syn_dataset[image_index]

# Display the image with the bounding box
fig_real, ax = plt.subplots(1)
ax.imshow(image_real.permute(1, 2, 0))
for box in label_real:
    bbox = box["bbox"]
    x1 = bbox[0]
    y1 = bbox[1]
    x2 = bbox[2]
    y2 = bbox[3]
    #x1, y1, x2, y2 = box

    if box["type"] == "Pedestrian":
        ax.text((x1+x2)/2, y1, 'pedestrian', ha='center', va='bottom', transform=ax.transData, fontsize=8, color='red')
        rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='red', linewidth=1)
    elif box["type"] == "Car":
        ax.text((x1+x2)/2, y1, 'car', ha='center', va='bottom', transform=ax.transData, fontsize=8, color='green')
        rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='green', linewidth=1)
    elif box["type"] == "Van":
        ax.text((x1+x2)/2, y1, 'van', ha='center', va='bottom', transform=ax.transData, fontsize=8, color='blue')
        rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='blue', linewidth=1)
    elif box["type"] == "Truck":
        ax.text((x1+x2)/2, y1, 'truck', ha='center', va='bottom', transform=ax.transData, fontsize=8, color='orange')
        rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='orange', linewidth=1)
    elif box["type"] == "Cyclist":
        ax.text((x1+x2)/2, y1, 'cyclist', ha='center', va='bottom', transform=ax.transData, fontsize=8, color='orange')
        rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='orange', linewidth=1)
    else:
        # For the other classes, we use gray
        ax.text((x1+x2)/2, y1, box["type"], ha='center', va='bottom', transform=ax.transData, fontsize=8, color='gray')
        rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='gray', linewidth=1) 
    ax.add_patch(rect)
plt.show()

fig_syn, ax = plt.subplots(1)
ax.imshow(image_syn.permute(1, 2, 0))
for box in label_syn:
    bbox = box["bbox"]
    x1 = bbox[0]
    y1 = bbox[1]
    x2 = bbox[2]
    y2 = bbox[3]
    #x1, y1, x2, y2 = box

    if box["type"] == "Pedestrian":
        ax.text((x1+x2)/2, y1, 'pedestrian', ha='center', va='bottom', transform=ax.transData, fontsize=8, color='red')
        rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='red', linewidth=1)
    elif (box["type"] == "Car" or box["type"] == "car"):
        ax.text((x1+x2)/2, y1, 'car', ha='center', va='bottom', transform=ax.transData, fontsize=8, color='green')
        rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='green', linewidth=1)
    elif box["type"] == "Van":
        ax.text((x1+x2)/2, y1, 'van', ha='center', va='bottom', transform=ax.transData, fontsize=8, color='blue')
        rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='blue', linewidth=1)
    elif box["type"] == "Truck":
        ax.text((x1+x2)/2, y1, 'truck', ha='center', va='bottom', transform=ax.transData, fontsize=8, color='orange')
        rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='orange', linewidth=1)
    elif box["type"] == "Cyclist":
        ax.text((x1+x2)/2, y1, 'cyclist', ha='center', va='bottom', transform=ax.transData, fontsize=8, color='orange')
        rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='orange', linewidth=1)
    else:
        # For the other classes, we use gray
        ax.text((x1+x2)/2, y1, box["type"], ha='center', va='bottom', transform=ax.transData, fontsize=8, color='gray')
        rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='gray', linewidth=1)
    ax.add_patch(rect)
plt.show()


In [None]:
COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

In [None]:
import torch
import torchvision
from torchvision.datasets import Kitti
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

CONFIDENCE_THRESHOLD = 0.5
IOU_THRESHOLD = 0.5


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the pre-trained model (note that this is not really ideal, as the model has more classes based on COCO)
model = torchvision.models.detection.ssd300_vgg16(pretrained=True)
model.to(device)
model.eval()



# Get an image from the KITTI dataset
image_real, _ = kitti_real_dataset[image_index]
image_real = image_real.to(device)

# Generate predictions for the image
with torch.no_grad():
    prediction = model([image_real])

# Display the image with the predicted bounding boxes
image_real = image_real.cpu()
prediction = prediction[0]
boxes = prediction['boxes'].cpu().numpy()
scores = prediction['scores'].cpu().numpy()
labels = prediction['labels'].cpu().numpy()
fig, ax = plt.subplots(1)
ax.imshow(image_real.permute(1, 2, 0))
for box, score, label in zip(boxes, scores, labels):
    if score > 0.5:
        x1, y1, x2, y2 = box
        #print(label)
        ax.text((x1+x2)/2, y1, COCO_INSTANCE_CATEGORY_NAMES[label], ha='center', va='bottom', transform=ax.transData, fontsize=8, color='red')
        rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='red', linewidth=1)
        ax.add_patch(rect)
        print(COCO_INSTANCE_CATEGORY_NAMES[label])
plt.show()


# Get an image from the KITTI dataset
image_syn, _ = kitti_syn_dataset[image_index]
image_syn = image_syn.to(device)

# Generate predictions for the image
with torch.no_grad():
    prediction = model([image_syn])

# Display the image with the predicted bounding boxes
image_syn = image_syn.cpu()
prediction = prediction[0]
boxes = prediction['boxes'].cpu().numpy()
scores = prediction['scores'].cpu().numpy()
labels = prediction['labels'].cpu().numpy()
fig, ax = plt.subplots(1)
ax.imshow(image_syn.permute(1, 2, 0))
for box, score, label in zip(boxes, scores, labels):
    if score > CONFIDENCE_THRESHOLD:

        x1, y1, x2, y2 = box
        ax.text((x1+x2)/2, y1, COCO_INSTANCE_CATEGORY_NAMES[label], ha='center', va='bottom', transform=ax.transData, fontsize=8, color='red')
        rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='red', linewidth=1)
        ax.add_patch(rect)
        print(COCO_INSTANCE_CATEGORY_NAMES[label])
plt.show()

In [None]:
model

In [None]:
class BoundingBox:
    def __init__(self, x1, y1, x2, y2, cat = "dc"):
        self.x1 = x1
        self.y1 = y1
        self.x2 = x2
        self.y2 = y2
        self.cat = cat

def calculate_iou(box1, box2):
    # Determine the coordinates of the intersection rectangle
    x_left = max(box1.x1, box2.x1)
    y_top = max(box1.y1, box2.y1)
    x_right = min(box1.x2, box2.x2)
    y_bottom = min(box1.y2, box2.y2)

    # Calculate area of intersection
    intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top)

    # Calculate areas of the input boxes
    box1_area = (box1.x2 - box1.x1) * (box1.y2 - box1.y1)
    box2_area = (box2.x2 - box2.x1) * (box2.y2 - box2.y1)

    # Calculate union area
    union_area = box1_area + box2_area - intersection_area

    # Calculate IoU
    iou = intersection_area / union_area
    return iou

def count_undetected_objects(ground_truths, predictions, iou_threshold, size_gt = 0):
    undetected_count = 0
    count_smaller = 0
    undetected_count_smaller = 0

    for gt_box in ground_truths:
        area = (gt_box.x2 - gt_box.x1) * (gt_box.y2 - gt_box.y1)        
        if area >= size_gt["dc"]:
            detected = False
            for pred_box in predictions:
                if calculate_iou(gt_box, pred_box) >= iou_threshold:
                    detected = True
                    break
            if not detected:
                undetected_count += 1
        else:
            count_smaller += 1
            detected = False
            for pred_box in predictions:
                if calculate_iou(gt_box, pred_box) >= iou_threshold:
                    detected = True
                    break
            if not detected:
                undetected_count_smaller += 1            


    return undetected_count, count_smaller, undetected_count_smaller

