In [3]:
from SGG_Benchmark_main.demo.demo_model import SGG_Model
import cv2
import numpy as np
from transformers import AutoTokenizer, AutoModel




import torch
import torch.nn.functional as F

def get_pre(source):

    config_path = "output_motifs/config.yml"
    dict_path = "VG_dicts.json"
    weights_path = "output_motifs/best_model_epoch_4.pth"

    example_img = source  

    img = cv2.imread(example_img)
    img=cv2.resize(img,(1024,1024))

    model = SGG_Model(config_path, dict_path, weights_path, rel_conf=0.01, box_conf=0.1, show_fps=False)

    predictions = model.predict(img)

    return predictions

def SGG(img_path):
    """SGG, input image path, return processed result"""

    predictions=get_pre(img_path)

    predictions=deal_SGG_result(predictions)

    return predictions

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def get_box_pos(bbox,orig_size=[1024,1024]):
    """
    inputs:
        bbox xywh
        orig_size  x,y 
    outputs:
        Position of the box in the image
        top left corner, top, top right corner
        left, center, right
        lower left corner, lower,lower right corner
        012
        345
        678
           
    """
    center=((bbox[0]+bbox[2])/2,(bbox[1]+bbox[3])/2)    #
    #center=(bbox[0],bbox[1])

    if center[0]<orig_size[0]/3 and center[1]<orig_size[1]/3:
        return 'top left corner'
    if center[0]>=orig_size[0]/3 and center[0]<=2*orig_size[0]/3 and center[1]<orig_size[1]/3:
        return 'top'
    if center[0]>2*orig_size[0]/3 and center[1]<orig_size[1]/3:
        return 'top right corner'
    if center[0]<orig_size[0]/3 and center[1]>=orig_size[1]/3 and center[1]<=2*orig_size[1]/3:
        return 'left'
    if center[0]>=orig_size[0]/3 and center[0]<=2*orig_size[0]/3 and center[1]>=orig_size[1]/3 and center[1]<=2*orig_size[1]/3:
        return 'center'
    if center[0]>2*orig_size[0]/3 and center[1]>=orig_size[1]/3 and center[1]<=2*orig_size[1]/3:
        return 'right'
    if center[0]<orig_size[0]/3 and center[1]>2*orig_size[1]/3:
        return 'lower left corner'
    if center[0]>=orig_size[0]/3 and center[0]<=2*orig_size[0]/3 and center[1]>2*orig_size[1]/3:
        return 'lower'
    if center[0]>2*orig_size[0]/3 and center[1]>2*orig_size[1]/3:
        return 'lower right corner'


def fix(predictions,probs=0.15):
    fix=np.array(predictions['bbox_scores'])>probs

    bbox=np.array(predictions['bbox'])[fix]
    bbox_labels=np.array(predictions['bbox_labels'])[fix]

    bbox_idx=bbox_labels.shape[0]

    rel_pairs=[]
    rel_labels=[]

    for i,pair in enumerate(predictions['rel_pairs'][:50]):
        if pair[0]<bbox_idx and pair[1]<bbox_idx:
            rel_pairs.append(pair)
            rel_labels.append(predictions['rel_labels'][i])

    return list(bbox),list(bbox_labels),rel_pairs,rel_labels

def deal_SGG_result(predictions):


    bbox,bbox_labels,rel_pairs,rel_labels=fix(predictions)

    
    object_count={}
    #for i in predictions['bbox_labels']:
    for i in bbox_labels:
        if i in object_count.keys():
            object_count[i]+=1
        else:
            object_count[i]=1

    location_dict={key:[] for key in object_count.keys()}

    #for i in range(len(predictions['bbox'])):
    for i in range(len(bbox)):
        #location_dict[predictions['bbox_labels'][i]].append(get_box_pos(predictions['bbox'][i],shape))
        location_dict[bbox_labels[i]].append(get_box_pos(bbox[i]))


    relation_dict=[]

    #for i in range(len(predictions['rel_pairs'])):
    #    relation_dict.append([predictions['bbox_labels'][predictions['rel_pairs'][i][0]],predictions['rel_labels'][i],predictions['bbox_labels'][predictions['rel_pairs'][i][1]]])
    for i in range(len(rel_pairs)):
        relation_dict.append([bbox_labels[rel_pairs[i][0]],rel_labels[i],bbox_labels[rel_pairs[i][1]]])

    deal_result={"number":object_count,"location":location_dict,"relationship":relation_dict}
    
    block=[]

    for i in deal_result['number'].keys():            
        dicts={i:{'number':deal_result['number'][i],'location':deal_result['location'][i]}}

        rel_list=[]
        for j in deal_result['relationship']:
            if j[0]==i or j[2]==i:
                rel_list.append(j)

        dicts[i]['relationship']=rel_list
        block.append(dicts)


    return block


In [5]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

def get_embed(sentences,model,tokenizer):
    """Calculate sentence encoding"""


    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')


    with torch.no_grad():
        model_output = model(**encoded_input)


    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])


    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

    return sentence_embeddings



In [6]:
def compute_cosine_simi(vector,vector_list):
    """Calculate the cosine similarity between the question and SGG
    Vector: Question corresponds to embedded vector
    vector_list: The result of SGG
    """
    return F.cosine_similarity(vector,vector_list)

In [7]:
def query_part(query,model,tokenizer):
    """
    Receive user questions and convert them into embeddings
    """
    return get_embed(query,model,tokenizer)

def img_part(img_path,model,tokenizer):

    SGG_result=SGG(img_path)

    sgg_embed=get_embed([list(i.keys())[0] for i in SGG_result],model,tokenizer)

    return SGG_result,sgg_embed




In [8]:
def retrieval(query_embed,SGG_result,SGG_embed):
    """
    Retrieve the most relevant information to the query,
    """
    simi=compute_cosine_simi(query_embed,SGG_embed)
    sort_simi,sort_index=torch.sort(simi,descending=True)

    if len(sort_index)<4:
        return [SGG_result[i] for i in sort_index]
    else:
        return [SGG_result[i] for i in sort_index[:4]]
    

In [9]:
def get_prompt(query,retrieval_result):
    """
    get prompt
    """
    prompt="The key information is represented by a list, each item is a dictionary\
 containing three items: number represents the quantity of the object, location\
 represents the position of the object, and relationship is a list, each item represents its relationship with other objects\
. Now please follow the following requirements: Please integrate all the information provided to you and use logical reasoning to answer the question as briefly as possible. If the information provided to you is not relevant to the question, please answer 'I don't know',Pretend that the message you received is a picture\
and based on the information:{} in the picture,please answer the following\
 questions:{}".format(str(retrieval_result),query)

    return prompt

In [10]:
from openai import OpenAI
import json
import os

def get_response(prompt):
    client = OpenAI(
        api_key="", # add your keys
        base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", 
    )


    completion = client.chat.completions.create(
        model="qwen-max",
        messages=[
            {'role': 'system', 'content': 'You are an assistant specialized in processing images,I will give you some information as image input, please pretend that you have received the image'},
            {'role': 'user', 'content': prompt}],
        seed=2
        )
    return completion

In [11]:
def Accurate_VQA(query,img):
    """
    input: query: Your Question
    img: Your image path
    """
    query_embed=query_part(query,model,tokenizer)
    SGG_result,SGG_embed=img_part(img,model,tokenizer)
    retrieval_result=retrieval(query_embed,SGG_result,SGG_embed)
    prompt=get_prompt(query,retrieval_result)
    response=get_response(prompt)

    return response.choices[0].message.content

In [12]:
Accurate_VQA("How many man?",'JPEGImages\\598.jpg')

loading word vectors from .\glove.6B.200d.pt
loading word vectors from .\glove.6B.200d.pt


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


'There is 1 man in the picture.'