## Import delle librerie

In [19]:
import json
import os
import torch
import random
import xml.etree.ElementTree as ET
import torchvision.transforms.functional as FT

import torch
from tqdm import tqdm
from pprint import PrettyPrinter

import json
import random
from collections import defaultdict

## Path

In [32]:
# path del dataset
base_dict = '/kaggle/input/our-xview-dataset'

# path della cartella contenente le immagini
img_dict = '/kaggle/input/our-xview-dataset/images'

# path dei file .txt da utilizzare per prelevare rispettivamente le immagini per il train, la validation e il test
# MANCANO -> servono per provare il codice
train_img_path = '/kaggle/input/our-xview-dataset/YOLO_cfg/train.txt' # file contenete i path delle immagini del dataset di train
val_img_path = '/kaggle/input/our-xview-dataset/YOLO_cfg/val.txt'
test_img_path = '/kaggle/input/our-xview-dataset/YOLO_cfg/test.txt'

# path di output
output_folder = '/kaggle/working/'

# path contenente le annotazioni in formato .json
coco_json_path = os.path.join(base_dict, 'COCO_annotations_new.json') 
new_coco_json_path = os.path.join(output_folder, 'mod_COCO_annotations.json') 

# path file per il training
train_image = os.path.join(output_folder, 'TRAIN_images.json')
train_bbox = os.path.join(output_folder, 'TRAIN_objects.json')

# path file per la validation
val_image = os.path.join(output_folder, 'VAL_images.json')
val_bbox = os.path.join(output_folder, 'VAL_objects.json')

# path file per il test
test_image = os.path.join(output_folder, 'TEST_images.json')
test_bbox = os.path.join(output_folder, 'TEST_objects.json')

checkpoint_path = './checkpoint_ssd300.pth.tar'

In [33]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# COCO Preprocessing

In [22]:
def process_custom_coco_json(input_path, output_path):
    """
    Funzione per processare un JSON COCO in formato personalizzato.
    """
    # Leggi il JSON dal file di input
    with open(input_path, 'r') as f:
        data = json.load(f)
 
    # Ottieni e correggi il formato delle categorie
    raw_categories = data.get('categories', [])
    categories = []
 
    for category in tqdm(raw_categories, desc="Processing Categories"):
        for id_str, name in category.items():
            try:
                categories.append({"id": int(id_str), "name": name})
            except ValueError:
                print(f"Errore nel parsing della categoria: {category}")
 
    # Trova la categoria "Aircraft" con ID 0
    aircraft_category = next((cat for cat in categories if cat['id'] == 0 and cat['name'] == "Aircraft"), None)
    if aircraft_category:
        aircraft_category['id'] = 11  # Cambia l'ID della categoria "Aircraft" a 11
 
    # Aggiungi la categoria "background" con ID 0 se non esiste
    if not any(cat['id'] == 0 for cat in categories):
        categories.append({"id": 0, "name": "background"})
 
    # Preprocessa le annotazioni in un dizionario per immagini
    image_annotations_dict = {}
    for annotation in tqdm(data.get('annotations', []), desc="Building Image Annotations Dictionary"):
        image_id = annotation['image_id']
        if image_id not in image_annotations_dict:
            image_annotations_dict[image_id] = []
        image_annotations_dict[image_id].append(annotation)
 
    # Lista di nuove annotazioni da aggiungere per immagini senza bbox
    new_annotations = []
 
    # Elenco di annotazioni da rimuovere
    annotations_to_remove = []
 
    for annotation in tqdm(data.get('annotations', []), desc="Processing Annotations"):
        if annotation['category_id'] == 0:  # Se è Aircraft
            annotation['category_id'] = 11
        # Converte il formato del bbox
        if isinstance(annotation['bbox'], str):
            annotation['bbox'] = json.loads(annotation['bbox'])
        x, y, width, height = annotation['bbox']
        xmin = x
        xmax = x + width
        ymin = y
        ymax = y + height
        # Verifica che xmin < xmax e ymin < ymax
        if xmin >= xmax or ymin >= ymax:
            annotations_to_remove.append(annotation['id'])
        else:
            annotation['bbox'] = [xmin, ymin, xmax, ymax]
 
    # Rimuovi le annotazioni non valide
    data['annotations'] = [ann for ann in data['annotations'] if ann['id'] not in annotations_to_remove]
 
    # Verifica se ci sono immagini senza annotazioni (usando il dizionario delle annotazioni)
    for image in tqdm(data.get('images', []), desc="Processing Images"):
        if image['id'] not in image_annotations_dict:  # Se l'immagine non ha annotazioni
            # Aggiungi la categoria "background"
            new_annotation = {
                'id': len(data['annotations']) + len(new_annotations),
                'image_id': image['id'],
                'category_id': 0,  # Categoria background con ID 0
                'area': image['width'] * image['height'],
                'bbox': [0.0, image['width'], 0.0, image['height']],  # Background con bbox che copre tutta l'immagine
                'iscrowd': 0
            }
            new_annotations.append(new_annotation)
 
    # Aggiungi le nuove annotazioni al JSON originale
    data['annotations'].extend(new_annotations)
 
    # Aggiorna le categorie nel JSON
    data['categories'] = categories
 
    # Scrivi il JSON modificato nel file di output
    with open(output_path, 'w') as f:
        json.dump(data, f, indent=4)

In [23]:
process_custom_coco_json(coco_json_path, new_coco_json_path)

Processing Categories: 100%|██████████| 11/11 [00:00<00:00, 108558.46it/s]
Building Image Annotations Dictionary: 100%|██████████| 670213/670213 [00:00<00:00, 946920.69it/s] 
Processing Annotations: 100%|██████████| 670213/670213 [00:03<00:00, 204842.94it/s]
Processing Images: 100%|██████████| 45891/45891 [00:00<00:00, 569362.03it/s]


# Splitting

In [24]:
def split_coco_and_check(coco_file, output_path, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
    # Carica il file COCO
    with open(coco_file, 'r') as f:
        coco_data = json.load(f)

    images = coco_data['images']
    annotations = coco_data['annotations']
    categories = coco_data['categories']

    total_images = len(images)
    total_annotations = len(annotations)

    # Shuffle delle immagini per garantire casualità
    random.shuffle(images)

    # Calcola i numeri di immagini per train, val e test
    train_count = int(train_ratio * total_images)
    val_count = int(val_ratio * total_images)
    test_count = total_images - train_count - val_count

    train_images = images[:train_count]
    val_images = images[train_count:train_count + val_count]
    test_images = images[train_count + val_count:]

    # Crea set di ID immagini
    train_ids = {img['id'] for img in train_images}
    val_ids = {img['id'] for img in val_images}
    test_ids = {img['id'] for img in test_images}

    # Divide le annotazioni
    train_annotations = [ann for ann in annotations if ann['image_id'] in train_ids]
    val_annotations = [ann for ann in annotations if ann['image_id'] in val_ids]
    test_annotations = [ann for ann in annotations if ann['image_id'] in test_ids]

    train_bbox_count = len(train_annotations)
    val_bbox_count = len(val_annotations)
    test_bbox_count = len(test_annotations)

    # Salva i file di output
    def save_split(file_name, images_split):
        with open(file_name, 'w') as f:
            for img in images_split:
                f.write(f"{output_path}/{img['file_name']}\n")

    save_split('train.txt', train_images)
    save_split('val.txt', val_images)
    save_split('test.txt', test_images)

    # Controlla le proporzioni
    check_split_proportions(
        total_images, total_annotations,
        len(train_images), len(val_images), len(test_images),
        train_bbox_count, val_bbox_count, test_bbox_count,
        train_ratio, val_ratio, test_ratio,
        train_annotations, val_annotations, test_annotations,
        categories
    )

def check_split_proportions(total_images, total_annotations, train_count, val_count, test_count, 
                            train_bbox_count, val_bbox_count, test_bbox_count, 
                            train_ratio, val_ratio, test_ratio, 
                            train_annotations, val_annotations, test_annotations, categories):
    # Percentuali per immagini
    train_image_percentage = (train_count / total_images) * 100
    val_image_percentage = (val_count / total_images) * 100
    test_image_percentage = (test_count / total_images) * 100

    # Percentuali per bbox
    train_bbox_percentage = (train_bbox_count / total_annotations) * 100
    val_bbox_percentage = (val_bbox_count / total_annotations) * 100
    test_bbox_percentage = (test_bbox_count / total_annotations) * 100

    print(f"Totale immagini: {total_images}")
    print(f"Totale annotazioni (bbox): {total_annotations}")
    print(f"Train: {train_count} immagini ({train_image_percentage:.2f}%) ({train_bbox_count} bbox) ({train_bbox_percentage:.2f}%)")
    print(f"Val: {val_count} immagini ({val_image_percentage:.2f}%) ({val_bbox_count} bbox) ({val_bbox_percentage:.2f}%)")
    print(f"Test: {test_count} immagini ({test_image_percentage:.2f}%) ({test_bbox_count} bbox) ({test_bbox_percentage:.2f}%)")

    # Calcola il numero di annotazioni per categoria nei vari set
    category_count_train = defaultdict(int)
    category_count_val = defaultdict(int)
    category_count_test = defaultdict(int)

    for annotation in train_annotations:
        category_count_train[annotation['category_id']] += 1
    for annotation in val_annotations:
        category_count_val[annotation['category_id']] += 1
    for annotation in test_annotations:
        category_count_test[annotation['category_id']] += 1

    # Stampa le proporzioni per categoria
    print("\nProporzioni per categoria:")
    for category in categories:
        category_id = category['id']
        category_name = category['name']

        # Conta il numero di annotazioni per categoria in ogni set
        train_cat_count = category_count_train.get(category_id, 0)
        val_cat_count = category_count_val.get(category_id, 0)
        test_cat_count = category_count_test.get(category_id, 0)

        # Calcola la percentuale di annotazioni per categoria
        total_cat_annotations = train_cat_count + val_cat_count + test_cat_count
        if total_cat_annotations > 0:
            train_cat_percentage = (train_cat_count / total_cat_annotations) * 100
            val_cat_percentage = (val_cat_count / total_cat_annotations) * 100
            test_cat_percentage = (test_cat_count / total_cat_annotations) * 100
        else:
            train_cat_percentage = val_cat_percentage = test_cat_percentage = 0.0

        print(f"{category_name}:")
        print(f"  Train: {train_cat_count} annotazioni ({train_cat_percentage:.2f}%)")
        print(f"  Val: {val_cat_count} annotazioni ({val_cat_percentage:.2f}%)")
        print(f"  Test: {test_cat_count} annotazioni ({test_cat_percentage:.2f}%)")

In [25]:
split_coco_and_check(new_coco_json_path, img_dict, train_ratio=0.6, val_ratio=0.2, test_ratio=0.2)

Totale immagini: 45891
Totale annotazioni (bbox): 683897
Train: 27534 immagini (60.00%) (415411 bbox) (60.74%)
Val: 9178 immagini (20.00%) (137906 bbox) (20.16%)
Test: 9179 immagini (20.00%) (132014 bbox) (19.30%)

Proporzioni per categoria:
Aircraft:
  Train: 1038 annotazioni (60.77%)
  Val: 311 annotazioni (18.21%)
  Test: 359 annotazioni (21.02%)
Passenger Vehicle:
  Train: 138548 annotazioni (61.44%)
  Val: 44053 annotazioni (19.54%)
  Test: 42902 annotazioni (19.03%)
Truck:
  Train: 21198 annotazioni (61.47%)
  Val: 6679 annotazioni (19.37%)
  Test: 6608 annotazioni (19.16%)
Railway Vehicle:
  Train: 2696 annotazioni (63.69%)
  Val: 621 annotazioni (14.67%)
  Test: 916 annotazioni (21.64%)
Maritime Vessel:
  Train: 3853 annotazioni (60.88%)
  Val: 1417 annotazioni (22.39%)
  Test: 1059 annotazioni (16.73%)
Engineering Vehicle:
  Train: 3218 annotazioni (58.73%)
  Val: 1226 annotazioni (22.38%)
  Test: 1035 annotazioni (18.89%)
Building:
  Train: 231977 annotazioni (60.13%)
  Val: 

# Dataset Preprocessing

In [30]:
def parse_coco_annotation(annotation_data):
    """
    Converti le annotazioni COCO in una struttura utile.
    """
    boxes = []
    labels = []

    for ann in annotation_data:
        category_id = ann['category_id']  # Usa direttamente il category_id come etichetta
        bbox = ann['bbox']  # I bounding box sono già nel formato [xmin, ymin, xmax, ymax]
        xmin, ymin, xmax, ymax = bbox

        boxes.append([xmin, ymin, xmax, ymax])
        labels.append(category_id)

    return {'boxes': boxes, 'labels': labels}

def create_coco_data_lists(coco_file, splits_path, output_folder):
    """
    Converte i dati COCO e split in liste per train, val e test.
    """
    with open(coco_file, 'r') as f:
        coco_data = json.load(f)

    # Prepara mappature da immagini e annotazioni
    images = {img['file_name']: img for img in coco_data['images']}  # Mappa file_name -> immagine
    annotations_by_image = defaultdict(list)
    for ann in coco_data['annotations']:
        annotations_by_image[ann['image_id']].append(ann)

    # Genera i dati per ciascuno split
    for split in ['train', 'val', 'test']:
        split_file = os.path.join(splits_path, f"{split}.txt")
        with open(split_file, 'r') as f:
            image_files = [line.strip() for line in f.readlines()]

        image_list = []
        objects_list = []

        for image_file in image_files:
            file_name = os.path.basename(image_file)  # Ottieni solo il nome del file
            if file_name not in images:
                continue

            image_info = images[file_name]
            image_id = image_info['id']
            annotations = annotations_by_image[image_id]
            objects = parse_coco_annotation(annotations)

            if not objects['boxes']:
                continue

            image_list.append(image_file)
            objects_list.append(objects)

        # Salva i risultati
        with open(os.path.join(output_folder, f"{split.upper()}_images.json"), 'w') as j:
            json.dump(image_list, j)
        with open(os.path.join(output_folder, f"{split.upper()}_objects.json"), 'w') as j:
            json.dump(objects_list, j)

In [31]:
create_coco_data_lists(new_coco_json_path, output_folder, output_folder)

## Dataloader

In [34]:
import json
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms.functional as FT

class CustomDataset(Dataset):
    """
    A PyTorch Dataset class to be used in a DataLoader for batching.
    """

    def __init__(self, path_image, path_bbox):
        """
        :param path_image: Path to the JSON file containing image paths.
        :param path_bbox: Path to the JSON file containing bounding boxes and labels.
        """
        # Load data
        with open(path_image, 'r') as j:
            self.images = json.load(j)
        with open(path_bbox, 'r') as j:
            self.objects = json.load(j)

        # Ensure the lengths match
        assert len(self.images) == len(self.objects), "Images and annotations must have the same length."

    def __transform(self, image, boxes, labels):
        """
        Apply transformations to the image, boxes, and labels.
        :param image: A PIL Image.
        :param boxes: Bounding boxes as a tensor of dimensions (n_objects, 4).
        :param labels: Labels as a tensor of dimensions (n_objects).
        :return: Transformed image, boxes, and labels.
        """
        def resize(image, boxes, dims=(300, 300)):
            # Resize image
            new_image = FT.resize(image, dims)

            # Normalize bounding boxes
            old_dims = torch.FloatTensor([image.width, image.height, image.width, image.height]).unsqueeze(0)
            new_boxes = boxes / old_dims  # Percent coordinates

            return new_image, new_boxes

        # ImageNet normalization values
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]

        # Resize image and normalize boxes
        image, boxes = resize(image, boxes)

        # Convert image to tensor
        image = FT.to_tensor(image)

        # Normalize image
        image = FT.normalize(image, mean=mean, std=std)

        return image, boxes, labels

    def __getitem__(self, idx):
        """
        Retrieve an image and its corresponding objects.
        :param idx: Index of the data point.
        :return: Transformed image, bounding boxes, and labels.
        """
        # Load image
        image = Image.open(self.images[idx]).convert('RGB')

        # Load objects
        objects = self.objects[idx]
        boxes = torch.FloatTensor(objects['boxes'])  # (n_objects, 4)
        labels = torch.LongTensor(objects['labels'])  # (n_objects)

        # Apply transformations
        image, boxes, labels = self.__transform(image, boxes, labels)

        return image, boxes, labels

    def __len__(self):
        """
        Total number of data points.
        :return: Length of the dataset.
        """
        return len(self.images)

    def collate_fn(self, batch):
        """
        Since each image may have a different number of objects, we need a collate function (to be passed to the DataLoader).

        This describes how to combine these tensors of different sizes. We use lists.

        :param batch: an iterable of N sets from __getitem__()
        :return: a tensor of images, lists of varying-size tensors of bounding boxes, labels.
        """

        images = list()
        boxes = list()
        labels = list()

        for b in batch:
            images.append(b[0])
            boxes.append(b[1])
            labels.append(b[2])

        images = torch.stack(images, dim=0)

        return images, boxes, labels  # tensor (N, 3, 300, 300), 2 lists of N tensors each

In [35]:
train_dataset = CustomDataset(train_image, train_bbox)
val_dataset = CustomDataset(val_image, val_bbox)
test_dataset = CustomDataset(test_image, test_bbox)

# Recupera un campione
image, boxes, labels = train_dataset[0]

# Dimensioni
print("Image shape:", image.shape)  # Torch tensor di dimensione (3, 300, 300)
print("Boxes:", boxes)             # Bounding box normalizzati
print("Labels:", labels)           # Etichette

Image shape: torch.Size([3, 300, 300])
Boxes: tensor([[0.2812, 0.9375, 0.3281, 0.9719],
        [0.0000, -0.0000, 0.1562, 0.2031],
        [0.7219, 0.9062, 0.9281, 1.0000],
        [0.2250, 0.0000, 0.5000, 0.1344],
        [0.5188, 0.8938, 0.6781, 1.0000],
        [0.8875, 0.9250, 1.0031, 0.9969],
        [-0.0000, 0.0000, 0.2656, 0.1094],
        [0.5625, 0.1187, 0.5875, 0.1531],
        [0.5125, 0.1094, 0.5281, 0.1406],
        [0.6438, 0.9312, 0.7719, 1.0000],
        [0.2531, 0.9750, 0.2781, 1.0000]])
Labels: tensor([1, 6, 6, 6, 6, 6, 6, 2, 1, 6, 1])


In [36]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True, 
                                               collate_fn=train_dataset.collate_fn, num_workers=3, pin_memory=True)
val_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True, 
                                             collate_fn=train_dataset.collate_fn, num_workers=3, pin_memory=True)
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True, 
                                              collate_fn=train_dataset.collate_fn, num_workers=3, pin_memory=True)

## Model

fare l'import da GitHub https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection/blob/master/model.py

## Training

In [None]:
def find_intersection(set_1, set_2):
    """
    Find the intersection of every box combination between two sets of boxes that are in boundary coordinates.

    :param set_1: set 1, a tensor of dimensions (n1, 4)
    :param set_2: set 2, a tensor of dimensions (n2, 4)
    :return: intersection of each of the boxes in set 1 with respect to each of the boxes in set 2, a tensor of dimensions (n1, n2)
    """

    # PyTorch auto-broadcasts singleton dimensions
    lower_bounds = torch.max(set_1[:, :2].unsqueeze(1), set_2[:, :2].unsqueeze(0))  # (n1, n2, 2)
    upper_bounds = torch.min(set_1[:, 2:].unsqueeze(1), set_2[:, 2:].unsqueeze(0))  # (n1, n2, 2)
    intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0)  # (n1, n2, 2)
    return intersection_dims[:, :, 0] * intersection_dims[:, :, 1]  # (n1, n2)


def find_jaccard_overlap(set_1, set_2):
    """
    Find the Jaccard Overlap (IoU) of every box combination between two sets of boxes that are in boundary coordinates.

    :param set_1: set 1, a tensor of dimensions (n1, 4)
    :param set_2: set 2, a tensor of dimensions (n2, 4)
    :return: Jaccard Overlap of each of the boxes in set 1 with respect to each of the boxes in set 2, a tensor of dimensions (n1, n2)
    """

    # Find intersections
    intersection = find_intersection(set_1, set_2)  # (n1, n2)

    # Find areas of each box in both sets
    areas_set_1 = (set_1[:, 2] - set_1[:, 0]) * (set_1[:, 3] - set_1[:, 1])  # (n1)
    areas_set_2 = (set_2[:, 2] - set_2[:, 0]) * (set_2[:, 3] - set_2[:, 1])  # (n2)

    # Find the union
    # PyTorch auto-broadcasts singleton dimensions
    union = areas_set_1.unsqueeze(1) + areas_set_2.unsqueeze(0) - intersection  # (n1, n2)

    return intersection / union  # (n1, n2)


def adjust_learning_rate(optimizer, scale):
    """
    Scale learning rate by a specified factor.

    :param optimizer: optimizer whose learning rate must be shrunk.
    :param scale: factor to multiply learning rate with.
    """
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * scale
    print("DECAYING learning rate.\n The new LR is %f\n" % (optimizer.param_groups[1]['lr'],))


def accuracy(scores, targets, k):
    """
    Computes top-k accuracy, from predicted and true labels.

    :param scores: scores from the model
    :param targets: true labels
    :param k: k in top-k accuracy
    :return: top-k accuracy
    """
    batch_size = targets.size(0)
    _, ind = scores.topk(k, 1, True, True)
    correct = ind.eq(targets.view(-1, 1).expand_as(ind))
    correct_total = correct.view(-1).float().sum()  # 0D tensor
    return correct_total.item() * (100.0 / batch_size)


def save_checkpoint(epoch, model, optimizer):
    """
    Save model checkpoint.

    :param epoch: epoch number
    :param model: model
    :param optimizer: optimizer
    """
    state = {'epoch': epoch,
             'model': model,
             'optimizer': optimizer}
    filename = 'checkpoint_ssd300.pth.tar'
    torch.save(state, filename)

def clip_gradient(optimizer, grad_clip):
    """
    Clips gradients computed during backpropagation to avoid explosion of gradients.

    :param optimizer: optimizer with the gradients to be clipped
    :param grad_clip: clip value
    """
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)

In [None]:
class Trainer:
    def __init__(self, model, train_dataset, train_dataloader, criterion, optimizer, batch_size, num_workers, device, 
                 grad_clip=None, print_freq=10, iterations=120000, decay_lr_at=None, decay_lr_to=0.1, 
                 momentum=0.9, weight_decay=5e-4):
        """
        Initialize the Trainer.
        
        :param model: SSD300 model instance
        :param train_dataset: Dataset object
        :param criterion: Loss function
        :param optimizer: Optimizer
        :param batch_size: Training batch size
        :param num_workers: Number of data loading workers
        :param device: Device to use for training ('cuda' or 'cpu')
        :param grad_clip: Gradient clipping value (default: None)
        :param print_freq: Frequency of printing training progress
        :param iterations: Total number of training iterations
        :param decay_lr_at: Iterations to decay learning rate
        :param decay_lr_to: Learning rate decay factor
        :param momentum: Momentum for optimizer
        :param weight_decay: Weight decay for optimizer
        """
        self.model = model
        self.train_dataset = train_dataset
        self.criterion = criterion
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.device = device
        self.grad_clip = grad_clip
        self.print_freq = print_freq
        self.iterations = iterations
        self.decay_lr_at = decay_lr_at if decay_lr_at is not None else [80000, 100000]
        self.decay_lr_to = decay_lr_to
        self.momentum = momentum
        self.weight_decay = weight_decay

        # Prepare dataloader
        self.train_loader = train_dataloader

        # Calculate epochs and decay epochs
        self.epochs = iterations // (len(train_dataset) // 32)
        self.decay_epochs = [it // (len(train_dataset) // 32) for it in self.decay_lr_at]

    def adjust_learning_rate(self, epoch):
        """
        Adjust the learning rate at specific epochs.
        """
        if epoch in self.decay_epochs:
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = param_group['lr'] * self.decay_lr_to
            print(f"Learning rate adjusted to {param_group['lr']} at epoch {epoch}")

    def train_one_epoch(self, epoch):
        """
        Perform one epoch of training.
        """
        self.model.train()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()

        start = time.time()

        for i, (images, boxes, labels, _) in enumerate(self.train_loader):
            data_time.update(time.time() - start)

            # Move to device
            images = images.to(self.device)
            boxes = [b.to(self.device) for b in boxes]
            labels = [l.to(self.device) for l in labels]

            # Forward pass
            predicted_locs, predicted_scores = self.model(images)

            # Compute loss
            loss = self.criterion(predicted_locs, predicted_scores, boxes, labels)

            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()

            # Gradient clipping
            if self.grad_clip is not None:
                clip_gradient(self.optimizer, self.grad_clip)

            # Update model parameters
            self.optimizer.step()

            # Update metrics
            losses.update(loss.item(), images.size(0))
            batch_time.update(time.time() - start)

            start = time.time()

            # Print status
            if i % self.print_freq == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch, i, len(self.train_loader),
                                                                      batch_time=batch_time,
                                                                      data_time=data_time,
                                                                      loss=losses))

        del predicted_locs, predicted_scores, images, boxes, labels

    def save_checkpoint(self, epoch):
        """
        Save model checkpoint.
        """
        torch.save({
            'epoch': epoch,
            'model': self.model,
            'optimizer': self.optimizer,
        }, f'checkpoint_epoch_{epoch}.pth')
        print(f"Checkpoint saved for epoch {epoch}.")

    def train(self, start_epoch=0):
        """
        Train the model across all epochs.
        """
        for epoch in range(start_epoch, self.epochs):
            self.adjust_learning_rate(epoch)
            self.train_one_epoch(epoch)
            self.save_checkpoint(epoch)


In [None]:
# Model parameters
n_classes = len(label_map())  # number of different types of objects

# Learning parameters
checkpoint = None  # path to model checkpoint, None if none
batch_size = 8  # batch size
iterations = 10  # number of iterations to train
workers = 4  # number of workers for loading data in the DataLoader
print_freq = 200  # print training status every __ batches
lr = 1e-3  # learning rate
decay_lr_at = [80000, 100000]  # decay learning rate after these many iterations
decay_lr_to = 0.1  # decay learning rate to this fraction of the existing learning rate
momentum = 0.9  # momentum
weight_decay = 5e-4  # weight decay
grad_clip = None  # clip if gradients are exploding, which may happen at larger batch sizes (sometimes at 32) - you will recognize it by a sorting error in the MuliBox loss calculation


In [None]:
criterion = MultiBoxLoss(priors_cxcy=model.priors_cxcy).to(device)

# Ottimizzatore
biases = [param for name, param in model.named_parameters() if param.requires_grad and name.endswith('.bias')]
not_biases = [param for name, param in model.named_parameters() if param.requires_grad and not name.endswith('.bias')]
optimizer = torch.optim.SGD(params=[{'params': biases, 'lr': 2 * lr}, {'params': not_biases}],
                            lr=lr, momentum=momentum, weight_decay=weight_decay)


In [None]:
# Creazione e avvio del trainer
trainer = Trainer(model, train_dataset, criterion, optimizer, batch_size, workers, device, grad_clip=grad_clip)

trainer.train()

## Testing sulle predizioni

In [None]:
class Evaluator:
    def __init__(self, model, test_dataset, batch_size, num_workers, device):
        """
        Initialize the Evaluator.
        
        :param model: Trained SSD model to be evaluated
        :param test_dataset: Dataset object for testing
        :param batch_size: Batch size for evaluation
        :param num_workers: Number of data loading workers
        :param device: Device to use for evaluation ('cuda' or 'cpu')
        """
        self.model = model.to(device)
        self.test_dataset = test_dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.device = device
        self.pp = PrettyPrinter()  # For printing APs nicely

        # Prepare dataloader
        self.test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=test_dataset.collate_fn,
            num_workers=num_workers,
            pin_memory=True
        )

    def evaluate(self):
        """
        Perform evaluation and compute mAP.
        """
        self.model.eval()

        # Lists to store detected and true boxes, labels, scores
        det_boxes = list()
        det_labels = list()
        det_scores = list()
        true_boxes = list()
        true_labels = list()
        true_difficulties = list()

        with torch.no_grad():
            for i, (images, boxes, labels, difficulties) in enumerate(tqdm(self.test_loader, desc='Evaluating')):
                images = images.to(self.device)

                # Forward pass
                predicted_locs, predicted_scores = self.model(images)

                # Detect objects
                det_boxes_batch, det_labels_batch, det_scores_batch = self.model.detect_objects(
                    predicted_locs, predicted_scores,
                    min_score=0.01, max_overlap=0.45, top_k=200
                )

                # Store this batch's results
                boxes = [b.to(self.device) for b in boxes]
                labels = [l.to(self.device) for l in labels]
                difficulties = [d.to(self.device) for d in difficulties]

                det_boxes.extend(det_boxes_batch)
                det_labels.extend(det_labels_batch)
                det_scores.extend(det_scores_batch)
                true_boxes.extend(boxes)
                true_labels.extend(labels)
                true_difficulties.extend(difficulties)

        # Calculate mAP
        APs, mAP = self.calculate_mAP(det_boxes, det_labels, det_scores, true_boxes, true_labels, true_difficulties)

        # Print AP for each class
        self.pp.pprint(APs)
        print('\nMean Average Precision (mAP): %.3f' % mAP)

    @staticmethod
    def calculate_mAP(det_boxes, det_labels, det_scores, true_boxes, true_labels, true_difficulties):
        """
        Calculate Mean Average Precision (mAP).
        Placeholder for an actual implementation.
        
        :param det_boxes: Detected boxes
        :param det_labels: Detected labels
        :param det_scores: Detected scores
        :param true_boxes: Ground truth boxes
        :param true_labels: Ground truth labels
        :param true_difficulties: Ground truth difficulties
        :return: APs and mAP
        """
        # Replace this with your actual mAP calculation logic
        APs = {f'class_{i}': 0.0 for i in range(1, 21)}  # Dummy values for each class
        mAP = 0.0  # Dummy value for mAP
        return APs, mAP


In [None]:
# Caricamento del modello
checkpoint = torch.load(checkpoint_path)
model = checkpoint['model']


# Creazione e avvio del valutatore
evaluator = Evaluator(model=model, test_dataset=test_dataset, batch_size=64, num_workers=4, device=device)
evaluator.evaluate()


In [None]:
def calculate_mAP(det_boxes, det_labels, det_scores, true_boxes, true_labels, true_difficulties):
    """
    Calculate the Mean Average Precision (mAP) of detected objects.

    See https://medium.com/@jonathan_hui/map-mean-average-precision-for-object-detection-45c121a31173 for an explanation

    :param det_boxes: list of tensors, one tensor for each image containing detected objects' bounding boxes
    :param det_labels: list of tensors, one tensor for each image containing detected objects' labels
    :param det_scores: list of tensors, one tensor for each image containing detected objects' labels' scores
    :param true_boxes: list of tensors, one tensor for each image containing actual objects' bounding boxes
    :param true_labels: list of tensors, one tensor for each image containing actual objects' labels
    :param true_difficulties: list of tensors, one tensor for each image containing actual objects' difficulty (0 or 1)
    :return: list of average precisions for all classes, mean average precision (mAP)
    """
    assert len(det_boxes) == len(det_labels) == len(det_scores) == len(true_boxes) == len(true_labels)  # these are all lists of tensors of the same length, i.e. number of images
    n_classes = len(label_map)

    # Store all (true) objects in a single continuous tensor while keeping track of the image it is from
    true_images = list()
    for i in range(len(true_labels)):
        true_images.extend([i] * true_labels[i].size(0))
    true_images = torch.LongTensor(true_images).to(
        device)  # (n_objects), n_objects is the total no. of objects across all images
    true_boxes = torch.cat(true_boxes, dim=0)  # (n_objects, 4)
    true_labels = torch.cat(true_labels, dim=0)  # (n_objects)

    assert true_images.size(0) == true_boxes.size(0) == true_labels.size(0)

    # Store all detections in a single continuous tensor while keeping track of the image it is from
    det_images = list()
    for i in range(len(det_labels)):
        det_images.extend([i] * det_labels[i].size(0))
    det_images = torch.LongTensor(det_images).to(device)  # (n_detections)
    det_boxes = torch.cat(det_boxes, dim=0)  # (n_detections, 4)
    det_labels = torch.cat(det_labels, dim=0)  # (n_detections)
    det_scores = torch.cat(det_scores, dim=0)  # (n_detections)

    assert det_images.size(0) == det_boxes.size(0) == det_labels.size(0) == det_scores.size(0)

    # Calculate APs for each class (except background)
    average_precisions = torch.zeros((n_classes - 1), dtype=torch.float)  # (n_classes - 1)
    for c in range(1, n_classes):
        # Extract only objects with this class
        true_class_images = true_images[true_labels == c]  # (n_class_objects)
        true_class_boxes = true_boxes[true_labels == c]  # (n_class_objects, 4)
        n_easy_class_objects = (1 - true_class_difficulties).sum().item()  # ignore difficult objects

        # Keep track of which true objects with this class have already been 'detected'
        # So far, none
        true_class_boxes_detected = torch.zeros((true_class_difficulties.size(0)), dtype=torch.uint8).to(
            device)  # (n_class_objects)

        # Extract only detections with this class
        det_class_images = det_images[det_labels == c]  # (n_class_detections)
        det_class_boxes = det_boxes[det_labels == c]  # (n_class_detections, 4)
        det_class_scores = det_scores[det_labels == c]  # (n_class_detections)
        n_class_detections = det_class_boxes.size(0)
        if n_class_detections == 0:
            continue

        # Sort detections in decreasing order of confidence/scores
        det_class_scores, sort_ind = torch.sort(det_class_scores, dim=0, descending=True)  # (n_class_detections)
        det_class_images = det_class_images[sort_ind]  # (n_class_detections)
        det_class_boxes = det_class_boxes[sort_ind]  # (n_class_detections, 4)

        # In the order of decreasing scores, check if true or false positive
        true_positives = torch.zeros((n_class_detections), dtype=torch.float).to(device)  # (n_class_detections)
        false_positives = torch.zeros((n_class_detections), dtype=torch.float).to(device)  # (n_class_detections)
        for d in range(n_class_detections):
            this_detection_box = det_class_boxes[d].unsqueeze(0)  # (1, 4)
            this_image = det_class_images[d]  # (), scalar

            # Find objects in the same image with this class, their difficulties, and whether they have been detected before
            object_boxes = true_class_boxes[true_class_images == this_image]  # (n_class_objects_in_img)
            # If no such object in this image, then the detection is a false positive
            if object_boxes.size(0) == 0:
                false_positives[d] = 1
                continue

            # Find maximum overlap of this detection with objects in this image of this class
            overlaps = find_jaccard_overlap(this_detection_box, object_boxes)  # (1, n_class_objects_in_img)
            max_overlap, ind = torch.max(overlaps.squeeze(0), dim=0)  # (), () - scalars

            # 'ind' is the index of the object in these image-level tensors 'object_boxes', 'object_difficulties'
            # In the original class-level tensors 'true_class_boxes', etc., 'ind' corresponds to object with index...
            original_ind = torch.LongTensor(range(true_class_boxes.size(0)))[true_class_images == this_image][ind]
            # We need 'original_ind' to update 'true_class_boxes_detected'

            # If the maximum overlap is greater than the threshold of 0.5, it's a match
            if max_overlap.item() > 0.5:
                # If this object has already not been detected, it's a true positive
                if true_class_boxes_detected[original_ind] == 0:
                    true_positives[d] = 1
                    true_class_boxes_detected[original_ind] = 1  # this object has now been detected/accounted for
                # Otherwise, it's a false positive (since this object is already accounted for)
                else:
                    false_positives[d] = 1
            # Otherwise, the detection occurs in a different location than the actual object, and is a false positive
            else:
                false_positives[d] = 1

        # Compute cumulative precision and recall at each detection in the order of decreasing scores
        cumul_true_positives = torch.cumsum(true_positives, dim=0)  # (n_class_detections)
        cumul_false_positives = torch.cumsum(false_positives, dim=0)  # (n_class_detections)
        cumul_precision = cumul_true_positives / (
                cumul_true_positives + cumul_false_positives + 1e-10)  # (n_class_detections)
        cumul_recall = cumul_true_positives / n_easy_class_objects  # (n_class_detections)

        # Find the mean of the maximum of the precisions corresponding to recalls above the threshold 't'
        recall_thresholds = torch.arange(start=0, end=1.1, step=.1).tolist()  # (11)
        precisions = torch.zeros((len(recall_thresholds)), dtype=torch.float).to(device)  # (11)
        for i, t in enumerate(recall_thresholds):
            recalls_above_t = cumul_recall >= t
            if recalls_above_t.any():
                precisions[i] = cumul_precision[recalls_above_t].max()
            else:
                precisions[i] = 0.
        average_precisions[c - 1] = precisions.mean()  # c is in [1, n_classes - 1]

    # Calculate Mean Average Precision (mAP)
    mean_average_precision = average_precisions.mean().item()

    # Keep class-wise average precisions in a dictionary
    average_precisions = {rev_label_map[c + 1]: v for c, v in enumerate(average_precisions.tolist())}

    return average_precisions, mean_average_precision
