In [1]:
import sys
sys.path.insert(0, "/notebooks/pipenv")
from PIL import Image
import requests
import visual_genome.local as vg
import json
import copy
import subprocess

import numpy as np
import torch
import spacy
import nltk
from spacy_wordnet.wordnet_annotator import WordnetAnnotator 
from sentence_transformers import SentenceTransformer


In [2]:
nltk.download('wordnet')
nlp = spacy.load('en_core_web_lg')
nlp.add_pipe("spacy_wordnet", after='tagger', config={'lang': nlp.lang})

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.


<spacy_wordnet.wordnet_annotator.WordnetAnnotator at 0x7f5a0effba30>

In [3]:
VG_DATA = '/storage/vg_data'

In [4]:
def cosine_sim(x,y):
    return np.dot(x,y) / (np.linalg.norm(x)*np.linalg.norm(y))

def compare_cross_lists(l1, l2):
    return np.any([x in l2 for x in l1])

class SimilarityManager:
    def __init__(self):
        self.nlp = nlp
        self.similarity_model = SentenceTransformer('sentence-transformers/paraphrase-xlm-r-multilingual-v1')
        self.similarity_model.cuda()

    def similarity(self, src, target):
        rc = []
        s1 = self.nlp(src)
        s2 = self.nlp(target)
        for w in s1:
            if w.pos_ not in ['NOUN', 'ADJ', 'ADV', 'VERB', 'PROPN'] and len(s1)>1:
                continue
            rc.append(max([w.similarity(x) for x in s2]))
        return np.mean(rc)
    
    def compare_cross_synsets(self, text1, text2):
        t1 = self.nlp(text1)
        t2 = self.nlp(text2)
        return compare_cross_lists([x._.wordnet.synsets() for x in t1], [x._.wordnet.synsets() for x in t2])
    
    def compare_triplet(self, t1, t2, method='spacy'):
        if len(t1) != len(t2):
            return 0.
        sim = 1.
        if method=='bert':
            embs = self.similarity_model.encode([' '.join(t1).lower(), ' '.join(t2).lower()])
            sim = cosine_sim(*embs)
        else:
            for x,y in zip(t1,t2):
                if method=='wordnet':
                    sim *= self.compare_cross_synsets(x,y)
                elif method=='spacy':
                    sim *= self.similarity(x,y)
                else:
                    print("Unknown similarity method: {}".format(method))
        return sim

        
smanager = SimilarityManager()

Downloading:   0%|          | 0.00/345 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.74k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/718 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/150 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/9.10M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/550 [00:00<?, ?B/s]

In [None]:
def triplet_from_rel(rel):
    return (rel.subject.names[0], rel.predicate, rel.object.names[0])

In [None]:
def get_sc_graph(id):
    return vg.get_scene_graph(id, images=VG_DATA,
                    image_data_dir=VG_DATA+'/by-id/',
                    synset_file=VG_DATA+'/synsets.json')
freeze_dict = lambda d: tuple((k, d[k]) for k in sorted(d.keys()))
rel_to_triplet = lambda rel: (rel['subject'].id, rel['predicate'], rel['object'].id)

In [None]:
# vg.add_attrs_to_scene_graphs(data_dir=VG_DATA)
# vg.save_scene_graphs_by_id(data_dir=VG_DATA, image_data_dir=VG_DATA+'/by-id/')

In [None]:
ipc_data = json.load(open('/storage/ipc_data/paragraphs_v1.json','r'))

In [None]:
len(ipc_data)

In [None]:
# src: A single triplet
# dst: A list of triplets

def recall_triplet(src, dst, **kwargs):
    scores = [smanager.compare_triplet(src,x, **kwargs) for x in dst]
    return max(scores)

#src: A list of triplets
#dst: A list of triplets
def recall_triplets(src, dst, **kwargs):
    rc = [recall_triplet(x,dst, **kwargs) for x in src]
    return rc
    # return np.mean(rc)

def recall_paragraph_sg(paragraph, sg, methods=('wordnet', 'wordnet', 'bert')):
    ipc_triplets = spice_get_triplets(paragraph)
    rel_triplets = list(map(triplet_from_rel,sg.relationships))
    total_recall = []
    for i in [1,2,3]:
        ipc_i = [x for x in ipc_triplets if len(x)==i]
        rel_i = [x for x in rel_triplets if len(x)==i]
        total_recall.extend(recall_triplets(rel_i,ipc_i,method=methods[i-1]))
    return total_recall    

In [None]:
def produce_pair(ipc_num: int):
    ipc = ipc_data[ipc_num]
    sg = get_sc_graph(ipc['image_id'])
    print("Paragraph is:")
    print(ipc['paragraph'])
    for rel in sg.relationships:
        print("Processing: "+str(rel))
        print("Subject is: {}".format(rel.subject.names))
        sim = smanager.similarity(rel.subject.names[0],ipc['paragraph'])
        print("Similarity: {}".format(sim))
        

In [None]:
def spice_get_triplets(text):
    INP_FNAME = '/tmp/example.json'
    OUT_FNAME = '/tmp/example_output.json'
    inp = {
        'image_id': 1,
        'test': "",
        'refs': [text],        
    }
    json.dump([inp],open(INP_FNAME,'w'))
    p = subprocess.Popen('java -Xmx8G -jar /notebooks/spice_bin/SPICE-1.0/spice-1.0.jar {} -detailed -silent -subset -out {}'.format(INP_FNAME,OUT_FNAME),shell=True,
                        stdin=subprocess.PIPE,
                        stdout=subprocess.PIPE,
                        stderr=subprocess.PIPE)
    p.communicate()
    outp = json.load(open(OUT_FNAME,'r'))
    return [x['tuple'] for x in outp[0]['ref_tuples']]
    

In [None]:
rc1 = spice_get_triplets("The quick brown fox jumped over the blue fence")
rc2 = spice_get_triplets("a slow white fox jumped over the red fence")

In [None]:
[recall_triplet(x,rc2, use_synsets=True) for x in rc1]

In [None]:
i = 8     # Problem in i=4
ipc = ipc_data[i]
sg = get_sc_graph(ipc['image_id'])
rc = recall_paragraph_sg(ipc['paragraph'],sg)

In [None]:
np.mean(rc)

In [None]:
ipc_triplets = spice_get_triplets(ipc_data[4]['paragraph'])

In [None]:
ipc_data[4]['paragraph']

In [None]:
z1 = [recall_triplet(x,ipc_triplets, method='spacy') for x in rel_triplets]
z2 = [recall_triplet(x,ipc_triplets, method='wordnet') for x in rel_triplets]
z3 = [recall_triplet(x,ipc_triplets, method='bert') for x in rel_triplets]

In [None]:
np.mean(z1), np.mean(z2), np.mean(z3)

In [None]:
[x for x in ipc_triplets if len(x)==3]

In [None]:
rels = copy.deepcopy([x.__dict__ for x in sg.relationships])

In [None]:
for r in rels:
    del r['id']

In [None]:
len(set([rel_to_triplet(x) for x in rels]))