<h2 style="text-align:center">Space Debris Detection using detection transformer with custom backbone</h3>

In [1]:
import os
import torch
import torch.nn as nn
import torchvision
from transformers import DetrForObjectDetection, DetrImageProcessor, DetrConfig
from transformers.models.detr.modeling_detr import DetrSinePositionEmbedding
from torch.utils.data import DataLoader
import timm

from squeezenet_backbone import SqueezeNetBackbone
import pytorch_lightning as pl
from pytorch_lightning import Trainer

torch.set_float32_matmul_precision('medium')

In [2]:
torch.get_float32_matmul_precision()

'medium'

In [3]:
# Add this new class to your notebook cell


class CustomBackboneJoiner(nn.Module):
    def __init__(self, backbone: nn.Module, d_model: int):
        super().__init__()
        self.backbone = backbone
        self.position_embedding = DetrSinePositionEmbedding(embedding_dim=d_model // 2, normalize=True)

    def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
        # Get the feature map list from our SqueezeNet backbone
        # SqueezeNet returns features_list, None
        features_list, _ = self.backbone(pixel_values, pixel_mask)
        
        # Now, calculate positional embeddings for the feature map
        pos_embeddings_list = []
        for feat_tensor, mask in features_list:
            # The position embedding module expects a specific dict format

            pos_embeddings_list.append(self.position_embedding(feat_tensor, mask))
            
        return features_list, pos_embeddings_list

In [4]:
dataset_path = "debris_det_dataset"

ANNOTATION_FILE_NAME = "_annotations.coco.json"
TRAIN_DIR = os.path.join(dataset_path, "train")
VAL_DIR = os.path.join(dataset_path, "valid")

In [5]:
class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(self, image_dir_path:str, image_processor, train:bool=True):
        annot_file_path = os.path.join(image_dir_path, ANNOTATION_FILE_NAME)
        super(CocoDetection, self).__init__(image_dir_path, annot_file_path)
        self.image_processor = image_processor

    def __getitem__(self, idx):
        images, annotations = super(CocoDetection, self).__getitem__(idx)
        image_id = self.ids[idx]
        annotations = {'image_id': image_id, 'annotations': annotations}
        encoding = self.image_processor(images=images, annotations=annotations, return_tensors="pt")
        pixel_values = encoding['pixel_values'].squeeze()
        target = encoding['labels'][0]
        return pixel_values, target
    

TRAIN_DATASET = CocoDetection(TRAIN_DIR, DetrImageProcessor.from_pretrained("facebook/detr-resnet-50"), train=True)
VAL_DATASET = CocoDetection(VAL_DIR, DetrImageProcessor.from_pretrained("facebook/detr-resnet-50"), train=False)

print(f"Number of training samples: {len(TRAIN_DATASET)}")
print(f"Number of validation samples: {len(VAL_DATASET)}")

loading annotations into memory...
Done (t=0.13s)
creating index...
index created!
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
Number of training samples: 20000
Number of validation samples: 2000


In [6]:
def collate_fn(batch):
    pixel_values = [item[0] for item in batch]
    encoding = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50").pad(pixel_values, return_tensors="pt")
    labels = [item[1] for item in batch]
    return {'pixel_values': encoding['pixel_values'], 'pixel_mask': encoding['pixel_mask'], 'labels': labels}


TRAIN_DATALOADER = DataLoader(dataset=TRAIN_DATASET, collate_fn=collate_fn, batch_size=4, shuffle=True, num_workers=4)
VAL_DATALOADER = DataLoader(dataset=VAL_DATASET, collate_fn=collate_fn, batch_size=4, num_workers=4)


In [None]:
class Detr(pl.LightningModule):

    def __init__(self, lr, lr_backbone, weight_decay, backbone_name:str =None, set_pretrained_backbone:bool=True, use_timm_backbone:bool=True):
        super().__init__()

        self.lr = lr
        self.lr_backbone = lr_backbone
        self.weight_decay = weight_decay
        self.use_timm_backbone = use_timm_backbone

        if backbone_name=="squeezenet":
            backbone = SqueezeNetBackbone(pretrained=set_pretrained_backbone)
            print("Using SqueezeNet backbone")

            self.model = DetrForObjectDetection(DetrConfig(num_labels=1, num_queries=10))


            d_model = self.model.config.d_model

            backbone_joiner = CustomBackboneJoiner(backbone, d_model)

            self.model.model.backbone = backbone_joiner
            in_channels = backbone.num_channels
            
            self.model.model.input_projection = torch.nn.Conv2d(in_channels=in_channels, out_channels=d_model, kernel_size=1)
            

        elif backbone_name in timm.list_models():
            print(f"Using {backbone_name} backbone from timm")

            config = DetrConfig(
                backbone=backbone_name,
                use_timm_backbone=use_timm_backbone,
                use_pretrained_backbone=set_pretrained_backbone,
                num_queries=10,
                num_labels=1, # COCO class count + 1 for "no object"
                lr_backbone=lr_backbone,
                lr=lr,
                weight_decay=weight_decay,
            )

            self.model = DetrForObjectDetection(config)
        else:
            print("Using default DETR ResNet-50 backbone")

            self.model = DetrForObjectDetection.from_pretrained(
                "facebook/detr-resnet-50",
                num_labels=2, # COCO class count + 1 for "no object"
                lr_backbone=lr_backbone,
                lr=lr,
                weight_decay=weight_decay,
            )


    def forward(self, pixel_values, pixel_mask): 
        return self.model(pixel_values=pixel_values, pixel_mask=pixel_mask)

    def common_step(self, batch, batch_idx):
        pixel_values = batch["pixel_values"]
        pixel_mask = batch["pixel_mask"]
        labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]

        outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)

        loss = outputs.loss
        loss_dict = outputs.loss_dict

        return loss, loss_dict

    def training_step(self, batch, batch_idx):
        loss, loss_dict = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step, and the average across the epoch
        self.log("training_loss", loss)
        for k,v in loss_dict.items():
            self.log("train_" + k, v.item())

        return loss

    def validation_step(self, batch, batch_idx):
        loss, loss_dict = self.common_step(batch, batch_idx)     
        self.log("validation/loss", loss)
        for k, v in loss_dict.items():
            self.log("validation_" + k, v.item())
            
        return loss

    def configure_optimizers(self):
        # DETR authors decided to use different learning rate for backbone
        # you can learn more about it here: 
        # - https://github.com/facebookresearch/detr/blob/3af9fa878e73b6894ce3596450a8d9b89d918ca9/main.py#L22-L23
        # - https://github.com/facebookresearch/detr/blob/3af9fa878e73b6894ce3596450a8d9b89d918ca9/main.py#L131-L139
        param_dicts = [
            {
                "params": [p for n, p in self.named_parameters() if "backbone" not in n and p.requires_grad]},
            {
                "params": [p for n, p in self.named_parameters() if "backbone" in n and p.requires_grad],
                "lr": self.lr_backbone,
            },
        ]
        return torch.optim.AdamW(param_dicts, lr=self.lr, weight_decay=self.weight_decay)

    def train_dataloader(self):
        return TRAIN_DATALOADER

    def val_dataloader(self):
        return VAL_DATALOADER

In [9]:
model = Detr(lr=1e-4, lr_backbone=1e-5, weight_decay=1e-4, backbone_name="squeezenet", set_pretrained_backbone=True, use_timm_backbone=False)

Using SqueezeNet backbone


In [None]:
batch = next(iter(TRAIN_DATALOADER))

# outputs = model(pixel_values=batch['pixel_values'], pixel_mask=batch['pixel_mask'])
outputs = model(pixel_values=batch['pixel_values'], pixel_mask=batch['pixel_mask'])

In [11]:
outputs.logits.shape

torch.Size([4, 10, 2])

In [12]:
max_epochs = 50

trainer = Trainer(
    devices=1, 
    accelerator="gpu", 
    max_epochs=max_epochs,
    gradient_clip_val=0.1, 
    accumulate_grad_batches=8, 
    log_every_n_steps=5,
    val_check_interval=0.25,
)

trainer.fit(model)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                   | Params | Mode 
---------------------------------------------------------
0 | model | DetrForObjectDetection | 18.4 M | train
---------------------------------------------------------
18.4 M    Trainable params
0         Non-trainable params
18.4 M    Total params
73.412    Total estimated model params size (MB)
248       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\USER X\Desktop\debrisClassifier\venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
c:\Users\USER X\Desktop\debrisClassifier\venv\Lib\site-packages\pytorch_lightning\utilities\data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 4. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
c:\Users\USER X\Desktop\debrisClassifier\venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
