In [2]:
import sys
sys.path.insert(0, "/notebooks/pipenv")
sys.path.insert(0, "/notebooks/nebula3_database")
sys.path.insert(0, "/notebooks/")
from PIL import Image
import requests
import visual_genome.local as vg
import json
import copy
import subprocess

import numpy as np
import torch
import spacy
import nltk
from spacy_wordnet.wordnet_annotator import WordnetAnnotator 
from sentence_transformers import SentenceTransformer
from database.arangodb import DatabaseConnector
from config import NEBULA_CONF


In [3]:
nltk.download('wordnet')
nlp = spacy.load('en_core_web_lg')
nlp.add_pipe("spacy_wordnet", after='tagger', config={'lang': nlp.lang})

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


<spacy_wordnet.wordnet_annotator.WordnetAnnotator at 0x7f330c15dc60>

In [51]:
VG_DATA = '/storage/vg_data'
IPC_COLLECTION = 'ipc_relations_spice'
RECALL_COLLECTION = 'ipc_recall_spice'

In [5]:
class PIPELINE:
    def __init__(self):
        config = NEBULA_CONF()
        self.db_host = config.get_database_host()
        self.database = config.get_playground_name()
        self.gdb = DatabaseConnector()
        self.db = self.gdb.connect_db(self.database)

pipeline = PIPELINE()

In [6]:
def cosine_sim(x,y):
    return np.dot(x,y) / (np.linalg.norm(x)*np.linalg.norm(y))

def compare_cross_lists(l1, l2):
    return np.any([x in l2 for x in l1])

class SimilarityManager:
    def __init__(self):
        self.nlp = nlp
        self.similarity_model = SentenceTransformer('sentence-transformers/paraphrase-xlm-r-multilingual-v1')
        self.similarity_model.cuda()

    def similarity(self, src, target):
        rc = []
        s1 = self.nlp(src)
        s2 = self.nlp(target)
        for w in s1:
            if w.pos_ not in ['NOUN', 'ADJ', 'ADV', 'VERB', 'PROPN'] and len(s1)>1:
                continue
            rc.append(max([w.similarity(x) for x in s2]))
        return np.mean(rc)
    
    def compare_cross_synsets(self, text1, text2):
        t1 = self.nlp(text1)
        t2 = self.nlp(text2)
        return compare_cross_lists([x._.wordnet.synsets() for x in t1], [x._.wordnet.synsets() for x in t2])
    
    def compare_triplet(self, t1, t2, method='spacy'):
        if len(t1) != len(t2):
            return 0.
        sim = 1.
        if method=='bert':
            embs = self.similarity_model.encode([' '.join(t1).lower(), ' '.join(t2).lower()])
            sim = cosine_sim(*embs)
        else:
            for x,y in zip(t1,t2):
                if method=='wordnet':
                    sim *= self.compare_cross_synsets(x,y)
                elif method=='spacy':
                    sim *= self.similarity(x,y)
                else:
                    print("Unknown similarity method: {}".format(method))
        return sim

        
smanager = SimilarityManager()

In [7]:
def triplet_from_rel(rel):
    return (rel.subject.names[0], rel.predicate, rel.object.names[0])

In [8]:
def get_sc_graph(id):
    return vg.get_scene_graph(id, images=VG_DATA,
                    image_data_dir=VG_DATA+'/by-id/',
                    synset_file=VG_DATA+'/synsets.json')
freeze_dict = lambda d: tuple((k, d[k]) for k in sorted(d.keys()))
rel_to_triplet = lambda rel: (rel['subject'].id, rel['predicate'], rel['object'].id)

In [9]:
# vg.add_attrs_to_scene_graphs(data_dir=VG_DATA)
# vg.save_scene_graphs_by_id(data_dir=VG_DATA, image_data_dir=VG_DATA+'/by-id/')

In [10]:
ipc_data = json.load(open('/storage/ipc_data/paragraphs_v1.json','r'))

In [19]:
len(ipc_data), ipc_data[4]

(19561,
 {'url': 'https://cs.stanford.edu/people/rak248/VG_100K_2/2383120.jpg',
  'image_id': 2383120,
  'paragraph': 'A very clean and tidy a bathroom. Everything is a neat porcelain white. This bathroom is both retro and modern.'})

In [40]:
# src: A single triplet
# dst: A list of triplets

def recall_triplet(src, dst, **kwargs):
    if not dst:
        return 0.    
    scores = [smanager.compare_triplet(src,x, **kwargs) for x in dst]
    return max(scores)

#src: A list of triplets
#dst: A list of triplets
def recall_triplets(src, dst, **kwargs):
    rc = [recall_triplet(x,dst, **kwargs) for x in src]
    return rc
    # return np.mean(rc)

def total_recall_triplets(src_triplets, dst_triplets, methods=('wordnet', 'wordnet', 'bert')):
    total_recall = []
    for i in [1,2,3]:
        dst_i = [x for x in dst_triplets if len(x)==i]
        src_i = [x for x in src_triplets if len(x)==i]
        total_recall.extend(recall_triplets(src_i,dst_i,method=methods[i-1]))
    return total_recall    

def recall_paragraph_sg(paragraph, sg, **kwargs):
    ipc_triplets = spice_get_triplets(paragraph)
    rel_triplets = list(map(triplet_from_rel,sg.relationships))
    return total_recall_triplets(rel_triplets, ipc_triplets, **kwargs)

In [13]:
def produce_pair(ipc_num: int):
    ipc = ipc_data[ipc_num]
    sg = get_sc_graph(ipc['image_id'])
    print("Paragraph is:")
    print(ipc['paragraph'])
    for rel in sg.relationships:
        print("Processing: "+str(rel))
        print("Subject is: {}".format(rel.subject.names))
        sim = smanager.similarity(rel.subject.names[0],ipc['paragraph'])
        print("Similarity: {}".format(sim))
        

In [14]:
def spice_get_triplets(text):
    SPICE_FNAME = '/notebooks/SPICE-1.0/spice-1.0.jar'
    INP_FNAME = '/tmp/example.json'
    OUT_FNAME = '/tmp/example_output.json'
    inp = {
        'image_id': 1,
        'test': "",
        'refs': [text],        
    }
    json.dump([inp],open(INP_FNAME,'w'))
    p = subprocess.Popen('java -Xmx8G -jar {} {} -detailed -silent -subset -out {}'.format(SPICE_FNAME,INP_FNAME,OUT_FNAME),shell=True,
                        stdin=subprocess.PIPE,
                        stdout=subprocess.PIPE,
                        stderr=subprocess.PIPE)
    p.communicate()
    outp = json.load(open(OUT_FNAME,'r'))
    return [x['tuple'] for x in outp[0]['ref_tuples']]
    

In [48]:
def check_image_id_in_collection(id,collection=IPC_COLLECTION):
    results = {}
    query = 'FOR doc IN {} FILTER doc.image_id == {} RETURN doc'.format(collection,id)
    #print(query)
    cursor = pipeline.db.aql.execute(query)
    for doc in cursor:
        results.update(doc)
    return results

def process_ipc(ipc):
    ipc['triplets'] = spice_get_triplets(ipc['paragraph'])
    return ipc

def process_all_ipc(all_ipc):
    for i,ipc in enumerate(all_ipc):
        if check_ipc_id(ipc['image_id']):
            print('idx {}, image_id: {} already exists. Moving on.'.format(i,ipc['image_id']))
            continue
        rc_doc = process_ipc(ipc)
        rc_doc['idx'] = i
        query = "INSERT {} INTO {}".format(rc_doc,IPC_COLLECTION)
        cursor = pipeline.db.aql.execute(query) 


In [30]:
rc1 = spice_get_triplets("A man is wearing a hat and riding a skateboard. He is holding a guitar. He is on the sidewalk next to a street with several cars parked on it")
rc2 = spice_get_triplets("A man is wearing a hat and riding a skateboard, he is holding a guitar, he is on the sidewalk next to a street with several cars parked on it")

# rc2 = spice_get_triplets("a slow white fox jumped over the red fence")

In [31]:
rc1

[['car'],
 ['car', 'park on', 'sidewalk'],
 ['car', 'several'],
 ['guitar'],
 ['hat'],
 ['man'],
 ['man', 'ride', 'skateboard'],
 ['man', 'wear', 'hat'],
 ['sidewalk'],
 ['sidewalk', 'next to', 'street'],
 ['skateboard'],
 ['street'],
 ['street', 'with', 'car']]

In [32]:
rc2

[['car'],
 ['car', 'park on', 'man'],
 ['car', 'several'],
 ['guitar'],
 ['hat'],
 ['man'],
 ['man', 'hold', 'guitar'],
 ['man', 'on', 'sidewalk'],
 ['man', 'ride', 'skateboard'],
 ['man', 'wear', 'hat'],
 ['sidewalk'],
 ['sidewalk', 'next to', 'street'],
 ['skateboard'],
 ['street'],
 ['street', 'with', 'car']]

In [15]:
i = 4     # Problem in i=4
ipc = ipc_data[i]
sg = get_sc_graph(ipc['image_id'])
rc = recall_paragraph_sg(ipc['paragraph'],sg)

In [28]:
process_all_ipc(ipc_data)    


idx 0, image_id: 2356347 already exists. Moving on.
idx 1, image_id: 2317429 already exists. Moving on.
idx 2, image_id: 2414610 already exists. Moving on.
idx 3, image_id: 2365091 already exists. Moving on.
idx 4, image_id: 2383120 already exists. Moving on.
idx 456, image_id: 2349250 already exists. Moving on.
idx 3035, image_id: 2340853 already exists. Moving on.
idx 3446, image_id: 2408653 already exists. Moving on.
idx 6675, image_id: 2412652 already exists. Moving on.
idx 7090, image_id: 2363167 already exists. Moving on.
idx 7152, image_id: 2377108 already exists. Moving on.
idx 8003, image_id: 2385924 already exists. Moving on.
idx 8307, image_id: 2414810 already exists. Moving on.
idx 8469, image_id: 2370496 already exists. Moving on.
idx 9207, image_id: 2327869 already exists. Moving on.


In [None]:
ipc_triplets = spice_get_triplets(ipc_data[4]['paragraph'])

In [None]:
ipc_data[4]['paragraph']

In [None]:
z1 = [recall_triplet(x,ipc_triplets, method='spacy') for x in rel_triplets]
z2 = [recall_triplet(x,ipc_triplets, method='wordnet') for x in rel_triplets]
z3 = [recall_triplet(x,ipc_triplets, method='bert') for x in rel_triplets]

In [None]:
np.mean(z1), np.mean(z2), np.mean(z3)

In [62]:
def process_recall(methods=('wordnet', 'wordnet', 'bert')):
    query = 'FOR doc IN {} RETURN doc'.format(IPC_COLLECTION)
    cursor = pipeline.db.aql.execute(query)
    all_ipc = list(cursor)
    for ipc in all_ipc:
        if check_image_id_in_collection(ipc['image_id'],RECALL_COLLECTION):
            print('IPC index {}/Image id {} already exists. Moving on.'.format(ipc['idx'],ipc['image_id']))
            continue
        sg = get_sc_graph(ipc['image_id'])
        recall = total_recall_triplets(list(map(triplet_from_rel,sg.relationships)), rc[1]['triplets'])
        mean_recall = np.mean(recall) if recall else 0.
        recall_doc = {
            'ipc_idx': ipc['idx'],
            'image_id': ipc['image_id'],
            'recall_methods': list(methods),
            'recall': recall,
            'recall_mean': mean_recall
        }
        print('Writing index {}/Image id {}.'.format(ipc['idx'],ipc['image_id']))
        query = "INSERT {} INTO {}".format(recall_doc,RECALL_COLLECTION)
        cursor = pipeline.db.aql.execute(query)         



In [63]:
process_recall()

IPC index 0/Image id 2356347 already exists. Moving on.
IPC index 1/Image id 2317429 already exists. Moving on.
IPC index 2/Image id 2414610 already exists. Moving on.
IPC index 3/Image id 2365091 already exists. Moving on.
IPC index 4/Image id 2383120 already exists. Moving on.
IPC index 5/Image id 2333990 already exists. Moving on.
IPC index 6/Image id 2388203 already exists. Moving on.
IPC index 7/Image id 2338364 already exists. Moving on.
IPC index 8/Image id 2410301 already exists. Moving on.
IPC index 9/Image id 2404368 already exists. Moving on.
IPC index 10/Image id 2388350 already exists. Moving on.
IPC index 11/Image id 2417454 already exists. Moving on.
IPC index 12/Image id 2401749 already exists. Moving on.
IPC index 13/Image id 2413148 already exists. Moving on.
IPC index 14/Image id 2401672 already exists. Moving on.
IPC index 15/Image id 2396483 already exists. Moving on.
IPC index 16/Image id 2354861 already exists. Moving on.
IPC index 17/Image id 2378834 already exi

In [60]:
sg = get_sc_graph(2392613)

In [61]:
len(sg.relationships)

0

In [43]:
recall_paragraph_sg(rc[1]['paragraph'],sg)

[0.72258914,
 0.34141386,
 0.34141386,
 0.72258914,
 0.72258914,
 0.72258914,
 0.72258914,
 0.72258914,
 0.72258914,
 0.72258914,
 0.34141386,
 0.34141386,
 0.34141386,
 0.34141386,
 0.34141386,
 0.7458614,
 0.7458614,
 0.7458614,
 0.7458614,
 0.34141386,
 0.34141386,
 0.34141386,
 0.7458614,
 0.34141386,
 0.34141386,
 0.34141386,
 0.34141386,
 0.7458614,
 0.72258914,
 0.34141386,
 0.5867421,
 0.7458614,
 0.7458614,
 0.34141386,
 0.34141386,
 0.34141386,
 0.72258914]