In [1]:
!pip install -q evaluate bert_score

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/61.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import torch
from PIL import Image
import torch

from transformers.modeling_outputs import BaseModelOutput
import matplotlib.pyplot as plt
import numpy as np
from functools import partial
from huggingface_hub import PyTorchModelHubMixin

# Load Test Predictions

In [14]:
from datasets import load_dataset

raw_results = load_dataset("MehdiJmlkh/SmolVLM-Results", split="train")
fine_tune_results = load_dataset("MehdiJmlkh/SmolVLM-FT-Results", split="train")
smol_driver_results = load_dataset("MehdiJmlkh/SmolDriver-Results", split="train")

# Load Metrics

In [34]:
import evaluate
bleu=evaluate.load("bleu")
bertscore=evaluate.load("bertscore")

In [38]:
class Evaluate:
    def __init__(self, results) -> None:
        self.preds = list(results["prediction"])
        self.refs = list(results["answer"])
        self.questions = list(results["question"])

        self.bleu_result = bleu.compute(predictions=self.preds, references=self.refs)
        self.bert_result = bertscore.compute(predictions=self.preds, references=self.refs, lang="en")

    def get_avg_scores(self):
        scores = {
              "BLEU": self.bleu_result["bleu"],
              "BERTScore_P": sum(self.bert_result["precision"]) / len(self.bert_result["precision"]),
              "BERTScore_R": sum(self.bert_result["recall"]) / len(self.bert_result["recall"]),
              "BERTScore_F1": sum(self.bert_result["f1"]) / len(self.bert_result["f1"])
        }
        return scores

    def print_lowest_bert(self, score_type, k=5):
        lowest_scores, lowest_indexes = self.__get_lowest_k(self.bert_result[score_type], k)

        for idx, score in zip(lowest_indexes, lowest_scores):
            print(f"{score_type} score:", score)
            print("Index:" ,idx)
            print("Question:", self.questions[idx])
            print(f"Label:", self.refs[idx])
            print(f"Answer:", self.preds[idx])
            print("-" * 50)

    def __get_lowest_k(self, scores, k=5):
        scores = np.array(scores)
        lowest_indices = scores.argsort()[:k]
        lowest_scores = scores[lowest_indices]
        return list(lowest_scores), list(lowest_indices)

In [12]:
raw_eval = Evaluate(raw_results)
raw_eval.get_avg_scores()

{'BLEU': 0.08854275399540497,
 'BERTScore_P': 0.8701746625111515,
 'BERTScore_R': 0.9164209641012034,
 'BERTScore_F1': 0.8921452242628972}

In [15]:
fine_tune_eval = Evaluate(fine_tune_results)
fine_tune_eval.get_avg_scores()

{'BLEU': 0.185222515450203,
 'BERTScore_P': 0.9430369230141317,
 'BERTScore_R': 0.9381485501626381,
 'BERTScore_F1': 0.9402770299212377}

In [39]:
smol_driver_eval = Evaluate(smol_driver_results)
smol_driver_eval.get_avg_scores()

{'BLEU': 0.3536571369006748,
 'BERTScore_P': 0.971374602828707,
 'BERTScore_R': 0.9586761032728325,
 'BERTScore_F1': 0.9647547295667175}

In [40]:
smol_driver_eval.print_lowest_bert("precision")

precision score: 0.7873366475105286
Index: 719
Question: What is the traffic signal that the ego vehicle should pay attention to?
Label: None.
Answer: There is one pedestrian to the ego vehicle to the front left of the ego vehicle to the front left of the ego vehicle to the front left of the ego vehicle
--------------------------------------------------
precision score: 0.8320485949516296
Index: 698
Question: In this scenario, what are dangerous actions to take for the ego vehicle?
Label: Back up, brake suddenly.
Answer: Accelerate and go ahead, turn left, and turn right are dangerous actions to take for the ego vehicle in this scenario.
--------------------------------------------------
precision score: 0.8441762924194336
Index: 718
Question: Are there motorcycles without riders to the front left of the ego car?
Label: No.
Answer: There are many pedestrians to the front left of the ego car.
--------------------------------------------------
precision score: 0.8468424081802368
Index: 7

In [41]:
smol_driver_eval.print_lowest_bert("recall")

recall score: 0.7433418035507202
Index: 452
Question: Identify all the traffic elements in the front view, categorize them, determine their status, and predict the bounding box around each one. The output should be a list formatted as (c, s, x1, y1, x2, y2), where c represents the category, s denotes the status, and x1, y1, x2, y2 are the offsets of the top-left and bottom-right corners of the box relative to the center point.
Label: There are three traffic elements in the front view. The information of these traffic elements is [(traffic light, unknown, 881.77, 462.69, 889.48, 477.31), (traffic light, red, 692.64, 442.39, 699.31, 459.03), (traffic light, red, 684.74, 406.2, 705.7, 414.91)].
Answer: None.
--------------------------------------------------
recall score: 0.7548848390579224
Index: 661
Question: Identify all the traffic elements in the front view, categorize them, determine their status, and predict the bounding box around each one. The output should be a list formatted as

In [42]:
smol_driver_eval.print_lowest_bert("f1")

f1 score: 0.8231935501098633
Index: 661
Question: Identify all the traffic elements in the front view, categorize them, determine their status, and predict the bounding box around each one. The output should be a list formatted as (c, s, x1, y1, x2, y2), where c represents the category, s denotes the status, and x1, y1, x2, y2 are the offsets of the top-left and bottom-right corners of the box relative to the center point.
Label: There are many traffic elements in the front view. The information of these traffic elements is [(traffic light, red, 1128.53, 360.19, 1141.75, 394.88), (traffic light, unknown, 1128.4, 407.96, 1142.1, 422.68), (traffic light, red, 691.54, 372.1, 705.42, 407.05), (traffic light, red, 523.53, 364.09, 539.36, 406.31), (traffic light, unknown, 546.41, 365.55, 562.14, 402.49), (traffic light, unknown, 881.31, 444.06, 887.69, 457.19), (traffic light, red, 866.6, 431.5, 871.81, 442.97), (traffic light, unknown, 800.74, 443.22, 806.9, 449.49), (traffic light, unknown