In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

from PIL import Image
import matplotlib.pyplot as plt

from RelTR.models.backbone import Backbone, Joiner
from RelTR.models.position_encoding import PositionEmbeddingSine
from RelTR.models.transformer import Transformer
from RelTR.models.reltr import RelTR

import os
import pandas as pd
import json

In [2]:
CLASSES = [ 'N/A', 'airplane', 'animal', 'arm', 'bag', 'banana', 'basket', 'beach', 'bear', 'bed', 'bench', 'bike',
                'bird', 'board', 'boat', 'book', 'boot', 'bottle', 'bowl', 'box', 'boy', 'branch', 'building',
                'bus', 'cabinet', 'cap', 'car', 'cat', 'chair', 'child', 'clock', 'coat', 'counter', 'cow', 'cup',
                'curtain', 'desk', 'dog', 'door', 'drawer', 'ear', 'elephant', 'engine', 'eye', 'face', 'fence',
                'finger', 'flag', 'flower', 'food', 'fork', 'fruit', 'giraffe', 'girl', 'glass', 'glove', 'guy',
                'hair', 'hand', 'handle', 'hat', 'head', 'helmet', 'hill', 'horse', 'house', 'jacket', 'jean',
                'kid', 'kite', 'lady', 'lamp', 'laptop', 'leaf', 'leg', 'letter', 'light', 'logo', 'man', 'men',
                'motorcycle', 'mountain', 'mouth', 'neck', 'nose', 'number', 'orange', 'pant', 'paper', 'paw',
                'people', 'person', 'phone', 'pillow', 'pizza', 'plane', 'plant', 'plate', 'player', 'pole', 'post',
                'pot', 'racket', 'railing', 'rock', 'roof', 'room', 'screen', 'seat', 'sheep', 'shelf', 'shirt',
                'shoe', 'short', 'sidewalk', 'sign', 'sink', 'skateboard', 'ski', 'skier', 'sneaker', 'snow',
                'sock', 'stand', 'street', 'surfboard', 'table', 'tail', 'tie', 'tile', 'tire', 'toilet', 'towel',
                'tower', 'track', 'train', 'tree', 'truck', 'trunk', 'umbrella', 'vase', 'vegetable', 'vehicle',
                'wave', 'wheel', 'window', 'windshield', 'wing', 'wire', 'woman', 'zebra']

REL_CLASSES = ['__background__', 'above', 'across', 'against', 'along', 'and', 'at', 'attached to', 'behind',
                'belonging to', 'between', 'carrying', 'covered in', 'covering', 'eating', 'flying in', 'for',
                'from', 'growing on', 'hanging from', 'has', 'holding', 'in', 'in front of', 'laying on',
                'looking at', 'lying on', 'made of', 'mounted on', 'near', 'of', 'on', 'on back of', 'over',
                'painted on', 'parked on', 'part of', 'playing', 'riding', 'says', 'sitting on', 'standing on',
                'to', 'under', 'using', 'walking in', 'walking on', 'watching', 'wearing', 'wears', 'with']

In [3]:
position_embedding = PositionEmbeddingSine(128, normalize=True)
backbone = Backbone('resnet50', False, False, False)
backbone = Joiner(backbone, position_embedding)
backbone.num_channels = 2048

transformer = Transformer(d_model=256, dropout=0.1, nhead=8, 
                          dim_feedforward=2048,
                          num_encoder_layers=6,
                          num_decoder_layers=6,
                          normalize_before=False,
                          return_intermediate_dec=True)

model = RelTR(backbone, transformer, num_classes=151, num_rel_classes = 51,
              num_entities=100, num_triplets=200)

# The checkpoint is pretrained on Visual Genome
ckpt = torch.hub.load_state_dict_from_url(
    url='https://cloud.tnt.uni-hannover.de/index.php/s/PB8xTKspKZF7fyK/download/checkpoint0149.pth',
    map_location='cpu', check_hash=True)
model.load_state_dict(ckpt['model'])
model.eval()



RelTR(
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=256, bias=True)
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (decoder): TransformerDecoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerDecoderLayer(
          (self_attn_entity): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features

In [4]:
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
          (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b

In [14]:
def get_sgg(file_name):
    im = Image.open(file_name)
    img = transform(im).unsqueeze(0)
    # propagate through the model
    outputs = model(img)

    # keep only predictions with >0.3 confidence
    probas = outputs['rel_logits'].softmax(-1)[0, :, :-1]
    probas_sub = outputs['sub_logits'].softmax(-1)[0, :, :-1]
    probas_obj = outputs['obj_logits'].softmax(-1)[0, :, :-1]
    keep = torch.logical_and(probas.max(-1).values > 0.3, torch.logical_and(probas_sub.max(-1).values > 0.3,
                                                                            probas_obj.max(-1).values > 0.3))
    topk = 10 # display up to 10 images
    keep_queries = torch.nonzero(keep, as_tuple=True)[0]
    indices = torch.argsort(-probas[keep_queries].max(-1)[0] * probas_sub[keep_queries].max(-1)[0] * probas_obj[keep_queries].max(-1)[0])[:topk]
    keep_queries = keep_queries[indices]

    with torch.no_grad():
        # propagate through the model

        sub, rel, obj = [], [], []
        for idx in keep_queries:
            sub.append(CLASSES[probas_sub[idx].argmax()])
            rel.append(REL_CLASSES[probas[idx].argmax()])
            obj.append(CLASSES[probas_obj[idx].argmax()])
            #print(CLASSES[probas_sub[idx].argmax()]+' '+REL_CLASSES[probas[idx].argmax()]+' '+CLASSES[probas_obj[idx].argmax()])
        return sub, rel, obj

In [6]:
data_frame = pd.read_csv('Datasets/Incidents/capLabel_All.csv')
data_frame = data_frame.drop_duplicates(subset=['filename'], keep='first')
data_frame

Unnamed: 0,filename,caption
0,036166_29fba09e.jpg,a dog standing on the side of a dirt road
1,074DF9CDDC33B3B774722C27D2B8074854000304.jpg,a white and black animal standing on the side ...
2,082001_bcd90588.jpg,a herd of sheep walking down a dirt road
3,1129772_a671f2ff.jpg,a herd of sheep walking down a dirt road
4,131590_023e94c7.jpg,two sheep walking down a dirt road
...,...,...
11559,X40vYG4ISI.jpg,a car parked on the side of a road next to a tree
11560,XZGwySqE0h.jpg,a bus is parked on the side of the road
11561,YJR1eMotsT.jpg,a fire truck is parked next to a pile of logs
11562,zlkcsgEw1e.jpg,a street scene with focus on a street light


In [13]:
anno_file = os.listdir('Datasets/Incidents/anno')
anno_file = [item.replace('.json', '') for item in anno_file]
anno_file
data_frame = data_frame[~data_frame['filename'].isin(anno_file)]
data_frame

Unnamed: 0,filename,caption
2064,uyChnNfDyx.jpg,a small white and black zebra walking down a s...
2065,v5THD5zGvQ.jpg,two animals crossing a road in the middle of t...
2066,VdSmYFjN41.jpg,a dog running across a road in the middle of t...
2067,VkT80Dj0f5.jpg,a herd of sheep standing on top of a road
2068,vQ62j5mowQ.jpg,a large animal standing on the side of a road
...,...,...
11559,X40vYG4ISI.jpg,a car parked on the side of a road next to a tree
11560,XZGwySqE0h.jpg,a bus is parked on the side of the road
11561,YJR1eMotsT.jpg,a fire truck is parked next to a pile of logs
11562,zlkcsgEw1e.jpg,a street scene with focus on a street light


In [15]:
root_dir = 'Datasets/Incidents/incidents_cleaned/'
object_annotation = []
for file_name, caption in data_frame.values[:7000]:
    try:
        sub, rel, obj = get_sgg(root_dir + file_name)
        obj_anno = {
            "file_name": file_name,
            "subject": sub,
            "object": obj,
            "relation": rel,
            "query": caption
        }
        jsonFile = f'Datasets/Incidents/anno/{file_name}.json'
        with open(jsonFile, "w") as outfile:
            json.dump(obj_anno, outfile)
        object_annotation.append(obj_anno)
    except:
        continue
    # break

In [40]:
anno_file = os.listdir('Datasets/Incidents/anno/')
len(anno_file)

9063

In [41]:
root = 'Datasets/Incidents/anno/'
err = []
for item in anno_file:
    file_data = open(root + item)
    data_item = json.load(file_data)
    if(len(data_item['subject']) == 0 or len(data_item['object']) == 0):
        err.append(item)

In [42]:
len(err)

682

In [43]:
err

['4s3pmsysNK.json',
 '5752642_ab14ab8c.json',
 '1635750_c8770905.json',
 '2058863541.json',
 '25FA4009FA6E43F0713EA2F3EBCF3C56CD94DB0F.json',
 'AGZqDogXNa.json',
 'CD6525BF9A9438BE0E46C11E4024BE16EA4B80A4.json',
 '5394016900.json',
 '2132BBF2F56A3DB33BB83D61AF750F6F4FEB42E8.json',
 'xdEF7gPhb6.json',
 'tnNFRdbwnD.json',
 'cVMv01vOg2.json',
 '3mt94s8g5L.json',
 'A5FECC16C871EC567F60D9B31F19FF5F252A55A9.json',
 '0B888F68DFD56CB1786342E45BAC9BC85C366409.json',
 'Qb5LiNY2jG.json',
 '3341692_44df6db3_original.json',
 '268512_ae8630c2.json',
 '4990998_1994ce92.json',
 'D06FFD0BA53A6A1FE0EE0AB329E9BC03B0DEB0C4.json',
 '1585758_32460d59.json',
 'F8391B32F38326297DA770726248F73EECF53638.json',
 '3864027_132af30b_original.json',
 '4858521_64b7d9cb.json',
 '5791276580.json',
 '097939_4128a6fb.json',
 '1789485_cbae34b4_original.json',
 '7B3BED1FD95691AED670E7B7B8D1E8C99BFF7856.json',
 '4236554_55bd40a9_original.json',
 '3247868_1902f6c1.json',
 'MC9A0abX9E.json',
 'Z5PP797bqH.json',
 '5DCCF2534F26

In [18]:
from sklearn.model_selection import train_test_split
train_files, valid_files = train_test_split(anno_file, test_size=0.2, random_state=42)


In [20]:
import shutil
root = 'Datasets/Incidents/anno/'
for item in train_files:
    src = root + item
    dst = root.replace('anno', 'train') 
    shutil.copy(src, dst)

for item in valid_files:
    src = root + item
    dst = root.replace('anno', 'val')
    shutil.copy(src, dst)

In [46]:

for item in os.listdir('Datasets/Incidents/val/'):
    if(item in err):
        rm_file = os.path.join('Datasets/Incidents/val/', item)
        print(rm_file)
        os.remove(rm_file)
        

Datasets/Incidents/val/25FA4009FA6E43F0713EA2F3EBCF3C56CD94DB0F.json
Datasets/Incidents/val/2132BBF2F56A3DB33BB83D61AF750F6F4FEB42E8.json
Datasets/Incidents/val/3mt94s8g5L.json
Datasets/Incidents/val/3341692_44df6db3_original.json
Datasets/Incidents/val/1585758_32460d59.json
Datasets/Incidents/val/5791276580.json
Datasets/Incidents/val/1789485_cbae34b4_original.json
Datasets/Incidents/val/5DCCF2534F26D5B4A3CA9C67209C9C476AF43F8C.json
Datasets/Incidents/val/8D51B40783C7D307A59516CCA9E98CA5089B30EC.json
Datasets/Incidents/val/355332_e48cb374.json
Datasets/Incidents/val/4024266_3136f635.json
Datasets/Incidents/val/7juqcsD5Vu.json
Datasets/Incidents/val/A4047A0B2E3DEFF3DCFEC79C3F43E0BC71E03B72.json
Datasets/Incidents/val/4793918_74049b42_original.json
Datasets/Incidents/val/C7651F617693C290429FB753AC9099A5B013B9A1.json
Datasets/Incidents/val/3841439_89d3039b_original.json
Datasets/Incidents/val/2853790_0f6987c2_original.json
Datasets/Incidents/val/5611980_08d62a45_original.json
Datasets/In

In [21]:
import sys
print(sys.getrecursionlimit()) # Prints 1000

3000
