In [None]:
import torch
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image
import numpy as np
import cv2
from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt

In [None]:
!wget -O "twocats.jpg" "http://images.cocodataset.org/val2017/000000039769.jpg"

In [None]:
image = Image.open('twocats.jpg') # 640 x 480 x 3

In [None]:
plt.imshow(image)
plt.axis('off')
plt.show()

In [None]:
image.height, image.width

In [None]:
# load the DETR model pretrained for the COCO dataset
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")

In [None]:
model.config.id2label

In [None]:
model

In [None]:
inputs = processor(images=image, return_tensors="pt")
# inputs 'pixel_values': 1 x 3 x 800 x 1066 'pixel_mask': 1 x 800 x 1066

In [None]:
for key, value in inputs.items():
    print(f'{key}: {value.shape}')

In [None]:
inputs['pixel_mask'].min(), inputs['pixel_mask'].max()

In [None]:
outputs = model(**inputs) # -> 'logits', 'pred_boxes', 'last_hidden_state', 'encoder_last_hidden_state'
# logits: torch.Size([1, 100, 92]), pred_boxes: torch.Size([1, 100, 4]), last_hidden_state: torch.Size([1, 100, 256]), encoder_last_hidden_state: torch.Size([1, 850, 256]), 25*34 = 850
for key, value in outputs.items():
    print(f'{key}: {value.shape if hasattr(value, 'shape') else value}')

In [None]:
outputs['pred_boxes']

In [None]:
# backbone & pixel mask
features, object_queries_list = model.model.backbone(pixel_values=inputs['pixel_values'], pixel_mask=inputs['pixel_mask'])
feature_map, mask = features[-1]
projected_feature_map = model.model.input_projection(feature_map)

In [None]:
feature_map.shape

In [None]:
projected_feature_map.shape

In [None]:
"""
# sinusoidal position encoding (inside backbone)
pixel_values=inputs['pixel_values']
pixel_mask = inputs['pixel_mask']
y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * model.model.backbone.position_embedding.scale # 6.283185307179586
x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * model.model.backbone.position_embedding.scale # 6.283185307179586
# model.model.backbone.position_embedding.embedding_dim 128
dim_t = torch.arange(model.model.backbone.position_embedding.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
# dim_t 0, 1, 2, ..., 127
# model.model.backbone.position_embedding.temperature 10000
dim_t = model.model.backbone.position_embedding.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / model.model.backbone.position_embedding.embedding_dim)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
# pos == object_queries_list
"""

In [None]:
# encoder
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
flattened_mask = mask.flatten(1)
object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
encoder_outputs =model.model.encoder(
    inputs_embeds=flattened_features,
    attention_mask=flattened_mask,
    object_queries=object_queries,
    output_attentions=None,
    output_hidden_states=None,
    return_dict=None,
)
hidden_states = encoder_outputs[0] # 1 x 850 x 256

In [None]:
hidden_states.shape

In [None]:
# decoder
# model.model.query_position_embeddings Embedding(100, 256)
batch_size = hidden_states.shape[0]
query_position_embeddings = model.model.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
# query_position_embeddings 1 x 100 x 256
queries = torch.zeros_like(query_position_embeddings) # 1 x 100 x 256   zeros
decoder_outputs = model.model.decoder(
    inputs_embeds=queries,
    attention_mask=None,
    object_queries=object_queries,
    query_position_embeddings=query_position_embeddings,
    encoder_hidden_states=hidden_states,
    encoder_attention_mask=flattened_mask,
    output_attentions=None,
    output_hidden_states=None,
    return_dict=None,
)
sequence_output = decoder_outputs.last_hidden_state # 1 x 100 x 256

In [None]:
query_position_embeddings.shape

In [None]:
decoder_outputs['last_hidden_state'].shape

In [None]:
 model.bbox_predictor

In [None]:
import inspect
print(inspect.getsource(model.bbox_predictor.forward))

In [None]:
# post-processing

# convert outputs (bounding boxes and class logits) to COCO API
logits = model.class_labels_classifier(sequence_output) # 1 x 100 x 92
boxes = model.bbox_predictor(sequence_output).sigmoid() # 1 x 100 x 4
probabilities = torch.nn.functional.softmax(logits, -1) # 1 x 100 x 92
scores, labels = probabilities[..., :-1].max(-1) # 1 x 100, 1 x 100

# let's only keep detections with score > 0.9
threshold = 0.9
selection = scores > threshold
scores = scores[selection]
labels = labels[selection]
boxes = boxes[selection]
target_sizes = torch.tensor(image.size[::-1])

def convert_boxes(x, width, height):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w)*width, (y_c - 0.5 * h)*height, (x_c + 0.5 * w)*width, (y_c + 0.5 * h)*height]
    return torch.stack(b, dim=-1)

# get boxes
boxes = convert_boxes(boxes, target_sizes[1], target_sizes[0])
print(boxes.int())

In [None]:
# visualization
disp = cv2.cvtColor(np.array(image),cv2.COLOR_RGB2BGR)
for score, label, box in zip(scores, labels, boxes):
    box = [int(i) for i in box.tolist()]
    cv2.rectangle(disp,box[:2],box[2:],(0,255,0),1)
    label_text = model.config.id2label[label.item()]
    cv2.putText(disp,label_text,(box[0]+4,box[1]+16),0,0.8,(0,255,0),1)

cv2_imshow(disp)

In [None]:
#-----------------------------------------------------------------------------------

In [None]:
print(inspect.getsource(model.model.decoder.forward))

In [None]:
# decoder with probes

In [None]:
# model.model.query_position_embeddings Embedding(100, 256)
batch_size = hidden_states.shape[0]
query_position_embeddings = model.model.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
# query_position_embeddings 1 x 100 x 256
queries = torch.zeros_like(query_position_embeddings) # 1 x 100 x 256   zeros

In [None]:
batch_size, query_position_embeddings.shape, queries.shape

In [None]:
decoder_outputs = model.model.decoder(
    inputs_embeds=queries,
    attention_mask=None,
    object_queries=object_queries,
    query_position_embeddings=query_position_embeddings,
    encoder_hidden_states=hidden_states,
    encoder_attention_mask=flattened_mask,
    output_attentions=None,
    output_hidden_states=True,
    return_dict=None,
)

In [None]:
decoder_outputs.keys()

In [None]:
sequence_outputs = torch.stack(decoder_outputs.hidden_states)
print(sequence_outputs.shape)

In [None]:
# find the best detection indices
sequence_output = sequence_outputs[-1]
logits = model.class_labels_classifier(sequence_output) # 1 x 100 x 92
boxes = model.bbox_predictor(sequence_output).sigmoid() # 1 x 100 x 4
probabilities = torch.nn.functional.softmax(logits, -1) # 1 x 100 x 92
scores, labels = probabilities[..., :-1].max(-1) # 1 x 100, 1 x 100
topK_instance = 5
_, indices = scores.topk(topK_instance, sorted=True)
print(indices)

In [None]:
# apply head to the probes
disps = []
for i, sequence_output in enumerate(sequence_outputs):

    # post-processing, convert outputs (bounding boxes and class logits) to COCO API
    logits = model.class_labels_classifier(sequence_output) # 1 x 100 x 92
    boxes = model.bbox_predictor(sequence_output).sigmoid() # 1 x 100 x 4
    probabilities = torch.nn.functional.softmax(logits, -1) # 1 x 100 x 92
    scores, labels = probabilities[..., :-1].max(-1) # 1 x 100, 1 x 100

    # let's only keep detections with score > 0.9
    selection = torch.zeros(scores.shape,dtype=torch.bool)
    selection[:,indices] = True
    scores = scores[selection]
    labels = labels[selection]
    boxes = boxes[selection]
    target_sizes = torch.tensor(image.size[::-1])

    def convert_boxes(x, width, height):
        x_c, y_c, w, h = x.unbind(-1)
        b = [(x_c - 0.5 * w)*width, (y_c - 0.5 * h)*height, (x_c + 0.5 * w)*width, (y_c + 0.5 * h)*height]
        return torch.stack(b, dim=-1)

    # get boxes
    boxes = convert_boxes(boxes, target_sizes[1], target_sizes[0])

    # visualization
    disp = cv2.cvtColor(np.array(image),cv2.COLOR_RGB2BGR)
    colors = [ (0,0,255), (0,255,255), (0,255,0), (255,255,0), (255,0,0), (255,255,255), (180,180,180), (80,255,180), (80,80,255), (255,80,80) ]
    for score, label, box, color in zip(scores, labels, boxes, colors):
        box = [int(i) for i in box.tolist()]
        cv2.rectangle(disp,box[:2],box[2:],color,1)
        label_text = model.config.id2label[label.item()]
        cv2.putText(disp,label_text,(box[0]+4,box[1]+16),0,0.8,color,1)

    disps.append(disp)

In [None]:
# visualize the detection progress trough transformations
cv2_imshow(disps[0])

In [None]:
cv2_imshow(disps[1])

In [None]:
cv2_imshow(disps[2])

In [None]:
cv2_imshow(disps[3])

In [None]:
cv2_imshow(disps[4])

In [None]:
cv2_imshow(disps[5])

In [None]:
cv2_imshow(disps[6])