<a href="https://colab.research.google.com/github/Nilanshrajput/Vqa_detr/blob/master/DETR_Vqa_pytlightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Object Detection with DETR - a minimal implementation

In this notebook we show a demo of DETR (Detection Transformer), with slight differences with the baseline model in the paper.

We show how to define the model, load pretrained weights and visualize bounding box and class predictions.

Let's start with some common imports.

In [4]:
!pip install transformers
!pip install pytorch-lightning

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [1]:
import time
from tqdm import tqdm_notebook
example_iter = [1,2,3,4,5]
for rec in tqdm_notebook(example_iter):
    time.sleep(.1)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  after removing the cwd from sys.path.


HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))




In [6]:
from PIL import Image
import requests
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'
import json

import torch
import torch.utils.data as data
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
import os
import json
import tqdm

import logging
from argparse import Namespace

from functools import lru_cache

import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.autograd import Variable

from transformers.tokenization_bert import BertTokenizer
from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel, BertModel
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup


from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer


from sklearn.metrics import classification_report
from sklearn.metrics import f1_score
from typing import List, Dict
import pdb
import os

## DETR
Here is a minimal implementation of DETR:

In [7]:
class DETRdemo(nn.Module):
    """
    Demo DETR implementation.

    Demo implementation of DETR in minimal number of lines, with the
    following differences wrt DETR in the paper:
    * learned positional encoding (instead of sine)
    * positional encoding is passed at input (instead of attention)
    * fc bbox predictor (instead of MLP) nj
    The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100.
    Only batch size 1 supported.
    """
    def __init__(self, num_classes, hidden_dim=256, nheads=8,
                 num_encoder_layers=6, num_decoder_layers=6):
        super().__init__()

        # create ResNet-50 backbone
        self.backbone = resnet50()
        del self.backbone.fc

        # create conversion layer
        self.conv = nn.Conv2d(2048, hidden_dim, 1)

        # create a default PyTorch transformer
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers)

        # prediction heads, one extra class for predicting non-empty slots
        # note that in baseline DETR linear_bbox layer is 3-layer MLP
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        self.linear_bbox = nn.Linear(hidden_dim, 4)

        # output positional encodings (object queries)
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))

        # spatial positional encodings
        # note that in baseline DETR we use sine positional encodings
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs):
        # propagate inputs through ResNet-50 up to avg-pool layer
        x = self.backbone.conv1(inputs)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)

        # convert from 2048 to 256 feature planes for the transformer
        h = self.conv(x)
        bb_ot = h
        
        # construct positional encodings
        """        H, W = h.shape[-2:]
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)"""

        bs,_,H, W = h.shape
        pos = torch.cat([
        self.col_embed[:W].unsqueeze(0).unsqueeze(1).repeat(bs,H, 1, 1),
        self.row_embed[:H].unsqueeze(0).unsqueeze(2).repeat(bs,1, W, 1),
        ], dim=-1).flatten(1, 2)


        #print(self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1))
        # propagate through the transformer
        #shape changed to (W*H,bs,hidden_dim) for both pos and h
        h = self.transformer(pos.permute(1, 0, 2) + 0.1 * h.flatten(2).permute(2, 0, 1),
                             self.query_pos.unsqueeze(1).repeat(1,bs,1)).transpose(0, 1)
        
        # finally project transformer outputs to class labels and bounding boxes
        return {'pred_logits': self.linear_class(h), 
                'pred_boxes': self.linear_bbox(h).sigmoid(),
                'decoder_out':h,
                'res_out':bb_ot}

In [29]:
class VQA_DETR(LightningModule):
    def __init__(self,hparams,num_ans,ans_to_index,hidden_size=256, num_attention_heads = 8, num_hidden_layers = 6):
        super().__init__()

        self.hparams = hparams
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
        self.ans_to_index = ans_to_index
        self.bert_decoder_config = BertConfig(is_decoder = True,hidden_size=hidden_size, num_attention_heads=num_attention_heads, num_hidden_layers=num_hidden_layers)
        #self.enc_dec_config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config= self.bert_config, decoder_config= self.bert_config)
        #self.model = EncoderDecoderModel(config= self.enc_dec_config)
        self.bert_decoder = BertModel(config=self.bert_decoder_config)

        self.detr = DETRdemo(num_classes=91)
        state_dict = torch.hub.load_state_dict_from_url(
            url='https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth',
            map_location='cpu', check_hash=True)
        self.detr.load_state_dict(state_dict)
        del state_dict
        #self.detr  = self.detr.cuda()

        self.classifier  = nn.Linear(hidden_size*2,num_ans)

        self.drop_out = nn.Dropout(p=0.2)
        self.log_softmax = nn.LogSoftmax().cuda()
        

    def forward(self,img, q_ids):
        
        img_ecs = self.detr(img)['decoder_out'].flatten(2)
        o1,_ = self.bert_decoder(input_ids = q_ids, encoder_hidden_states = img_ecs)

        mean_pool = torch.mean(o1,1)
        max_pool,_ = torch.max(o1,1)
        cat = torch.cat((mean_pool, max_pool),1)

        bo = self.drop_out(cat)
        output = self.classifier(bo)
        
        nll = -self.log_softmax(output)

        return {'logits':output,'nll':nll}


    def training_step(self, batch, batch_idx):
        im,q,a  = batch
        ids = q["ids"]

        outputs = self(im,ids)
        output_nll =outputs['nll']
        logits =  outputs['logits']

        loss = self.loss_fn(output_nll, a)
        f1 = self.metric_f1(logits, a)
        tensorboard_logs = {'train_loss': loss,'train_f1_score': f1}

        return {'loss': loss, 'log': tensorboard_logs,"progress_bar": {'train_loss': loss,'train_f1':f1}}

    def validation_step(self, batch, batch_idx):
        im,q,a  = batch
        ids = q["ids"]

        outputs = self(im,ids)
        output_nll =outputs['nll']
        logits =  outputs['logits']

        loss = self.loss_fn(output_nll, a)
        f1 = self.metric_f1(logits, a)
        tensorboard_logs = {'val_loss': loss,'val_f1_score': f1}

        return {'val_loss': loss, 'log': tensorboard_logs,"progress_bar": {'val_loss': loss,'val_f1':f1}}


    def validation_end(self, outputs: List[dict]):
        loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        f1 = torch.stack([x['progress_bar']['val_f1'] for x in outputs]).mean()
        return {"val_loss": loss,"val_f1":f1}

    def loss_fn(self, nll, targets):

        return (nll * targets / 10).sum(dim=1).mean()#nn.CrossEntropyLoss()(outputs, targets)

    
    
    def metric_f1(self, preds, y):

        _, max_preds = preds.max(dim = -1) # get the index of the max 
        _, y = y.max(dim= -1)
        shape = max_preds.shape[0]
        f1=f1_score(y.detach().view(shape).numpy(),max_preds.detach().view(shape).numpy(),average='macro')
        f1  = torch.tensor(f1, dtype  = torch.float32)
        return f1

    @lru_cache()
    def total_steps(self):
        return len(self.train_dataloader()) // self.hparams.accumulate_grad_batches * self.hparams.epochs


    def configure_optimizers(self):

        param_optimizer = list(self.named_parameters())
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
        ]

        optimizer = AdamW(optimizer_parameters, lr=self.hparams.lr)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.num_warmup_steps,
            num_training_steps=self.total_steps(),
        )

        return [optimizer],  [{"scheduler": scheduler, "interval": "step"}]

    
    def prepare_data(self):
        self.val_dataset  = VQA(root='/home/ubuntu/vqa/vqa_data', answer_to_index=self.ans_to_index,split= 'val', tokenizer=self.tokenizer, max_len=15 )
        self.train_dataset  = VQA(root='/home/ubuntu/vqa/vqa_data', answer_to_index=self.ans_to_index,split= 'train', tokenizer=self.tokenizer, max_len=15 )

    def train_dataloader(self):
        loader = DataLoader(self.train_dataset, batch_size = self.hparams.batch_size,num_workers=self.hparams.num_workers, shuffle= True)
        return loader

    def val_dataloader(self):
        loader = DataLoader(self.val_dataset, batch_size = self.hparams.val_batch_size,num_workers=self.hparams.num_workers, shuffle= False)
        return loader


In [1]:
"""!wget --header="Host: s3.amazonaws.com" --header="User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/81.0.4044.138 Safari/537.36" --header="Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9" --header="Accept-Language: en-GB,en-US;q=0.9,en;q=0.8" "https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip" -c -O 'v2_Annotations_Train_mscoco.zip'
#!unzip -q v2_Annotations_Train_mscoco.zip -d /home/ubuntu/vqa/vqa_data 
!wget --header="Host: images.cocodataset.org" --header="User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/81.0.4044.138 Safari/537.36" --header="Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9" --header="Accept-Language: en-GB,en-US;q=0.9,en;q=0.8" "http://images.cocodataset.org/zips/train2014.zip" -c -O 'train2014.zip'
!unzip -q train2014.zip -d /home/ubuntu/vqa/vqa_data 
!rm -r train2014.zip
!wget --header="Host: s3.amazonaws.com" --header="User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/81.0.4044.138 Safari/537.36" --header="Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9" --header="Accept-Language: en-GB,en-US;q=0.9,en;q=0.8" "https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip" -c -O 'v2_Questions_Train_mscoco.zip'
!unzip q v2_Questions_Train_mscoco.zip -d /home/ubuntu/vqa/vqa_data 
!wget --header="Host: s3.amazonaws.com" --header="User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/81.0.4044.138 Safari/537.36" --header="Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9" --header="Accept-Language: en-GB,en-US;q=0.9,en;q=0.8" "https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip" -c -O 'v2_Annotations_Val_mscoco.zip'
!unzip -q v2_Annotations_Val_mscoco.zip -d /home/ubuntu/vqa/vqa_data 
!wget --header="Host: images.cocodataset.org" --header="User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/81.0.4044.138 Safari/537.36" --header="Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9" --header="Accept-Language: en-GB,en-US;q=0.9,en;q=0.8" "http://images.cocodataset.org/zips/val2014.zip" -c -O 'val2014.zip'
!unzip -q val2014.zip -d /home/ubuntu/vqa/vqa_data 
!rm -r val2014.zip
!wget --header="Host: s3.amazonaws.com" --header="User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/81.0.4044.138 Safari/537.36" --header="Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9" --header="Accept-Language: en-GB,en-US;q=0.9,en;q=0.8" "https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip" -c -O 'v2_Questions_Val_mscoco.zip'
!unzip -q v2_Questions_Val_mscoco.zip -d /home/ubuntu/vqa/vqa_data 
!wget --header="Host: s3.amazonaws.com" --header="User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/81.0.4044.138 Safari/537.36" --header="Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9" --header="Accept-Language: en-GB,en-US;q=0.9,en;q=0.8" "https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip" -c -O 'v2_Questions_Test_mscoco.zip'
!unzip -q v2_Questions_Test_mscoco.zip -d /home/ubuntu/vqa/vqa_data 
#!wget --header="Host: images.cocodataset.org" --header="User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/81.0.4044.138 Safari/537.36" --header="Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9" --header="Accept-Language: en-GB,en-US;q=0.9,en;q=0.8" "http://images.cocodataset.org/zips/test2015.zip" -c -O 'test2015.zip'
#!unzip -q test2015.zip"""

--2020-06-12 20:56:28--  https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.104.125
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.104.125|:443... connected.
HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable

    The file is already fully retrieved; nothing to do.

--2020-06-12 20:56:29--  http://images.cocodataset.org/zips/train2014.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 52.216.238.219
Connecting to images.cocodataset.org (images.cocodataset.org)|52.216.238.219|:80... connected.
HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable

    The file is already fully retrieved; nothing to do.



In [12]:
!wget --header="Host: s3.amazonaws.com" --header="User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/81.0.4044.138 Safari/537.36" --header="Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9" --header="Accept-Language: en-GB,en-US;q=0.9,en;q=0.8" "https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip" -c -O 'v2_Questions_Train_mscoco.zip'
!unzip v2_Questions_Train_mscoco.zip -d /home/ubuntu/vqa/vqa_data 


--2020-06-12 21:07:43--  https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.134.221
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.134.221|:443... connected.
HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable

    The file is already fully retrieved; nothing to do.

Archive:  v2_Questions_Train_mscoco.zip
  inflating: /home/ubuntu/vqa/vqa_data/v2_OpenEnded_mscoco_train2014_questions.json  


In [13]:

def assert_eq(real, expected):
    assert real == expected, "%s (true) vs %s (expected)" % (real, expected)

def _create_entry(question, answer):
    answer.pop("image_id")
    answer.pop("question_id")
    entry = {
        "question_id": question["question_id"],
        "image_id": question["image_id"],
        "question": question["question"],
        "answer": [a['answer'] for a in answer['answers']],
    }
    return entry

def _load_dataset(dataroot, name):
    """Load entries
    dataroot: root path of dataset
    name: 'train', 'val', 'trainval', 'minsval'
    """
    if name == 'train' or name == 'val':
        question_path = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % name)
        questions = sorted(json.load(open(question_path))["questions"], key=lambda x: x["question_id"])
        answer_path = os.path.join(dataroot, "v2_mscoco_%s2014_annotations.json" % name)
        answers = json.load(open(answer_path, "rb"))["annotations"]
        answers = sorted(answers, key=lambda x: x["question_id"])

    elif name  == 'trainval':
        question_path_train = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % 'train')
        questions_train = sorted(json.load(open(question_path_train))["questions"], key=lambda x: x["question_id"])
        answer_path_train = os.path.join(dataroot, "v2_mscoco_%s2014_annotations.json" % 'train')
        answers_train = json.load(open(answer_path_train, "rb"))["annotations"]
        answers_train = sorted(answers_train, key=lambda x: x["question_id"])

        question_path_val = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % 'val')
        questions_val = sorted(json.load(open(question_path_val))["questions"], key=lambda x: x["question_id"])
        answer_path_val = os.path.join(dataroot, "v2_mscoco_%s2014_annotations.json" % 'val')
        answers_val = json.load(open(answer_path_val, "rb"))["annotations"]
        answers_val = sorted(answers_val, key=lambda x: x["question_id"])
        questions = questions_train + questions_val[:-3000]
        answers = answers_train + answers_val[:-3000]

    elif name == 'minval':
        question_path_val = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % 'val')
        questions_val = sorted(json.load(open(question_path_val))["questions"], key=lambda x: x["question_id"])
        answer_path_val = os.path.join(dataroot, "v2_mscoco_%s2014_annotations.json" % 'val')
        answers_val = json.load(open(answer_path_val, "rb"))["annotations"]
        answers_val = sorted(answers_val, key=lambda x: x["question_id"])        
        questions = questions_val[-3000:]
        answers = answers_val[-3000:]

    elif name == 'test':
        question_path_test = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2015_questions.json" % 'test')
        questions_test = sorted(json.load(open(question_path_test))["questions"], key=lambda x: x["question_id"])
        questions = questions_test
    else:
        assert False, "data split is not recognized."

    if 'test' in name:
        entries = []
        for question in questions:
            entries.append(question)
    else:
        assert_eq(len(questions), len(answers))
        entries = []
        for question, answer in zip(questions, answers):
            assert_eq(question["question_id"], answer["question_id"])
            assert_eq(question["image_id"], answer["image_id"])
            entries.append(_create_entry(question, answer))
    return entries

In [19]:
entries = _load_dataset(dataroot='/home/ubuntu/vqa/vqa_data',name='train')


In [20]:
# compile a list of all the answers
all_answers  = set()
for a in entries:
    all_answers.update(a['answer'])
all_answers=list(all_answers)


In [21]:
answer_to_index = dict()
for i,answer in enumerate(all_answers):
    answer_to_index[answer]=i


In [22]:
class VQA(data.Dataset):
    """ VQA dataset, open-ended """
    def __init__(self, root, answer_to_index, tokenizer ,split = 'train', max_len = 20):
        super(VQA, self).__init__()


        self.root = root
        self.answer_to_index = answer_to_index
        self.split = split
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.entries = self._load_dataset( self.root, self.split)

         # standard PyTorch mean-std input image normalization
        self.transform = T.Compose([
            T.Resize(size=(800,800)),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        self.id_to_image_fname = self._find_iamges()


    def assert_eq(self,real, expected):
        assert real == expected, "%s (true) vs %s (expected)" % (real, expected)

    def _create_entry(self,question, answer):
        answer.pop("image_id")
        answer.pop("question_id")
        entry = {
            "question_id": question["question_id"],
            "image_id": question["image_id"],
            "question": question["question"],
            "answer": [a['answer'] for a in answer['answers']],
        }
        return entry

    def _load_dataset(self,dataroot, name):
        """Load entries
        dataroot: root path of dataset
        name: 'train', 'val', 'trainval', 'minsval'
        """
        if name == 'train' or name == 'val':
            question_path = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % name)
            questions = sorted(json.load(open(question_path))["questions"], key=lambda x: x["question_id"])
            answer_path = os.path.join(dataroot, "v2_mscoco_%s2014_annotations.json" % name)
            answers = json.load(open(answer_path, "rb"))["annotations"]
            answers = sorted(answers, key=lambda x: x["question_id"])

        elif name  == 'trainval':
            question_path_train = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % 'train')
            questions_train = sorted(json.load(open(question_path_train))["questions"], key=lambda x: x["question_id"])
            answer_path_train = os.path.join(dataroot, "v2_mscoco_%s2014_annotations.json" % 'train')
            answers_train = json.load(open(answer_path_train, "rb"))["annotations"]
            answers_train = sorted(answers_train, key=lambda x: x["question_id"])

            question_path_val = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % 'val')
            questions_val = sorted(json.load(open(question_path_val))["questions"], key=lambda x: x["question_id"])
            answer_path_val = os.path.join(dataroot, "v2_mscoco_%s2014_annotations.json" % 'val')
            answers_val = json.load(open(answer_path_val, "rb"))["annotations"]
            answers_val = sorted(answers_val, key=lambda x: x["question_id"])
            questions = questions_train + questions_val[:-3000]
            answers = answers_train + answers_val[:-3000]

        elif name == 'minval':
            question_path_val = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % 'val')
            questions_val = sorted(json.load(open(question_path_val))["questions"], key=lambda x: x["question_id"])
            answer_path_val = os.path.join(dataroot, "v2_mscoco_%s2014_annotations.json" % 'val')
            answers_val = json.load(open(answer_path_val, "rb"))["annotations"]
            answers_val = sorted(answers_val, key=lambda x: x["question_id"])        
            questions = questions_val[-3000:]
            answers = answers_val[-3000:]

        elif name == 'test':
            question_path_test = os.path.join(dataroot, "v2_OpenEnded_mscoco_%s2015_questions.json" % 'test')
            questions_test = sorted(json.load(open(question_path_test))["questions"], key=lambda x: x["question_id"])
            questions = questions_test
        else:
            assert False, "data split is not recognized."

        if 'test' in name:
            entries = []
            for question in questions:
                entries.append(question)
        else:
            assert_eq(len(questions), len(answers))
            entries = []
            for question, answer in zip(questions, answers):
                assert_eq(question["question_id"], answer["question_id"])
                assert_eq(question["image_id"], answer["image_id"])
                entries.append(_create_entry(question, answer))
        return entries



    def _encode_question(self, question):
        """ Turn a question into a vector of indices and a question length """
        
        inputs = self.tokenizer.encode_plus(
            question,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            )

        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        token_type_ids = inputs["token_type_ids"]

        padding_length = self.max_len - len(ids)
        ids += ([0]*padding_length)
        mask += ([0]*padding_length)
        token_type_ids += ([0]*padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            
        }

    def _encode_ansfor_bert(self, answers):
        pass


    def _encode_answers(self, answers):
        """ Turn an answer into a vector """
        # answer vec will be a vector of answer counts to determine which answers will contribute to the loss.
        # this should be multiplied with 0.1 * negative log-likelihoods that a model produces and then summed up
        # to get the loss that is weighted by how many humans gave that answer
        answer_vec = torch.zeros(len(self.answer_to_index),dtype=torch.long)
        for answer in answers:
            index = self.answer_to_index.get(answer)
            if index is not None:
                answer_vec[index] += 1
        return answer_vec

   
    def _find_iamges(self):
        id_to_filename = {}
        imgs_folder = os.path.join(self.root,'%s2014'%self.split)
        for filename in os.listdir(imgs_folder):
            if not filename.endswith('.jpg'):
                continue
            id_and_extension = filename.split('_')[-1]
            id = int(id_and_extension.split('.')[0])
            id_to_filename[id] = os.path.join(imgs_folder,filename)
        return id_to_filename


    def _load_image(self, image_id):
        """ Load an image """

        img_path = self.id_to_image_fname[image_id]
        img  = Image.open(img_path)
        img = np.asarray(img)
        
        if len(img.shape)==2:
            print(img.shape)
            img=np.expand_dims(img, axis=-1)
            
            img = np.repeat(img,3, axis = -1)
            print(img.shape)

        return img

    def __getitem__(self, item):
       
        entry  = self.entries[item]
        image_id = entry['image_id']

        img = self._load_image(image_id)
        q = entry['question']
        a = self._encode_answers(entry['answer'])
        img = Image.fromarray(img)
        img = self.transform(img)
        #question_id = entry['question_id']
        q= self._encode_question(q)

        return img, q, a

    def __len__(self):
        return len(self.entries)

    @staticmethod
    def collate_fn(batch):
        """The collat_fn method to be used by the
        PyTorch data loader.
        """
        
        # Unzip the batch

        imgs,qs, answers = list(zip(*batch))

        # concatenate the vectors
        imgs = torch.stack(imgs)
        
        #concatenate the labels
        q = torch.stack(qs)

        a = torch.stack(answers)
        
        return imgs, q, a

In [33]:
hparams = Namespace(
    batch_size=1,
    val_batch_size=1,
    num_warmup_steps=100,
    epochs=20,
    lr=3e-5,
    accumulate_grad_batches=1,
    num_workers = 8
)

In [24]:
len_ans = len(answer_to_index)

In [34]:
vqa_detr = VQA_DETR(num_ans=len_ans,ans_to_index=answer_to_index,hparams=hparams)

In [35]:
#trainer = Trainer(gpus=4, max_epochs=20,log_gpu_memory=True)
trainer = Trainer(fast_dev_run=True)
trainer.fit(vqa_detr)

Running in fast_dev_run mode: will run a full train, val and test loop using a single batch
GPU available: True, used: False
No environment variable for node rank defined. Set as 0.

    | Name                                                         | Type                    | Params
-----------------------------------------------------------------------------------------------------
0   | bert_decoder                                                 | BertModel               | 20 M  
1   | bert_decoder.embeddings                                      | BertEmbeddings          | 7 M   
2   | bert_decoder.embeddings.word_embeddings                      | Embedding               | 7 M   
3   | bert_decoder.embeddings.position_embeddings                  | Embedding               | 131 K 
4   | bert_decoder.embeddings.token_type_embeddings                | Embedding               | 512   
5   | bert_decoder.embeddings.LayerNorm                            | LayerNorm               | 512   
6

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

OSError: [Errno 12] Cannot allocate memory

In [None]:
# Start tensorboard.
%reload_ext tensorboard
%tensorboard --logdir kaggle/working/lightning_logs 

In [None]:
model_save= vqa_detr

In [None]:
model_save.eval().to('cpu')

In [None]:
img,q,a=vqa_data[24]

In [None]:

out = model_save(img.unsqueeze(0),q['ids'].unsqueeze(0))


In [None]:
a,index = out['logits'].max(dim = -1)

In [None]:
bert_tokenizer.convert_ids_to_tokens(q['ids'])

In [None]:
all_answers[int(index)]

In [None]:
plt.imshow(img.permute(1,2,0))

In [None]:
while 1:
    continue

In [None]:
import torch
from torch import nn

In [None]:
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
print(output)
output.backward()

In [None]:
torch.__version__

In [None]:
output.requires_grad

In [None]:
x = np.array([1, 2])
print(x.shape)

In [None]:
z=np.expand_dims(x, axis=-1)
np.repeat(z,3,axis=-1).shape

In [None]:
for im,q,a in data_loader:
    print(im.shape)
    print(q['ids'].shape)
    print(a.shape)
    break

In [None]:
len(answer_to_index)

In [None]:
img,q,a=vqa_data[58]

In [None]:
a,b=torch.max(a.unsqueeze(0),dim = 1)

In [None]:
b.requires_grad = True

In [None]:
a.unsqueeze(0).view(-1,1).shape


In [None]:
bert_config = BertConfig(hidden_size=256, num_attention_heads=8, num_hidden_layers=6)
bert_decoder_config = BertConfig(is_decoder=True, hidden_size=256, num_attention_heads=8, num_hidden_layers=6)

enc_dec_config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config= bert_config, decoder_config= bert_decoder_config)


In [None]:
model = EncoderDecoderModel(config= enc_dec_config)

In [None]:
outputs = model(input_ids = q['ids'].unsqueeze(0), decoder_inputs_embeds = h[0].unsqueeze(0).flatten(2).permute(2, 0, 1))

In [None]:
outputs[2].shape

In [None]:
bert_decoder = BertModel(config  = bert_decoder_config)
outputs = bert_decoder(input_ids = q['ids'].unsqueeze(0), encoder_hidden_states = h[0].unsqueeze(0).flatten(2).permute(0,2,1))

In [None]:
outputs = bert_decoder(input_ids = q['ids'].unsqueeze(0), encoder_hidden_states = h[0].unsqueeze(0).flatten(2).permute(0,2,1))

In [None]:
while 1 :
    continue

In [None]:
BertModel??

In [None]:
 h[0].unsqueeze(0).flatten(2).permute(0, 2, 1).shape

In [None]:
q['ids'].unsqueeze(0).shape