In [1]:
import sys
sys.path.insert(0, "/notebooks/pipenv")
sys.path.insert(0, "/notebooks/nebula3_database")
sys.path.insert(0, "/notebooks/")
from PIL import Image
import requests
import visual_genome.local as vg
import json
import copy
import subprocess

import numpy as np
import torch
import spacy
import nltk
import openai
from spacy_wordnet.wordnet_annotator import WordnetAnnotator 
from sentence_transformers import SentenceTransformer
from database.arangodb import DatabaseConnector
from config import NEBULA_CONF


In [2]:
nltk.download('wordnet')
nlp = spacy.load('en_core_web_lg')
nlp.add_pipe("spacy_wordnet", after='tagger', config={'lang': nlp.lang})

with open('/storage/keys/openai.key','r') as f:
    OPENAI_API_KEY = f.readline().strip()
openai.api_key = OPENAI_API_KEY

VG_DATA = '/storage/vg_data'
IPC_COLLECTION = 'ipc_relations_spice'
RECALL_COLLECTION = 'ipc_recall_spice'
GLOBAL_TOKENS_COLLECTION = 's3_global_tokens'

class PIPELINE:
    def __init__(self):
        config = NEBULA_CONF()
        self.db_host = config.get_database_host()
        self.database = config.get_playground_name()
        self.gdb = DatabaseConnector()
        self.db = self.gdb.connect_db(self.database)

pipeline = PIPELINE()
def get_sc_graph(id):
    return vg.get_scene_graph(id, images=VG_DATA,
                    image_data_dir=VG_DATA+'/by-id/',
                    synset_file=VG_DATA+'/synsets.json')

def get_all_s3_ids():
    results = {}
    query = 'FOR doc IN {} RETURN doc.image_id'.format(GLOBAL_TOKENS_COLLECTION)
    cursor = pipeline.db.aql.execute(query)
    return [doc for doc in cursor]

s3_ids = get_all_s3_ids()

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [30]:
class GTBaseGenerator:
    def __init__(self):
        self.pipeline = PIPELINE()
        self.ipc_data = json.load(open('/storage/ipc_data/paragraphs_v1.json','r'))
        self.global_captioner = 'blip'
        self.global_tagger = 'blip'
        self.places_source = 'blip'
        self.global_prompt1 = '''Caption of image: {}
This image is taking place in: {}
Tags: This image is about {}
Describe this image in detail:'''

    def get_image_id_from_collection(self, id,collection=GLOBAL_TOKENS_COLLECTION):
        results = {}
        query = 'FOR doc IN {} FILTER doc.image_id == {} RETURN doc'.format(collection,id)
        cursor = self.pipeline.db.aql.execute(query)
        for doc in cursor:
            results.update(doc)
        return results
    
    def get_structure(self, id):
        sg = get_sc_graph(id)
        global_doc = self.get_image_id_from_collection(id)
        if not global_doc:
            print("Couldn't find global tokens for id {}".format(id))
            return
        rc_doc = {
            'image_id': id,
            'url': sg.image.url            
        }
        for (k,v) in global_doc.items():
            if k.startswith('global'):
                rc_doc[k]=copy.copy(v)
        rois = []
        for obj in sg.objects:
            obj_dic = {
                'GT': list(zip(obj.names,[1.0]*len(obj.names)))
            }
            attr_dic = {
                'GT': list(zip(obj.attributes,[1.0]*len(obj.attributes)))
            }
            obj_doc = {                
                'objects': obj.names,
                'attributes': obj.attributes,
                'bbox': [obj.x, obj.y, obj.x+obj.width, obj.y+obj.height]              
                }
            rois.append(obj_doc)
        rc_doc['rois']=rois

        return rc_doc

    def get_prompt(self, id, include_answer=False):
        base_doc = self.get_structure(id)
        if base_doc == None:
            return
        caption = base_doc['global_captions'][self.global_captioner]
        all_objects = sorted(base_doc['global_objects'][self.global_tagger], key=lambda x: -float(x[1]))
        all_persons = sorted(base_doc['global_persons'][self.global_tagger], key=lambda x: -float(x[1]))
        all_places = sorted(base_doc['global_scenes'][self.places_source], key=lambda x: -float(x[1]))
        print("Caption: {}".format(caption))
        print("Objects: ")
        print(all_objects[:5])
        print("Places:")
        print(all_places[:5])
        print("Persons:")
        print(all_persons[:5])
        objects = '; '.join([x[0] for x in all_objects[:5]])
        personds = '; '.join([x[0] for x in all_persons[:5]])
        places = ' or '.join([x[0] for x in all_places[:3]])
        prompt_before_answer = self.global_prompt1.format(caption,places,objects)
        if include_answer:
            [answer] = [x['paragraph'] for x in self.ipc_data if x['image_id']==id]
            final_prompt = prompt_before_answer+" "+answer
        else:
            final_prompt = prompt_before_answer
        return final_prompt
        
base_gen = GTBaseGenerator()

In [27]:

def generate_gpt_prompt(ids, pgen=GTBaseGenerator()):
    rc = []
    for id in ids:
        rc.append(pgen.get_prompt(id,include_answer=True))
    return '\n'.join(rc)

In [18]:
s3_train, s3_test = np.split(np.array(s3_ids),[900])

In [28]:
rc = generate_gpt_prompt(np.random.choice(s3_train,5),pgen=base_gen)

KeyError: 1

In [25]:
s3_ids

[2351166,
 2346948,
 2383274,
 2414127,
 2378181,
 2370172,
 2370402,
 2351205,
 2385221,
 2387230,
 2383900,
 2373856,
 2343339,
 2413602,
 2322828,
 2377053,
 2364991,
 2367278,
 2389004,
 2395295,
 2381539,
 2406992,
 2328592,
 2363140,
 2351511,
 2385486,
 2316581,
 2415236,
 2346053,
 2323000,
 2360872,
 2407010,
 491,
 2413077,
 2379310,
 2415895,
 2370676,
 2382046,
 2357443,
 2374678,
 2351195,
 2357874,
 2405882,
 2362911,
 2325024,
 2412952,
 2369572,
 2374159,
 2325549,
 2354120,
 2413090,
 2349723,
 2322127,
 2404015,
 2365433,
 2400155,
 2350325,
 2348403,
 2323344,
 2328295,
 2352527,
 2373934,
 2383741,
 2403903,
 2375861,
 2398452,
 2359084,
 2325956,
 2376974,
 2368460,
 2388891,
 2365683,
 2370808,
 2352491,
 2353625,
 2349328,
 1256,
 2396757,
 2356999,
 2404544,
 2393538,
 2402973,
 2392432,
 387,
 2372845,
 2345047,
 2386436,
 2413100,
 2377779,
 2374836,
 2350810,
 2361606,
 2361372,
 2372024,
 2406617,
 2345652,
 2408865,
 2360280,
 2367052,
 2364956,
 2319587,
 

In [31]:
id = 2348389
rc = base_gen.get_prompt(id, include_answer=True)
print(rc)

KeyError: 1

In [None]:
rc = base_gen.get_structure(id)
rc['url']

In [None]:
rc = vg.get_all_region_descriptions(data_dir=VG_DATA)


In [None]:
list(zip([],[1.0]*0))