In [None]:
!pip install transformers
!pip install sentence-transformers
!pip install wandb
!pip install happytransformer
!pip install Levenshtein==0.20.3
!pip install ijson
!pip install spacy
!pip install spacy-transformers
!pip install lemminflect

In [None]:
!pip install inflect==5.6.2

In [None]:
import requests
from google.colab import drive
import json
import sys
import csv
import os, errno
import base64
import time
import re
import logging
import copy
from itertools import combinations, chain, product, groupby
from datetime import datetime
from timeit import default_timer as timer
import datetime as dt
from collections import OrderedDict, Counter
from json import JSONEncoder

import pickle
import spacy
import ijson
from PIL import Image
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt

from Levenshtein import distance as lev_dis
from typing import List, Dict, Tuple, Iterable, Type, Union, Callable, Optional
from scipy.special import softmax
from happytransformer import HappyTextToText, TTSettings
from transformers import DebertaForSequenceClassification, DebertaTokenizerFast
from sentence_transformers import SentenceTransformer
from torch import Tensor
import inflect
from lemminflect import getLemma, getInflection

import gensim.downloader as api

In [None]:
import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')
from nltk.corpus import wordnet as wn

In [None]:
!nvidia-smi

In [None]:
#use drive if your files are there
# drive.mount('', force_remount=True)

In [None]:
# inflect engine
inflect_eng = inflect.engine()

In [None]:
# logger
logger = None

In [None]:
mask_token = '<mask>'
semantic = ['rel', 'relation']
structural = ['query']

#Populate the paths

#path to Visual Genome Scene Graph Objects
sg_ds_path = '.../Visual-Genome/scene_graphs.json'
#path to Visual Genome images
path_to_images = '.../Visual-Genome/VG_100K'
#path to Visual Genome synset mapping
synset_mapping_path = '.../visual_genome_object_synsets.json'
#mass nouns
mass_nouns_path = '.../mass_nouns.txt'

#this file pre-processed to ease the pipeline. produced from VG files.
subj_obj_2_rels_path = ".../subject_object_to_relations.json"


### Inner Modules

Global Utils

In [None]:
def check_timing(f, *args, **kwargs):
    start = timer()
    ret = f(*args, **kwargs)
    end = timer()
    logger.debug(f"-- Function {f.__name__} Time taken: {end - start} sec")
    return ret


def image_id_to_scene_graph(path):
    ret = dict()
    start = timer()
    with open(path) as f:
        items = ijson.items(f, "item")
        for item in items:
            if "image_id" in item:
                img_id = item["image_id"]
                if img_id in ret:
                    print(f"{img_id} is already mapped")
                else:
                    ret[img_id] = item
    print(f"Time taken generating scene graph mapping: {timer() - start} sec")
    return ret


def search_replace(q, old, new, inject_det=None, replace_with_det=[]):
    from tokenizers.pre_tokenizers import Whitespace

    def join_w_offsets(t_w_o):
        from itertools import zip_longest
        ret = []
        for (t1, (s1, e1)), (t2, (s2, e2)) in zip_longest(t_w_o, t_w_o[1:], fillvalue=(None, (-1, -1))):
            ret.append(t1)
            if e1 + 1 == s2:
                ret.append(' ')

        return "".join(ret)


    def align_offset(new_start, t_w_o):
        ret = []
        if len(t_w_o) > 0:
            d = new_start - t_w_o[0][1][0]
            for t1, (s, e) in t_w_o:
                ret.append((t1, (s + d, e + d)))

        return ret

    def inject_at_index_(tokens_with_offset, word, i):
      if i > 0:
        one_token_before = tokens_with_offset[i-1]
        one_word_before = one_token_before[0]

        #replace
        if one_word_before in replace_with_det:
          gap = len(word) - len(one_word_before)
          one_token_before_start = one_token_before[1][0]
          det_token = (word, (one_token_before_start, one_token_before_start+len(word)))
          tokens_with_offset[i-1] = det_token

          # align
          for j in range(i, len(tokens_with_offset)):
            new_offset = tuple(v+gap for v in tokens_with_offset[j][1])
            tokens_with_offset[j] = (tokens_with_offset[j][0], new_offset)


        one_token_before = tokens_with_offset[i-1]
        one_word_before = one_token_before[0]


        #inject
        if one_word_before.lower() != word:
          one_token_before_end = one_token_before[1][1]
          added_gap = 1+len(word)
          det_token = (word, (one_token_before_end+1, one_token_before_end+added_gap))
          tokens_with_offset.insert(i, det_token)
          
          for j in range(i+1, len(tokens_with_offset)):
            new_offset = tuple(v+added_gap for v in tokens_with_offset[j][1])
            tokens_with_offset[j] = (tokens_with_offset[j][0], new_offset)
          
          return True

      return False

    def search_replace_tokenized(tokens_with_offset, old_t_w_o, new_t_w_o):

        old_t = [x[0] for x in old_t_w_o]
        window_size = len(old_t)
        for i in range(len(tokens_with_offset) - window_size + 1):
            window_tokens = [x[0] for x in tokens_with_offset[i: i + window_size]]

            if window_tokens == old_t:
                
                # print(tokens_with_offset)
                if isinstance(inject_det, str):

                  is_suc = inject_at_index_(tokens_with_offset, inject_det, i)

                  i += int(is_suc)
                
                # print(tokens_with_offset)


                bef = tokens_with_offset[:i]
                between = tokens_with_offset[i: i+window_size]
                rest = tokens_with_offset[i + window_size:]

                start_new = between[0][1][0]

                new = align_offset(start_new, new_t_w_o)

                if len(new) > 0:
                    start_rest = new[-1][1][1] + rest[0][1][0] - between[-1][1][1]
                else:
                    start_rest = bef[-1][1][1] + rest[0][1][0] - between[-1][1][1]

                rest = align_offset(start_rest, rest)

                return bef + new + rest

        return None

    tokens_with_offset = Whitespace().pre_tokenize_str(q)
    old_t_w_o = Whitespace().pre_tokenize_str(old)
    new_t_w_o = Whitespace().pre_tokenize_str(new)
    if len(old_t_w_o) > 0:
        new_tokens_with_offset = search_replace_tokenized(tokens_with_offset, old_t_w_o, new_t_w_o)

        if new_tokens_with_offset is not None:
            return join_w_offsets(new_tokens_with_offset)
        else:  # failed to find
            return None

    # will return the same question if there is nothing to replace or search
    return q

In [None]:
search_replace("What is the color of dog running on the grass?", "dog", "cat", inject_det='the')

In [None]:
search_replace("Who is wearing a shirt?", "shirt", "boarder", inject_det='the', replace_with_det=['a'])

Helpers

In [None]:
class ObjectDescription:
    def __init__(self, phrase, id, is_answer=False):
        self.id = id
        self.phrase = phrase
        self.is_answer = is_answer

    def __repr__(self):
        return self.phrase

    def __str__(self):
        return self.phrase

class SemanticEqOutput:
    def __init__(self, result: Union[bool, List[bool]], source_embeds=None, target_embeds=None, scores=None):
        self.result = result
        self.source_embeds = source_embeds
        self.target_embeds = target_embeds
        self.scores = scores

    def __bool__(self):
        return all(self.result)

    def __str__(self):
      return json.dumps({"result": self.result, "scores": self.scores})


class DumpingStorage:
    def __init__(self, buffer_size=3000):
        self.buffer_size = buffer_size
        self.storage = OrderedDict()
        self.stabilized = 0

    def __contains__(self, item):
        return item in self.storage

    def add(self, item, val):
        self.storage.update({item: val})
        self.stabilize()

    def get(self, item):
        return self.storage.get(item)

    def stabilize(self, keep_portion=0.75):
        if len(self.storage.keys()) > self.buffer_size:
            pairs = list(self.storage.items())
            unkeep = 1 - keep_portion
            pairs = pairs[int(unkeep * len(pairs)):]  # release the most earliest
            # pairs = random.sample(pairs, int(0.75 * len(pairs)))
            # self.storage = {k: v for k, v in pairs}
            self.storage = OrderedDict(pairs)
            self.stabilized += 1

In [None]:
class AugmentQuestion:
    def __init__(self, q, failure_type, failure_reason, origin_id,image_id, generation_type, q_gec=None, gec_dist=None, is_lemmatised_aug=None,failure_sub_type=None, properties={}):
        """
        q: The final question
        failure_type: failure type between [1,5]
        failure_reason: Describing the failure in plain language
        q_gec: in case where grammar error correction layer has been applied, question before layer
        gec_distance: levenshtein distance before and after GEC (grammar error correction) layer
        is_lemmatised_aug: if lemmatization used while augmenting
        generation_type: how the question was generated, options are: "from_question"/"pre_defined"
        origin_id: id of the object the question created from (in case of "from_questions" its question id, in case of "pre_defined" its image id)
        properties: additional info if needed
        """
        self.q = q
        self.q_gec = q_gec
        self.failure_type = failure_type
        self.failure_sub_type = failure_sub_type
        self.failure_reason = failure_reason
        self.gec_dist = gec_dist
        self.is_lemmatized_aug = is_lemmatised_aug
        self.generation_type = generation_type
        self.origin_id = origin_id
        self.image_id = image_id
        self.properties = properties

    def __repr__(self):
        return self.q

    def __str__(self):
        return self.q

In [None]:
# Mostly, we use sentence transformers while scanning scene graphs for triplets and compare them with a given triplet
# thus, encode of the same phrase/sentence will occur many times, instead of that we utilize memory instead of GPU/CPU usage.
class SentenceTransformerWithMemory:
    "Uses RAM in order to save embeddeings in memory instead of encoding them each time"

    def __init__(self, model_name):
        self.memory = DumpingStorage()
        self.model = SentenceTransformer(model_name)

    def encode(self, sentences: List[str], *args, **kwargs) -> Tensor:

        filtered_sentences = [sen for i, sen in enumerate(sentences) if sen not in self.memory]
        filtered_indices = [i for i, sen in enumerate(sentences) if sen not in self.memory]

        in_memory_indices = [i for i, sen in enumerate(sentences) if sen in self.memory]

        if len(filtered_sentences) > 0:
            embs = self.model.encode(filtered_sentences, *args, **kwargs)  # (len(filtered_sentences), 768)

            if not isinstance(embs, Tensor):
                embs = torch.tensor(embs)

            embs_lst = embs.tolist()
        else:
            embs_lst = []

        for s, emb in zip(filtered_sentences, embs_lst):  # update memory
            self.memory.add(s, emb)

        res = [None]*len(sentences)  # prepare output

        for i, emb in zip(filtered_indices, embs_lst):  # restore it to the right index
            res[i] = emb

        for i in in_memory_indices:  # get embeddings from memory
            s = sentences[i]
            res[i] = self.memory.get(s)

        if not all([x is not None for x in res]):
          return self.model.encode(sentences, *args, **kwargs)

        return torch.tensor(res)

### Helpful Maps 
(imgid->scene_graph), (subject_object -> relations), (obj -> synset)

In [None]:
mass_nouns = set([x.strip().lower() for x in open(mass_nouns_path).readlines()])

In [None]:
mass_nouns.update(['floor','street', 'scene'])

In [None]:
list(mass_nouns)[:30]

In [None]:
with open(synset_mapping_path) as nf:
  synset_mapping = json.load(nf)

In [None]:
synset_mapping['man']

In [None]:
synset_mapping['people']

In [None]:
##subject_object_relation mapping - load into memory (Note: This is ~300k length hash table)
s_o_2_relations = {}
with open(subj_obj_2_rels_path) as f:
    s_o_2_relations = json.load(f)

In [None]:
s_o_2_relations['shoes_man']

In [None]:
s_o_2_relations['building_sign']

In [None]:
imgid_to_scene_graph = image_id_to_scene_graph(sg_ds_path)

In [None]:
len(s_o_2_relations), len(imgid_to_scene_graph)

### Models

In [None]:
!python -m spacy download en_core_web_trf

In [None]:
!python -m spacy download en_core_web_sm

In [None]:
# semantic similarity model fine tuned model (paraphrase-mpnet-base-v2)

#path to fine tuned model of accommodation metric
path_to_ft_sen_trans = ''



# sen_transformer_name = 'paraphrase-mpnet-base-v2'
sen_transformer = SentenceTransformerWithMemory(path_to_ft_sen_trans)

# mnli model
mnli_model_name = "microsoft/deberta-base-mnli"
mnli_tokenizer = DebertaTokenizerFast.from_pretrained(mnli_model_name)
mnli_model = DebertaForSequenceClassification.from_pretrained(mnli_model_name)

# grammar model
grammar_transformer_name = "vennify/t5-base-grammar-correction"
happy_tt = HappyTextToText("T5", grammar_transformer_name)

# pos tag and dependency parser model
mini_nlp = spacy.load("en_core_web_sm", disable=['parser', 'ner'])  # improve performance

# pos tag and dependency parser model
nlp = spacy.load("en_core_web_trf", disable=['parser', 'ner'])  # improve performance

# object-object similarity
glove_model = api.load("glove-wiki-gigaword-300")

### Hyper Parameters

In [None]:
batch_size=32

## Methods

### Utils Methods

#### General utils functions

In [None]:
def merge_same_objects_in_similar_set(similar_set, delta=10):
  """
  Pre: Similar-Set length > 1
  Object similar set is a list of objects who have the same meaning,
  however, sometimes the annotators of VG scene graph annotate the same object multiple times.
  In order to avoid it, we iterate the similar set and merge the same objects by the following algorithm:

  while similar_set is not empty:
  1. Pop representive from the set and put it in a cluster
  2. For each object in the updated similar set (after pop):
    a. if its center is delta close (window-square of +- delta) to the center of cluster, add it to the cluster
    b. filter it from the similar set
  3. Add respresentive of the cluster to the new similar set
  """
  ss = copy.deepcopy(similar_set)
  
  if len(similar_set) > 1:
    new_set = []

    while len(ss) > 0:
      rep = ss.pop(0)
      cluster = [rep]

      filter = []
      for i, cand in enumerate(ss):
        x = cand['x']
        y = cand['y']

        lead_x = np.average([a['x'] for a in cluster])
        lead_y = np.average([a['y'] for a in cluster])
        #if the center of the cluster is delta close to the compared object, add it
        if abs(lead_x-x) <= delta and abs(lead_y-y) <= delta:
          filter.append(i)
          cluster.append(cand)
      
      ss = [o for i,o in enumerate(ss) if i not in filter]
      new_set.append(rep) # we take the first object in the cluster, cause it doesn't matter to us

    return new_set
  else:
    return ss

In [None]:
sim_set = [{'synsets': ['laptop.n.01'], 'h': 97, 'object_id': 1904014, 'names': ['laptop'], 'w': 117, 'attributes': ['pink'], 'y': 47, 'x': 266},{'synsets': ['laptop.n.01'], 'h': 97, 'object_id': 1904014, 'names': ['laptop'], 'w': 117, 'attributes': ['pink'], 'y': 40, 'x': 266},{"synsets": ["laptop.n.01"], "h": 97, "object_id": 1904014, "names": ["laptop"], "w": 117, "attributes": ["pink"], "y": 77, "x": 266}, {"synsets": ["laptop.n.01"], "h": 105, "object_id": 2056970, "names": ["laptop"], "w": 138, "attributes": ["pink"], "y": 73, "x": 266}, {"synsets": ["laptop.n.01"], "h": 109, "object_id": 2002723, "names": ["laptop"], "w": 119, "y": 75, "x": 268}, {'synsets': ['laptop.n.01'], 'h': 97, 'object_id': 1904014, 'names': ['laptop'], 'w': 117, 'attributes': ['pink'], 'y': 10, 'x': 20}]
merge_same_objects_in_similar_set(sim_set)

In [None]:
def is_mass_noun(noun):
  global mass_nouns

  if not isinstance(noun, str):
    raise Exception(f"{noun} argument is not string")

  return noun.lower() in mass_nouns

In [None]:
def transform_present_progressive(rel):

    # limit to only verbs
    doc = mini_nlp(rel)

    ret = []

    for t in doc:

      tok = t.lemma_
      if t.pos_ == "VERB": # we only inflect if it is verb, skipping prepositions

        pp_r_tuple = getInflection(t.lemma_, tag='VBG')

        if len(pp_r_tuple) > 0:
          tok = pp_r_tuple[0]

      
      ret.append(tok)

    return " ".join(ret)

In [None]:
transform_present_progressive('play on')

In [None]:
class ObjectEncoder(JSONEncoder):
  def default(self, o):
    return o.__dict__

In [None]:
def extract_names_from_VG_obj(obj):
  """Extract the names of an objects in visual genome object representation"""
  if 'names' in obj:
    names = obj['names']
  elif 'name' in obj:
    names = obj['name']
  else:
    names = None

  if isinstance(names, str):
    names = [names]

  return names

In [None]:
def extract_scene_graph_from_question(q_obj):
    image_id = int(q_obj["imageId"])
    if image_id in imgid_to_scene_graph:
        return imgid_to_scene_graph[image_id]
    
    return None

In [None]:
def from_object_id_to_string(objects, ids):
  ret_strings = [None] * len(ids)
  for i, id in enumerate(ids):
    for o in objects:
      if o["object_id"] == id:
          if "name" in o:
              ret_strings[i] = o["name"]
          elif "names" in o and len(o["names"]) > 0:
              ret_strings[i] = o["names"][0]
              
          break

    if all([x is not None for x in ret_strings]):
      return ret_strings

In [None]:
def from_object_id_to_object_obj(objects, ids):
  ret_objects = [None] * len(ids)
  for i, id in enumerate(ids):
    for o in objects:
      if o["object_id"] == id:
        ret_objects[i] = o
        break
    
    if all([x is not None for x in ret_objects]):
      return ret_objects

In [None]:
def chunks(lst, n):
    n = max(1, n)
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def extract_similar_words_from_batch(criterions_ans, batch):
    sim = set()
    for ans in criterions_ans:
        for i, member in enumerate(ans):
            if member: sim.add(batch[i])
    return sim

In [None]:
def inv_map(my_map):
  inv_map = {}
  for k,v in my_map.items():
    inv_map[v] = inv_map.get(v, []) + [k]

  return inv_map

In [None]:
def extract_triplets_from_scene_graph(scene_graph):

  relations = scene_graph["relationships"]
  objects = scene_graph["objects"]
  triplets = []

  for rel_obj in relations:
      try:
          r_tag = str.lower(rel_obj["predicate"])
          s_tag_id = rel_obj["subject_id"]
          o_tag_id = rel_obj["object_id"]
          s_o = from_object_id_to_string(objects, [s_tag_id, o_tag_id])
          if s_o is not None and len(s_o) == 2:
              s, o = s_o
              triplet = (s, r_tag, o)
              triplets.append(triplet)

      except Exception as e:
          logger.debug(e)
          continue

  return triplets

In [None]:
def generate_subjects_2_pair(scene_graph):
  """
  Generate dict of subjects in triplets which have more than one relation:
  e.g (dog, running on, grass), (dog, chasing, ball)

  s1 -> [(r1,o1), (r2, o2),...]
  """

  relations = scene_graph["relationships"]
  objects = scene_graph["objects"]
  d = {}

  for rel_obj in relations:
      r_str = str.lower(rel_obj["predicate"])
      s_id = rel_obj["subject_id"]
      o_id = rel_obj["object_id"]

      if s_id not in d:
        d[s_id] = []

      d[s_id].append((r_str, o_id)) 

  subj2pairs = {}

  for s_id, pair_lst in d.items():

    try:
        rs = [x[0] for x in pair_lst]
        o_ids = [s_id] + [x[1] for x in pair_lst]
        o_strs = from_object_id_to_string(objects, o_ids)
        if o_strs is not None and len(o_strs) == len(o_ids):
            s = o_strs[0]
            os = o_strs[1:]
            subj2pairs[s] = [(r,o) for r,o in zip(rs, os)]

    except Exception as e:
        logger.debug(e)
        # continue
        raise

  return subj2pairs

In [None]:
def extract_triplets_from_semantic_program(q_obj):
    """
    Sem_op structure: {"operation": string, "dependencies": number[], "argument": string
    :param q_obj:
    :return:
    """

    def remove_id(s, reg=r"\((\d+)\)"):
        """
        Split between a string and a reference id:
        e.g "chair (3605777)", "shopping bag (xxxxx)", "animal,eating from,s (2323212)"
        :param s:
        :param reg:
        :return:
        """
        match = re.search(reg, s)
        p = re.compile(reg)
        rest = p.sub('', s, 1).strip()
        if match is not None:
            return rest, int(match.groups()[0])
        else:
            return rest, -1

    relations = []
    fill = False
    try:
        for sem_op in reversed(q_obj["semantic"]):  # reverse iterating helpful to control compositional relations
            if fill and sem_op["operation"] == "select" and len(relations) > 0:
                argument = sem_op['argument']
                obj, obj_visual_genome_id = remove_id(argument)  # "chair (3605777)", "shopping bag (xxxxx)"
                relations[0] = tuple(
                    [ObjectDescription(obj, obj_visual_genome_id) if x == "FILL" else x for x in
                     relations[0]])  # replace
                fill = False
            elif sem_op["operation"] == "relate" or sem_op["operation"] == "verify rel":
                argument = sem_op['argument']
                a1, rel, a2 = argument.split(',')  # e.g "animal,eating from,s (2323212)"
                subject_or_object = q_obj["answer"] if a1 == "_" else a1  # replacing _ with answer
                a2, obj_visual_genome_id = remove_id(a2)

                fill_obj = ObjectDescription(subject_or_object, obj_visual_genome_id, a1 == "_")

                if fill and len(relations) > 0:  # compositional
                    relations[0] = tuple([fill_obj if x == "FILL" else x for x in relations[0]])  # replace
                    fill = False

                if a2 == "s":
                    relations.insert(0, (fill_obj, rel, "FILL"))
                    fill = True
                elif a2 == "o":
                    relations.insert(0, ("FILL", rel, fill_obj))
                    fill = True
                else:
                    raise Exception(f"Unknown secondary argument={a2}: {json.dumps(q_obj)}")
    except Exception as e:
        # logger.debug(f"Msg: {e}\nFailed to extract relation on data-point : {q_obj}")
        return None

    return relations

In [None]:
def compute_dist_over_relations(sub: str, obj: str, normalised=False) -> dict:
    """
    Compute dist from pre-processed mapping: subject_object -> (r, number)[]
    :param sub:
    :param obj:
    :param normalised:
    :return:
    """
    global s_o_2_relations
    key = "_".join([sub, obj])

    ret = dict([tuple(x) for x in (s_o_2_relations.get(key) or [])])

    if normalised:
        total = sum(ret.values())
        ret = dict([(k, v / total) for k, v in ret.items()])

    return ret

In [None]:
def classify_nli(premise: Union[List[str], str], hypothesis: Union[List[str], str], auto_extend=True):
    """
    Wrapper for nli classifier
    :param premise:
    :param hypothesis:
    :param auto_extend:
    :return:
    """
    # 0 - contradiction, 1 - neutral, 2 - entailment
    global mnli_model
    global mnli_tokenizer

    if isinstance(premise, str):
        premise = [premise]

    if isinstance(hypothesis, str):
        hypothesis = [hypothesis]

    mnli_model.eval()
    inputs = mnli_tokenizer(premise, hypothesis, return_tensors='pt', padding=True)
    outputs = mnli_model(**inputs)  # (batch_size, num_labels)
    return torch.argmax(outputs.logits, dim=-1).tolist()  # (batch_size,)

In [None]:
def batch_equal(target: Union[str, List[str]], source: Union[str, List[str]], auto_extension=True):
    if isinstance(target, str):
        target = [target]
    if isinstance(source, str):
        source = [source]

    if auto_extension:
        if len(target) < len(source):
            target = target + [target[-1]] * (len(source) - len(target))
        elif len(source) < len(target):
            source = source + [source[-1]] * (len(target) - len(source))

    return SemanticEqOutput(result=[x == y for x, y in zip(target, source)])

#### Split Object Similar Set

In [None]:
def split_object_similar_set(similar_set):
    """
    Split the given similar_set for two groups - pronouns or people/person and others
    e.g [(girl, lying on, grass), (they, lying on, grass),(can, lying on, grass), (dog, lying on, grass), (cat, lying on, grass)]
    interrogative = [(girl, lying on, grass), (they, lying on, grass)]
    other = [(can, lying on, grass), (dog, lying on, grass), (cat, lying on, grass)]
    """
    interrogative_synsets_names = ['person.n.01', 'homo.n.02', 'people.n.01']
    interrogative_synsets = [wn.synset(x) for x in interrogative_synsets_names]
    pronouns = set(['he','she','him','i','you','we','they'])

    def is_pronoun(x):
      return x.lower() in pronouns

    interrogative = []
    other = []
    for o in similar_set:

      names = set(extract_names_from_VG_obj(o))

      is_inter_pronoun = any([is_pronoun(x) for x in names])

      is_inter_synsets = False

      if 'synsets' in o and len(o['synsets']) > 0:
        
        synset = wn.synset(o['synsets'][0])
        
        for s in interrogative_synsets:
          common_hypernyms = synset.lowest_common_hypernyms(s)
          if len(common_hypernyms) > 0:
            is_inter_synsets = is_inter_synsets or synset.lowest_common_hypernyms(s)[0].name() == s.name()

      if is_inter_pronoun or is_inter_synsets:
        interrogative.append(o)
      else:
        other.append(o)

    return interrogative, other

In [None]:
sgc = imgid_to_scene_graph[2377379]
stt = from_object_id_to_object_obj(sgc['objects'], [565924, 565935])
print(split_object_similar_set(stt))
stt2 = [{
  'synsets': ['water.n.01'],
  'h': 65,
  'object_id': 565924,
  'names': ['water'],
  'w': 39,
  'attributes': ['drops', 'droplets', 'drop'],
  'y': 144,
  'x': 61
  },{
  'synsets': ['food.n.01'],
  'h': 74,
  'object_id': 565935,
  'names': ['food'],
  'w': 39,
  'y': 132,
  'x': 57
  },{
  'synsets': ['girl.n.01'],
  'h': 74,
  'object_id': 565935,
  'names': ['girl'],
  'w': 39,
  'y': 132,
  'x': 57},
  {
  'synsets': ['child.n.01'],
  'h': 74,
  'object_id': 565935,
  'names': ['kid'],
  'w': 39,
  'y': 132,
  'x': 57},
  {
  'synsets': ['man.n.01'],
  'h': 74,
  'object_id': 565935,
  'names': ['man'],
  'w': 39,
  'y': 132,
  'x': 57},
  {
  'synsets': ['man.n.01'],
  'h': 74,
  'object_id': 565935,
  'names': ['man'],
  'w': 39,
  'y': 132,
  'x': 57},
  {
  'synsets': [],
  'h': 74,
  'object_id': 565935,
  'names': ['he'],
  'w': 39,
  'y': 132,
  'x': 57},
  ]

pack = split_object_similar_set(stt2)
assert len(pack[0]) == 5 and len(pack[1]) == 2
pack

#### Similarity functions

In [None]:
def bert_sentence_similarity_batched(s1: Union[str, List[str]], s2: Union[str, List[str]], auto_extension=True,
                                     threshold=0.83, return_scores=False, return_embeds=False) -> SemanticEqOutput:
    """
    Computing phrase similarity with Sentence BERT
    see paper Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks
    returns a boolean vector v, where v[i] = cosine_sim(s1[i],s2[i]) > threshold.
    :param s1:
    :param s2:
    :param batch_size: capping batch_size
    :param auto_extension: If True, will extend s1 or s2 to fit dimensions for each other
    by duplicating the last encoded sentence of the one with the smaller size, otherwise
    will let broadcasting do the work.
    :param threshold:
    :return:

    """
    # model = SentenceTransformer(model_name)
    global sen_transformer
    start = timer()
    if isinstance(s1, str):
        s1 = [s1]
    if isinstance(s2, str):
        s2 = [s2]

    sentences = s1 + s2

    bs = len(sentences)
    sentence_embeds = sen_transformer.encode(sentences, batch_size=bs, show_progress_bar=False,
                                             normalize_embeddings=True,
                                             convert_to_tensor=True)  # (len(s1)+len(s2), 768)
    sentence_embeds = torch.split(sentence_embeds, [len(s1), len(s2)])  # ((len(s1), 768), (len(s2), 768))
    s1_embs, s2_embs = sentence_embeds
    # this is different from broadcasting, since broadcasting is lacking 1 dimension,
    # unlike here where the dimension size is not equal.
    if auto_extension:
        if len(s1) < len(s2):
            s1_embs = s1_embs[-1, :].repeat(len(s2), 1)
        elif len(s2) < len(s1):
            s2_embs = s2_embs[-1, :].repeat(len(s1), 1)

    similarity = torch.sum(s1_embs * s2_embs, dim=-1)  # sum over columns
    sem_eq = (similarity > threshold).tolist()

    ret = SemanticEqOutput(result=sem_eq)

    if return_scores:
        ret.scores = similarity

    if return_embeds:
        ret.source_embeds = s1_embs
        ret.target_embeds = s2_embs

    return ret

In [None]:
# a = bert_sentence_similarity_batched(['reins are on head', 'pane in window', 'knob on door', 'frame around window'], 'baseball cap on wall', return_scores=True)

In [None]:
# a.scores

In [None]:
def semantically_equivalent(target: Union[str, List[str]], source: Union[str, List[str]], criteria=None,
                            criteria_answers=False, **kwargs) \
        -> Union[SemanticEqOutput, List[SemanticEqOutput]]:
    """
    Return True if target is semantically equivalent to source
    according if exist criterion from criteria defined that is satisfied.
    Better provide criteria suc that their computation complexity increasing. (e.g ["eq", "bert"])
    :param criteria_answers: if True will return the result of the passed criterions as list
    :param criteria:
    :param target:
    :param source:
    :return:
    """
    if criteria is None:
        criteria = ['eq']

    def check_criterion(c, **kwargs):
        if c == "eq":
            return batch_equal(target, source)
        elif c == 'bert':
            return bert_sentence_similarity_batched(target, source, **kwargs)
        # elif c == "synset":  # TODO: Adapt it to batched version
        #     return synset_similarity(target, source, **kwargs)
        else:
            return False

    if criteria_answers:
        return [check_criterion(c, **kwargs) for c in criteria]
    else:
        res = None
        for c in criteria:
            res = check_criterion(c, **kwargs)
            # if all is semantic equivalent return it to avoid checking more criteria for efficiency
            if res:
                return res

        return res

In [None]:
def holding_in_scene_graph_batched(scene_graph, subj: str, rel: str, obj: str, debug=True):
    """
    A.K.A "Scene Graph Validator"

    Given a triplet <subj, rel, obj> and a scene graph,
    search the triplet within the graph according to equivalence criteria
    return True if found else False.
    :param scene_graph:
    :param subj:
    :param rel:
    :param obj:
    :return:
    """
    def select_phrase_kind(s, r, o):
      if subj is not None and obj is not None:
        return f"{s} {r} {o}"
      elif subj is None and obj is not None:
        return f"{r} {o}"
      elif obj is None and subj is not None:
        return f"{s} {r}"
      else:
        return f"{r}"
      


    global batch_size

    start = timer()

    relations = scene_graph["relationships"]
    objects = scene_graph["objects"]
    # searching_phrase = f"{subj} {rel} {obj}"
    searching_phrase = select_phrase_kind(subj, rel, obj)
    phrase_batch = []

    relations = chunks(relations, batch_size)

    for batch in relations:
        for rel_obj in batch:
            try:
                r_tag = str.lower(rel_obj["predicate"])
                s_tag_id = rel_obj["subject_id"]
                o_tag_id = rel_obj["object_id"]
                s_o = from_object_id_to_string(objects, [s_tag_id, o_tag_id])
                if s_o is not None and len(s_o) == 2:
                    s, o = s_o
                    # found_phrase = f"{s} {r_tag} {o}"
                    found_phrase = select_phrase_kind(s, r_tag, o)
                    phrase_batch.append(found_phrase)

            except Exception as e:
                logger.debug(e)
                continue

        # demand here for "context-based" semantic instead of "context-free" equivalence of relations.
        # print(searching_phrase)
        phrase_criterions = semantically_equivalent(phrase_batch, searching_phrase, criteria=["eq","bert"],
                                                    criteria_answers=True, **{"return_scores": True})
        
        
        
        # raise Exception('STOP')

        criteria_results = [sem_out.result for sem_out in phrase_criterions]
        if any([any(res) for res in criteria_results]):
            iden_phrases = extract_similar_words_from_batch(criteria_results, phrase_batch)
            if debug:
              logger.debug([list(zip(sem_out.result, phrase_batch, sem_out.scores.tolist())) for sem_out in phrase_criterions if sem_out.scores is not None])
              logger.debug(
                  f"Found Equally Phrases: {iden_phrases}, similar to : {searching_phrase}, Timing: {timer() - start}s")
            return True

    return False

In [None]:
def is_obj_holding_in_scene_graph_by_model(scene_graph, obj: str):
    """
    Given an object and a scene graph,
    search the object within the graph
    return True if found else False.
    :param q_obj:
    :param subj:
    :param rel:
    :param obj:
    :return:
    """

    global batch_size

    start = timer()

    objects = scene_graph["objects"]
    objects = [x['name'] if 'name' in x else x['names'][0] for x in objects if 'name' in x or 'names' in x]

    objects = chunks(objects, batch_size)

    for batch in objects:

        phrase_criterions = semantically_equivalent(batch, obj, criteria=["eq", "bert"],
                                                    criteria_answers=True)
        criteria_results = [sem_out.result for sem_out in phrase_criterions]
        if any([any(res) for res in criteria_results]):
            iden_phrases = extract_similar_words_from_batch(criteria_results, batch)
            logger.debug(
                f"Found Equally Phrases: {iden_phrases}, similar to : {obj}, Timing: {timer() - start}s")
            return True

    return False

In [None]:
def compute_objects_similarity_word2vec(w1: str, w2:str):
  global glove_model

  # this function normalize the vectors so similarity is in [0,1]
  return glove_model.similarity(w1,w2)

In [None]:
def are_synsets_equal(synset1: str, synset2: str, threshold=0.5):
  if synset1 == synset2:
    return True

  s1 = wn.synset(synset1)
  s2 = wn.synset(synset2)

  s1_names = s1.lemma_names()
  s2_names = s2.lemma_names()

  intersection = set(s1_names).intersection(s2_names)

  if len(intersection) > 0:
    return True


  try:
    if len(s1_names) > 0 and len(s2_names) > 0:
      #representive

      sim = compute_objects_similarity_word2vec(s1_names[0], s2_names[0])
      # logger.debug(f"{s1_names[0]}-{s2_names[0]}: {sim}")

      if sim >= threshold:
        return True


      # combs = list(product(s1_names, s2_names))

      # sims = [(comb, compute_objects_similarity_word2vec(*comb)) for comb in combs]

      # for comb, sim in sims:
      #   logger.debug(f"{comb[0]}-{comb[1]}: {sim}")

      # if any([x[1] >= threshold for x in sims]):
      #   return True

  except KeyError as e:
    # logger.debug(e)
    pass


  return False

In [None]:
are_synsets_equal('laptop.n.01', 'bag.n.01')

In [None]:
are_synsets_equal('car.n.01', 'vehicle.n.01')

In [None]:
are_synsets_equal('computer.n.01', 'laptop.n.01')

In [None]:
def is_obj_holding_in_scene_graph_by_synsets(scene_graph, obj: str, obj_synset: str):
    """
    Given an object and a scene graph,
    search the object within the graph
    return True if found else False.
    :param q_obj:
    :param subj:
    :param rel:
    :param obj:
    :return:
    """

    global synset_mapping

    # check out "are_two_objects_similar" for similarity definition

    objects = scene_graph["objects"]

    for obj_dict in objects:
      
      synsets = obj_dict['synsets'] if 'synsets' in obj_dict else None

      #if synsets are available we compare them
      if synsets is not None and len(synsets) > 0 and obj_synset is not None:
        cur_synset = synsets[0] # we take only the first cause it is usually the region main intention

        if are_synsets_equal(cur_synset, obj_synset):
          return True


        # if obj_synset in synsets:
        #   return True

      #if names are available we compare them
      # names = extract_names_from_VG_obj(obj_dict)
      
      # if names is not None:
      #   if obj in names:
      #     return True

    return False

In [None]:
sg_check = imgid_to_scene_graph[2377378]
sg_check

In [None]:
is_obj_holding_in_scene_graph_by_synsets(sg_check, 'laptop', 'laptop.n.01')

In [None]:
def compute_similar_sets_by_predicate(elements, predicate):
  """
  General algorithm to compute similar sets between elements
  (will group up elements which hold on predicate)
  predicate signature (ele1, ele2), where ele_i is from elements
  """

  #we assign each element to its own similar set
  enum_elements2similar_sets = dict([(i,i) for i, x in enumerate(elements)])

  combs = list(combinations(range(len(elements)), 2))

  #for every combination we compute if it is similar
  for i1,i2 in combs:

    e1 = elements[i1]
    e2 = elements[i2]
    
    is_similar = predicate(e1, e2)

    if is_similar:
      
      #lowest index is the representive
      low, big = sorted([i1,i2])

      #merging synsets
      for element_num, similar_set_index in enum_elements2similar_sets.items():
        if similar_set_index == big:
          enum_elements2similar_sets[element_num] = low

  similar_sets2elements = inv_map(enum_elements2similar_sets) #inverse

  #set indices
  sets_indices = list(similar_sets2elements.values())

  similar_sets = [[elements[i] for i in indices] for indices in sets_indices]

  return similar_sets

In [None]:
compute_similar_sets_by_predicate(['a','b','a','b','c'], lambda x,y: x==y)

In [None]:
def are_two_objects_similar(o1, o2):
  syn1 = o1['synsets'][0]
  syn2 = o2['synsets'][0]

  return are_synsets_equal(syn1, syn2)


  # is_similar = (syn1 == syn2)

  # names1 = extract_names_from_VG_obj(o1)
  # names2 = extract_names_from_VG_obj(o2)

  # if names1 and names2:
  #   names1 = set(names1)
  #   names2 = set(names2)
  #   is_similar = is_similar or (len(names1.intersection(names2)) > 0)

  # return is_similar

In [None]:
def are_objects_comparable(o1, o2):
  return ('synsets' in o1 and 'synsets' in o2) and len(o1['synsets']) > 0 and len(o2['synsets']) > 0

In [None]:
def compute_object_similar_groups(scene_graph):
    """
    Outputs list of lists of objects, where two similar objects (similarity defined with predicate)
    will end up in the same group.
    example of object from scene graph:
    {
      "synsets":["apple.n.01"],
      "h":119,"object_id":1023988,
      "names":["apple"],
      "w":126,
      "attributes":["shiny","fresh","whole","uneaten","fruit","red","unpeeled"],
      "y":233,"x":339}
    There are few cases which we need to pay attention to:
    1. synsets contain multiple different synsets, but names contain only 1
    2. synsets contain multiple different synsets and there are few names
    3. synsets contain 1 synset but names contain multiple options
    4. synset contain 1 synset and names contain 1 name - OPTIMAL

    we filter options 1 and 2, since if there are multiple synsets it means there are different meanings
    for the same object and then it become ambigious to decide which one we need to choose for comparison.

    we define similar group:

    if synset is equal or len(names intersection) > 0

    :param q_obj:
    :return:
    """

    objects = scene_graph["objects"]
    objects = [x for x in objects if 'synsets' in x and len(x['synsets']) == 1]


    return compute_similar_sets_by_predicate(objects, are_two_objects_similar)

In [None]:
def compute_triplet_similar_groups(scene_graph):
  """
  This function computes similarity groups for triplets 
  i.e group triplets under the same list, where
  1. triplets in group have the same relation and the same subject
  2. triplets in group have the same relation and the same object
  """

  relationships = scene_graph["relationships"]

  def subj_predicate(r1_obj, r2_obj):
    r1 = str(r1_obj['predicate']).lower()
    r2 = str(r2_obj['predicate']).lower()

    s1_id = r1_obj['subject_id']
    s2_id = r2_obj['subject_id']

    return r1 == r2 and s1_id == s2_id

  def obj_predicate(r1_obj, r2_obj):
    r1 = str(r1_obj['predicate']).lower()
    r2 = str(r2_obj['predicate']).lower()

    o1_id = r1_obj['object_id']
    o2_id = r2_obj['object_id']

    return r1 == r2 and o1_id == o2_id

  return compute_similar_sets_by_predicate(relationships, subj_predicate), compute_similar_sets_by_predicate(relationships, obj_predicate)

In [None]:
def is_plausible_pair(n1, n2):
  """Checks is noun1 and noun2 are plausible pair in Visual Genome - that is if they are embedded
   under some relation, e.g n1=dog, n2=grass, check if dog_grass has some relation to embedded under, could be running_on for example"""
  key = f"{n1}_{n2}"
  return key in s_o_2_relations


### Augementation through lemmatization and grammar error correction layer

In [None]:
def fix_grammar(s):
    """Taken from HuggingFace model name: vennify/t5-base-grammar-correction"""
    global happy_tt
    args = TTSettings(num_beams=5, min_length=1)
    result = happy_tt.generate_text(f"grammar: {s}", args=args)
    return [result.text, lev_dis(s, result.text)] # new_text, levenstein distance from old

In [None]:
#to check if the augmentation was used via lemmatization, check aug stats before and after

def augment(q_obj, exist_phrase, aug_phrase, augment_stats=None, inject_det=None, replace_with_det=[]):
    """
    Phrase can be with multiple words, we need to find each of them sequentially in order to properly augment.
    Descriptive verbs and prepositional phrases can be in a canonic form
    or with different suffixes of past/future/plural
    :param q_obj:
    :param exist_phrase: the phrase within the question to search for
    :param aug_phrase: the phrase to replace with
    :return:
    """

    q_str = q_obj["question"]

    q_a1 = search_replace(q_str, exist_phrase, aug_phrase, inject_det=inject_det, replace_with_det=replace_with_det)

    if q_a1 is not None:  # non lemma
        # q_new, q_fixed = complete(q_obj, q_a1)
        logger.debug(f"Successfully augmented from: {q_str} --> {q_a1}")
        if augment_stats:
            augment_stats["original"] += 1
        return q_a1, True

    else:  # lemmatisation of the question
        q_lemma = " ".join([x.lemma_ for x in nlp(q_str)])
        exist_phrase_lemma = " ".join([x.lemma_ for x in nlp(exist_phrase)])
        q_a2 = search_replace(q_lemma, exist_phrase_lemma, aug_phrase)

        if q_a2 is not None:
            # q_new, q_fixed = complete(q_obj, q_a2, lemmatized_aug=True)
            logger.debug(f"Successfully (Through Lemma) augmented from: {q_str} --> {q_a2}")
            if augment_stats:
                augment_stats["lemma"] += 1
            return q_a2, True
        else:
            logger.debug(
                f"Could not augment: q: {q_obj} \n Found Phrase: [{exist_phrase}] which is not correctly formed within:\n {q_str}, ")
            if augment_stats:
                augment_stats["failed"] += 1
            return q_obj, False

### Selecting Relations Methods

In [None]:
def k_most_freq_rand_selection(id, scene_graph, subj: str, rel: str, obj: str, trials=20, top=5):

    from_top = top
    freq_dist = compute_dist_over_relations(subj, obj, normalised=True)
    freq_dist = sorted(freq_dist.items(), key=lambda x: x[1], reverse=True)
    if len(freq_dist) > 0:
        total_trials = 0
        while total_trials < trials:
            if total_trials > 0 and total_trials % top == 0:  # each top cycle we increase the possibilities to sample from
                from_top += top
            chosen_index = np.random.randint(0, min(len(freq_dist), from_top))
            aug_rel, prob = freq_dist[chosen_index]
            orig_phrase = f"{subj} {rel} {obj}"
            aug_pharse = f"{subj} {aug_rel} {obj}"
            if not semantically_equivalent(aug_rel, rel, criteria=['eq', 'bert']) and not semantically_equivalent(
                    aug_pharse, orig_phrase, criteria=['eq', 'bert']) and not check_timing(holding_in_scene_graph_batched, sg_ds_path,scene_graph, subj,aug_rel, obj):
                return freq_dist[chosen_index]  # exit-point
            total_trials += 1

        logger.debug(f"Process of choosing relations failed for data point id: {id}")
        logger.debug(f"Freq dist: {json.dumps(freq_dist)}")
    else:
        logger.debug(f"Empty Freq dist")
        return None

In [None]:
def choose_by_entailment(id, scene_graph, subj: str, rel: str, obj: str, random_candidates=True):
    """
    Choose a relation by utilizing NLI classifier, predicting each relation candidate such that:
    premise = relation presupposition, hypothesis = candidate presupposition.
    Filter any entailed presuppositions and return the first candidate which
    is not semantic similar to the premise and not found in the scene graph.
    :param id:
    :param q_obj:
    :param subj:
    :param rel:
    :param obj:
    :param threshold:
    :return:
    """

    freq_dist = compute_dist_over_relations(subj, obj, normalised=True)

    if len(freq_dist) > 0:
        rels = list(freq_dist.keys())
        presupp_cands = [f"{subj} {r} {obj}" for r in rels]
        presupp = f"{subj} {rel} {obj}"

        nli_labels = classify_nli([presupp] * len(presupp_cands), presupp_cands)  # size=len(cands)
        cands = [(x, nli_labels[i]) for i, x in enumerate(presupp_cands) if
                 nli_labels[i] in [0, 1]]  # allow contradiction and neutral

        sem_eq_input = [x[0] for x in cands]  # take only presuppositions
        if len(sem_eq_input) > 0:
            output = semantically_equivalent([presupp], sem_eq_input, criteria=['eq', 'bert'],
                                             **{"return_scores": True})
            if not output:  # not all semantic equivalent (= exist a member which is semantic equivalent)
                sim_scores = output.scores.tolist()
                # to avoid any neutral which can be understood through accommodation
                is_eq = output.result
                cands = [[*x, sim_scores[i]] for i, x in enumerate(cands) if not is_eq[i]]

                # select according to the dist created by soft-max on all candidate scores
                # the higher your score the more chance you will get selected.
                # or just select randomly

                p = None if random_candidates else softmax([x[-1] for x in cands])
                
                for i in np.random.choice(len(cands), len(cands), p=p, replace=False):
                    cand, nli_label, score = cands[i]
                    split = cand.split()
                    r = " ".join(split[1:-1])
                    if not holding_in_scene_graph_batched(scene_graph, subj, r, obj):
                        return r, nli_label, score

                logger.debug(f"Process of choosing relations failed for data point id: {id}")
                logger.debug(f"Freq dist: {json.dumps(list(freq_dist.keys()))}")
            else:
                logger.debug(f"There is a triplet in which is semantic equivalent to: {presupp}")
                logger.debug(
                    f"Presupp: {presupp}, Cands: {[[*x, output.scores.tolist()[i]] for i, x in enumerate(cands) if output.scores is not None]}")
        else:
            logger.debug(f"Presupp: {presupp}, All cands: {presupp_cands} classified as entailed.")
    else:
        logger.debug(f"Empty Freq dist for subj={subj}, obj={obj}")

In [None]:
def choose_relation(method: str, *args, **kwargs):
    """
    Factory method to bridge choosing strategy
    :param method:
    :param args:
    :param kwargs:
    :return:
    """
    if method == "k_most_freq_rand_selection":
        return k_most_freq_rand_selection(*args, **kwargs)
    elif method == "entailment":
        return choose_by_entailment(*args, **kwargs)
    else:
        return None

### Generating questions

#### Generating Maps

In [None]:
def r_s_or_o_2_relations(s_o_2_relations):
  """
  generate two maps from the extracted "subject_object -> relations" occurance map:
  1. relation_object -> subjects
  2. subject_relation -> objects
  """
  r_o_2_subjects = {}
  s_r_2_objects = {}

  for s_o, relations_occs in s_o_2_relations.items():
    # "shade_sidewalk": [["on", 1], ["on top of", 1]]
    rels = set([x[0] for x in relations_occs])

    subj, obj = s_o.split('_')

    for rel in rels:
      key = f"{rel}_{obj}"

      if key not in r_o_2_subjects:
        r_o_2_subjects[key] = set()

      r_o_2_subjects[key].add(subj)

      key = f"{subj}_{rel}"

      if key not in s_r_2_objects:
        s_r_2_objects[key] = set()

      s_r_2_objects[key].add(obj)

  return {k:list(v) for k,v in r_o_2_subjects.items()}, {k:list(v) for k,v in s_r_2_objects.items()}

Here we will populate helpful maps:
1. relation_object => subjects = [s1,s2,s3,..]
2. subject_relation => objects = [o1,o2,o3,...]

Additionally, their keys as lists (used for sampling):

r_o = list of 1st map keys
s_r = list of 2nd map keys

In [None]:
r_o_2_subjects, s_r_2_objects = r_s_or_o_2_relations(s_o_2_relations)
s_r = list(s_r_2_objects.keys())
r_o = list(r_o_2_subjects.keys())

In [None]:
len(s_r), len(r_o)

In [None]:
# def check_shuffle_timing():
#   start = timer()
#   np.random.shuffle(s_r) # in-place, O(n)
#   print(f"timing: {timer() - start}")

# check_shuffle_timing()

In [None]:
def search_plausible_relation_for_object(obj: str):
  np.random.shuffle(r_o)
  for pair in r_o:
    r, o = pair.split('_')
    if o == obj:
      return r

  return None

def search_plausible_relation_for_subject(obj: str):
  np.random.shuffle(s_r)
  for pair in s_r:
    s, r = pair.split('_')
    if s == obj:
      return r

  return None

In [None]:
def collect_relations_for_object(obj: str):
  global r_o
  rels = []
  for pair in r_o:
    r, o = pair.split('_')
    if o == obj:
      rels.append(r)

  return rels


def collect_relations_for_subject(subj: str):
  global s_r
  rels = []
  for pair in s_r:
    s, r = pair.split('_')
    if s == subj:
      rels.append(r)

  return rels


#### Sampling Methods

In [None]:
def sample_embedded_relation_randomly(scene_graph, subj: str, obj: str):
    """
    Given subject and object, sample a relation embedded between them
    from plausible relations of VG  and validate the triplet isn't holding in scene graph.
    :param subj:
    :param rel:
    :param obj:
    :param threshold:
    :return:
    """

    freq_dist = compute_dist_over_relations(subj, obj, normalised=True)

    if len(freq_dist) > 0:
        rels = list(freq_dist.keys())
        np.random.shuffle(rels) # we choose randomly

        for rel in rels:
           if not holding_in_scene_graph_batched(scene_graph, subj, rel, obj, debug=False):
                  return rel

        logger.debug(f"Process of sample_embedded_relation_randomly failed for subj={subj}, obj={obj}")
    else:
        logger.debug(f"Empty Freq dist for subj={subj}, obj={obj}")

    return None

In [None]:
def sample_noun_and_validate_triplets(scene_graph, triplet, subj=True):
  """
  Sampling a subject or object (with subj=True as switch) from visual genome distribution
  and validating if the triplet exists in the scene graph

  None will return if key not in map or didnt find candidate that not holding in scene graph.
  """
  

  subj, rel, obj = triplet

  def validate_presupp_failure(cand):
    "returns if candidate (subj or obj) is not holding in scene graph (i.e will trigger failure)"
    return subj and not holding_in_scene_graph_batched(scene_graph, cand, rel, obj) \
      or (not subj and not holding_in_scene_graph_batched(scene_graph, subj, rel, cand))


  map = r_o_2_subjects if subj else s_r_2_objects

  key = f"{rel}_{obj}" if subj else f"{subj}_{rel}"

  if key in map:
    # sample valid subject or object
    cands = copy.deepcopy(map[key])
    np.random.shuffle(cands)
    for cand in cands:

      if validate_presupp_failure(cand):
        return cand

  return None


In [None]:
def sample_plausible_pair(scene_graph, subj=True):
  """
  This method samples plausible pair (subj or obj) and relation from all visual genome pairs
  which has idenfied synset to compare with.
  i.e noun that appeared in some relationship in visual genome
  and validate it doesn't hold in scene graph
  """
  global synset_mapping

  def validate_noun_failure(noun, noun_synset):
    return not is_obj_holding_in_scene_graph_by_synsets(scene_graph, noun, noun_synset)

  pairs_str = s_r if subj else r_o

  np.random.shuffle(pairs_str) # in-place, O(n)

  for pair in pairs_str:
    if subj:
      noun, rel = pair.split('_')
    else:
      rel, noun = pair.split('_')

    noun_synset = synset_mapping[noun] if noun in synset_mapping else None
    
    if noun_synset is not None and validate_noun_failure(noun, noun_synset):
      return pair

  return None

In [None]:
def sample_plausible_pair_from_scene_graph(scene_graph):
  """
  This method samples plausible pair (subj or obj) and relation from the scene graph,
  then it validates it is not in the scene graph.

  e.g 
  (s,r,o) => (s,r') => check over scene-graph (s,r') ~! (s,r'')
  (s,r,o) => (r', o) => check over scene-graph (r', o) ~! (r'', o)
  """
  global synset_mapping

  def validate_pair_failure(s, r, o):
    return not holding_in_scene_graph_batched(scene_graph, s, r, o)

  scene_graph_triplets = extract_triplets_from_scene_graph(scene_graph)

  subj_pair = None
  obj_pair = None

  for t in scene_graph_triplets:
    s, r, o = t

    #subj
    if subj_pair is None:
      rel_lst = collect_relations_for_subject(s)
      for r in rel_lst:
        if validate_pair_failure(s, r, None):
          subj_pair = (s, r)
          break

    #obj
    if obj_pair is None:
      rel_lst = collect_relations_for_object(o)
      for r in rel_lst:
        if validate_pair_failure(None, r, o):
          obj_pair = (r, o)
          break

    
    if subj_pair is not None and obj_pair is not None:
      break


  return subj_pair, obj_pair

#### Types generation

##### Type 1
Definitive reference to a non-existent object failure.

From question:

1. (man, to the left of, pedestrians)
  "What is the man to the left of the pedestrians wearing?" 
  -> (woman, to the left of, pedestrians)
  "What is the woman to the left of the pedestrians wearing?"

Pre-Defined:


1. Who is [REL] the [OBJ]? - (“Who is wearing the jacket?”) -> 

  We sample with with respect to two conditions:

  [VALIDITY]: from the plausible relationships (in VG) sample “wearing_o”. 

  [FAILURE]: check if “o” in scene graph.                          

  **Warning**: if we check if anything is wearing an object in the scene graph and the object does exist in the scene, then the type of failure changes to “nothing is wearing an object” instead of ‘non-existent object’.

2. What is the [SUBJ] [REL]? - (“What is the man throwing?”) ->

  [VALIDITY]: from the plausible relationships (in VG) sample “s_wearing”. 
  [FAILURE]: check if “s” in scene graph.


3. Where is the [SUBJ/OBJ]? - sample object that doesn’t hold in scene graph (“Where is the jacket?”)

In [None]:
def gen_type_1_from_questions(q_obj, id, triplet, aug_stats, subj_has_multiple_presuppositions=False):

  from_questions = []

  #from question

  scene_graph = extract_scene_graph_from_question(q_obj)

  if scene_graph:

    #subj
    if not subj_has_multiple_presuppositions:
      cand = sample_noun_and_validate_triplets(scene_graph, triplet, subj=True)

      if cand is not None:
        s,r,o = triplet
        exist_phrase = s

        lemmas_before = aug_stats['lemma']
        q_new, is_suc = augment(q_obj,exist_phrase, cand, aug_stats, inject_det="the", replace_with_det=['a'])

        if is_suc:
          q_new_gec, lev_dist = fix_grammar(q_new)

          q = AugmentQuestion(
              q=q_new_gec,
              failure_type=1,
              failure_reason=f"No {cand} in image",
              origin_id=id,
              image_id=q_obj['imageId'],
              generation_type="from_questions",
              q_gec=q_new,
              gec_dist=lev_dist,
              is_lemmatised_aug=(aug_stats['lemma'] - lemmas_before == 1),
              properties={"is_subj_aug":True, "triplet_sampled": (cand, r, o)}
          )

          from_questions.append(q)

    #obj
    cand = sample_noun_and_validate_triplets(scene_graph, triplet, subj=False)

    if cand is not None:
      s,r,o = triplet
      exist_phrase = o

      lemmas_before = aug_stats['lemma']
      q_new, is_suc = augment(q_obj,exist_phrase, cand, aug_stats, inject_det="the", replace_with_det=['a'])

      if is_suc:
        q_new_gec, lev_dist = fix_grammar(q_new)

        q = AugmentQuestion(
            q=q_new_gec,
            failure_type=1,
            failure_reason=f"No {cand} in image",
            origin_id=id,
            image_id=q_obj['imageId'],
            generation_type="from_questions",
            q_gec=q_new,
            gec_dist=lev_dist,
            is_lemmatised_aug=(aug_stats['lemma'] - lemmas_before == 1),
            properties={"is_subj_aug":False, "triplet_sampled": (s, r, cand)}
        )

        from_questions.append(q)

  return from_questions
  


In [None]:
def gen_type_1_pre_defined(scene_graph):
  """

  """
  global synset_mapping

  pre_defined = []

  id = scene_graph["image_id"]

  #pre-defined 1
  pair_str = sample_plausible_pair(scene_graph, subj=False)
  if pair_str:
    r, o = pair_str.split('_')

    #we know it has synset since it is pre-condition for "sample_plausible_pair"
    noun_synset = synset_mapping[o]

    pp_r = transform_present_progressive(r)

    q = AugmentQuestion(
        q=f"Who is {pp_r} the {o}?",
        failure_type=1,
        failure_sub_type=1,
        failure_reason=f"No {o} in image",
        origin_id=id,
        image_id=id,
        generation_type="pre_defined",
        properties={"is_subj_aug":False, "pair_sampled": (r, o), "noun_synset": noun_synset}
        )
    pre_defined.append(q)

    #pre-defined 3
    q = AugmentQuestion(
        q=f"Where is the {o}?",
        failure_type=1,
        failure_sub_type=3,
        failure_reason=f"No {o} in image",
        origin_id=id,
        image_id=id,
        generation_type="pre_defined",
        properties={"is_subj_aug":False, "noun_sampled": o, "noun_synset": noun_synset}
    )

    pre_defined.append(q)

  #pre-defined 2
  pair_str = sample_plausible_pair(scene_graph, subj=True)
  if pair_str:
    s, r = pair_str.split('_')

    #we know it has synset since it is pre-condition for "sample_plausible_pair"
    noun_synset = synset_mapping[s]

    q = AugmentQuestion(
        q=f"What is the {s} {r}?",
        failure_type=1,
        failure_sub_type=2,
        failure_reason=f"No {s} in image",
        origin_id=id,
        image_id=id,
        generation_type="pre_defined",
        properties={"is_subj_aug":True, "pair_sampled": (s, r), "noun_synset": noun_synset}
        )
    pre_defined.append(q)

    #pre-defined 3
    q = AugmentQuestion(
        q=f"Where is the {s}?",
        failure_type=1,
        failure_sub_type=3,
        failure_reason=f"No {s} in image",
        origin_id=id,
        image_id=id,
        generation_type="pre_defined",
        properties={"is_subj_aug":True, "noun_sampled": s, "noun_synset": noun_synset}
    )

    pre_defined.append(q)

  return pre_defined


##### Type 2

Definite singular reference to an object when the scene contains multiple instances of this object 

From question:
 1. "What is the man to the left of the pedestrians wearing?" -> "What is the woman to the left of the pedestrians wearing?" or 
  "What is the man to the left of the tunnel wearing?", where there are men or tunnels in the scene.

Pre-Defined:
  1. Where is the [SUBJ/OBJ]? search for an object in the scene that have    multiple appearances and refer to it.
  2. What is the [SUBJ] [REL]? additionally search for that object, a plausible relation.
  3. Who is [REL] the [OBJ]?

Notes for implementation:

- for pre-defined we must filter non-countable nouns

- for from_question the problem is that the triplet might solve the failure.
e.g if we have multiple woman in the scene and we refer to "woman to the left of the pedestraians"
the description become more specific and can potentially solve the ambiguity.
another example - triplet is (man, to the left of, pedestraians) we search in the scene if first, there is another object or subject that hold to this relation and then if there is, count the number its appearances.
however, if we generally count it, this can be misleading, since the subject or object can appear multiple times but not under the same relation.
e.g triplets from scene graph (woman, to the left of, pedestraians), (woman, to the right of, pedestraians), (woman, wearing, pants)
here we have multiple woman in the scene, but not to the left of the pedestrians.

Solution: count object/subjects that only appear within a certain triplet.
  e.g (woman, to the left of, pedestraians)
  

In [None]:
def select_representive(similar_set):
  """choose the representive of a similar set to be the most frequent name of the object that the similar set is representing
    other possibility could be the majority from the canonical names (e.g trousers.n.01 => trousers) of representive synset"""

  #majority over names
  representive, count = Counter(chain(*[extract_names_from_VG_obj(x) for x in similar_set])).most_common(1)[0]
  singular = inflect_eng.singular_noun(representive)

  return singular if isinstance(singular, str) else representive

In [None]:
def gen_type_2_from_questions(q_obj, id, triplet, aug_stats):
  #for now we will pass on the implementation due to its complexity and belief that very few examples can be generated from this
  return []

In [None]:
def gen_type_2_pre_defined(scene_graph):
  """
  """

  pre_defined = []

  id = scene_graph["image_id"]

  #list of lists
  similar_objects = compute_object_similar_groups(scene_graph)

  obj = None
  rel_obj = None
  subj_rel = None

  for similar_set in similar_objects:
    if obj and rel_obj and subj_rel:
      break

    rep = select_representive(similar_set)

    squeezed_set = merge_same_objects_in_similar_set(similar_set, delta=7)

    if not is_mass_noun(rep) and len(squeezed_set) > 1: # number of appearances > 1

      if obj is None:
        obj = {"candidate": rep, "similar_set": similar_set}

      if subj_rel is None:
        rel = search_plausible_relation_for_subject(rep)
        if rel:
          subj_rel = {"candidate": f"{rep}_{rel}", "similar_set": similar_set}

      if rel_obj is None:
        rel = search_plausible_relation_for_object(rep)
        if rel:
          rel_obj = {"candidate": f"{rel}_{rep}", "similar_set": similar_set}

  #pre-defined 1
  if obj:
    o = obj['candidate']

    q = AugmentQuestion(
        q=f"Where is the {o}?",
        failure_type=2,
        failure_sub_type=1,
        failure_reason=f"There are multiple {o} in image",
        origin_id=id,
        image_id=id,
        generation_type="pre_defined",
        properties={"noun_sampled": o ,"similar_set":obj['similar_set']}
    )
    pre_defined.append(q)


  #pre-defined 2
  #e.g (knife, on top of) -> "What is the knife on top of?"
  if subj_rel:
    pair = subj_rel['candidate']

    s, r = pair.split('_')
    q = AugmentQuestion(
        q=f"What is the {s} {r}?",
        failure_type=2,
        failure_sub_type=2,
        failure_reason=f"There are multiple {s} in image",
        origin_id=id,
        image_id=id,
        generation_type="pre_defined",
        properties={"is_subj_aug": True, "pair_sampled": (s, r), "similar_set": subj_rel['similar_set']}
        )
    pre_defined.append(q)


  #pre-defined 3
  #(on top of, desk) -> "Who is on top of the desk?"
  if rel_obj:
    pair = rel_obj['candidate']

    r, o = pair.split('_')
    q = AugmentQuestion(
        q=f"Who is {r} the {o}?",
        failure_type=2,
        failure_sub_type=3,
        failure_reason=f"There are multiple {o} in image",
        origin_id=id,
        image_id=id,
        generation_type="pre_defined",
        properties={"is_subj_aug": False, "pair_sampled": (r, o), "similar_set": rel_obj['similar_set']}
        )
    pre_defined.append(q)

  return pre_defined

##### Type 3

Definite plural reference to an object when the scene contains only one instance   

From question:
1. "What is the man to the left of the pedestrians wearing?" -> 
"What are the men to the left of the pedestrians wearing?", where there is only one man in the scene.
 
Pre-Defined:
1. Where is the [SUBJ/OBJ]?
2. What is the [SUBJ] [REL]?
3. Who is [REL] the [OBJ]?

Notes for implementation:

- pre-defined - we must filter non-countable nouns

- from_question:
  If we have triplet from question we need to sample an object from the scene
  that have only one occurence (in total) in the scene.
  however, the sampled object might not hold in the triplet in the image.
  for example:
  we have (man, to the left of, pedestraians) and we select 'woman' cause it appeared only one time
  in the scene, although there is only one woman in the scene, it might not be to the left of the pedestrians.
  
  Solution: count object/subjects that only appear within a certain triplet.
  e.g (woman, to the left of, pedestraians)

In [None]:
def gen_type_3_from_questions(q_obj, id, triplet, aug_stats):
  #for now we will pass on the implementation due to its complexity and belief that very few examples can be generated from this
  return []

In [None]:
def gen_type_3_pre_defined(scene_graph):
  """
  
  """

  pre_defined = []

  id = scene_graph["image_id"]

  #list of lists
  similar_objects = compute_object_similar_groups(scene_graph)

  obj = None
  rel_obj = None
  subj_rel = None

  for similar_set in similar_objects:
    if obj and rel_obj and subj_rel:
      break

    rep = select_representive(similar_set)

    if not is_mass_noun(rep) and len(similar_set) == 1:

      #not needed to save similar set since it has length of 1
      if obj is None:
        obj = rep

      if subj_rel is None:
        rel = search_plausible_relation_for_subject(rep)
        if rel is not None:
          subj_rel = f"{rep}_{rel}"

      if rel_obj is None:
        rel = search_plausible_relation_for_object(rep)
        if rel is not None:
          rel_obj = f"{rel}_{rep}"

  #pre-defined 1
  if obj:
    obj_plural = inflect_eng.plural_noun(obj)
    q = AugmentQuestion(
        q=f"Where are the {obj_plural}?",
        failure_type=3,
        failure_sub_type=1,
        failure_reason=f"There is only one {obj} in image",
        origin_id=id,
        image_id=id,
        generation_type="pre_defined",
        properties={"noun_sampled": obj}

    )
    pre_defined.append(q)


  #pre-defined 2
  #e.g (knives, on top of the) -> What are the knives on top of ?

  if subj_rel:
    s, r = subj_rel.split('_')
    s_plural = inflect_eng.plural_noun(s)

    q = AugmentQuestion(
        q=f"What are the {s_plural} {r}?",
        failure_type=3,
        failure_sub_type=2,
        failure_reason=f"There is only one {s} in image",
        origin_id=id,
        image_id=id,
        generation_type="pre_defined",
        properties={"is_subj_aug": True, "pair_sampled": (s, r)}
        )
    pre_defined.append(q)


  #pre-defined 3
  #e.g (on top of, desks) -> "Who is on top of the desks?"
  if rel_obj:
    r, o = rel_obj.split('_')
    o_plural = inflect_eng.plural_noun(o)
    q = AugmentQuestion(
        q=f"Who is {r} the {o_plural}?",
        failure_type=3,
        failure_sub_type=3,
        failure_reason=f"There is only one {o} in image",
        origin_id=id,
        image_id=id,
        generation_type="pre_defined",
        properties={"is_subj_aug": False, "pair_sampled": (r, o)}
        )
    pre_defined.append(q)

  return pre_defined

##### Type 4

Definite complex reference to an object that exists in the scene but the instance does not hold in the relation expressed in the description.  
From questions:  
1. "What is the man to the left of the pedestrians wearing?" -> "What is the man to the right of the pedestrians wearing?"  

Pre-Defined:  
1. Where is the [SUBJ] [REL] the [OBJ]? - sample triplet from the scene graph, substitute relation e.g (man,wearing,jacket) -> (“Where is the man throwing the jacket?”), USE SUBSTITUTION between subject and object if it is they are plausible pair in visual genome
2. What does the [S1] [R1] the [O1] [R2]? - get compositional triplets e.g [(s1,r1,o1), (s1,r2,o2)] from the scene graph and choose new relation to trigger triplet failure 
3. What is the [SUBJ] [REL]? - sample triplet from the scene graph, substitute relation e.g (man,wearing,jacket) -> ("What is the man throwing?")
4. Who is [REL] the [OBJ]? - sample triplet from the scene graph, substitute relation e.g (man,wearing,jacket) -> ("Who is throwing the jacket?")


In [None]:
def gen_type_4_from_questions(q_obj, id, triplet, aug_stats):
  
  from_questions = []

  #from_questions
  subj, rel, obj = triplet
  scene_graph = extract_scene_graph_from_question(q_obj)

  if scene_graph is not None:
    package = choose_relation("entailment", id, scene_graph, subj, rel, obj)

    if package is not None:
      aug_rel, nli_label, sim_score = package

      lemmas_before = aug_stats['lemma']
      q_new, is_suc = augment(q_obj, rel, aug_rel, aug_stats)

      if is_suc:

        q_new_gec, lev_dist = fix_grammar(q_new)

        q = AugmentQuestion(
            q=q_new_gec,
            failure_type=4,
            failure_reason=f"No {subj} {aug_rel} {obj} in image",
            q_gec=q_new,
            is_lemmatised_aug=(aug_stats['lemma'] - lemmas_before == 1),
            origin_id=id,
            image_id=q_obj['imageId'],
            generation_type="from_questions",
            properties={"triplet_sampled": (subj, aug_rel, obj)}
        )

        from_questions.append(q)

  return from_questions


In [None]:
def compute_compositional_pair_iterator(pair_lst):
  """
  This function computes list of pairs between elements of pair_lst
  in a way that prevent the pipeline from generating duplicate questions.
  e.g man -> [(carries, backpack), (wears, sneakers), (wears, backpack), (at, crosswalk)]
  if we pair (carries, backpack) with (wears, sneakers), then we pair (carries, backpack) and (wears, backpack)
  then, if we sample the same augmented relation for both - lets say "opening", we will get the same question:
  (carries, backpack) + (wears, sneakers) = "What does the man opening the backpack wears?"
  (carries, backpack) + (wears, backpack) = "What does the man opening the backpack wears?"
  """
  combs = []
  for i, x in enumerate(pair_lst):
    x_combs = []
    for j, y in enumerate(pair_lst):
      if i!=j:
        is_same_rel_exist = any([y[0] == z[0] for z in x_combs])
        is_same_rel_as_x = x[0] == y[0]
        if not is_same_rel_as_x and not is_same_rel_exist:
          x_combs.append(y)

    x_combs = [(x, z) for z in x_combs]
      
    combs.extend(x_combs)
  
  return combs

In [None]:
compute_compositional_pair_iterator([('carries', 'backpack'), ('wears', 'sneakers'), ('wears', 'backpack'), ('at', 'crosswalk')])

In [None]:
def gen_type_4_pre_defined(scene_graph):

  pre_defined = []

  id = scene_graph["image_id"]
  
  #pre-defined
  scene_graph_triplets = extract_triplets_from_scene_graph(scene_graph)

  np.random.shuffle(scene_graph_triplets)
  
  logger.debug(f"starting pre-defind 1")
  subs_triplet = None
  for t in scene_graph_triplets:
    s, r, o = t

    if is_plausible_pair(o, s): #subject and object substitution

      aug_rel = sample_embedded_relation_randomly(scene_graph, o, s)
      
      if aug_rel is not None:
        subs_triplet = (o, aug_rel, s)


  if subs_triplet:
    s, r, o = subs_triplet
    #pre-defined 1
    q = AugmentQuestion(
        q=f"Where is the {s} {r} the {o}?",
        failure_type=4,
        failure_sub_type=1,
        failure_reason=f"No {s} {r} {o} in image",
        origin_id=id,
        image_id=id,
        generation_type="pre_defined",
        properties={"triplet_sampled": (s, r, o)}
        )
    pre_defined.append(q)

  
  # new_triplet = None
  # for t in scene_graph_triplets:
  #   s, r, o = t

  #   package = choose_relation("entailment", id, scene_graph, s, r, o)
  #   if package is not None:
  #     aug_rel, nli_label, sim_score = package

  #     new_triplet = (s, aug_rel, o)
  #     break
          

  # if new_triplet:
  #   s, r, o = new_triplet

  logger.debug(f"starting pre-defind 3 and 4")
  subj_pair, obj_pair = sample_plausible_pair_from_scene_graph(scene_graph)

  if subj_pair:

    s, r = subj_pair

    #pre-defined 3
    q = AugmentQuestion(
        q=f"What is the {s} {r}?",
        failure_type=4,
        failure_sub_type=3,
        failure_reason=f"No {s} {r} in image",
        origin_id=id,
        image_id=id,
        generation_type="pre_defined",
        properties={"is_subj_aug": True, "pair_sampled": (s, r)}
        )
    pre_defined.append(q)

  if obj_pair:

    r, o = obj_pair

    pp_r = transform_present_progressive(r)

    #pre-defined 4 
    q = AugmentQuestion(
      q=f"Who is {pp_r} the {o}?",
      failure_type=4,
      failure_sub_type=4,
      failure_reason=f"Nothing is {r} {o} in image",
      origin_id=id,
      image_id=id,
      generation_type="pre_defined",
      properties={"is_subj_aug": False, "triplet_sampled": (r, o)}
    )
    pre_defined.append(q)



  logger.debug(f"starting pre-defind 2")
  #pre-defined 2
  subj2pairs = generate_subjects_2_pair(scene_graph) # s-> [(r1,o1),(r2,o2),..]

  # logger.debug(subj2pairs)
  # logger.debug('merging')

  #merge the same pairs to prevent duplicates
  subj2pairs = {k:set(v) for k,v in subj2pairs.items()}

  # logger.debug(subj2pairs)

  # logger.debug(f"finished s2pairs map")
  

  for s, pair_lst in subj2pairs.items():
    if len(pair_lst) > 1:

      combs = compute_compositional_pair_iterator(pair_lst) #all pairs in pair_list

      # logger.debug(f'combs--{s}')
      # logger.debug(combs)

      for p1,p2 in combs:
        
        # logger.debug('-'*80)
        # logger.debug(f"p1={p1}, p2={p2}")

        r1,o1 = p1
        r2,o2 = p2

        package = choose_relation("entailment", id, scene_graph, s, r1, o1)
        if package is not None:
          aug_rel, nli_label, sim_score = package

          new_triplet = (s, aug_rel, o1)

          
          q = AugmentQuestion(
            q=f"What does the {s} {aug_rel} the {o1} {r2}?",
            failure_type=4,
            failure_sub_type=2,
            failure_reason=f"No {s} {aug_rel} {o1} in image",
            origin_id=id,
            image_id=id,
            generation_type="pre_defined",
            properties={"triplet_sampled": new_triplet, "chain_triplet": (s,r2,o2), "subst_rel": r1}
          )

          # logger.debug(f"{(r1, o1)}-->{(aug_rel, o1)}, chain-pair:{(r2,o2)}. Q: {q.q}")

          pre_defined.append(q)



        # package = choose_relation("entailment", id, scene_graph, s, r2, o2)
        # if package is not None:
        #   aug_rel, nli_label, sim_score = package

        #   new_triplet = (s, aug_rel, o2)

        #   q = AugmentQuestion(
        #     q=f"What does the {s} {aug_rel} the {o2} {r1}?",
        #     failure_type=4,
        #     failure_reason=f"No {s} {aug_rel} {o2} in image",
        #     origin_id=id,
        #     image_id=id,
        #     generation_type="pre_defined",
        #     properties={"triplet_sampled": new_triplet, "chain_triplet": (s,r1,o1), "subst_rel": r2}
        #   )

        #   logger.debug(f"{(r2, o2)}-->{(aug_rel, o2)}, chain-pair:{(r1,o1)}. Q: {q.q}")

        #   pre_defined.append(q)

      break # we only want to try generate from 1 pair list to avoid heavy computations


  return pre_defined


##### Type 5
Definitive complex reference to an object or subject that exists in the scene, 
  the instance does hold in the relation expressed, but the relation refer to multiple options.
  We only use predefined here:
  1. What does the [SUBJ] [REL] ? - given a scene graph, search for triplets
  with multiple object options => (biker, wearing, jacket/pants/shoes) -> “What does the biker wear?”
  or (biker, has, jacket/pants/shoes)
  2. Who is [REL] the [OBJ]? -  given a question q, search for triplets with multiple
  object options => (biker/boy/woman, wearing, jacket) -> “Who is wearing the jacket?”

In [None]:
def gen_type_5_from_questions(q_obj, id, triplet, aug_stats):
  return []

In [None]:
def gen_type_5_pre_defined(scene_graph):
  """
  """
  id = scene_graph["image_id"]

  def generate_multiple_object_question(objects, subj_groups):
    qs = []

    for group in subj_groups:

      if len(group) > 1: #if there are atleast 2 objects fit to (subj, rel, *), where * is unique object

        some_triplet = group[0] # all subjects and relations in group are the same
        s_id = some_triplet['subject_id']
        r = str(some_triplet['predicate']).lower()

        objects_ids = [x["object_id"] for x in group]

        #extract objects
        objects_in_group = from_object_id_to_object_obj(objects, objects_ids)

        logger.debug(f"objects-in-group: {objects_in_group}")

        combs = combinations(objects_in_group, 2)

        for o1, o2 in combs:
          
          # if are_objects_comparable(o1, o2) and are_two_objects_similar(o1, o2):
          #   logger.debug(f"comparable, but similar: {[o1,o2]}")

          if are_objects_comparable(o1, o2) and not are_two_objects_similar(o1, o2):
            # if two objects are not similar, it means the triplet holds for atleast two unique objects
            #thus, we can generate a question with multiple answers based on that idea

            #pre-defined 1
            o1_id = o1['object_id']
            o2_id = o2['object_id']

            nouns_list = from_object_id_to_string(objects, [o1_id, o2_id, s_id])

            logger.debug(f"failure nouns_list = (obj1,obj2, subj): {nouns_list}")

            if nouns_list is not None and len(nouns_list) == 3:
              o1_name,o2_name, subj_name = nouns_list

              objects_names = from_object_id_to_string(objects, objects_ids)

              qs.append(AugmentQuestion(
                  q=f"What does the {subj_name} {r}?",
                  failure_type=5,
                  failure_sub_type=1,
                  failure_reason=f"The question refer to atleast two objects: {[o1_name, o2_name]} in the image",
                  origin_id=id,
                  image_id=id,
                  generation_type="pre_defined",
                  properties={"is_subj_aug": True, "pair_sampled": (subj_name, r), "similar_set": group}
                  ))
              
              # the group contains triplets with the same rel and subj but with different objects, since the generated question is asking
              # about the object, if we won't break the loop we will generate duplicate questions
              break
              
    return qs

  def generate_multiple_subject_question(objects, obj_groups):
    """
    When generating a multiple subject question, we need to pay attention to which kind of subjects
    withold in the same triplet. 
    If there are multiple pronouns or people/person, the question must have "interrogative pronoun"
    such as "Who, Which, Whom, What, Whose". Else, use "What, Which"
    e.g (can, lying on, grass), (dog, lying on, grass), (cat, lying on, grass), (girl, lying on, grass), (they, lying on, grass)
    1. "What is lying on the grass?" - fits for (can, lying on, grass), (dog, lying on, grass), (cat, lying on, grass)
    2. "Who is lying on the grass?" - fits for (girl, lying on, grass), (they, lying on, grass)
    """
    qs = []

    for group in obj_groups:

      if len(group) > 1: #if there are atleast 2 subjects fit to (*, rel, obj), where * is unique subject

        some_triplet = group[0] # all objects and relations in group are the same
        o_id = some_triplet['object_id']
        r = str(some_triplet['predicate']).lower()

        subjects_ids = [x["subject_id"] for x in group]

        #extract subjects
        subjects_in_group = from_object_id_to_object_obj(objects, subjects_ids)

        logger.debug(f"subjects-in-group: {subjects_in_group}")

        #split - explanation above
        inter, other = split_object_similar_set(subjects_in_group)

        for subjects, is_interrogative in zip([inter, other], [True, False]):

          if len(subjects) <= 1:
            continue

          combs = combinations(subjects, 2)

          for s1, s2 in combs:

            # if are_objects_comparable(s1, s2) and are_two_objects_similar(s1, s2):
            #   logger.debug(f"comparable, but similar: {[s1,s2]}")

            if are_objects_comparable(s1, s2) and not are_two_objects_similar(s1, s2):
              # if two subjects are not similar, it means the triplet holds for atleast two unique subject
              #thus, we can generate a question with multiple answers based on that idea

              #pre-defined 2
              s1_id = s1['object_id']
              s2_id = s2['object_id']

              nouns_list = from_object_id_to_string(objects, [s1_id, s2_id, o_id])

              logger.debug(f"failure nouns_list = (subj1,sub2,obj): {nouns_list}")

              if nouns_list is not None and len(nouns_list) == 3:
                s1_name,s2_name, obj_name = nouns_list

                subjects_names = from_object_id_to_string(objects, subjects_ids)

                WH = 'Who' if is_interrogative else 'What'

                qs.append(AugmentQuestion(
                    q=f"{WH} is {r} the {obj_name}?",
                    failure_type=5,
                    failure_sub_type=2,
                    failure_reason=f"The question refer to atleast two subjects: {[s1_name, s2_name]} in the image",
                    origin_id=id,
                    image_id=id,
                    generation_type="pre_defined",
                    properties={"is_subj_aug": False, "pair_sampled": (r, obj_name), "similar_set": group}
                    ))
                
                # the group contains triplets with the same rel and obj but with different subjects, since the generated question is asking
                # about the subject, if we won't break the loop we will generate duplicate questions
                break 
              
              
    return qs



  pre_defined = []


  subj_groups, obj_groups = compute_triplet_similar_groups(scene_graph)

  objects = scene_graph["objects"]

  qs = generate_multiple_object_question(objects, subj_groups)

  pre_defined.extend(qs)

  qs = generate_multiple_subject_question(objects, obj_groups)

  pre_defined.extend(qs)


  return pre_defined

#### Factory Pipe

In [None]:
def generate_predefined(scene_graph) -> List[AugmentQuestion]:
  questions = []

  ## type 1
  logger.debug('Generating Type 1...')
  questions += gen_type_1_pre_defined(scene_graph)

  ## type 2
  logger.debug('Generating Type 2...')
  questions += gen_type_2_pre_defined(scene_graph)

  ## type 3
  logger.debug('Generating Type 3...')
  questions += gen_type_3_pre_defined(scene_graph)

  ## type 4
  logger.debug('Generating Type 4...')
  questions += check_timing(gen_type_4_pre_defined, scene_graph)
  # questions += gen_type_4_pre_defined(scene_graph)

  ## type 5
  logger.debug('Generating Type 5...')
  questions += gen_type_5_pre_defined(scene_graph)

  return questions

In [None]:
def generate_from_questions(q_obj, id, triplet, aug_stats, subj_has_multiple_presuppositions=False) -> List[AugmentQuestion]:
  
  questions = []

  ## type 1
  logger.debug('Generating Type 1...')
  questions += gen_type_1_from_questions(q_obj, id, triplet, aug_stats, subj_has_multiple_presuppositions=subj_has_multiple_presuppositions)

  ## type 2
  logger.debug('Generating Type 2...')
  questions += gen_type_2_from_questions(q_obj, id, triplet, aug_stats)

  ## type 3
  logger.debug('Generating Type 3...')
  questions += gen_type_3_from_questions(q_obj, id, triplet, aug_stats)

  ## type 4
  logger.debug('Generating Type 4...')
  questions += gen_type_4_from_questions(q_obj, id, triplet, aug_stats)

  ## type 5
  logger.debug('Generating Type 5...')
  questions += gen_type_5_from_questions(q_obj, id, triplet, aug_stats)

  return questions


In [None]:
"""
        self.q = q
        self.q_gec = q_gec
        self.failure_type = failure_type
        self.failure_reason = failure_reason
        self.gec_dist = gec_dist
        self.is_lemmatized_aug = is_lemmatised_aug
        self.generation_type = generation_type
        self.origin_id = origin_id
        self.properties = properties
"""

def generation_summary(questions: List[AugmentQuestion]):
  logger.debug("--Generation Summary--")
  logger.debug(f"Total Generated: {len(questions)}")
  logger.debug(f"Types Distribution: {Counter([q.failure_type for q in questions])}")

  for i in range(1,6):
    qs_type_i = [q for q in questions if q.failure_type == i]
    are_all_have_sub_types = [q.failure_sub_type is not None for q in qs_type_i]
    if are_all_have_sub_types:
      logger.debug(f"Type {i} Sub-Types Distribution: {Counter([q.failure_sub_type for q in qs_type_i])}")

  for i,q in enumerate(questions):
    logger.debug(f"Type=[{q.failure_type}][{q.failure_sub_type if q.failure_sub_type else -1}] Failure=[{q.failure_reason}], {q.q}")

  unique_properties_keys = sorted(set(chain(*[q.properties.keys() for q in questions])))
  props_summary = {}

  for key in unique_properties_keys:
    props_summary[key] = [q.properties[key] for q in questions if key in q.properties]

  logger.debug(f"Properties Distribution: {props_summary}")

## Pipeline

#### Pre Defined Pipeline

In [None]:
def predefined_pipeline(path_to_data: str, resume_obj: dict = None):
    """
    Pipeline to iterate over scene-graphs and generate pre-defined type of questions from each
    :param path_to_data: path to where the json of the data exists (questions objects)
    :param resume_obj: encapuslates all what necessary to resume pipeline from where it starts
    """
    #TODO: put your base folder path, where it will create all files
    base_folder = f''
    logger_file_mode = 'a'

    if resume_obj is None:

      dt_now = str(dt.datetime.now())

      folder = base_folder + dt_now

      data_folder = folder+'/data/'

      if not os.path.exists(data_folder):
        os.makedirs(data_folder)
      
      log_path = f'{folder}/output.log'

      f = open(log_path, 'w')
      f.close()
             
      # set seeds
      np.random.seed(42)

      numpy_random_state = np.random.get_state()

      # np.save(folder+'/random_state.npy', numpy_random_state)

      #setting stats
      stats_path = f'{folder}/stats.json'
      stats = {
          "iterated": 0, #total scene graph iterated
          "generated": 0, #total questions generated
          "types_dist": Counter({k:0 for k in range(1,6)}),
          "sub_type_dist": Counter({k: Counter() for k in range(1,6)}),
          "complete_ids": [] # finished scene graph ids
      }
      

    else:
      data_folder = resume_obj['data_folder']
      log_path = resume_obj['log_path']
      stats_path = resume_obj['stats_path']
      folder = resume_obj['folder_path']


      if log_path is None:
        logger_file_mode = 'w'
        dt_now = str(dt.datetime.now())
        log_path = f'{folder}/output-{dt_now}.log'
      
      with open(stats_path) as f:
        stats = json.load(f)

      random_state_path = resume_obj['random_state_path']
      with open(random_state_path, 'rb') as f:
        numpy_random_state = pickle.load(f)
        np.random.set_state(numpy_random_state)
    

    global logger
    logger = logging.getLogger('Generator')
    logger.setLevel("DEBUG")
    logger.handlers.clear()
    output_file_handler = logging.FileHandler(log_path, mode=logger_file_mode, encoding=None, delay=False)
    logger.addHandler(output_file_handler)

    logger.debug(f"data_folder=[{data_folder}], log_path=[{log_path}], stats_path=[{stats_path}]")
    logger.debug(f"Iterated=[{stats['iterated']}], Completed: [{len(stats['complete_ids'])}], Generated: [{stats['generated']}]")

    logger.debug("Begin Pre-Defined Questions Generation Pipeline")
    with open(path_to_data) as f:
        items = ijson.items(f, "item")
        for i, scene_graph_obj in tqdm(enumerate(items)):

          id = scene_graph_obj['image_id']
          
          if id in stats['complete_ids']: #skipping ids if resumed
            continue

          stats['iterated'] += 1
          logger.debug("=" * 80)

          try: 

              logger.debug(f'Generating for image-id={id}')
              start = timer()
              questions = generate_predefined(scene_graph_obj)
              logger.debug(f"generation time taken: [{timer() - start}] seconds")

              generation_summary(questions)

              # emit
              with open(data_folder+f"{id}.json", 'w') as nf:
                q_id_range = range(stats['generated'], stats['generated']+len(questions))
                #dict-format: {num_of_successfully_questions}->question object
                json.dump({k:v for k,v in zip(q_id_range, questions)}, nf, cls=ObjectEncoder)
                stats['generated'] += len(questions)

              # for q in questions: 
              #   #name-format: {num_of_successfully_questions}_{origin_id}
              #   with open(data_folder+f"{stats['generated']}_{id}.json", 'w') as nf:
              #     json.dump(q, nf, cls=ObjectEncoder)
                
              #   stats['generated'] += 1

              #mark finished
              stats['complete_ids'].append(id)

              # update stats
              stats["types_dist"].update(Counter([q.failure_type for q in questions]))

              for i in range(1,6):
                qs_type_i = [q for q in questions if q.failure_type == i]
                are_all_have_sub_types = [q.failure_sub_type is not None for q in qs_type_i]
                if are_all_have_sub_types:
                  stats["sub_type_dist"][i].update(Counter([q.failure_sub_type for q in qs_type_i]))
              
              # save state after stochastic methods
              numpy_random_state = np.random.get_state()
              with open(folder+'/random_state.pkl', 'wb') as bf:
                pickle.dump(numpy_random_state, bf)
              # np.save(folder+'/random_state.npy', numpy_random_state)

              #update stats
              with open(stats_path, 'w') as xf:
                  json.dump(stats, xf)
                              

          except Exception as e:
              logger.debug(f"Raised exception {e} on scene graph object id {id}, continue processing...")
              raise e

          logger.debug(f"Iterated=[{stats['iterated']}], Completed: [{len(stats['complete_ids'])}], Generated: [{stats['generated']}]")
          gen_ratio = stats['generated']/len(stats['complete_ids']) if len(stats['complete_ids']) > 0 else 0
          logger.debug(f"Generation Ratio(generated/iterated): {gen_ratio}")




In [None]:
#Path to Visual Genome scene graphs
train_sg_path = '.../train_gqa_vg_scene_graphs.json'
predefined_pipeline(path_to_data=train_sg_path)

### From Questions Pipeline

In [None]:
def from_questions_pipeline(path_to_data: str, resume_obj: dict = None):
    """
    Pipeline to iterate over questions and generate from_questions type of questions from each
    :param path_to_data: path to where the json of the data exists (questions objects)
    :param resume_obj: encapuslates all what necessary to resume pipeline from where it starts
    """
    #TODO: put your base folder path, where it will create all files
    base_folder = f''

    if resume_obj is None:

      dt_now = str(dt.datetime.now())

      folder = base_folder + dt_now

      data_folder = folder+'/data/'

      if not os.path.exists(data_folder):
        os.makedirs(data_folder)
      
      log_path = f'{folder}/output.log'

      f = open(log_path, 'w')
      f.close()

      # set seeds
      np.random.seed(42)

      numpy_random_state = np.random.get_state()

      #setting stats
      stats_path = f'{folder}/stats.json'
      stats = {
          "augment_stats": {"lemma": 0, "original": 0, "failed": 0}, #distribution of augmentation stats of 'generated' questions
          "iterated": 0, #num question iterated from the dataset that fit to the condition of query-rel
          "generated": 0, # total generated questions
          "types_dist": Counter(),
          "questions_extract_relations_fail": 0, #num questions iterated from the dataset and failed on relations extraction
          "complete_ids": [] #num questions iterated from the dataset and were used for augmentation
      }
      

    else:
      data_folder = resume_obj['data_folder']
      log_path = resume_obj['log_path']
      stats_path = resume_obj['stats_path']
      folder = resume_obj['folder_path']
      with open(stats_path) as f:
        stats = json.load(f)
      
      random_state_path = resume_obj['random_state_path']
      with open(random_state_path, 'rb') as f:
        numpy_random_state = pickle.load(f)
        np.random.set_state(numpy_random_state)
      
    global logger
    logger = logging.getLogger('Generator')
    logger.setLevel("DEBUG")
    logger.handlers.clear()
    output_file_handler = logging.FileHandler(log_path)
    logger.addHandler(output_file_handler)


    logger.debug(f"data_folder=[{data_folder}], log_path=[{log_path}], stats_path=[{stats_path}]")
    logger.debug(f"Iterated=[{stats['iterated']}], Completed: [{len(stats['complete_ids'])}], Generated: [{stats['generated']}]")


    logger.debug("Begin From Questions Generation Pipeline")
    with open(path_to_data) as f:
        qs = ijson.kvitems(f, '')

        for i, (id, q_obj) in tqdm(enumerate(qs)):
          
          if id in stats['complete_ids']: #skipping ids if resumed
            continue

          if 'types' in q_obj and 'structural' in q_obj['types'] and 'semantic' in q_obj['types'] \
                  and q_obj['types']['structural'] in structural and q_obj['types']['semantic'] in semantic:
              stats['iterated'] += 1
              logger.debug("=" * 80)
              try:
                  presupp_triplets = extract_triplets_from_semantic_program(q_obj)
                  if presupp_triplets is None or (len(presupp_triplets) > 0 and len(presupp_triplets[0]) != 3):
                      logger.debug(f"Relation presupposition extraction process failed on data point id: {id}")
                      stats['questions_extract_relations_fail'] += 1
                  elif len(presupp_triplets) > 0 and len(presupp_triplets[0]) == 3:  # hard validation!

                      triplets = [[str(x) for x in t] for t in presupp_triplets]

                      logger.debug(f'Generating for question-id={id}')
                      start = timer()

                      subj_has_multiple_presuppositions = len(triplets) > 1
                      questions = list(chain(*[generate_from_questions(q_obj, id, triplet, stats['augment_stats'],subj_has_multiple_presuppositions=subj_has_multiple_presuppositions) for triplet in triplets]))
                      logger.debug(f"generation time taken: [{timer() - start}] seconds")

                      generation_summary(questions)

                      # emit
                      for q in questions: 
                        #name-format: {num_generated_questions}_{origin_id}
                        with open(data_folder+f"{stats['generated']}_{id}.json", 'w') as nf:
                          json.dump(q, nf, cls=ObjectEncoder)
                        
                        stats['generated'] += 1

                      #mark finished
                      stats['complete_ids'].append(id)

                      # update stats
                      stats["types_dist"].update(Counter([q.failure_type for q in questions]))
                      
                      # save state after stochastic methods
                      numpy_random_state = np.random.get_state()
                      with open(folder+'/random_state.pkl', 'wb') as bf:
                        pickle.dump(numpy_random_state, bf)
                      

                      with open(stats_path, 'w') as xf:
                          json.dump(stats, xf)
                                  

              except Exception as e:
                  logger.debug(f"Raised exception {e} on question object id {id}, continue processing...")
                  raise e
                                      
              logger.debug(f"Iterated=[{stats['iterated']}], Completed: [{len(stats['complete_ids'])}], Generated: [{stats['generated']}]")
              gen_ratio = stats['generated']/len(stats['complete_ids']) if len(stats['complete_ids']) > 0 else 0
              logger.debug(f"Generation Ratio(generated/iterated): {gen_ratio}, Augmentation-Stats: [{stats['augment_stats']}]")
              logger.debug(f"Extracting-relations-fails: {stats['questions_extract_relations_fail']}")



In [None]:
#Path to GQA train balanced questions
path_to_data = '.../train_balanced_questions.json'
from_questions_pipeline(path_to_data=path_to_data)