## Easy Object Detection with Language Modeling: Simple Implementation of Pix2Seq model in PyTorch

![](https://raw.githubusercontent.com/moein-shariatnia/Pix2Seq/master/imgs/pix2seq%20-%20framework.png)

# Introduction

Object detection does not have to be a difficult task! I clearly remember the first time I implemented YOLO from scratch and it was a pain to understand how it works under the hood. For the beginners in computer vision applications, I believe that object detection is the hardest one to understand among classification, segmentation, and etc.

Once I first heard about the paper "[Pix2seq: A Language Modeling Framework for Object Detection](https://arxiv.org/abs/2109.10852)" , I got pretty damn excited and I was sure my next blog post will be about it; so, here I am writing this post and hoping that you'll like it and find the pix2seq model easy to understand and implement.

At the end of this tutorial, you'll learn to implement a simple model for object detection which produces the following results:

![link text](https://raw.githubusercontent.com/moein-shariatnia/Pix2Seq/master/imgs/results3.jpg)

## What's interesting about this paper

The idea is pretty simple: Reframe the object detection problem as a task of text (token) generation! We want the model to "tell us" what objects exist in the image and also the (x, y) coordinates of their bounding boxes (bboxes), all in a specific format in the generated sequence; just like text generation!

![](https://raw.githubusercontent.com/moein-shariatnia/Pix2Seq/master/imgs/pix2seq.png)

As you see, the object detection task is transformed to an image-captioning-ish task: describe the image in text (sequence) but this time tell us exactly where the objects are.

# Pix2Seq: Simple Implementation

## Needed Modules

The closest task to what Pix2Seq does is image-captioning. So, we are going to need an image encoder to convert an image into vectors of hidden representation and then a decoder to take the image representations and those of the previously generated tokens and predict the next token. We also need a tokenizer to convert object classes and coordinates into tokens that form their special vocabulary; just like the words in a natural language.

## My Simple Implementation of Pix2Seq

![](https://raw.githubusercontent.com/moein-shariatnia/Pix2Seq/master/imgs/pix2seq%20-%20framework.png)

You can see the high level pipeline of this project in the picture above. As you see, we need a dataset of images and their bboxes for which we will use Pascal VOC 2012 dataset. Next, we will write our own tokenizer from scratch to convert the bbox classes and coordinates into a sequence of tokens. Then, we will use DeiT [(from this paper)](https://arxiv.org/abs/2012.12877) as our image encoder and feed the image embeddings to a vanilla Transformer Decoder [(from this paper)](https://arxiv.org/abs/1706.03762?amp=1). The decoder's task is to predict the next token given the previous ones. The outputs of the decoder are given to the language modeling loss function.

# Installation

In [1]:
!pip install timm -q
!pip install transformers -q
!pip install huggingface_hub -q

# Imports

In [1]:
import gc
import os
import cv2
import math
import random
from glob import glob
import numpy as np
import pandas as pd
from functools import partial
from tqdm import tqdm
import matplotlib.pyplot as plt

import albumentations as A
import xml.etree.ElementTree as ET
from sklearn.model_selection import StratifiedGroupKFold

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

import timm
from timm.models.layers import trunc_normal_

import transformers
from transformers import top_k_top_p_filtering
from transformers import get_linear_schedule_with_warmup

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(seed=42)

# Config

This will be where we store the most important variables in order to have a quick access to them.

In [41]:
class CFG:
    img_path = "/home/achazhoor/Documents/workspace/pix_2_seq/data/IMAGES"
    xml_path =  "/home/achazhoor/Documents/workspace/pix_2_seq/data/ANNOTATIONS"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    max_len = 300
    img_size = 224
    num_bins = img_size
    
    batch_size = 4
    epochs = 100
    
    model_name = 'deit3_medium_patch16_224.fb_in22k_ft_in1k'
    num_patches = 196
    lr = 1e-5
    weight_decay = 1e-4

    generation_steps = 101
    l1_lambda = 1e-7             #kept high as the model is complex
    patience = 12

# Download and Extract data

In [42]:
IMG_FILES = glob(CFG.img_path + "/*.jpg")
XML_FILES = glob(CFG.xml_path + "/*.xml")
len(XML_FILES), len(IMG_FILES)

(1800, 1800)

## Process XML files and build the dataframe

In [43]:
class XMLParser:
    def __init__(self,xml_file):

        self.xml_file = xml_file
        self._root = ET.parse(self.xml_file).getroot()
        self._objects = self._root.findall("object")
        # path to the image file as describe in the xml file
        self.img_path = os.path.join(CFG.img_path, self._root.find('filename').text)
        # image id 
        self.image_id = self._root.find("filename").text
        # names of the classes contained in the xml file
        self.names = self._get_names()
        # coordinates of the bounding boxes
        self.boxes = self._get_bndbox()

    def parse_xml(self):
        """"Parse the xml file returning the root."""
    
        tree = ET.parse(self.xml_file)
        return tree.getroot()

    def _get_names(self):

        names = []
        for obj in self._objects:
            name = obj.find("name")
            names.append(name.text)

        return np.array(names)

    def _get_bndbox(self):

        boxes = []
        for obj in self._objects:
            coordinates = []
            bndbox = obj.find("bndbox")
            coordinates.append(np.int32(bndbox.find("xmin").text))
            coordinates.append(np.int32(np.float32(bndbox.find("ymin").text)))
            coordinates.append(np.int32(bndbox.find("xmax").text))
            coordinates.append(np.int32(bndbox.find("ymax").text))
            boxes.append(coordinates)

        return np.array(boxes)

def xml_files_to_df(xml_files):
    
    """"Return pandas dataframe from list of XML files."""
    
    names = []
    boxes = []
    image_id = []
    xml_path = []
    img_path = []
    for f in xml_files:
        xml = XMLParser(f)
        names.extend(xml.names)
        boxes.extend(xml.boxes)
        image_id.extend([xml.image_id] * len(xml.names))
        xml_path.extend([xml.xml_file] * len(xml.names))
        img_path.extend([xml.img_path] * len(xml.names))
    a = {"image_id": image_id,
         "names": names,
         "boxes": boxes,
         "xml_path":xml_path,
         "img_path":img_path}
    
    df = pd.DataFrame.from_dict(a, orient='index')
    df = df.transpose()
    
    df['xmin'] = -1
    df['ymin'] = -1
    df['xmax'] = -1
    df['ymax'] = -1

    df[['xmin','ymin','xmax','ymax']] = np.stack([df['boxes'][i] for i in range(len(df['boxes']))])

    df.drop(columns=['boxes'], inplace=True)
    df['xmin'] = df['xmin'].astype('float32')
    df['ymin'] = df['ymin'].astype('float32')
    df['xmax'] = df['xmax'].astype('float32')
    df['ymax'] = df['ymax'].astype('float32')
    
    df['id'] = df['image_id'].map(lambda x: x.split(".jpg")[0])
    
    return df

def build_df(xml_files):
    # parse xml files and create pandas dataframe
    df = xml_files_to_df(xml_files)
    

    classes = sorted(df['names'].unique())
    cls2id = {cls_name: i for i, cls_name in enumerate(classes)}
    df['label'] = df['names'].map(cls2id)
    
    # in this df, each object of a given image is in a separate row
    df = df[['id', 'label', 'xmin', 'ymin', 'xmax', 'ymax', 'img_path']]
    
    return df, classes

In [44]:
df, classes = build_df(XML_FILES)
cls2id = {cls_name: i for i, cls_name in enumerate(classes)}
id2cls = {i: cls_name for i, cls_name in enumerate(classes)}

print(len(classes))
df.head()

6


Unnamed: 0,id,label,xmin,ymin,xmax,ymax,img_path
0,patches_184,2,34.0,96.0,137.0,186.0,/home/achazhoor/Documents/workspace/pix_2_seq/...
1,patches_184,2,89.0,1.0,126.0,51.0,/home/achazhoor/Documents/workspace/pix_2_seq/...
2,patches_184,2,161.0,1.0,200.0,21.0,/home/achazhoor/Documents/workspace/pix_2_seq/...
3,patches_184,1,84.0,62.0,110.0,99.0,/home/achazhoor/Documents/workspace/pix_2_seq/...
4,patches_184,1,92.0,185.0,121.0,200.0,/home/achazhoor/Documents/workspace/pix_2_seq/...


In [45]:


df['img_path'] = df['img_path'].apply(lambda x: x if x.lower().endswith('.jpg') else f"{x}.jpg")
#train_df = train_df[train_df['ids'] != 'patches_211.jpg']
#print("THis is inside the trail", df.head(11))
# df.to_excel("filtered_dataframe.xlsx", index=False)

# print("DataFrame has been exported to 'filtered_dataframe.xlsx'")

# Initialize an empty list to store paths that don't exist
missing_paths = []

# Initialize an empty list to store paths that don't exist
missing_paths = []

# Loop through each image path in the DataFrame
for index, row in df.iterrows():
    img_path = row['img_path']
    if not os.path.exists(img_path):
        missing_paths.append(img_path)

# Print the list of missing paths and their total number
if missing_paths:
    print(f"The total number of missing image paths is {len(missing_paths)}.")
    print("The following image paths do not exist:")
    for path in missing_paths:
        print(path)
else:
    print("All image paths in the DataFrame exist.")

len(df)



All image paths in the DataFrame exist.


4189

## Split dataframe to train and validation sets

In [46]:
def split_df(df, n_folds=17, training_fold=0):
    mapping = df.groupby("id")['img_path'].agg(len).to_dict()
    df['stratify'] = df['id'].map(mapping)

    kfold = StratifiedGroupKFold(
        n_splits=n_folds, shuffle=True, random_state=42)

    for i, (_, val_idx) in enumerate(kfold.split(df, y=df['stratify'], groups=df['id'])):
        df.loc[val_idx, 'fold'] = i

    train_df = df[df['fold'] != training_fold].reset_index(drop=True)
    valid_df = df[df['fold'] == training_fold].reset_index(drop=True)

    return train_df, valid_df

In [47]:
train_df, valid_df = split_df(df)
print("Train size: ", train_df['id'].nunique())
print("Valid size: ", valid_df['id'].nunique())

Train size:  1694
Valid size:  106


# Building Dataset and Data Loaders

As I mentioned earlier, we will use VOC 2012 dataset with 17125 images and their corresponding objects from 20 classes. The paper uses COCO dataset which is an order of magnitude larger than VOC and they also pre-train the models on a much larger dataset before training on COCO. But, to stay simple, I'm gonna use this rather small VOC dataset.

```classes = [
  "aeroplane", "bicycle", "bird", "boat", "bottle", 
  "bus", "car", "cat", "chair", "cow", 
  "diningtable", "dog", "horse", "motorbike", "person"
  "pottedplant", "sheep", "sofa", "train", "tvmonitor"
]```

In [48]:
import albumentations as A

def get_transform_train(size):
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.RandomBrightnessContrast(p=0.2),
        A.GaussianBlur(blur_limit=(3, 7), p=0.2),
        #A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.2),
        A.Resize(size, size),
        A.Normalize(),
    ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})

def get_transform_valid(size):
    return A.Compose([
        A.Resize(size, size),
        A.Normalize(),
    ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})


We need a PyTorch dataset class that gives us an image and its bbox coordinates and classes in form of a sequence.

In [49]:
class VOCDataset(torch.utils.data.Dataset):
    def __init__(self, df, transforms=None, tokenizer=None):
        self.ids = df['id'].unique()
        self.df = df
        self.transforms = transforms
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        sample = self.df[self.df['id'] == self.ids[idx]]
        img_path = sample['img_path'].values[0]

        img = cv2.imread(img_path)[..., ::-1]
        labels = sample['label'].values
        bboxes = sample[['xmin', 'ymin', 'xmax', 'ymax']].values

        if self.transforms is not None:
            transformed = self.transforms(**{
                'image': img,
                'bboxes': bboxes,
                'labels': labels
            })
            img = transformed['image']
            bboxes = transformed['bboxes']
            labels = transformed['labels']

        img = torch.FloatTensor(img).permute(2, 0, 1)

        if self.tokenizer is not None:
            seqs = self.tokenizer(labels, bboxes)
            seqs = torch.LongTensor(seqs)
            return img, seqs

        return img, labels, bboxes

    def __len__(self):
        return len(self.ids)

As you see, most of the code here is what you expect from a simple dataset for classification but there are small differences too. We need a Tokenizer to convert our labels and bbox coordinates (x and y) to a sequence so that we can perform train our model for the language modeling task (predicting the next tokens conditioned on the previously seen tokens).

## Tokenizer

How are we going to convert these information into a sequence? Well, it's not that difficult. To represent an object in an image, we need 5 numbers: 4 coordinate numbers and 1 to indicate which class it belongs to. 
You actually need to know the coordinates of 2 points of a bounding box to be able to draw it in an image; in pascal format, we use the top left point and the bottom right point of the bbox as those 2 critical points and each point is represented by its x and y values → so, we will need 4 numbers overall to draw a bounding box. You can see alternative formats to represent a bounding box down below. Also, look at where the start of x and y axis is (the 0, 0 point).

![](https://albumentations.ai/docs/images/getting_started/augmenting_bboxes/bbox_example.jpg)

As you see in the dataset's code, we give the bbox coordinates and labels to our tokenizer and get a simple list of tokens out. The tokenizer needs to do the following tasks:
1. mark the start and end of the sequence w/ especial tokens (BOS and EOS tokens).
2. quantize the continuous value of coordinates (we can have x=34.7 as the coordinate of a point but we need discrete values like 34 as our tokens because we are finally doing a classification on a finite set of tokens)
3. encode the label of the objects into their corresponding tokens
4. randomize the order of objects in the final sequence (more on this below)

If you are familiar with NLP applications, these steps might sound familiar to you as they are also done when we are dealing with words in a natural language; we need to tokenize them and assign each word to its own discrete token, mark the start and end of the sequence, etc.
Regarding the number 4 in this list, this is what the paper does and there is an extensive ablation study on whether it is a good idea. What it says is that each time that we show the same image to the model (in different epochs), we randomize the order in which the objects appear in the corresponding sequence which we feed to the model (with one token shifted) and our loss function. For example, if there is a "person", a "car", and a "cat" in an image, the tokenizer and dataset will put these objects in random order in the sequence: 
- BOS, car_xmin, car_ymin, car_xmax, car_ymax, car_label, person_xmin, person_ymin, person_xmax, person_ymax, person_label, cat_xmin, cat_ymin, cat_xmax, cat_ymax, cat_label, EOS
- BOS, person_xmin, person_ymin, person_xmax, person_ymax, person_label, car_xmin, car_ymin, car_xmax, car_ymax, car_label, cat_xmin, cat_ymin, cat_xmax, cat_ymax, cat_label, EOS
- …

In [50]:
class Tokenizer:
    def __init__(self, num_classes: int, num_bins: int, width: int, height: int, max_len=500):
        self.num_classes = num_classes
        self.num_bins = num_bins
        self.width = width
        self.height = height
        self.max_len = max_len

        self.BOS_code = num_classes + num_bins
        self.EOS_code = self.BOS_code + 1
        self.PAD_code = self.EOS_code + 1

        self.vocab_size = num_classes + num_bins + 3

    def quantize(self, x: np.array):
        """
        x is a real number in [0, 1]
        """
        return (x * (self.num_bins - 1)).astype('int')
    
    def dequantize(self, x: np.array):
        """
        x is an integer between [0, num_bins-1]
        """
        return x.astype('float32') / (self.num_bins - 1)

    def __call__(self, labels: list, bboxes: list, shuffle=True):
        assert len(labels) == len(bboxes), "labels and bboxes must have the same length"
        bboxes = np.array(bboxes)
        labels = np.array(labels)
        labels += self.num_bins
        labels = labels.astype('int')[:self.max_len]

        bboxes[:, 0] = bboxes[:, 0] / self.width
        bboxes[:, 2] = bboxes[:, 2] / self.width
        bboxes[:, 1] = bboxes[:, 1] / self.height
        bboxes[:, 3] = bboxes[:, 3] / self.height

        bboxes = self.quantize(bboxes)[:self.max_len]

        if shuffle:
            rand_idxs = np.arange(0, len(bboxes))
            np.random.shuffle(rand_idxs)
            labels = labels[rand_idxs]
            bboxes = bboxes[rand_idxs]

        tokenized = [self.BOS_code]
        for label, bbox in zip(labels, bboxes):
            tokens = list(bbox)
            tokens.append(label)

            tokenized.extend(list(map(int, tokens)))
        tokenized.append(self.EOS_code)

        return tokenized    
    
    def decode(self, tokens: torch.tensor):
        """
        toekns: torch.LongTensor with shape [L]
        """
        mask = tokens != self.PAD_code
        tokens = tokens[mask]
        tokens = tokens[1:-1]
        assert len(tokens) % 5 == 0, "invalid tokens"

        labels = []
        bboxes = []
        for i in range(4, len(tokens)+1, 5):
            label = tokens[i]
            bbox = tokens[i-4: i]
            labels.append(int(label))
            bboxes.append([int(item) for item in bbox])
        labels = np.array(labels) - self.num_bins
        bboxes = np.array(bboxes)
        bboxes = self.dequantize(bboxes)
        
        bboxes[:, 0] = bboxes[:, 0] * self.width
        bboxes[:, 2] = bboxes[:, 2] * self.width
        bboxes[:, 1] = bboxes[:, 1] * self.height
        bboxes[:, 3] = bboxes[:, 3] * self.height
        
        return labels, bboxes

Another note on how to quantize the continuous values of coordinates: imagine that the image size is 224. You can have a bbox with these 4 coordinates (12.2, 35.8, 68.1, 120.5). 
You will need at least 224 tokens (num_bins) to be able to tokenize (quantize) these 4 numbers with a precision of 1 pixel (you will lose information below 1 pixel). As you see in the tokenizer code, to convert this bbox coordinates to their tokenized version, we need to do the following:
1. normalize the coordinates (make them between 0 and 1 by dividing them by the max value = 224)
2. do this: ```int(x * (num_bins-1))```

so, the converted version will be: (12, 35, 67, 119). Remember that int() function in Python does not round the number to the closest integer, but it will keep only the integer part of the number. As you see, we have lost some information on the exact position of the bbox but it is still a very good approximation. We can use a larger number of tokens (num of bins, as stated in the paper) and we will have a more precise location. Our tokenizer also has decode() function which we will use to convert sequences into bbox coordinates and labels.

In [51]:
tokenizer = Tokenizer(num_classes=len(classes), num_bins=CFG.num_bins,
                          width=CFG.img_size, height=CFG.img_size, max_len=CFG.max_len)
CFG.pad_idx = tokenizer.PAD_code

In [52]:
classes

['crazing',
 'inclusion',
 'patches',
 'pitted_surface',
 'rolled-in_scale',
 'scratches']

## Collate Function

Here, we will implement a custom collate_function to give to our PyTorch data loader. This function will take care of Padding for us: to make all the sequences the same length by adding PAD_IDX to the shorter ones in order to be able to build a batch with them. We are going to pad the sequence to a fixed max length of 300 tokens.

In [53]:
def collate_fn(batch, max_len, pad_idx):
    """
    if max_len:
        the sequences will all be padded to that length
    """
    image_batch, seq_batch = [], []
    for image, seq in batch:
        image_batch.append(image)
        seq_batch.append(seq)

    seq_batch = pad_sequence(
        seq_batch, padding_value=pad_idx, batch_first=True)
    if max_len:
        pad = torch.ones(seq_batch.size(0), max_len -
                         seq_batch.size(1)).fill_(pad_idx).long()
        seq_batch = torch.cat([seq_batch, pad], dim=1)
    image_batch = torch.stack(image_batch)
    return image_batch, seq_batch

In [54]:
def get_loaders(train_df, valid_df, tokenizer, img_size, batch_size, max_len, pad_idx, num_workers=2):

    train_ds = VOCDataset(train_df, transforms=get_transform_train(
        img_size), tokenizer=tokenizer)

    trainloader = torch.utils.data.DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=partial(collate_fn, max_len=max_len, pad_idx=pad_idx),
        num_workers=num_workers,
        pin_memory=True,
    )

    valid_ds = VOCDataset(valid_df, transforms=get_transform_valid(
        img_size), tokenizer=tokenizer)

    validloader = torch.utils.data.DataLoader(
        valid_ds,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=partial(collate_fn, max_len=max_len, pad_idx=pad_idx),
        num_workers=2,
        pin_memory=True,
    )

    return trainloader, validloader

In [55]:
train_loader, valid_loader = get_loaders(
        train_df, valid_df, tokenizer, CFG.img_size, CFG.batch_size, CFG.max_len, tokenizer.PAD_code)

# Models

Finally arrived to the coolest part for every deep learning lover: The Model 😍
Let's take a second look at the first image of this tutorial. First, we will need an encoder to take the input image and give us some embeddings (representations). The paper uses a ResNet50 (and also in other experiments uses ViT) but I decided to use DeiT. As the name suggests, this is a data efficient vision transformer and I thought it would be a good fit for our small dataset. Like ViT, it splits the image into patches and processes them like words in a sentence which again could be great for our task, as we will have a separate embedding for each of these patches and we can give them to our decoder in the next section to predict the target sequence (see it like translation from English to French, where our image is like a sentence in English and our target sequence containing the coordinates and labels of bboxes is like the equivalent sentence in French).
I will use timm library to implement a pre-trained DeiT model.

In [56]:
class Encoder(nn.Module):
    def __init__(self, model_name='deit3_medium_patch16_224.fb_in22k_ft_in1k', pretrained=False, out_dim=256):
        super().__init__()
        self.model = timm.create_model(
            model_name, num_classes=0, global_pool='', pretrained=pretrained)
        self.bottleneck = nn.AdaptiveAvgPool1d(out_dim)

    def forward(self, x):
        features = self.model(x)
        return self.bottleneck(features[:, 1:])

Resnext Encoder

The bottleneck layer is to reduce the number of features of these embeddings to that of the decoder. The paper used a decoder dim of 256 and that's the reason why I am reducing it here using Average Pooling. Also, the first token in this model relates to the CLS token and I am skipping it in the forward method (```features[:, 1:]```).

In [57]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, encoder_length, dim, num_heads, num_layers):
        super().__init__()
        self.dim = dim
        
        self.embedding = nn.Embedding(vocab_size, dim)
        self.decoder_pos_embed = nn.Parameter(torch.randn(1, CFG.max_len-1, dim) * .02)
        self.decoder_pos_drop = nn.Dropout(p=0.05)
        
        decoder_layer = nn.TransformerDecoderLayer(d_model=dim, nhead=num_heads)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.output = nn.Linear(dim, vocab_size)
        
        
        self.encoder_pos_embed = nn.Parameter(torch.randn(1, encoder_length, dim) * .02)
        self.encoder_pos_drop = nn.Dropout(p=0.05)
        
        self.init_weights()
        
    def init_weights(self):
        for name, p in self.named_parameters():
            if 'encoder_pos_embed' in name or 'decoder_pos_embed' in name: 
                print("skipping pos_embed...")
                continue
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
                
        trunc_normal_(self.encoder_pos_embed, std=.02)
        trunc_normal_(self.decoder_pos_embed, std=.02)
        
    
    def forward(self, encoder_out, tgt):
        """
        encoder_out: shape(N, L, D)
        tgt: shape(N, L)
        """
        
        tgt_mask, tgt_padding_mask = create_mask(tgt)
        tgt_embedding = self.embedding(tgt)
        tgt_embedding = self.decoder_pos_drop(
            tgt_embedding + self.decoder_pos_embed
        )
        
        encoder_out = self.encoder_pos_drop(
            encoder_out + self.encoder_pos_embed
        )
        
        encoder_out = encoder_out.transpose(0, 1)
        tgt_embedding = tgt_embedding.transpose(0, 1)
        
        preds = self.decoder(memory=encoder_out, 
                             tgt=tgt_embedding,
                             tgt_mask=tgt_mask, 
                             tgt_key_padding_mask=tgt_padding_mask)
        
        preds = preds.transpose(0, 1)
        return self.output(preds)
    
    def predict(self, encoder_out, tgt):
        length = tgt.size(1)
        padding = torch.ones(tgt.size(0), CFG.max_len-length-1).fill_(CFG.pad_idx).long().to(tgt.device)
        tgt = torch.cat([tgt, padding], dim=1)
        tgt_mask, tgt_padding_mask = create_mask(tgt)
        # is it necessary to multiply it by math.sqrt(d) ?
        tgt_embedding = self.embedding(tgt)
        tgt_embedding = self.decoder_pos_drop(
            tgt_embedding + self.decoder_pos_embed
        )
        
        encoder_out = self.encoder_pos_drop(
            encoder_out + self.encoder_pos_embed
        )
        
        encoder_out = encoder_out.transpose(0, 1)
        tgt_embedding = tgt_embedding.transpose(0, 1)
        
        preds = self.decoder(memory=encoder_out, 
                             tgt=tgt_embedding,
                             tgt_mask=tgt_mask, 
                             tgt_key_padding_mask=tgt_padding_mask)
        
        preds = preds.transpose(0, 1)
        return self.output(preds)[:, length-1, :]

Our decoder takes the patch embeddings of the input image and learns to predict the sequence containing bboxes. Here I am using PyTorch nn.TransformerDecoder module to implement a 6 layer decoder with a feature dimension of 256. We also need to add positional embeddings to the embeddings so that the model knows about each token's position in the sequence (I am adding positional embedding for both encoder tokens and decoder tokens. While we have to do this for the decoder, we might not need to add them to the encoder tokens as the DeiT model knows about the order of patches itself). I am doing this by those nn.Parameter modules which will learn 1 parameter per token position. Finally, we will use a nn.Linear layer to predict the next token from our vocabulary.
The ```create_mask()``` function (you will see its definition in tge next section named **Utils**) gives us two masks needed for training the decoder: one to tell the model to ignore the PAD tokens and do not incorporate them in its attention modules and another to mask the future tokens in order to make the decoder predict tokens only by looking at the current token and the previous ones.

The decoder's predict method takes the previously generated tokens, pads them to the max_length and predicts the next token for each sequence in the batch and returns those new tokens.

In [58]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, image, tgt):
        encoder_out = self.encoder(image)
        preds = self.decoder(encoder_out, tgt)
        return preds
    def predict(self, image, tgt):
        encoder_out = self.encoder(image)
        preds = self.decoder.predict(encoder_out, tgt)
        return preds

This is a simple class encapsulating the encoder and decoder. It also has a predict function which calls the predict function of Decoder to detect objects in an image.

In [59]:
encoder = Encoder(model_name=CFG.model_name, pretrained=True, out_dim=256)
decoder = Decoder(vocab_size=tokenizer.vocab_size,
                  encoder_length=CFG.num_patches, dim=256, num_heads=4, num_layers=2)
model = EncoderDecoder(encoder, decoder)
model.to(CFG.device);

skipping pos_embed...
skipping pos_embed...


# Train and Eval

Now let's see how we can train this model. Most of the following code is just standard PyTorch training boilerplate but there is a simple but important point in it. As mentioned earlier, we train the model like a language model (GPT for e.g.) and it works like this → the model needs to predict the next token only seeing the previous ones (tokens to the left). At the start, it only sees the BOS sentence and it needs to predict the next token, and so on and so forth. And this is achieved simply by this part:
1. ```y_input = y[:, :-1]```
2. ```y_expected = y[:, 1:]```
3. ```preds = model(x, y_input)```

In [60]:
def train_epoch(model, train_loader, optimizer, lr_scheduler, criterion, logger=None):
    model.train()
    loss_meter = AvgMeter()
    tqdm_object = tqdm(train_loader, total=len(train_loader))

    l1_lambda = CFG.l1_lambda  # L1 regularization strength

    for x, y in tqdm_object:
        x, y = x.to(CFG.device, non_blocking=True), y.to(CFG.device, non_blocking=True)
        
        y_input = y[:, :-1]
        y_expected = y[:, 1:]
        
        preds = model(x, y_input)
        loss = criterion(preds.reshape(-1, preds.shape[-1]), y_expected.reshape(-1))

        # Calculate L1 regularization
        l1_norm = sum(p.abs().sum() for p in model.parameters())
        total_loss = loss + l1_lambda * l1_norm

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()

        loss_meter.update(total_loss.item(), x.size(0))  # Update with total loss including L1 penalty

        lr = get_lr(optimizer)
        tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=f"{lr:.6f}")
        if logger is not None:
            logger.log({"train_step_loss": loss_meter.avg, 'lr': lr})
    
    return loss_meter.avg


In [61]:
from sklearn.metrics import f1_score, recall_score, confusion_matrix

def valid_epoch(model, valid_loader, criterion):
    model.eval()
    loss_meter = AvgMeter()
    all_preds = []
    all_targets = []
    tqdm_object = tqdm(valid_loader, total=len(valid_loader))

    with torch.no_grad():
        for x, y in tqdm_object:
            x, y = x.to(CFG.device, non_blocking=True), y.to(CFG.device, non_blocking=True)
            y_input = y[:, :-1]
            y_expected = y[:, 1:]

            preds = model(x, y_input)
            loss = criterion(preds.reshape(-1, preds.shape[-1]), y_expected.reshape(-1))
            loss_meter.update(loss.item(), x.size(0))

            predicted_classes = preds.argmax(dim=-1)
            all_preds.extend(predicted_classes.reshape(-1).cpu().numpy())
            all_targets.extend(y_expected.reshape(-1).cpu().numpy())

    # Calculate F1 score, sensitivity, and specificity
    f1 = f1_score(all_targets, all_preds, average='macro')
    #sensitivity = recall_score(all_targets, all_preds, average='macro')
    sensitivity = recall_score(all_targets, all_preds, average='macro', zero_division=1)


    # Specificity calculation
    cm = confusion_matrix(all_targets, all_preds)
    tn = cm.sum(axis=1) - np.diag(cm)  
    fp = cm.sum(axis=0) - np.diag(cm)
    epsilon = 1e-6  # Small value to avoid division by zero
    specificity_per_class = (tn + epsilon) / (tn + fp + epsilon)

    average_specificity = np.nanmean(specificity_per_class)

    return loss_meter.avg, f1, sensitivity, average_specificity


In [62]:
def train_eval(model, train_loader, valid_loader, criterion, optimizer, lr_scheduler, step, logger):
    best_loss = float('inf')
    epochs_since_improvement = 0  # Counter for epochs since last improvement
    patience = CFG.patience  # Set the patience for early stopping

    for epoch in range(CFG.epochs):
        print(f"Epoch {epoch + 1}")
        if logger is not None:
            logger.log({"Epoch": epoch + 1})

        train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler if step == 'batch' else None, criterion, logger=logger)
        valid_loss, f1, sensitivity, average_specificity = valid_epoch(model, valid_loader, criterion)        
        print(f"Valid Loss: {valid_loss:.3f}, F1 Score: {f1:.3f}, Sensitivity: {sensitivity:.3f}, Specificity: {average_specificity:.3f}")
        if valid_loss < best_loss:
            best_loss = valid_loss
            epochs_since_improvement = 0
            torch.save(model.state_dict(), 'best_valid_loss.pth')
            print("Saved Best Model")
            if logger is not None:
                logger.save('best_valid_loss.pth')
        else:
            epochs_since_improvement += 1

        if logger is not None:
            logger.log({'train_loss': train_loss, 'valid_loss': valid_loss})

        if epochs_since_improvement >= patience:
            print("Early stopping triggered")
            break


## Utils

In [63]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=CFG.device))
            == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float(
        '-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(tgt):
    """
    tgt: shape(N, L)
    """
    tgt_seq_len = tgt.shape[1]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    tgt_padding_mask = (tgt == CFG.pad_idx)

    return tgt_mask, tgt_padding_mask


class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg, self.sum, self.count = [0]*3

    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count

    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]

In [64]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalCrossEntropy(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='none'):
        super(FocalCrossEntropy, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Number of classes
        num_classes = inputs.shape[1]

        # Softmax over the inputs
        softmax_p = F.softmax(inputs, dim=1)

        # Create one-hot encoding for targets
        targets_one_hot = F.one_hot(targets, num_classes).type_as(inputs)

        # Compute the focal loss components
        alpha_factor = torch.ones(targets_one_hot.shape).type_as(inputs) * self.alpha
        alpha_factor = torch.where(torch.eq(targets_one_hot, 1.), alpha_factor, 1. - alpha_factor)
        focal_weight = torch.where(torch.eq(targets_one_hot, 1.), 1. - softmax_p, softmax_p)
        focal_weight = alpha_factor * torch.pow(focal_weight, self.gamma)

        # Cross entropy loss
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        ce_loss = ce_loss.unsqueeze(1)
        ce_loss = ce_loss.repeat(1, num_classes)

        # Compute the final focal loss
        focal_loss = focal_weight * ce_loss

        # Reduce the loss based on the reduction parameter
        if self.reduction == 'mean':
            return torch.mean(focal_loss)
        elif self.reduction == 'sum':
            return torch.sum(focal_loss)
        else:
            return focal_loss

In [65]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        """
        Focal Loss for multi-class classification.
        - alpha: Weighting factor for the rare classes.
        - gamma: Focusing parameter to emphasize hard examples.
        - reduction: Reduction type to apply to the final loss. Options: 'none', 'mean', 'sum'.
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        """
        Forward pass for the focal loss.
        - inputs: Logits from the model (shape: [batch_size, num_classes]).
        - targets: Ground truth labels (shape: [batch_size]).
        """
        # Convert targets to one-hot encoding
        targets_one_hot = F.one_hot(targets, num_classes=inputs.size(1)).type_as(inputs)

        # Compute the softmax over the inputs
        softmax_probs = F.softmax(inputs, dim=1)

        # Compute the focal loss components
        focal_weight = torch.pow(1.0 - softmax_probs, self.gamma)
        focal_loss = -self.alpha * focal_weight * torch.log(softmax_probs)
        focal_loss = torch.sum(focal_loss * targets_one_hot, dim=1)

        # Apply reduction
        if self.reduction == 'mean':
            return torch.mean(focal_loss)
        elif self.reduction == 'sum':
            return torch.sum(focal_loss)
        else:
            return focal_loss


In [66]:
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)

num_training_steps = CFG.epochs * (len(train_loader.dataset) // CFG.batch_size)
num_warmup_steps = int(0.05 * num_training_steps)
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
                                               num_training_steps=num_training_steps,
                                               num_warmup_steps=num_warmup_steps)
criterion = nn.CrossEntropyLoss(ignore_index=CFG.pad_idx)
#criterion = FocalCrossEntropy(alpha=0.25, gamma=2.0, reduction='mean')
#criterion = FocalLoss(alpha=0.25, gamma=2.0, reduction='mean')
train_eval(model,
           train_loader,
           valid_loader,
           criterion,
           optimizer,
           lr_scheduler=lr_scheduler,
           step='batch',
           logger=None)

Epoch 1


100%|██████████| 424/424 [00:18<00:00, 23.04it/s, lr=0.000002, train_loss=5.9] 
100%|██████████| 27/27 [00:00<00:00, 97.61it/s] 


Valid Loss: 5.113, F1 Score: 0.002, Sensitivity: 0.033, Specificity: 0.878
Saved Best Model
Epoch 2


100%|██████████| 424/424 [00:18<00:00, 22.82it/s, lr=0.000004, train_loss=5.07]
100%|██████████| 27/27 [00:00<00:00, 94.81it/s]


Valid Loss: 4.661, F1 Score: 0.002, Sensitivity: 0.025, Specificity: 0.959
Saved Best Model
Epoch 3


100%|██████████| 424/424 [00:18<00:00, 22.55it/s, lr=0.000006, train_loss=4.81]
100%|██████████| 27/27 [00:00<00:00, 95.44it/s] 


Valid Loss: 4.485, F1 Score: 0.002, Sensitivity: 0.031, Specificity: 0.958
Saved Best Model
Epoch 4


100%|██████████| 424/424 [00:18<00:00, 22.51it/s, lr=0.000008, train_loss=4.63]
100%|██████████| 27/27 [00:00<00:00, 95.83it/s] 


Valid Loss: 4.333, F1 Score: 0.004, Sensitivity: 0.036, Specificity: 0.958
Saved Best Model
Epoch 5


100%|██████████| 424/424 [00:18<00:00, 23.06it/s, lr=0.000010, train_loss=4.48]
100%|██████████| 27/27 [00:00<00:00, 98.59it/s] 


Valid Loss: 4.177, F1 Score: 0.005, Sensitivity: 0.037, Specificity: 0.956
Saved Best Model
Epoch 6


100%|██████████| 424/424 [00:18<00:00, 22.94it/s, lr=0.000010, train_loss=4.25]
100%|██████████| 27/27 [00:00<00:00, 92.46it/s]


Valid Loss: 3.937, F1 Score: 0.005, Sensitivity: 0.044, Specificity: 0.952
Saved Best Model
Epoch 7


100%|██████████| 424/424 [00:18<00:00, 22.44it/s, lr=0.000010, train_loss=4.06]
100%|██████████| 27/27 [00:00<00:00, 94.04it/s]


Valid Loss: 3.836, F1 Score: 0.004, Sensitivity: 0.040, Specificity: 0.956
Saved Best Model
Epoch 8


100%|██████████| 424/424 [00:18<00:00, 22.45it/s, lr=0.000010, train_loss=3.98]
100%|██████████| 27/27 [00:00<00:00, 94.74it/s]


Valid Loss: 3.778, F1 Score: 0.005, Sensitivity: 0.046, Specificity: 0.949
Saved Best Model
Epoch 9


100%|██████████| 424/424 [00:18<00:00, 22.35it/s, lr=0.000010, train_loss=3.94]
100%|██████████| 27/27 [00:00<00:00, 93.45it/s]


Valid Loss: 3.744, F1 Score: 0.005, Sensitivity: 0.046, Specificity: 0.945
Saved Best Model
Epoch 10


100%|██████████| 424/424 [00:18<00:00, 22.40it/s, lr=0.000009, train_loss=3.89]
100%|██████████| 27/27 [00:00<00:00, 92.87it/s]


Valid Loss: 3.701, F1 Score: 0.006, Sensitivity: 0.047, Specificity: 0.946
Saved Best Model
Epoch 11


100%|██████████| 424/424 [00:18<00:00, 22.41it/s, lr=0.000009, train_loss=3.86]
100%|██████████| 27/27 [00:00<00:00, 93.75it/s]


Valid Loss: 3.683, F1 Score: 0.008, Sensitivity: 0.042, Specificity: 0.942
Saved Best Model
Epoch 12


100%|██████████| 424/424 [00:18<00:00, 22.38it/s, lr=0.000009, train_loss=3.83]
100%|██████████| 27/27 [00:00<00:00, 92.22it/s]


Valid Loss: 3.666, F1 Score: 0.008, Sensitivity: 0.046, Specificity: 0.942
Saved Best Model
Epoch 13


100%|██████████| 424/424 [00:19<00:00, 22.25it/s, lr=0.000009, train_loss=3.81]
100%|██████████| 27/27 [00:00<00:00, 93.53it/s]


Valid Loss: 3.657, F1 Score: 0.010, Sensitivity: 0.041, Specificity: 0.937
Saved Best Model
Epoch 14


100%|██████████| 424/424 [00:19<00:00, 22.27it/s, lr=0.000009, train_loss=3.79]
100%|██████████| 27/27 [00:00<00:00, 93.74it/s]


Valid Loss: 3.644, F1 Score: 0.009, Sensitivity: 0.042, Specificity: 0.940
Saved Best Model
Epoch 15


100%|██████████| 424/424 [00:18<00:00, 22.36it/s, lr=0.000009, train_loss=3.77]
100%|██████████| 27/27 [00:00<00:00, 88.29it/s]


Valid Loss: 3.626, F1 Score: 0.008, Sensitivity: 0.042, Specificity: 0.924
Saved Best Model
Epoch 16


100%|██████████| 424/424 [00:19<00:00, 22.26it/s, lr=0.000009, train_loss=3.76]
100%|██████████| 27/27 [00:00<00:00, 92.32it/s]


Valid Loss: 3.637, F1 Score: 0.010, Sensitivity: 0.051, Specificity: 0.916
Epoch 17


100%|██████████| 424/424 [00:19<00:00, 22.16it/s, lr=0.000009, train_loss=3.75]
100%|██████████| 27/27 [00:00<00:00, 92.53it/s]


Valid Loss: 3.614, F1 Score: 0.010, Sensitivity: 0.047, Specificity: 0.908
Saved Best Model
Epoch 18


100%|██████████| 424/424 [00:19<00:00, 22.23it/s, lr=0.000009, train_loss=3.72]
100%|██████████| 27/27 [00:00<00:00, 87.67it/s]


Valid Loss: 3.619, F1 Score: 0.011, Sensitivity: 0.047, Specificity: 0.916
Epoch 19


100%|██████████| 424/424 [00:19<00:00, 22.19it/s, lr=0.000009, train_loss=3.72]
100%|██████████| 27/27 [00:00<00:00, 92.86it/s]


Valid Loss: 3.634, F1 Score: 0.010, Sensitivity: 0.051, Specificity: 0.895
Epoch 20


100%|██████████| 424/424 [00:19<00:00, 22.17it/s, lr=0.000008, train_loss=3.71]
100%|██████████| 27/27 [00:00<00:00, 87.38it/s]


Valid Loss: 3.597, F1 Score: 0.011, Sensitivity: 0.048, Specificity: 0.901
Saved Best Model
Epoch 21


100%|██████████| 424/424 [00:19<00:00, 22.13it/s, lr=0.000008, train_loss=3.69]
100%|██████████| 27/27 [00:00<00:00, 93.24it/s]


Valid Loss: 3.614, F1 Score: 0.014, Sensitivity: 0.053, Specificity: 0.853
Epoch 22


100%|██████████| 424/424 [00:19<00:00, 22.20it/s, lr=0.000008, train_loss=3.68]
100%|██████████| 27/27 [00:00<00:00, 94.53it/s]


Valid Loss: 3.603, F1 Score: 0.015, Sensitivity: 0.055, Specificity: 0.839
Epoch 23


100%|██████████| 424/424 [00:19<00:00, 22.17it/s, lr=0.000008, train_loss=3.67]
100%|██████████| 27/27 [00:00<00:00, 92.05it/s]


Valid Loss: 3.627, F1 Score: 0.012, Sensitivity: 0.052, Specificity: 0.817
Epoch 24


100%|██████████| 424/424 [00:19<00:00, 22.30it/s, lr=0.000008, train_loss=3.65]
100%|██████████| 27/27 [00:00<00:00, 92.74it/s]


Valid Loss: 3.614, F1 Score: 0.012, Sensitivity: 0.052, Specificity: 0.813
Epoch 25


100%|██████████| 424/424 [00:18<00:00, 22.80it/s, lr=0.000008, train_loss=3.64]
100%|██████████| 27/27 [00:00<00:00, 96.61it/s] 


Valid Loss: 3.610, F1 Score: 0.013, Sensitivity: 0.055, Specificity: 0.775
Epoch 26


100%|██████████| 424/424 [00:17<00:00, 24.48it/s, lr=0.000008, train_loss=3.64]
100%|██████████| 27/27 [00:00<00:00, 103.36it/s]


Valid Loss: 3.617, F1 Score: 0.012, Sensitivity: 0.053, Specificity: 0.772
Epoch 27


100%|██████████| 424/424 [00:17<00:00, 24.52it/s, lr=0.000008, train_loss=3.62]
100%|██████████| 27/27 [00:00<00:00, 106.70it/s]


Valid Loss: 3.626, F1 Score: 0.013, Sensitivity: 0.054, Specificity: 0.743
Epoch 28


100%|██████████| 424/424 [00:17<00:00, 24.63it/s, lr=0.000008, train_loss=3.61]
100%|██████████| 27/27 [00:00<00:00, 102.15it/s]


Valid Loss: 3.626, F1 Score: 0.013, Sensitivity: 0.059, Specificity: 0.725
Epoch 29


100%|██████████| 424/424 [00:17<00:00, 24.65it/s, lr=0.000007, train_loss=3.6] 
100%|██████████| 27/27 [00:00<00:00, 106.53it/s]


Valid Loss: 3.631, F1 Score: 0.012, Sensitivity: 0.057, Specificity: 0.723
Epoch 30


100%|██████████| 424/424 [00:17<00:00, 24.61it/s, lr=0.000007, train_loss=3.59]
100%|██████████| 27/27 [00:00<00:00, 100.77it/s]


Valid Loss: 3.632, F1 Score: 0.012, Sensitivity: 0.055, Specificity: 0.654
Epoch 31


100%|██████████| 424/424 [00:17<00:00, 24.60it/s, lr=0.000007, train_loss=3.57]
100%|██████████| 27/27 [00:00<00:00, 99.92it/s] 


Valid Loss: 3.613, F1 Score: 0.013, Sensitivity: 0.053, Specificity: 0.694
Epoch 32


100%|██████████| 424/424 [00:17<00:00, 24.54it/s, lr=0.000007, train_loss=3.56]
100%|██████████| 27/27 [00:00<00:00, 100.35it/s]

Valid Loss: 3.637, F1 Score: 0.012, Sensitivity: 0.055, Specificity: 0.711
Early stopping triggered





The most common metric for object detection is Average Precision (AP) which you can read more about it [here](https://jonathan-hui.medium.com/map-mean-average-precision-for-object-detection-45c121a31173). The paper gets an AP of 43 w/ ResNet50 backbone after training on a whole lot of data for many hours of training. I could get an AP of 26.4 on my validation set with this small model and short training time which was cool as this is a tutorial on how to implement this paper easily and I didn't aim to beat the SOTA with this!

![](https://raw.githubusercontent.com/moein-shariatnia/Pix2Seq/master/imgs/Screen%20Shot%202022-08-19%20at%202.41.29%20PM.png)

# Inference

Now let's take a look at how we can generate a detection sequence with this model for a test image.

The following generate() function shows the whole sequence generation pipeline → First, we will create a batch with shape (batch_size, 1) containing only a BOS token for each image in the batch. The model takes the images and these BOS tokens and then predicts the next token for each image. We take the model's predictions, perform softmax and argmax on it to get the predicted token and concatenate this newly predicted token with the previous batch_preds tensor which had BOS tokens. We then repeat this loop for max_len number of times.

In [67]:
def generate(model, x, tokenizer, max_len=50, top_k=0, top_p=1):
    x = x.to(CFG.device)
    batch_preds = torch.ones(x.size(0), 1).fill_(tokenizer.BOS_code).long().to(CFG.device)
    confs = []
    
    if top_k != 0 or top_p != 1:
        sample = lambda preds: torch.softmax(preds, dim=-1).multinomial(num_samples=1).view(-1, 1)
    else:
        sample = lambda preds: torch.softmax(preds, dim=-1).argmax(dim=-1).view(-1, 1)
        
    with torch.no_grad():
        for i in range(max_len):
            preds = model.predict(x, batch_preds)
            ## If top_k and top_p are set to default, the following line does nothing!
            preds = top_k_top_p_filtering(preds, top_k=top_k, top_p=top_p)
            if i % 4 == 0:
                confs_ = torch.softmax(preds, dim=-1).sort(axis=-1, descending=True)[0][:, 0].cpu()
                confs.append(confs_)
            preds = sample(preds)
            batch_preds = torch.cat([batch_preds, preds], dim=1)
    
    return batch_preds.cpu(), confs

We will also use this postprocess function to decode the predictions and get bbox coordinates and labels for each image.

In [68]:
def postprocess(batch_preds, batch_confs, tokenizer):
    EOS_idxs = (batch_preds == tokenizer.EOS_code).float().argmax(dim=-1)
    ## sanity check
    invalid_idxs = ((EOS_idxs - 1) % 5 != 0).nonzero().view(-1)
    EOS_idxs[invalid_idxs] = 0
    
    all_bboxes = []
    all_labels = []
    all_confs = []
    for i, EOS_idx in enumerate(EOS_idxs.tolist()):
        if EOS_idx == 0:
            all_bboxes.append(None)
            all_labels.append(None)
            all_confs.append(None)
            continue
        labels, bboxes = tokenizer.decode(batch_preds[i, :EOS_idx+1])
        confs = [round(batch_confs[j][i].item(), 3) for j in range(len(bboxes))]
        
        all_bboxes.append(bboxes)
        all_labels.append(labels)
        all_confs.append(confs)
        
    return all_bboxes, all_labels, all_confs

In [69]:
encoder = Encoder(model_name=CFG.model_name, pretrained=True, out_dim=256)
decoder = Decoder(vocab_size=tokenizer.vocab_size,
                  encoder_length=CFG.num_patches, dim=256, num_heads=2, num_layers=2)
model = EncoderDecoder(encoder, decoder)
model.to(CFG.device)

msg = model.load_state_dict(torch.load('/home/achazhoor/Documents/workspace/pix_2_seq/best_valid_loss.pth', map_location=CFG.device))
print(msg)
model.eval();

skipping pos_embed...
skipping pos_embed...
<All keys matched successfully>


In [70]:
print(valid_df.tail(20))

                     id  label   xmin   ymin   xmax   ymax  \
255         patches_110      2   11.0   94.0   47.0  137.0   
256         patches_110      2   61.0   68.0  140.0  165.0   
257   pitted_surface_41      3    2.0    1.0   68.0  199.0   
258   pitted_surface_41      3   77.0    1.0  200.0  199.0   
259         crazing_184      0    2.0   64.0  101.0  127.0   
260         crazing_184      0   58.0  133.0  133.0  191.0   
261         crazing_184      0  110.0   79.0  199.0  123.0   
262        scratches_34      5  135.0   37.0  160.0  198.0   
263  pitted_surface_224      3    1.0    3.0  177.0  200.0   
264          patches_23      2   90.0    1.0  157.0  198.0   
265          patches_44      2   70.0  100.0  128.0  159.0   
266          patches_65      2   55.0   71.0  138.0  148.0   
267         patches_114      2    1.0    1.0   88.0  113.0   
268         patches_114      2   58.0  120.0  102.0  189.0   
269         patches_114      2  131.0  125.0  182.0  198.0   
270     

In [71]:
img_paths = """crazing_184.jpg patches_23.jpg pitted_surface_178.jpg inclusion_206.jpg rolled-in_scale_134.jpg"""
img_paths = ["/home/achazhoor/Documents/workspace/pix_2_seq/data/IMAGES" + "/" + path for path in img_paths.split(" ")]

In [72]:
img_paths

['/home/achazhoor/Documents/workspace/pix_2_seq/data/IMAGES/crazing_184.jpg',
 '/home/achazhoor/Documents/workspace/pix_2_seq/data/IMAGES/patches_23.jpg',
 '/home/achazhoor/Documents/workspace/pix_2_seq/data/IMAGES/pitted_surface_178.jpg',
 '/home/achazhoor/Documents/workspace/pix_2_seq/data/IMAGES/inclusion_206.jpg',
 '/home/achazhoor/Documents/workspace/pix_2_seq/data/IMAGES/rolled-in_scale_134.jpg']

In [73]:
img_paths

['/home/achazhoor/Documents/workspace/pix_2_seq/data/IMAGES/crazing_184.jpg',
 '/home/achazhoor/Documents/workspace/pix_2_seq/data/IMAGES/patches_23.jpg',
 '/home/achazhoor/Documents/workspace/pix_2_seq/data/IMAGES/pitted_surface_178.jpg',
 '/home/achazhoor/Documents/workspace/pix_2_seq/data/IMAGES/inclusion_206.jpg',
 '/home/achazhoor/Documents/workspace/pix_2_seq/data/IMAGES/rolled-in_scale_134.jpg']

In [74]:
class VOCDatasetTest(torch.utils.data.Dataset):
    def __init__(self, img_paths, size):
        self.img_paths = img_paths
        self.transforms = A.Compose([A.Resize(size, size), A.Normalize()])

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        #print("The image path is ", img_path)

        img = cv2.imread(img_path)[..., ::-1]

        if self.transforms is not None:
            img = self.transforms(image=img)['image']

        img = torch.FloatTensor(img).permute(2, 0, 1)

        return img

    def __len__(self):
        return len(self.img_paths)

In [75]:
test_dataset = VOCDatasetTest(img_paths, size=CFG.img_size)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=len(img_paths), shuffle=False, num_workers=0)

In [76]:
GT_COLOR = (0, 255, 0) # Green
PRED_COLOR = (255, 0, 0) # Red
TEXT_COLOR = (255, 255, 255) # White


def visualize_bbox(img, bbox, class_name, color, thickness=1):
    """Visualizes a single bounding box on the image"""
    bbox = [int(item) for item in bbox]
    x_min, y_min, x_max, y_max = bbox
   
    cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness)
    
    ((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)    
    cv2.rectangle(img, (x_min, y_min), (x_min + text_width, y_min + int(text_height * 1.3)), color, -1)
    cv2.putText(
        img,
        text=class_name,
        org=(x_min, y_min+ int(text_height * 1.3)),
        fontFace=cv2.FONT_HERSHEY_SIMPLEX,
        fontScale=0.35, 
        color=TEXT_COLOR, 
        lineType=cv2.LINE_AA,
    )
    return img


def visualize(image, bboxes, category_ids, category_id_to_name, color=PRED_COLOR, show=True):
    img = image.copy()
    for bbox, category_id in zip(bboxes, category_ids):
        class_name = category_id_to_name[category_id]
        img = visualize_bbox(img, bbox, class_name, color)
    if show:
        plt.figure(figsize=(12, 12))
        plt.axis('off')
        plt.imshow(img)
        plt.show()
    return img

In [77]:
all_bboxes = []
all_labels = []
all_confs = []

with torch.no_grad():
    for x in tqdm(test_loader):
        batch_preds, batch_confs = generate(
            model, x, tokenizer, max_len=CFG.generation_steps, top_k=0, top_p=1)
        bboxes, labels, confs = postprocess(
            batch_preds, batch_confs, tokenizer)
        all_bboxes.extend(bboxes)
        all_labels.extend(labels)
        all_confs.extend(confs)

#os.mkdir("results")
for i, (bboxes, labels, confs) in enumerate(zip(all_bboxes, all_labels, all_confs)):
    img_path = img_paths[i]
    img = cv2.imread(img_path)[..., ::-1]
    img = cv2.resize(img, (CFG.img_size, CFG.img_size))
    img = visualize(img, bboxes, labels, id2cls, show=False)

    cv2.imwrite("results/" + img_path.split("/")[-1], img[..., ::-1])

100%|██████████| 1/1 [00:01<00:00,  1.04s/it]


# Final Words


I hope you've enjoyed this tutorial and learned something new. As always, I will be glad to hear your comments on this tutorial or answer any questions you might have regarding the paper and model.
Have a nice day!