In [1]:
import sys
sys.path.insert(0, "/notebooks/pipenv")
sys.path.insert(0, "/notebooks/nebula3_database")
sys.path.insert(0, "/notebooks/")
from PIL import Image
import requests
import visual_genome.local as vg
import json
import copy
import subprocess

import numpy as np
import torch
import spacy
import nltk
import openai
from spacy_wordnet.wordnet_annotator import WordnetAnnotator 
from sentence_transformers import SentenceTransformer
from database.arangodb import DatabaseConnector
from config import NEBULA_CONF


In [2]:
nltk.download('wordnet')
nlp = spacy.load('en_core_web_lg')
nlp.add_pipe("spacy_wordnet", after='tagger', config={'lang': nlp.lang})

with open('/storage/keys/openai.key','r') as f:
    OPENAI_API_KEY = f.readline().strip()
openai.api_key = OPENAI_API_KEY

VG_DATA = '/storage/vg_data'
IPC_COLLECTION = 'ipc_relations_spice'
RECALL_COLLECTION = 'ipc_recall_spice'
GLOBAL_TOKENS_COLLECTION = 's3_global_tokens'

class PIPELINE:
    def __init__(self):
        config = NEBULA_CONF()
        self.db_host = config.get_database_host()
        self.database = config.get_playground_name()
        self.gdb = DatabaseConnector()
        self.db = self.gdb.connect_db(self.database)

pipeline = PIPELINE()
def get_sc_graph(id):
    return vg.get_scene_graph(id, images=VG_DATA,
                    image_data_dir=VG_DATA+'/by-id/',
                    synset_file=VG_DATA+'/synsets.json')


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [56]:
class GTBaseGenerator:
    def __init__(self):
        self.pipeline = PIPELINE()
        self.ipc_data = json.load(open('/storage/ipc_data/paragraphs_v1.json','r'))
        self.global_captioner = 'blip'
        self.global_tagger = 'blip'
        self.places_source = 'blip'
        self.global_prompt1 = '''Caption of image: {}
This image is taking place in: {}
Tags: This image is about {}
Describe this image in detail:'''

    def get_image_id_from_collection(self, id,collection=GLOBAL_TOKENS_COLLECTION):
        results = {}
        query = 'FOR doc IN {} FILTER doc.image_id == {} RETURN doc'.format(collection,id)
        #print(query)
        cursor = self.pipeline.db.aql.execute(query)
        for doc in cursor:
            results.update(doc)
        return results
    
    def get_structure(self, id):
        sg = get_sc_graph(id)
        global_doc = self.get_image_id_from_collection(id)
        if not global_doc:
            print("Couldn't find global tokens for id {}".format(id))
            return
        rc_doc = {
            'image_id': id,
            'url': sg.image.url            
        }
        for (k,v) in global_doc.items():
            if k.startswith('global'):
                rc_doc[k]=copy.copy(v)
        rois = []
        for obj in sg.objects:
            obj_dic = {
                'GT': list(zip(obj.names,[1.0]*len(obj.names)))
            }
            attr_dic = {
                'GT': list(zip(obj.attributes,[1.0]*len(obj.attributes)))
            }
            obj_doc = {                
                'objects': obj.names,
                'attributes': obj.attributes,
                'bbox': [obj.x, obj.y, obj.x+obj.width, obj.y+obj.height]              
                }
            rois.append(obj_doc)
        rc_doc['rois']=rois

        return rc_doc

    def get_prompt(self, id, include_answer=False):
        base_doc = self.get_structure(id)
        if base_doc == None:
            return
        caption = base_doc['global_captions'][self.global_captioner]
        all_objects = sorted(base_doc['global_objects'][self.global_tagger], key=lambda x: -float(x[1]))
        all_persons = sorted(base_doc['global_persons'][self.global_tagger], key=lambda x: -float(x[1]))
        all_places = sorted(base_doc['global_scenes'][self.places_source], key=lambda x: -float(x[1]))
        print("Caption: {}".format(caption))
        print("Objects: ")
        print(all_objects[:5])
        print("Places:")
        print(all_places[:5])
        print("Persons:")
        print(all_persons[:5])
        objects = '; '.join([x[0] for x in all_objects[:5]])
        personds = '; '.join([x[0] for x in all_persons[:5]])
        places = ' or '.join([x[0] for x in all_places[:3]])
        prompt_before_answer = self.global_prompt1.format(caption,places,objects)
        if include_answer:
            [answer] = [x['paragraph'] for x in self.ipc_data if x['image_id']==id]
            final_prompt = prompt_before_answer+" "+answer
        else:
            final_prompt = prompt_before_answer
        return final_prompt


In [57]:
base_gen = GTBaseGenerator()

In [58]:
id = 2348389
rc = base_gen.get_prompt(id, include_answer=True)
print(rc)

Caption: a person riding a horse in a field of grass and trees
Objects: 
[['horsewoman', '0.29960856'], ['riding', '0.29438442'], ['horseback', '0.2918341'], ['horseman', '0.29157773'], ['clydesdale', '0.2803259']]
Places:
[['pasture', '0.2662959'], ['tree farm', '0.22102167'], ['forest path', '0.21471532'], ['vegetation', '0.21412508'], ['racecourse', '0.21341668']]
Persons:
[['Woman wears a red-and-white', '0.34333515'], ['A red and white jacket', '0.33720917'], ['Red and white jacket', '0.33632693'], ['Woman has on a large', '0.32926083'], ['A red and white', '0.32675546']]
Caption of image: a person riding a horse in a field of grass and trees
This image is taking place in: pasture or tree farm or forest path
Tags: This image is about horsewoman; riding; horseback; horseman; clydesdale
Describe this image in detail: In this image a person wearing a red top and black pants is riding a brown and white horse.  The person appears to be young.  The horse is walking slowly with one hoof 

In [19]:
rc = base_gen.get_structure(id)
rc['url']

'https://cs.stanford.edu/people/rak248/VG_100K/2348389.jpg'

In [37]:
' or '.join(['gil', 'dan', 'moshe'])

'gil or dan or moshe'

In [None]:
rc = vg.get_all_region_descriptions(data_dir=VG_DATA)


In [None]:
list(zip([],[1.0]*0))