<a href="https://colab.research.google.com/github/Nilanshrajput/Vqa_detr/blob/master/DETR_Vqa_pytlightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Object Detection with DETR - a minimal implementation

In this notebook we show a demo of DETR (Detection Transformer), with slight differences with the baseline model in the paper.

We show how to define the model, load pretrained weights and visualize bounding box and class predictions.

Let's start with some common imports.

In [0]:
!pip install transformers



In [0]:
from PIL import Image
import requests
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'
import json

import torch
import torch.utils.data as data
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
import os
import json
import tqdm

import logging

import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.autograd import Variable

from transformers.tokenization_bert import BertTokenizer
from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel, BertModel
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup
import torchtext

import pdb
import os

In [0]:
class DETRdemo(nn.Module):
    """
    Demo DETR implementation.

    Demo implementation of DETR in minimal number of lines, with the
    following differences wrt DETR in the paper:
    * learned positional encoding (instead of sine)
    * positional encoding is passed at input (instead of attention)
    * fc bbox predictor (instead of MLP) nj
    The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100.
    Only batch size 1 supported.
    """
    def __init__(self, num_classes, hidden_dim=256, nheads=8,
                 num_encoder_layers=6, num_decoder_layers=6):
        super().__init__()

        # create ResNet-50 backbone
        self.backbone = resnet50()
        del self.backbone.fc

        # create conversion layer
        self.conv = nn.Conv2d(2048, hidden_dim, 1)

        # create a default PyTorch transformer
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers)

        # prediction heads, one extra class for predicting non-empty slots
        # note that in baseline DETR linear_bbox layer is 3-layer MLP
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        self.linear_bbox = nn.Linear(hidden_dim, 4)

        # output positional encodings (object queries)
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))

        # spatial positional encodings
        # note that in baseline DETR we use sine positional encodings
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs):
        # propagate inputs through ResNet-50 up to avg-pool layer
        x = self.backbone.conv1(inputs)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)

        # convert from 2048 to 256 feature planes for the transformer
        h = self.conv(x)
        bb_ot = h
        
        # construct positional encodings
        """        H, W = h.shape[-2:]
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)"""

        bs,_,H, W = h.shape
        pos = torch.cat([
        self.col_embed[:W].unsqueeze(0).unsqueeze(1).repeat(bs,H, 1, 1),
        self.row_embed[:H].unsqueeze(0).unsqueeze(2).repeat(bs,1, W, 1),
        ], dim=-1).flatten(1, 2)


        #print(self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1))
        # propagate through the transformer
        #shape changed to (W*H,bs,hidden_dim) for both pos and h
        h = self.transformer(pos.permute(1, 0, 2) + 0.1 * h.flatten(2).permute(2, 0, 1),
                             self.query_pos.unsqueeze(1).repeat(1,bs,1)).transpose(0, 1)
        
        # finally project transformer outputs to class labels and bounding boxes
        return {'pred_logits': self.linear_class(h), 
                'pred_boxes': self.linear_bbox(h).sigmoid(),
                'decoder_out':h,
                'res_out':bb_ot}

In [0]:
class VQA_DETR(nn.Module):
    def __init__(self,num_ans, hidden_size=256, num_attention_heads = 8, num_hidden_layers = 6):
        super().__init__()

        self.bert__decoder_config = BertConfig(is_decoder = True,hidden_size=hidden_size, num_attention_heads=num_attention_heads, num_hidden_layers=num_hidden_layers)
        #self.enc_dec_config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config= self.bert_config, decoder_config= self.bert_config)
        #self.model = EncoderDecoderModel(config= self.enc_dec_config)
        self.bert_decoder = BertModel(config=self.bert__decoder_config)

        self.detr = DETRdemo(num_classes=91)
        state_dict = torch.hub.load_state_dict_from_url(
            url='https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth',
            map_location='cpu', check_hash=True)
        self.detr.load_state_dict(state_dict)
        #self.detr  = self.detr.cuda()

        self.classifier  = nn.Linear(hidden_size*2,num_ans)

        self.drop_out = nn.Dropout(p=0.2)
        self.log_softmax = nn.LogSoftmax().cuda()
        

    def forward(self,img, q_ids):
        
        img_ecs = self.detr(img)['decoder_out'].flatten(2)

        #print(img_ecs.shape)#shape should be(bs. seQ_len, hiddensize)
        #print(q_ids.shape)
        #outputs = self.model(inputs_embeds=img_ecs, decoder_input_ids= q_ids )
        o1,_ = self.bert_decoder(input_ids = q_ids, encoder_hidden_states = img_ecs)

        mean_pool = torch.mean(o1,1)
        max_pool,_ = torch.max(o1,1)
        cat = torch.cat((mean_pool, max_pool),1)

        bo = self.drop_out(cat)
        output = self.classifier(bo)
        
       
        nll = -self.log_softmax(output)

        return {'logits':output,
                'nll':nll}




    

In [0]:

def assert_eq(real, expected):
    assert real == expected, "%s (true) vs %s (expected)" % (real, expected)

def _create_entry(question, answer):
    answer.pop("image_id")
    answer.pop("question_id")
    entry = {
        "question_id": question["question_id"],
        "image_id": question["image_id"],
        "question": question["question"],
        "answer": [a['answer'] for a in answer['answers']],
    }
    return entry

def _load_dataset(dataroot, name):
    """Load entries
    dataroot: root path of dataset
    name: 'train', 'val', 'trainval', 'minsval'
    """
    if name == 'train' or name == 'val':
        question_path = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % name)
        questions = sorted(json.load(open(question_path))["questions"], key=lambda x: x["question_id"])
        answer_path = os.path.join(dataroot, "v2_mscoco_%s2014_annotations.json" % name)
        answers = json.load(open(answer_path, "rb"))["annotations"]
        answers = sorted(answers, key=lambda x: x["question_id"])

    elif name  == 'trainval':
        question_path_train = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % 'train')
        questions_train = sorted(json.load(open(question_path_train))["questions"], key=lambda x: x["question_id"])
        answer_path_train = os.path.join(dataroot, "v2_mscoco_%s2014_annotations.json" % 'train')
        answers_train = json.load(open(answer_path_train, "rb"))["annotations"]
        answers_train = sorted(answers_train, key=lambda x: x["question_id"])

        question_path_val = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % 'val')
        questions_val = sorted(json.load(open(question_path_val))["questions"], key=lambda x: x["question_id"])
        answer_path_val = os.path.join(dataroot, "v2_mscoco_%s2014_annotations.json" % 'val')
        answers_val = json.load(open(answer_path_val, "rb"))["annotations"]
        answers_val = sorted(answers_val, key=lambda x: x["question_id"])
        questions = questions_train + questions_val[:-3000]
        answers = answers_train + answers_val[:-3000]

    elif name == 'minval':
        question_path_val = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % 'val')
        questions_val = sorted(json.load(open(question_path_val))["questions"], key=lambda x: x["question_id"])
        answer_path_val = os.path.join(dataroot, "v2_mscoco_%s2014_annotations.json" % 'val')
        answers_val = json.load(open(answer_path_val, "rb"))["annotations"]
        answers_val = sorted(answers_val, key=lambda x: x["question_id"])        
        questions = questions_val[-3000:]
        answers = answers_val[-3000:]

    elif name == 'test':
        question_path_test = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2015_questions.json" % 'test')
        questions_test = sorted(json.load(open(question_path_test))["questions"], key=lambda x: x["question_id"])
        questions = questions_test
    else:
        assert False, "data split is not recognized."

    if 'test' in name:
        entries = []
        for question in questions:
            entries.append(question)
    else:
        assert_eq(len(questions), len(answers))
        entries = []
        for question, answer in zip(questions, answers):
            assert_eq(question["question_id"], answer["question_id"])
            assert_eq(question["image_id"], answer["image_id"])
            entries.append(_create_entry(question, answer))
    return entries

In [0]:
entries = _load_dataset(dataroot='/content/',name='train')

In [0]:
# compile a list of all the answers
all_answers  = set()
for a in entries:
    all_answers.update(a['answer'])
all_answers=list(all_answers)


In [0]:
answer_to_index = dict()
for i,answer in enumerate(all_answers):
    answer_to_index[answer]=i


In [0]:
vqa_detr = VQA_DETR(num_ans=len(answer_to_index))

In [0]:
class VQA(data.Dataset):
    """ VQA dataset, open-ended """
    def __init__(self, root, answer_to_index, tokenizer ,split = 'train', max_len = 20):
        super(VQA, self).__init__()


        self.root = root
        self.answer_to_index = answer_to_index
        self.split = split
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.entries = self._load_dataset( self.root, self.split)

         # standard PyTorch mean-std input image normalization
        self.transform = T.Compose([
            T.Resize(size=(800,800)),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        self.id_to_image_fname = self._find_iamges()


    def assert_eq(self,real, expected):
        assert real == expected, "%s (true) vs %s (expected)" % (real, expected)

    def _create_entry(self,question, answer):
        answer.pop("image_id")
        answer.pop("question_id")
        entry = {
            "question_id": question["question_id"],
            "image_id": question["image_id"],
            "question": question["question"],
            "answer": [a['answer'] for a in answer['answers']],
        }
        return entry

    def _load_dataset(self,dataroot, name):
        """Load entries
        dataroot: root path of dataset
        name: 'train', 'val', 'trainval', 'minsval'
        """
        if name == 'train' or name == 'val':
            question_path = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % name)
            questions = sorted(json.load(open(question_path))["questions"], key=lambda x: x["question_id"])
            answer_path = os.path.join(dataroot, "v2_mscoco_%s2014_annotations.json" % name)
            answers = json.load(open(answer_path, "rb"))["annotations"]
            answers = sorted(answers, key=lambda x: x["question_id"])

        elif name  == 'trainval':
            question_path_train = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % 'train')
            questions_train = sorted(json.load(open(question_path_train))["questions"], key=lambda x: x["question_id"])
            answer_path_train = os.path.join(dataroot, "v2_mscoco_%s2014_annotations.json" % 'train')
            answers_train = json.load(open(answer_path_train, "rb"))["annotations"]
            answers_train = sorted(answers_train, key=lambda x: x["question_id"])

            question_path_val = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % 'val')
            questions_val = sorted(json.load(open(question_path_val))["questions"], key=lambda x: x["question_id"])
            answer_path_val = os.path.join(dataroot, "v2_mscoco_%s2014_annotations.json" % 'val')
            answers_val = json.load(open(answer_path_val, "rb"))["annotations"]
            answers_val = sorted(answers_val, key=lambda x: x["question_id"])
            questions = questions_train + questions_val[:-3000]
            answers = answers_train + answers_val[:-3000]

        elif name == 'minval':
            question_path_val = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % 'val')
            questions_val = sorted(json.load(open(question_path_val))["questions"], key=lambda x: x["question_id"])
            answer_path_val = os.path.join(dataroot, "v2_mscoco_%s2014_annotations.json" % 'val')
            answers_val = json.load(open(answer_path_val, "rb"))["annotations"]
            answers_val = sorted(answers_val, key=lambda x: x["question_id"])        
            questions = questions_val[-3000:]
            answers = answers_val[-3000:]

        elif name == 'test':
            question_path_test = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2015_questions.json" % 'test')
            questions_test = sorted(json.load(open(question_path_test))["questions"], key=lambda x: x["question_id"])
            questions = questions_test
        else:
            assert False, "data split is not recognized."

        if 'test' in name:
            entries = []
            for question in questions:
                entries.append(question)
        else:
            assert_eq(len(questions), len(answers))
            entries = []
            for question, answer in zip(questions, answers):
                assert_eq(question["question_id"], answer["question_id"])
                assert_eq(question["image_id"], answer["image_id"])
                entries.append(_create_entry(question, answer))
        return entries



    def _encode_question(self, question):
        """ Turn a question into a vector of indices and a question length """
        
        inputs = self.tokenizer.encode_plus(
            question,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            )

        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        token_type_ids = inputs["token_type_ids"]

        padding_length = self.max_len - len(ids)
        ids += ([0]*padding_length)
        mask += ([0]*padding_length)
        token_type_ids += ([0]*padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            
        }

    def _encode_ansfor_bert(self, answers):
        pass


    def _encode_answers(self, answers):
        """ Turn an answer into a vector """
        # answer vec will be a vector of answer counts to determine which answers will contribute to the loss.
        # this should be multiplied with 0.1 * negative log-likelihoods that a model produces and then summed up
        # to get the loss that is weighted by how many humans gave that answer
        answer_vec = torch.zeros(len(self.answer_to_index),dtype=torch.long)
        for answer in answers:
            index = self.answer_to_index.get(answer)
            if index is not None:
                answer_vec[index] += 1
        _,idx = answer_vec.max(dim = 0)
        idx = torch.tensor(idx, dtype = torch.long)

        return idx

   
    def _find_iamges(self):
        id_to_filename = {}
        imgs_folder = os.path.join(self.root,'train2014')
        for filename in os.listdir(imgs_folder):
            if not filename.endswith('.jpg'):
                continue
            id_and_extension = filename.split('_')[-1]
            id = int(id_and_extension.split('.')[0])
            id_to_filename[id] = os.path.join(imgs_folder,filename)
        return id_to_filename


    def _load_image(self, image_id):
        """ Load an image """

        img_path = self.id_to_image_fname[image_id]
        img  = Image.open(img_path)
        img = np.asarray(img)
        
        if len(img.shape)==2:
            print(img.shape)
            img=np.expand_dims(img, axis=-1)
            
            img = np.repeat(img,3, axis = -1)
            print(img.shape)

        return img

    def __getitem__(self, item):
       
        entry  = self.entries[item]
        q = entry['question']
        a = self._encode_answers(entry['answer'])
        image_id = entry['image_id']

        img = self._load_image(image_id)
        img = Image.fromarray(img)
        img = self.transform(img)
        #question_id = entry['question_id']
        q= self._encode_question(q)

        return img, q, a

    def __len__(self):
        return len(self.entries)

    @staticmethod
    def collate_fn(batch):
        """The collat_fn method to be used by the
        PyTorch data loader.
        """
        
        # Unzip the batch

        imgs,qs, answers = list(zip(*batch))

        # concatenate the vectors
        imgs = torch.stack(imgs)
        
        #concatenate the labels
        q = torch.stack(qs)

        a = torch.stack(answers)
        
        return imgs, q, a

In [0]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [0]:
vqa_data = VQA(root='/content', answer_to_index=answer_to_index,split= 'train', tokenizer=bert_tokenizer, max_len=15 )

In [0]:
data_loader = DataLoader(vqa_data, batch_size = 4, shuffle= True)

In [0]:

def loss_fn(outputs, targets):
    
    #print(outputs.shape)
    with torch.enable_grad():
        loss  = nn.CrossEntropyLoss()(outputs, targets)
        loss.requres_grad = True
    
   
    return loss


def train_fn(data_loader, model, optimizer, device, scheduler):
    model.train()
    loss_v=0
    for bi, data in tqdm.tqdm(enumerate(data_loader), total=len(data_loader)):
        im,q,a  = data
        ids = q["ids"]
        
        targets = a

        ids = ids.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.long )
        im = im.to(device)

        optimizer.zero_grad()
        outputs = model(
            im,ids
        )

        output =outputs['logits'].to(device)
        #print(output.dtype)

        
        #a,targets=torch.max(output,dim = 1)
        #targets = targets.to(device, dtype=torch.long)
        #targets = Variable(targets)
 
       
        loss = loss_fn(output, targets)
        loss.backward()
        loss_v+= loss.item()
        print(loss.item())
        optimizer.step()
        scheduler.step()

    return loss_v


In [0]:
bs = 16
epochs = 10

In [0]:

device = torch.device("cuda")
model = vqa_detr
for param in model.parameters():
    param.requires_grad = True
model.to(device)


param_optimizer = list(model.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
]

num_train_steps = int(len(vqa_data )/ bs * epochs)
optimizer = AdamW(optimizer_parameters, lr=3e-5)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=num_train_steps
)



for epoch in range(epochs):
    loss = train_fn(data_loader, model, optimizer, device, scheduler)

    print(f'The loss for epoch {epoch} = {loss}')

    