In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import os

# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
train_dir = '../data/55_all/'
val_dir = '../data/56_all/'

In [None]:
import torchvision
import os

class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(self, img_folder, processor, train=True):
        ann_file = os.path.join(img_folder, "annotations.json" if train else "annotations.json")
        super().__init__(img_folder, ann_file)
        self.processor = processor

    def __getitem__(self, idx):
        # read in PIL image and target in COCO format
        # feel free to add data augmentation here before passing them to the next step
        img, target = super().__getitem__(idx)

        # preprocess image and target (converting target to DETR format, resizing + normalization of both image and target)
        image_id = self.ids[idx]
        target = {'image_id': image_id, 'annotations': target}
        encoding = self.processor(images=img, annotations=target, return_tensors="pt")
        pixel_values = encoding["pixel_values"].squeeze() # remove batch dimension
        target = encoding["labels"][0] # remove batch dimension

        return pixel_values, target

In [None]:
from transformers import DetrImageProcessor

image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")

In [None]:
train_dataset = CocoDetection(img_folder=train_dir, processor=image_processor)
val_dataset = CocoDetection(img_folder=val_dir, processor=image_processor, train=False)


In [None]:
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(val_dataset))

In [None]:
train_dataset[0]

In [None]:
# import torch

# def collate_fn(batch):
#     data = {}
#     data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch])
#     data["labels"] = [x["labels"] for x in batch]
#     if "pixel_mask" in batch[0]:
#         data["pixel_mask"] = torch.stack([x["pixel_mask"] for x in batch])
#     return data

def collate_fn(batch):
  pixel_values = [item[0] for item in batch]
  encoding = image_processor.pad(pixel_values, return_tensors="pt")
  labels = [item[1] for item in batch]
  batch = {}
  batch['pixel_values'] = encoding['pixel_values']
  batch['pixel_mask'] = encoding['pixel_mask']
  batch['labels'] = labels
  return batch

In [None]:
categories_full = [
    {"id": 0, "name": "biker", "supercategory": "objects"},
    {"id": 1, "name": "car", "supercategory": "objects"},
    {"id": 2, "name": "pedestrian", "supercategory": "objects"},
    {"id": 3, "name": "trafficlight", "supercategory": "objects"},
    {"id": 4, "name": "trafficlight-Green", "supercategory": "objects"},
    {"id": 5, "name": "trafficlight-GreenLeft", "supercategory": "objects"},
    {"id": 6, "name": "trafficlight-Red", "supercategory": "objects"},
    {"id": 7, "name": "trafficlight-RedLeft", "supercategory": "objects"},
    {"id": 8, "name": "trafficlight-Yellow", "supercategory": "objects"},
    {"id": 9, "name": "trafficlight-YellowLeft", "supercategory": "objects"},
    {"id": 10, "name": "truck", "supercategory": "objects"},
    {"id": 11, "name": "Arret", "supercategory": "objects"}
]
categories = [d['name'] for d in categories_full]
categories

In [None]:
import numpy as np
import os
from PIL import Image, ImageDraw

In [None]:
image_ids = train_dataset.coco.getImgIds()
# let's pick a random image
image_id = image_ids[np.random.randint(0, len(image_ids))]
print('Image n°{}'.format(image_id))
image = train_dataset.coco.loadImgs(image_id)[0]
image = Image.open(os.path.join(train_dir, image['file_name']))

annotations = train_dataset.coco.imgToAnns[image_id]
draw = ImageDraw.Draw(image, "RGBA")

cats = train_dataset.coco.cats
id2label = {k: v['name'] for k,v in cats.items()}
label2id = {v['name']: k for k,v in cats.items()}

for annotation in annotations:
  box = annotation['bbox']
  class_idx = annotation['category_id']
  x,y,w,h = tuple(box)
  draw.rectangle((x,y,x+w,y+h), outline='red', width=1)
  draw.text((x, y), id2label[class_idx], fill='white')

image

### Train

In [None]:
batch_size = 8
val_batch_size = 32
lr = 1e-4
weight_decay = 1e-4
lr_backbone = 1e-5
max_steps = 100000 // batch_size


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=val_batch_size)
batch = next(iter(train_dataloader))

In [None]:
batch.keys()

In [None]:
pixel_values, target = train_dataset[0]

In [None]:
pixel_values.shape

In [None]:
print(target)

In [None]:
from dataclasses import make_dataclass
import pytorch_lightning as pl
from transformers import DetrForObjectDetection
import torch

class Detr(pl.LightningModule):
     def __init__(self, lr, lr_backbone, weight_decay):
         super().__init__()
         # replace COCO classification head with custom head
         # we specify the "no_timm" variant here to not rely on the timm library
         # for the convolutional backbone
         self.model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50",
                                                             revision="no_timm",
                                                             num_labels=len(id2label),
                                                             ignore_mismatched_sizes=True)
         # see https://github.com/PyTorchLightning/pytorch-lightning/pull/1896
         self.lr = lr
         self.lr_backbone = lr_backbone
         self.weight_decay = weight_decay
         # self.config = make_dataclass("config", ['id2label', 'label2id'])(id2label=id2label, label2id=label2id)

     def forward(self, pixel_values, pixel_mask):
       outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask)

       return outputs

     def common_step(self, batch, batch_idx):
       pixel_values = batch["pixel_values"]
       pixel_mask = batch["pixel_mask"]
       labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]

       outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)

       loss = outputs.loss
       loss_dict = outputs.loss_dict

       return loss, loss_dict

     def training_step(self, batch, batch_idx):
        loss, loss_dict = self.common_step(batch, batch_idx)
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        for k,v in loss_dict.items():
          self.log("train_" + k, v.item())

        return loss

     def validation_step(self, batch, batch_idx):
        loss, loss_dict = self.common_step(batch, batch_idx)
        self.log("validation_loss", loss)
        for k,v in loss_dict.items():
          self.log("validation_" + k, v.item())

        return loss

     def configure_optimizers(self):
        param_dicts = [
              {"params": [p for n, p in self.named_parameters() if "backbone" not in n and p.requires_grad]},
              {
                  "params": [p for n, p in self.named_parameters() if "backbone" in n and p.requires_grad],
                  "lr": self.lr_backbone,
              },
        ]
        optimizer = torch.optim.AdamW(param_dicts, lr=self.lr,
                                  weight_decay=self.weight_decay)

        return optimizer

     def train_dataloader(self):
        return train_dataloader

     def val_dataloader(self):
        return val_dataloader

In [None]:
model = Detr(lr=1e-4, lr_backbone=1e-5, weight_decay=1e-4)

outputs = model(pixel_values=batch['pixel_values'], pixel_mask=batch['pixel_mask'])

In [None]:
outputs.logits.shape

In [None]:
# from pytorch_lightning import Trainer

# trainer = Trainer(max_steps=max_steps, gradient_clip_val=0.1)
# trainer.fit(model)

### Eval

In [None]:
model = Detr.load_from_checkpoint('./lightning_logs/version_5/checkpoints/epoch=240-step=25305.ckpt', lr=lr, lr_backbone=lr_backbone, weight_decay=weight_decay).to(device)
model.eval();

In [None]:
def convert_to_xywh(boxes):
    xmin, ymin, xmax, ymax = boxes.unbind(1)
    return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)

def prepare_for_coco_detection(predictions):
    coco_results = []
    for original_id, prediction in predictions.items():
        if len(prediction) == 0:
            continue

        boxes = prediction["boxes"]
        boxes = convert_to_xywh(boxes).tolist()
        scores = prediction["scores"].tolist()
        labels = prediction["labels"].tolist()

        coco_results.extend(
            [
                {
                    "image_id": original_id,
                    "category_id": labels[k],
                    "bbox": box,
                    "score": scores[k],
                }
                for k, box in enumerate(boxes)
            ]
        )
    return coco_results

In [None]:
from coco_eval import CocoEvaluator
from tqdm.notebook import tqdm

import numpy as np

# initialize evaluator with ground truth (gt)
evaluator = CocoEvaluator(coco_gt=val_dataset.coco, iou_types=["bbox"])
model = model.to(device)

print("Running evaluation...")
for idx, batch in enumerate(tqdm(val_dataloader)):
    # get the inputs
    pixel_values = batch["pixel_values"].to(device)
    pixel_mask = batch["pixel_mask"].to(device)
    labels = [{k: v.to(device) for k, v in t.items()} for t in batch["labels"]] # these are in DETR format, resized + normalized

    # forward pass
    with torch.no_grad():
      outputs = model(pixel_values=pixel_values, pixel_mask=pixel_mask)

    # turn into a list of dictionaries (one item for each example in the batch)
    orig_target_sizes = torch.stack([target["orig_size"] for target in labels], dim=0)
    results = image_processor.post_process_object_detection(outputs, target_sizes=orig_target_sizes, threshold=0)

    # provide to metric
    # metric expects a list of dictionaries, each item
    # containing image_id, category_id, bbox and score keys
    predictions = {target['image_id'].item(): output for target, output in zip(labels, results)}
    predictions = prepare_for_coco_detection(predictions)
    evaluator.update(predictions)

evaluator.synchronize_between_processes()
evaluator.accumulate()
evaluator.summarize()

In [None]:
#We can use the image_id in target to know which image it is
test_idx = 546
pixel_values, target = val_dataset[test_idx]

In [None]:
pixel_values = pixel_values.unsqueeze(0).to(device)
print(pixel_values.shape)

In [None]:
with torch.no_grad():
  # forward pass to get class logits and bounding boxes
  outputs = model(pixel_values=pixel_values, pixel_mask=None)
print("Outputs:", outputs.keys())

In [None]:
model.model.config.id2label = id2label
model.model.config.label2id = label2id


In [None]:
import matplotlib.pyplot as plt

# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

def plot_results(pil_img, scores, labels, boxes):
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for score, label, (xmin, ymin, xmax, ymax),c  in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))
        text = f'{id2label[label]}: {score:0.2f}'
        ax.text(xmin, ymin, text, fontsize=15,
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    plt.show()

In [None]:
# load image based on ID
image_id = target['image_id'].item()
image = val_dataset.coco.loadImgs(image_id)[0]
image = Image.open(os.path.join(val_dir, image['file_name']))

# postprocess model outputs
width, height = image.size
postprocessed_outputs = image_processor.post_process_object_detection(outputs,
                                                                target_sizes=[(height, width)],
                                                                threshold=0.7)
results = postprocessed_outputs[0]
plot_results(image, results['scores'], results['labels'], results['boxes'])

In [None]:
from transformers import Trainer, TrainingArguments
model_save_path = "ckpts/detr-simple-ft-240-epochs"
trainer = Trainer(model=model.model, processing_class=image_processor, data_collator=collate_fn, args=TrainingArguments(output_dir=model_save_path))
trainer.save_model()