<h2 style="text-align:center">Space Debris Detection using detection transformer with custom backbone</h3>

In [4]:
import os
import torch
from transformers import DetrForObjectDetection, DetrImageProcessor
from torch.utils.data import DataLoader
import torchvision

In [2]:
dataset_path = "debris_det_dataset"

ANNOTATION_FILE_NAME = "_annotations.coco.json"
TRAIN_DIR = os.path.join(dataset_path, "train")
VAL_DIR = os.path.join(dataset_path, "valid")

In [3]:
class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(self, image_dir_path:str, image_processor, train:bool=True):
        annot_file_path = os.path.join(image_dir_path, ANNOTATION_FILE_NAME)
        super(CocoDetection, self).__init__(image_dir_path, annot_file_path)
        self.image_processor = image_processor

    def __getitem__(self, idx):
        images, annotations = super(CocoDetection, self).__getitem__(idx)
        image_id = self.ids[idx]
        annotations = {'image_id': image_id, 'annotations': annotations}
        encoding = self.image_processor(images=images, annotations=annotations, return_tensors="pt")
        pixel_values = encoding['pixel_values'].squeeze()
        target = encoding['labels'][0]
        return pixel_values, target
    

TRAIN_DATASET = CocoDetection(TRAIN_DIR, DetrImageProcessor.from_pretrained("facebook/detr-resnet-50"), train=True)
VAL_DATASET = CocoDetection(VAL_DIR, DetrImageProcessor.from_pretrained("facebook/detr-resnet-50"), train=False)

print(f"Number of training samples: {len(TRAIN_DATASET)}")
print(f"Number of validation samples: {len(VAL_DATASET)}")

loading annotations into memory...
Done (t=0.11s)
creating index...
index created!
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
Number of training samples: 20000
Number of validation samples: 2000


In [None]:
def collate_fn(batch):
    pixel_values = [item[0] for item in batch]
    encoding = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50").pad(pixel_values, return_tensors="pt")
    labels = [item[1] for item in batch]
    return {'pixel_values': encoding['pixel_values'], 'pixel_mask': encoding['pixel_mask'], 'labels': labels}


TRAIN_DATALOADER = 