In [2]:
from transformers import BertTokenizer, FlavaModel
import spacy
import torch
import pickle
import json
import pandas as pd
import ast
from tqdm import tqdm

model = FlavaModel.from_pretrained("facebook/flava-full")
tokenizer = BertTokenizer.from_pretrained("facebook/flava-full")
nlp = spacy.load("en_core_web_sm")

`text_config_dict` is provided which will be used to initialize `FlavaTextConfig`. The value `text_config["id2label"]` will be overriden.
`multimodal_config_dict` is provided which will be used to initialize `FlavaMultimodalConfig`. The value `multimodal_config["id2label"]` will be overriden.
`image_codebook_config_dict` is provided which will be used to initialize `FlavaImageCodebookConfig`. The value `image_codebook_config["id2label"]` will be overriden.


In [31]:

def get_graph(sentence, debug=False):
    inputs = tokenizer(sentence)
    input_ids = inputs["input_ids"]
    debug and print(input_ids)
    raw_tokens = tokenizer.convert_ids_to_tokens(input_ids)
    debug and print(raw_tokens)
    tokens = list(filter(lambda x: x != "[CLS]" and x != "[SEP]", raw_tokens))
    debug and print(tokens)
    edges = nlp(sentence)
    for word in edges:
        debug and print(f"{word.text} [{word.i}] -- {word.dep_} --> {word.head.text} [{word.head.i}]")

    class Graph:
        def __init__(self, n):
            self.n = n
            self.mat = [[None for i in range(n)] for j in range(n)]

        def add_edge(self, u, v, edge_type):
            self.mat[u][v] = edge_type
        
        def get_01_graph(self):
            return [[1 if self.mat[i][j] is not None else 0 for j in range(self.n)] for i in range(self.n)]
        

    graph = Graph(len(tokens) + 2) # +2 for [CLS] and [SEP]
    sentence_arr = [token.text for token in edges]
    sentence_to_token_idxs = []
    i = 0
    for word in edges:
        token_idxs = [i]
        i += 1
        while i < len(tokens) and "##" in tokens[i]:
            token_idxs.append(i)
            i += 1
        sentence_to_token_idxs.append(token_idxs)

    debug and print("\n")
    for word in edges:
        debug and print(f"{word.text} {[tokens[i] for i in sentence_to_token_idxs[word.i]]}")
        from_nodes = sentence_to_token_idxs[word.i]
        to_nodes = sentence_to_token_idxs[word.head.i]
        for from_node in from_nodes:
            for to_node in to_nodes:
                graph.add_edge(from_node + 1, to_node + 1, word.dep_)

    # assert first column and first row is always 0
    # assert last column and last row is always 0
    for i in range(len(graph.mat)):
        assert graph.mat[i][0] is None
        assert graph.mat[0][i] is None
        assert graph.mat[i][-1] is None
        assert graph.mat[-1][i] is None

    # print 01 matrix
    for row in graph.get_01_graph():
        debug and print(row)

    # print type matrix
    for row in graph.mat:
        debug and print(row)

    return (graph.get_01_graph(), graph.mat, raw_tokens)

sentence = "This is an image of a bent 15 mph and crosswalk sign"
result = get_graph(sentence, debug=True)

with open("test_graph.pkl", "wb") as f:
    pickle.dump(result, f)


[101, 2023, 2003, 2019, 3746, 1997, 1037, 6260, 2321, 5601, 1998, 2892, 17122, 3696, 102]
['[CLS]', 'this', 'is', 'an', 'image', 'of', 'a', 'bent', '15', 'mph', 'and', 'cross', '##walk', 'sign', '[SEP]']
['this', 'is', 'an', 'image', 'of', 'a', 'bent', '15', 'mph', 'and', 'cross', '##walk', 'sign']
This [0] -- nsubj --> is [1]
is [1] -- ROOT --> is [1]
an [2] -- det --> image [3]
image [3] -- attr --> is [1]
of [4] -- prep --> image [3]
a [5] -- det --> mph [8]
bent [6] -- amod --> mph [8]
15 [7] -- nummod --> mph [8]
mph [8] -- pobj --> of [4]
and [9] -- cc --> image [3]
crosswalk [10] -- compound --> sign [11]
sign [11] -- conj --> image [3]


This ['this']
is ['is']
an ['an']
image ['image']
of ['of']
a ['a']
bent ['bent']
15 ['15']
mph ['mph']
and ['and']
crosswalk ['cross', '##walk']
sign ['sign']
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0

In [32]:
def get_sentences(tsv_path):
    df = pd.read_csv(tsv_path, sep='\t')
    df["neg_caption_arr"] = df["neg_caption"].apply(ast.literal_eval)

    sentences = []
    for index, row in df.iterrows():
        sentences.append(row["title"])
        for sentence in row["neg_caption_arr"]:
            sentences.append(sentence)
    return sentences

def clean_sentence(sentence):
    sentence = sentence.replace(" shes ", " she's ").replace(" whats ", " what's ")
    sentence = sentence.replace(" thats ", " that's ").replace(" hes ", " he's ").replace(" wont ", " will not ")
    sentence = sentence.replace(" cannot ", " can not ").replace(" cant ", " can not ").replace("Doesnt", "Does not")
    sentence = sentence.replace("theres", "there is").replace(" dont ", " do not ").replace(" dont.", " do not.")
    sentence = sentence.replace(" flower po0t", " flower pot").replace("the girl id licking", "the girl is licking")
    sentence = sentence.replace(" 15mph ", " 15 mph ").replace(" 11am", " 11 am")
    sentence = sentence.replace("\n", "")
    sentence = sentence.replace(",  ", ", ").replace(" .", ".")
    sentence = sentence.replace("  ", " ").replace("  ", " ")
    sentence = sentence.strip()
    return sentence

sentences = []
sentences += get_sentences("../aro_dataset/train_neg_clip.tsv")
sentences += get_sentences("../aro_dataset/valid_neg_clip.tsv")

sentences = list(map(clean_sentence, sentences))

print(len(sentences))
    


286147


In [33]:
import multiprocessing

results = {}


def process_sentence(sentence):
    try:
        return get_graph(sentence, debug=False)
    except:
        print("Error:", sentence)
        return None


with multiprocessing.Pool(12) as pool:
    graphs = pool.map(process_sentence, sentences)

for sentence, graph in zip(sentences, graphs):
    results[sentence] = graph

with open("graph.pkl", "wb") as f:
    pickle.dump(results, f)

In [34]:
# print first 5 graphs
for sentence, graph in zip(sentences[:5], graphs[:5]):
    print(sentence)
    print(graph)
    print("\n\n")

A woman marking a cake with the back of a chef's knife.
([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], [[None, None, None, None, None, None, No

In [None]:
# tsv_path = "../aro_dataset/valid_neg_clip.tsv"

# def fix_line1(line1):
#     line1 = line1.strip()
#     if line1[0] == '"':
#         line1 = line1[1:]
#     if line1[-1] == '"':
#         line1 = line1[:-1]
#     return line1

# def fix_line3(line3):
#     line3 = line3.replace('"[""', "['").replace('""]"', "']")
#     line3 = line3.replace('"[', "[").replace(']"', "]")
#     line3 = line3.replace('"', '\\"')
#     line3 = line3.replace("\\'", "'")
#     line3 = line3.replace("['", '["').replace("']", '"]')
#     line3 = line3.replace('"", ""', '", "').replace("', '", '", "')
#     return line3

# def get_sentences(tsv_path):
#     with open(tsv_path, "r") as f:
#         lines = f.readlines()[1:]
#     lines = list(map(lambda x: x.strip().split("\t"), lines))

#     sentences = []
#     for line in lines:
#         sentences.append(line[1])
#         if len(line) < 4:
#             print(f"line {line} has no line3")
#         # print(line[3], fix_line3(line[3]))
#         sentences += json.loads(fix_line3(line[3]))
#     return sentences

# print(get_sentences(tsv_path)[:10])

# with open("aro_sentences.txt", "w") as f:
#     for sentence in get_sentences(tsv_path):
#         f.write(sentence + "\n")