### Importing Necessary Libraries for Data Preparation & Util Functions

In [4]:
######################################################################################################################
#                               Import Python Files for Sentence & Annotations Extraction                            #
#                                        Provided as a Simple API on Github                                          #
#                                https://github.com/BryanPlummer/flickr30k_entities                                  #
######################################################################################################################

# Uncomment the below line once in order for the code to run smoothly
!pip install imagesize

import os
os.chdir(r"/kaggle/input/vgrutils/Visual Grounding RefEx/Flickr30/")

import Utils.flickr30k_entities_utils
from Utils.flickr30k_entities_utils import get_sentence_data, get_annotations

import Utils.helper_functions
from Utils.helper_functions import *

import random 
import seaborn as sb

import pickle
import torch

import ast
import torch
import pickle
import torch.nn as nn
import torch.optim
import torch.utils.data.distributed
from torch.utils.data import Dataset,DataLoader
from torchvision.ops import box_iou,generalized_box_iou_loss





Collecting imagesize
  Using cached imagesize-1.4.1-py2.py3-none-any.whl (8.8 kB)
Installing collected packages: imagesize
Successfully installed imagesize-1.4.1
[0m

### Driver Functions

In [5]:
######################################################################################################################
#                                                                                                                    #
#                                                   Mapping Function                                                 #
#                                                                                                                    #
######################################################################################################################
"""
Mapping Function does the following,
    - takes list of Image names as i/p and fetch Sentences & Annotations (contains bounding boxes)of all those Images
    - passes those Sentences & Annotations to the func Phrase_Id_to_Bbox & gets Bounding Boxes for all phrases
    in every image.
    - also, passes those Sentences & Annotations to the func Phrase_Id_to_Phrases & extracts phrases contained in all
    images.
    
    A typical look of the outputs would be:
    
        _Image_Train_Phrase_Id_to_Bbox -----> {'image_id_1' : {'Phrase_id_1' : [Bbox1, Bbox2 ... Bboxn],
                                                              'Phrase_id_2' : [Bbox1, Bbox2 ... Bboxn],
                                                              'Phrase_id_3' : [Bbox1, Bbox2 ... Bboxn],
                                                              .
                                                              .
                                                              .
                                                              'Phrase_id_n' : [Bbox1, Bbox2 ... Bboxn]}
                                                              
                                                'image_id_2' : {'Phrase_id_1' : [Bbox1, Bbox2 ... Bboxn],
                                                              'Phrase_id_2' : [Bbox1, Bbox2 ... Bboxn],
                                                              'Phrase_id_3' : [Bbox1, Bbox2 ... Bboxn],
                                                              .
                                                              .
                                                              .
                                                              'Phrase_id_n' : [Bbox1, Bbox2 ... Bboxn]}
                                                              
                                                              
                                                              
                                                              .
                                                              .
                                                              .
                                                              .
                                                              .
                                                              .
                                                              
                                                              
                                                'image_id_n' : {'Phrase_id_1' : [Bbox1, Bbox2 ... Bboxn],
                                                              'Phrase_id_2' : [Bbox1, Bbox2 ... Bboxn],
                                                              'Phrase_id_3' : [Bbox1, Bbox2 ... Bboxn],
                                                              .
                                                              .
                                                              .
                                                              'Phrase_id_n' : [Bbox1, Bbox2 ... Bboxn]}
                                                              
                                                              }
                                                              
                                                              
        _Image_Train_Phrase_Id_to_Phrase -----> {'image_id_1' : {'Phrase_id_1' : [Phrase1, Phrase2.... Phrase_n],
                                                              'Phrase_id_2' : [Phrase1, Phrase2.... Phrase_n],
                                                              'Phrase_id_3' : [Phrase1, Phrase2.... Phrase_n],
                                                              .
                                                              .
                                                              .
                                                              'Phrase_id_n' : [Phrase1, Phrase2.... Phrase_n]}
                                                              
                                                'image_id_2' : {'Phrase_id_1' : [Phrase1, Phrase2.... Phrase_n],
                                                              'Phrase_id_2' : [Phrase1, Phrase2.... Phrase_n],
                                                              'Phrase_id_3' : [Phrase1, Phrase2.... Phrase_n],
                                                              .
                                                              .
                                                              .
                                                              'Phrase_id_n' : [Phrase1, Phrase2.... Phrase_n]}
                                                              
                                                              
                                                              
                                                              .
                                                              .
                                                              .
                                                              .
                                                              .
                                                              .
                                                              
                                                              
                                                'image_id_n' : {'Phrase_id_1' : [Phrase1, Phrase2.... Phrase_n],
                                                              'Phrase_id_2' : [Phrase1, Phrase2.... Phrase_n],
                                                              'Phrase_id_3' : [Phrase1, Phrase2.... Phrase_n],
                                                              .
                                                              .
                                                              .
                                                              'Phrase_id_n' : [Phrase1, Phrase2.... Phrase_n]}
                                                              
                                                              }
        

NOTE: Please alter any folder paths for Images, Sentences and Annotations (Phrase & Bounding Boxes) in Helper Function File


"""



from collections import defaultdict
def Mapping(_Image_Names, _paths_dict):
    _Phrase_Id_to_Bbox = defaultdict()
    _Phrase_Id_to_Phrase = defaultdict()

    for _img in tqdm(_Image_Names):
        _img_sentences_path, _img_annotations_path, _img_absolute_path = get_Paths(_img, _paths_dict)
        sents = get_sentence_data(_img_sentences_path)
        anns = get_annotations(_img_annotations_path)
        _Phrase_Id_to_Bbox[_img] = phrase_Id_to_Bbox(sents, anns)
        _Phrase_Id_to_Phrase[_img] = phrase_Id_to_Phrases(sents, anns)
        
        
    return _Phrase_Id_to_Bbox, _Phrase_Id_to_Phrase




### Driver Code

In [6]:
_paths_dict = {
                '_sentences_path' : r'/kaggle/input/vgrutils/Visual Grounding RefEx/Flickr30/Data/annotations/Sentences',
                '_annotations_path' : r'/kaggle/input/vgrutils/Visual Grounding RefEx/Flickr30/Data/annotations/Annotations',
                '_image_folder_path' : r'/kaggle/input/flickr30k/flickr30k_images'
                }
_train_len = 5000 #len(_trainimg)



In [7]:
"""
######################################################################################################################
#                                                                                                                    #
#                       Enter path for train, val & test split in their respective variables                         #
#                                                                                                                    #
######################################################################################################################


train.txt, val.txt and test.txt are text files that contains predefined splits, i.e each file contains the split it
belongs to.

train.txt contains all image names as strings, that should be used for training
val.txt contains all image names as strings, that should be used for validation
test.txt contains all image names as strings, that should be used for testing

"""

_trainimg = load_Splits('/kaggle/input/vgrutils/Visual Grounding RefEx/Flickr30/Data/Splits/train.txt')
_vlimg = load_Splits('/kaggle/input/vgrutils/Visual Grounding RefEx/Flickr30/Data/Splits/val.txt')
_tsimg = load_Splits('/kaggle/input/vgrutils/Visual Grounding RefEx/Flickr30/Data/Splits/test.txt')


In [8]:
"""
######################################################################################################################
#                                                                                                                    #
#                                                 Call to the Mapping Functions                                      #
#                                                                                                                    #
######################################################################################################################
"""

_fractional_trainimg = _trainimg[:_train_len]
_Image_Train_Phrase_Id_to_Bbox, _Image_Train_Phrase_Id_to_Phrase = Mapping(_fractional_trainimg, _paths_dict)
_Image_Val_Phrase_Id_to_Bbox, _Image_Val_Phrase_Id_to_Phrase = Mapping(_vlimg, _paths_dict)
_Image_Test_Phrase_Id_to_Bbox, _Image_Test_Phrase_Id_to_Phrase = Mapping(_tsimg, _paths_dict)



  0%|          | 0/5000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [9]:
import pandas as pd
def prepare_DataFrame(Phrase_Dict, Bbox_Dict):
    Final_DF = pd.DataFrame()
    for Image_Id in tqdm(Phrase_Dict.keys()):
        
        Phrase_DF = pd.DataFrame.from_dict(Phrase_Dict[Image_Id], orient = 'index')
        Phrase_DF = pd.DataFrame(Phrase_DF.stack(level=0)).reset_index().drop('level_1', axis = 1)

        Bbox_DF = pd.DataFrame.from_dict(Bbox_Dict[Image_Id], orient = 'index')
        Bbox_DF = pd.DataFrame(Bbox_DF.stack(level=0)).reset_index().drop('level_1', axis = 1)
        Bbox_DF = Bbox_DF.groupby(['level_0'])[0].apply(list)
        

        Merged_DF = pd.merge(Phrase_DF, Bbox_DF, on = 'level_0', how='inner')
        Merged_DF['Image_Id'] = Image_Id

        Final_DF = pd.concat([Final_DF, Merged_DF], axis = 0)

    Final_DF = Final_DF.rename(columns = {'level_0' : 'Phrase_Id', '0_x': 'Phrase', '0_y':'Bounding_Box'})
    Final_DF = Final_DF[['Image_Id', 'Phrase_Id', 'Phrase', 'Bounding_Box']]
    Final_DF.reset_index(drop = True, inplace = True)
    print("Local Function Called......")
    return Final_DF



"""***************************************************************************************************************"""


'***************************************************************************************************************'

In [10]:
"""
######################################################################################################################
#                                                                                                                    #
#                                                Converting to DataFrames.                                           #
#                                                                                                                    #
######################################################################################################################
"""

_Fractional_Train_Set_Pid_to_P = {img : _Image_Train_Phrase_Id_to_Phrase[img] for img in _trainimg[:_train_len]}
_Fractional_Train_Set_Pid_to_B = {img : _Image_Train_Phrase_Id_to_Bbox[img] for img in _trainimg[:_train_len]}

Train_Frame = prepare_DataFrame(_Fractional_Train_Set_Pid_to_P, _Fractional_Train_Set_Pid_to_B)
Test_Frame = prepare_DataFrame(_Image_Test_Phrase_Id_to_Phrase, _Image_Test_Phrase_Id_to_Bbox)
Val_Frame = prepare_DataFrame(_Image_Val_Phrase_Id_to_Phrase, _Image_Val_Phrase_Id_to_Bbox)


Train_Frame.Phrase = Train_Frame.Phrase.str.lower()
Val_Frame.Phrase = Val_Frame.Phrase.str.lower()
Test_Frame.Phrase = Test_Frame.Phrase.str.lower()


  0%|          | 0/5000 [00:00<?, ?it/s]

Local Function Called......


  0%|          | 0/1000 [00:00<?, ?it/s]

Local Function Called......


  0%|          | 0/1000 [00:00<?, ?it/s]

Local Function Called......


# Retrieval Code for ViT Embeddings

In [11]:
"""
######################################################################################################################
#                                                                                                                    #
#                    image_index corresponding to image_id is the index of its embedding                             #
#                                                                                                                    #
######################################################################################################################
"""


_v_train_Image_Indices = pd.DataFrame(_trainimg[:_train_len], columns = ['Image_Id']).reset_index().rename(columns = {'index':'image_index'})
_v_val_Image_Indices = pd.DataFrame(_vlimg[:_train_len], columns = ['Image_Id']).reset_index().rename(columns = {'index':'image_index'})
_v_test_Image_Indices = pd.DataFrame(_tsimg[:_train_len], columns = ['Image_Id']).reset_index().rename(columns = {'index':'image_index'})


Train_Frame = Train_Frame.merge(_v_train_Image_Indices, on = 'Image_Id', how='left')
Val_Frame = Val_Frame.merge(_v_val_Image_Indices, on = 'Image_Id', how='left')
Test_Frame = Test_Frame.merge(_v_test_Image_Indices, on = 'Image_Id', how='left')

Vision_Embeddings_train = torch.load('/kaggle/input/embeddings-7k/v_train_embeds.pt')
Vision_Embeddings_val = torch.load('/kaggle/input/embeddings-7k/v_val_embeds.pt')
Vision_Embeddings_test = torch.load('/kaggle/input/embeddings-7k/v_test_embeds.pt')


# Retrieval Code for BERT Embeddings

In [12]:
"""
######################################################################################################################
#                                                                                                                    #
#               image_index corresponding to unique Phrase is the index of its embedding                             #
#                                                                                                                    #
######################################################################################################################
"""
with open('/kaggle/input/embeddings-7k/_train_Phrase_to_Index_Map.pkl', 'rb') as fp:
    _train_Phrase_to_Index_Map = pickle.load(fp)

with open('/kaggle/input/embeddings-7k/_val_Phrase_to_Index_Map.pkl', 'rb') as fp:
    _val_Phrase_to_Index_Map = pickle.load(fp)
    
with open('/kaggle/input/embeddings-7k/_test_Phrase_to_Index_Map.pkl', 'rb') as fp:
    _test_Phrase_to_Index_Map = pickle.load(fp)
    
    
    
_t_train_Image_Indices = pd.DataFrame(_train_Phrase_to_Index_Map.items(), columns = ['Phrase', 'text_index'])
_t_val_Image_Indices = pd.DataFrame(_val_Phrase_to_Index_Map.items(), columns = ['Phrase', 'text_index'])
_t_test_Image_Indices = pd.DataFrame(_test_Phrase_to_Index_Map.items(), columns = ['Phrase', 'text_index'])

Train_Frame = Train_Frame.merge(_t_train_Image_Indices, on = 'Phrase', how = 'left')
Val_Frame = Val_Frame.merge(_t_val_Image_Indices, on = 'Phrase', how = 'left')
Test_Frame = Test_Frame.merge(_t_test_Image_Indices, on = 'Phrase', how = 'left')


Textual_Embeddings_train = torch.load('/kaggle/input/embeddings-7k/t_train_embeds.pt')
Textual_Embeddings_val = torch.load('/kaggle/input/embeddings-7k/t_val_embeds.pt')
Textual_Embeddings_test = torch.load('/kaggle/input/embeddings-7k/t_test_embeds.pt')

# Actual Dataframes Look Like

In [13]:
just_to_see = ['Image_Id', 'Phrase_Id', 'Phrase']
necessary_columns = ['image_index', 'text_index', 'Bounding_Box']
train = Train_Frame[just_to_see + necessary_columns] 
val = Val_Frame[just_to_see + necessary_columns] 
test = Test_Frame[just_to_see + necessary_columns] 

## Training Set

In [14]:
train.head(30)

Unnamed: 0,Image_Id,Phrase_Id,Phrase,image_index,text_index,Bounding_Box
0,3359636318,112630,two people,0,1657,"[[46, 182, 105, 333], [143, 165, 207, 333]]"
1,3359636318,112632,the video game shop,0,15491,"[[0, 54, 168, 307]]"
2,3359636318,112631,the mobile phone store,0,15492,"[[191, 0, 498, 230]]"
3,3359636318,112625,people,0,0,"[[46, 182, 105, 333], [143, 165, 207, 333], [2..."
4,3359636318,112625,a group of people,0,15493,"[[46, 182, 105, 333], [143, 165, 207, 333], [2..."
5,3359636318,112625,several people,0,1658,"[[46, 182, 105, 333], [143, 165, 207, 333], [2..."
6,3359636318,112627,some stores,0,1659,"[[191, 0, 498, 230], [1, 0, 190, 307]]"
7,3359636318,112626,a sidewalk,0,1660,"[[2, 212, 499, 333]]"
8,6959556104,262504,a series of spectators,1,15494,"[[5, 70, 103, 314], [120, 54, 206, 172], [197,..."
9,6959556104,262504,the crowd,1,1661,"[[5, 70, 103, 314], [120, 54, 206, 172], [197,..."


## Validation Set

In [15]:
val.head(30)

Unnamed: 0,Image_Id,Phrase_Id,Phrase,image_index,text_index,Bounding_Box
0,100652400,197,a construction worker,0,2825,"[[52, 44, 109, 202]]"
1,100652400,197,a man,0,616,"[[52, 44, 109, 202]]"
2,100652400,198,a hard hat,0,2826,"[[58, 43, 87, 65]]"
3,100652400,198,hard hat,0,617,"[[58, 43, 87, 65]]"
4,100652400,198,a blue hard hat,0,4767,"[[58, 43, 87, 65]]"
5,100652400,199,a caution vest,0,2829,"[[61, 68, 97, 118]]"
6,100652400,199,a reflective vest,0,2828,"[[61, 68, 97, 118]]"
7,100652400,199,bright vest,0,618,"[[61, 68, 97, 118]]"
8,100652400,199,orange safety vest,0,2827,"[[61, 68, 97, 118]]"
9,100652400,200,an intersection,0,619,"[[0, 89, 373, 499]]"


## Testing Set

In [16]:
test.head(30)

Unnamed: 0,Image_Id,Phrase_Id,Phrase,image_index,text_index,Bounding_Box
0,1016887272,547,a collage of one person,0,5381,"[[193, 369, 230, 453], [207, 303, 255, 383], [..."
1,1016887272,547,seven climbers,0,600,"[[193, 369, 230, 453], [207, 303, 255, 383], [..."
2,1016887272,547,a group of people,0,4644,"[[193, 369, 230, 453], [207, 303, 255, 383], [..."
3,1016887272,547,several climbers,0,599,"[[193, 369, 230, 453], [207, 303, 255, 383], [..."
4,1016887272,548,a rock face,0,2731,"[[0, 53, 332, 499]]"
5,1016887272,548,the rock,0,601,"[[0, 53, 332, 499]]"
6,1016887272,548,a rock climbing wall,0,4645,"[[0, 53, 332, 499]]"
7,1016887272,548,a cliff,0,603,"[[0, 53, 332, 499]]"
8,1016887272,548,a rock,0,602,"[[0, 53, 332, 499]]"
9,1016887272,549,another man,0,605,"[[73, 301, 180, 499]]"


# Prepare DataLoaders

In [17]:
num_hid_dims = 0
def _the_Collate(batch):
    batch_size = len(batch)
    #print(batch_size)
    image_index_tensor = []
    text_index_tensor = []
    image_emb_tensor = []
    phrase_emb_tensor = []
    bbox_tensor = []
    
    for idx, (im_idx, t_idx, im_emb, p_emb, bbox) in enumerate(batch):
        image_index_tensor.append(im_idx)
        text_index_tensor.append(t_idx)
        image_emb_tensor.append(im_emb)
        phrase_emb_tensor.append(p_emb)
        bbox_tensor.append(bbox[0])
        
        
    """pad = [[0, 0, 0, 0]] * num_hid_dims
    for index, _ in enumerate(bbox_tensor):
        temp_pad = pad
        temp_pad[:len(bbox_tensor[index])] = bbox_tensor[index]
        bbox_tensor[index] = torch.tensor(temp_pad)"""
    
    image_index_tensor = torch.tensor(image_index_tensor)
    text_index_tensor = torch.tensor(text_index_tensor)
    image_emb_tensor = torch.stack(image_emb_tensor)
    phrase_emb_tensor = torch.stack(phrase_emb_tensor)
    bbox_tensor = torch.tensor(bbox_tensor)
    #print(bbox_tensor)
    
    return image_index_tensor, text_index_tensor, image_emb_tensor, phrase_emb_tensor, bbox_tensor

In [18]:
class CustomDataset(Dataset):
    def __init__(self, dataframe, img_emb,text_emb):
        self.dataframe = dataframe
        self.image_emb = img_emb
        self.text_emb = text_emb

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        image_index = self.dataframe['image_index'][index]
        text_index = self.dataframe['text_index'][index]
        image_embedding = self.image_emb[image_index]
        phrase_embedding = self.text_emb[text_index]
        bounding_boxes = self.dataframe['Bounding_Box'][index]
        return image_index, text_index, image_embedding, phrase_embedding, bounding_boxes

In [19]:

train_dataset = CustomDataset(train, img_emb=Vision_Embeddings_train, text_emb=Textual_Embeddings_train)
val_dataset = CustomDataset(val, img_emb=Vision_Embeddings_val, text_emb=Textual_Embeddings_val)
test_dataset = CustomDataset(test,img_emb=Vision_Embeddings_test,text_emb=Textual_Embeddings_test)


train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn= _the_Collate)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, collate_fn= _the_Collate)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True, collate_fn= _the_Collate)



In [20]:
"""for idx, (im_idx, t_idx, im_emb, p_emb, bbox) in enumerate(train_loader):
    print("Batch number : ", idx)
    print(im_idx)
    print(t_idx)
    print(im_emb)
    print(p_emb)
    print(bbox.size())
    break
"""

'for idx, (im_idx, t_idx, im_emb, p_emb, bbox) in enumerate(train_loader):\n    print("Batch number : ", idx)\n    print(im_idx)\n    print(t_idx)\n    print(im_emb)\n    print(p_emb)\n    print(bbox.size())\n    break\n'

In [21]:
class BaslineModel(nn.Module):

    def __init__(self,hidden_dim:int, img_emb_dim:int,word_emb_dim: int,bs:int):
        super().__init__()
        
        self.prediction_head = nn.Sequential(
            nn.Linear(img_emb_dim+word_emb_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim,4)
        )


    def forward(self, img_emb, word_emb):

        embed_concat = torch.cat((img_emb,word_emb), dim=-1)

        pred = self.prediction_head(embed_concat).sigmoid()
        pred = pred.squeeze(1)
        return pred
 

In [22]:

import random
import numpy as np
import torch
import torch.nn.functional as F



class AverageMeter(object):
    """Computes and stores the average and current value"""
    
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0.
        self.avg = 0.
        self.sum = 0.
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val*n
        self.count += n
        self.avg = self.sum / self.count
def xywh2xyxy(x):  # Convert bounding box format from [x, y, w, h] to [x1, y1, x2, y2]
    y = torch.zeros(x.shape) if x.dtype is torch.float32 else np.zeros(x.shape)
    y[:, 0] = (x[:, 0] - x[:, 2] / 2)
    y[:, 1] = (x[:, 1] - x[:, 3] / 2)
    y[:, 2] = (x[:, 0] + x[:, 2] / 2)
    y[:, 3] = (x[:, 1] + x[:, 3] / 2)
    return y

In [23]:
class Criterion(nn.Module):
    def __init__(self):
        super(Criterion, self).__init__()
        self.loss_weight = [3, 1]
        self.MSELoss = torch.nn.MSELoss(reduction='none')
    def forward(self, pred, gt, img_size=256):
        """
        :param pred:  (bs, 4)
        :param gt: (bs, 4)
        :return:
        """
        bs = pred.shape[0]
        gt = gt / img_size

        loss_bbox = F.l1_loss(pred, gt, reduction='none')
        loss_bbox = loss_bbox.sum() / bs

        loss_giou = 1 - torch.diag(generalized_box_iou_loss(
                                   self.box_cxcywh_to_xyxy(pred),
                                   self.box_cxcywh_to_xyxy(gt)))

        loss_giou = loss_giou.sum() / bs
        loss = 5 * loss_bbox + loss_giou * 2
        return loss, 5 * loss_bbox, loss_giou * 2
    
    def box_cxcywh_to_xyxy(self, x):
        x_c, y_c, w, h = x.unbind(-1)
        b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
             (x_c + 0.5 * w), (y_c + 0.5 * h)]
        return torch.stack(b, dim=-1)
    

In [24]:
import time
import logging
import numpy as np
from torch.autograd import Variable


def train_epoch(train_loader, model, optimizer, epoch, criterion=None, img_size=512):
    bs =32
    batch_time = AverageMeter()
    losses = AverageMeter()

    losses_bbox = AverageMeter()
    losses_giou = AverageMeter()

    acc = AverageMeter()
    miou = AverageMeter()

    model.train()
    end = time.time()

    for batch_idx, batch in enumerate(train_loader):
        imgs = batch[0]
        word_id = batch[1]
        img_emb = batch[2]
        word_emb =batch[3]
        bbox = batch[4]
        #bbox = torch.clamp(bbox, min=0, max=(512 - 1))
        image_emb = Variable(img_emb.unsqueeze(1))
        word_emb = Variable(word_emb.unsqueeze(1))
        bbox = Variable(bbox)

        norm_bbox = torch.zeros_like(bbox)

        norm_bbox[:, 0] = (bbox[:, 0] + bbox[:, 2]) / 2.0  # x_center
        norm_bbox[:, 1] = (bbox[:, 1] + bbox[:, 3]) / 2.0  # y_center
        norm_bbox[:, 2] = bbox[:, 2] - bbox[:, 0]   # w
        norm_bbox[:, 3] = bbox[:, 3] - bbox[:, 1]    # h
        #print(norm_bbox)

        # forward
        pred_box = model(image_emb, word_emb)
        #print(pred_box.size())# [bs, C, H, W]
        loss, loss_box, loss_giou = criterion(pred_box, norm_bbox, img_size=img_size)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # pred-box
        pred_bbox = pred_box.detach()
        pred_bbox = pred_bbox * img_size
        pred_box = xywh2xyxy(pred_bbox)

        losses.update(loss.item(), bs)
        losses_bbox.update(loss_box.item(), bs)
        losses_giou.update(loss_giou.item(), bs)

        target_bbox = bbox
        iou = box_iou(pred_box, target_bbox.data)
#         print("in here")
        
        accu = np.sum(np.array((iou.data.numpy() > 0.5), dtype=float)) / bs

        # metrics
        miou.update(torch.mean(iou).item(), image_emb.size(0))
        acc.update(accu, image_emb.size(0))

        batch_time.update(time.time() - end)
        end = time.time()

        if (batch_idx%300)== 0 :
            print_str = 'Epoch: [{0}][{1}/{2}]\t' \
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
                        'Loss_bbox {loss_box.val:.4f} ({loss_box.avg:.4f})\t' \
                        'Loss_giou {loss_giou.val:.4f} ({loss_giou.avg:.4f})\t' \
                        'Accu {acc.val:.4f} ({acc.avg:.4f})\t' \
                        'Mean_iu {miou.val:.4f} ({miou.avg:.4f})\t' \
                .format(epoch+1, batch_idx+1, len(train_loader),
                        batch_time=batch_time,
                        loss=losses,
                        loss_box=losses_bbox,
                        loss_giou=losses_giou,
                        acc=acc,
                        miou=miou)

            print(print_str)
            

def validate_epoch(val_loader, model, train_epoch, img_size=512):
    bs=32
    batch_time = AverageMeter()
    acc = AverageMeter()
    miou = AverageMeter()

    model.eval()
    end = time.time()

    for batch_idx,batch in enumerate(val_loader):
        imgs = batch[0]
        word_id = batch[1]
        img_emb = batch[2]
        word_emb =batch[3]
        bbox = batch[4]
        #bbox = torch.clamp(bbox, min=0, max=(512 - 1))
        
        image_emb = Variable(img_emb.unsqueeze(1))
        word_emb = Variable(word_emb.unsqueeze(1))
        bbox = Variable(bbox)

        norm_bbox = torch.zeros_like(bbox)

        norm_bbox[:, 0] = (bbox[:, 0] + bbox[:, 2]) / 2.0  # x_center
        norm_bbox[:, 1] = (bbox[:, 1] + bbox[:, 3]) / 2.0  # y_center
        norm_bbox[:, 2] = bbox[:, 2] - bbox[:, 0]   # w
        norm_bbox[:, 3] = bbox[:, 3] - bbox[:, 1]    # h

        with torch.no_grad():
            pred_box = model(image_emb, word_emb)  # [bs, C, H, W]
            

        pred_bbox = pred_box.detach()
        pred_bbox = pred_bbox * img_size
        pred_bbox = xywh2xyxy(pred_bbox)

        # constrain
        pred_bbox[pred_bbox < 0.0] = 0.0
        pred_bbox[pred_bbox > img_size-1] = img_size-1

        target_bbox = bbox
        # metrics
        iou = box_iou(pred_bbox, target_bbox.data)
        # accu = np.sum(np.array((iou.data.cpu().numpy() > 0.5), dtype=float)) / args.batch_size
        accu = np.sum(np.array((iou.data.cpu().numpy() > 0.5), dtype=float)) / bs

        acc.update(accu, bs)
        miou.update(torch.mean(iou).item(), bs)

        batch_time.update(time.time() - end)
        end = time.time()

        if (batch_idx%100) == 0:
            print_str = 'Validate: [{0}/{1}]\t' \
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})  ' \
                        'Acc {acc.val:.4f} ({acc.avg:.4f})  ' \
                        'Mean_iu {miou.val:.4f} ({miou.avg:.4f})  ' \
                .format(batch_idx+1, len(val_loader), batch_time=batch_time, acc=acc, miou=miou)

            print(print_str)
            
    print(f"Train_epoch {train_epoch+1}  Validate Result:  Acc {acc.avg}, MIoU {miou.avg}.")


    return acc.avg, miou.avg

def test_epoch(test_loader, model, img_size=512):
    bs = 32
    acc = AverageMeter()
    miou = AverageMeter()
    model.eval()

    for batch_idx, batch in enumerate(test_loader):
        imgs = batch[0]
        word_id = batch[1]
        img_emb = batch[2]
        word_emb =batch[3]
        bbox = batch[4]
        #bbox = torch.clamp(bbox, min=0, max=(512 - 1))
        image_emb = Variable(img_emb.unsqueeze(1))
        word_emb = Variable(word_emb.unsqueeze(1))
        bbox = Variable(bbox)
        norm_bbox = torch.zeros_like(bbox)

        norm_bbox[:, 0] = (bbox[:, 0] + bbox[:, 2]) / 2.0  # x_center
        norm_bbox[:, 1] = (bbox[:, 1] + bbox[:, 3]) / 2.0  # y_center
        norm_bbox[:, 2] = bbox[:, 2] - bbox[:, 0]   # w
        norm_bbox[:, 3] = bbox[:, 3] - bbox[:, 1]    # h

        with torch.no_grad():
            pred_box = model(image_emb, word_emb)  # [bs, C, H, W]

        pred_bbox = pred_box.detach()
        pred_bbox = pred_bbox * img_size
        pred_bbox = xywh2xyxy(pred_bbox)

        # constrain
        pred_bbox[pred_bbox < 0.0] = 0.0
        pred_bbox[pred_bbox > img_size-1] = img_size-1

        target_bbox = bbox
        # metrics
        iou = box_iou(pred_bbox, target_bbox.data)
        accu = np.sum(np.array((iou.data.cpu().numpy() > 0.5), dtype=float)) / bs

        acc.update(accu, bs)
        miou.update(torch.mean(iou).item(), bs)
    print(f"Test Result:  Acc {acc.avg}, MIoU {miou.avg}.")

In [25]:

import matplotlib as mpl

import torch.nn.parallel
import torch.optim
import torch.utils.data.distributed


epochs = 7
hidden_dim = 32
img_emb_dim = 768
word_emb_dim = 768
bs = 32
model =BaslineModel(hidden_dim,img_emb_dim,word_emb_dim,bs)
optimizer = torch.optim.Adam(model.parameters(), lr=10e-4, weight_decay=10e-3)

# get criterion
criterion = Criterion()
best_accu = -float('Inf')

# train
for epoch in range(epochs):
    model.train()
    train_epoch(train_loader, model, optimizer, epoch, criterion, 512)
    model.eval()
    accu_new, miou_new = validate_epoch(val_loader, model, epoch, 512)

    is_best = accu_new > best_accu
    best_accu = max(accu_new, best_accu)
    # save the pth

print(f'Best Acc: {best_accu}.')



Epoch: [1][1/1844]	Time 0.172 (0.172)	Loss 66.2466 (66.2466)	Loss_bbox 4.2585 (4.2585)	Loss_giou 61.9881 (61.9881)	Accu 1.2188 (1.2188)	Mean_iu 0.1822 (0.1822)	
Epoch: [1][301/1844]	Time 0.007 (0.006)	Loss 64.6920 (64.9323)	Loss_bbox 2.9093 (3.0680)	Loss_giou 61.7826 (61.8643)	Accu 0.3750 (0.6301)	Mean_iu 0.1020 (0.1222)	
Epoch: [1][601/1844]	Time 0.005 (0.006)	Loss 64.7016 (64.8025)	Loss_bbox 2.9718 (2.9446)	Loss_giou 61.7298 (61.8579)	Accu 0.4062 (0.6292)	Mean_iu 0.0909 (0.1193)	
Epoch: [1][901/1844]	Time 0.005 (0.006)	Loss 64.4094 (64.7427)	Loss_bbox 2.7048 (2.8929)	Loss_giou 61.7047 (61.8497)	Accu 0.2812 (0.6230)	Mean_iu 0.0871 (0.1180)	
Epoch: [1][1201/1844]	Time 0.005 (0.005)	Loss 64.6159 (64.7007)	Loss_bbox 2.6530 (2.8477)	Loss_giou 61.9629 (61.8529)	Accu 0.9688 (0.6332)	Mean_iu 0.1397 (0.1183)	
Epoch: [1][1501/1844]	Time 0.007 (0.005)	Loss 64.7980 (64.6727)	Loss_bbox 3.0215 (2.8170)	Loss_giou 61.7766 (61.8556)	Accu 0.3125 (0.6404)	Mean_iu 0.0950 (0.1184)	
Epoch: [1][1801/1844]	

In [26]:
torch.save(model, '/kaggle/working/baseline_model_5ep.pth')

In [27]:
model.eval()
test_epoch(test_loader, model,512)

Test Result:  Acc 0.8851883561643835, MIoU 0.13118556348020083.
