In [1]:
from collections import Counter, defaultdict
from functools import reduce
import numpy as np
# implement PMI baseline for unsupervised parsing tree generation
train = None
val = None
test = None
linux_data_path = "/home/zijiao/work/data/mscoco/train_caps.txt"
mac_data_path = "/Users/zijiaoyang/Documents/data/mscoco/train_caps.txt"
file_path = "/Users/zijiaoyang/Documents/data/mscoco"

In [2]:
# 413915 captions so image is supposed to be 82783 
with open(mac_data_path, 'r') as f:
    words_doc = []
    bigram_doc = []
    sentences = []
    for line in f:
        #sentence = tokenizer(line.strip())
        #sentence = ['<s>'] + sentence + ['</s>']
        sentence = line.strip().lower().split()
        sentences.append(sentence)
        bigram_doc.extend(list(zip(sentence, sentence[1:])))
        words_doc.extend(sentence)

In [3]:
class Vocabulary(object):
    """Simple vocabulary wrapper."""

    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if word not in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

import pickle
with open("../data/mscoco/vocab.pkl", 'rb') as f:
    vocab = pickle.load(f)

In [4]:
# compute dicts
word2count = defaultdict(lambda: 0)
bigram2count = defaultdict(lambda: 0)
for word in words_doc:
    if vocab(word) == '<unk>':
        word2count[vocab('<unk>')] += 1
    word2count[vocab(word)] += 1
for w1, w2 in bigram_doc:
    if vocab(w1) == '<unk>' or vocab(w2) == '<unk>':
        word2count[vocab('<unk>'), vocab('<unk>')] += 1
    bigram2count[(vocab(w1), vocab(w2))] += 1   

In [5]:
# Compute probs
total_wdcounts = reduce(lambda a, b: a+b, list(word2count.values()))
p_uni = {word: count/total_wdcounts for word, count in word2count.items()}
total_bicounts = reduce(lambda a, b: a+b, list(bigram2count.values()))
p_bi = {bigram: count/total_bicounts for bigram, count in bigram2count.items()}

In [6]:
def pmi(word1, word2, p_uni=p_uni, p_bi=p_bi, smooth=.7):
    """
    Compute Negtive pointwise mutual information
    # add 1 smoothing
    """
    word1, word2 = vocab(word1), vocab(word2)
    return np.minimum(np.log(p_bi.get((word1, word2), 0)+smooth/(p_uni.get(word1, 0) + smooth) * (p_uni.get(word2, 0) +smooth)), 0)

def parse(distance, left, right):
    """
    Compute the paring boundary based on given syntactic distance
    
    Input: distances computed for a sentence,
    left and right are boundaries
    :return: boundaries
    """
    if left == right:
        return []
    #print(left, right)
    p = left + np.argmax(distance[left: right])
    return [(left, right)] + parse(distance, left, p) + parse(distance, p+1, right)

    

In [7]:
# Compute spans for tree in data_path
def compute_npmi(data_path, sm=.7):
    with open(data_path) as f:
        sent_distances = []
        for line in f:
            # original code line.strip().lower().split() we used tokenizer here
            #sentence = tokenizer(line.strip().lower())
            sentence = line.strip().lower().split()
            bis = zip(sentence, sentence[1:])
            # Compute negative pointwise mutual info
            dist = [pmi(word1, word2,smooth=sm) for word1, word2 in bis]
            sent_distances.append(dist)
            #sent_distances.append((sentence, dist))
    return sent_distances



In [21]:
# generate test dists
import os
import numpy as np
import pickle

data_path = '/Users/zijiaoyang/Documents/data/mscoco/'
bras = []
for sm in np.linspace(0.1, 5, num=20):
    sent_distances = compute_npmi(os.path.join(data_path, 'test_caps.txt'), sm=sm)
    brackets = [parse(dis, 0, len(dis)-1) for dis in sent_distances]
    bras.append((sm, brackets))

In [22]:
# TODO: compute f1 score for pmi baseline
# TODO: solve possible OOV problem, partly solved
# TODO: make data preprocssing same as original code, so fair compare can be made: DONE

In [23]:
import argparse
import os

#from evaluation import test_trees
#from vocab import Vocabulary

def extract_spans(tree):
    answer = list()
    stack = list()
    items = tree.split()
    curr_index = 0
    for item in items:
        if item == ')':
            pos = -1
            right_margin = stack[pos][1]
            left_margin = None
            while stack[pos] != '(':
                left_margin = stack[pos][0]
                pos -= 1
            assert left_margin is not None
            assert right_margin is not None
            stack = stack[:pos] + [(left_margin, right_margin)]
            answer.append((left_margin, right_margin))
        elif item == '(':
            stack.append(item)
        else:
            stack.append((curr_index, curr_index))
            curr_index += 1
    return answer


def extract_statistics(gold_tree_spans, produced_tree_spans):
    gold_tree_spans = set(gold_tree_spans)
    produced_tree_spans = set(produced_tree_spans)
    precision_cnt = sum(list(map(lambda span: 1.0 if span in gold_tree_spans else 0.0, produced_tree_spans)))
    recall_cnt = sum(list(map(lambda span: 1.0 if span in produced_tree_spans else 0.0, gold_tree_spans)))
    precision_denom = len(produced_tree_spans)
    recall_denom = len(gold_tree_spans)
    return precision_cnt, precision_denom, recall_cnt, recall_denom


def f1_score(produced_trees, gold_trees):
    gold_trees = list(map(lambda tree: extract_spans(tree), gold_trees))
    #produced_trees = list(map(lambda tree: extract_spans(tree), produced_trees))
    # TODO: get spans from pmi baseline, $$DONE
    assert len(produced_trees) == len(gold_trees)
    precision_cnt, precision_denom, recall_cnt, recall_denom = 0, 0, 0, 0
    for i, item in enumerate(produced_trees):
        pc, pd, rc, rd = extract_statistics(gold_trees[i], item)
        precision_cnt += pc
        precision_denom += pd
        recall_cnt += rc
        recall_denom += rd
    precision = float(precision_cnt) / precision_denom * 100.0
    recall = float(recall_cnt) / recall_denom * 100.0
    f1 = 2 * precision * recall / (precision + recall)
    return f1, precision, recall



# parser = argparse.ArgumentParser()
# parser.add_argument('--candidate', type=str, required=False,
#                     help='model path to evaluate')
# parser.add_argument('--produced_path', required=True, default='./',
#                     help='the path to produced_tree_spans')
# args = parser.parse_args()
# TODO: change path: Done
#ground_truth = [line.strip() for line in open(
#    os.path.join('/home/zijiao/work/data/mscoco/', 'test_ground-truth.txt'))]
ground_truth = [line.strip() for line in open(
    os.path.join('/Users/zijiaoyang/Documents/data/mscoco/', 'test_ground-truth.txt'))]

# import pickle
# with open(args.produced_path, 'rb') as f:
#     trees = pickle.load(f)
#trees = [line.strip() for line in open(os.path.join('/home/zijiao/work/VGSNLextend/trees.txt'))]
for sm, trees in bras:
    f1, precision, recall =  f1_score(trees, ground_truth)
#print('Model:', args.candidate)
    print(f'sm is {sm:.2f}')
    print(f'F1 score: {f1:.2f}, precision: {precision:.2f}, recall: {recall:.2f}')
# TODO: check if it works, it worked......
# TODO: generate tree file for test:DONE
# ! change vocab to default

sm is 0.10
F1 score: 32.20, precision: 29.29, recall: 35.75
sm is 0.36
F1 score: 32.23, precision: 29.31, recall: 35.78
sm is 0.62
F1 score: 32.23, precision: 29.32, recall: 35.78
sm is 0.87
F1 score: 32.34, precision: 29.42, recall: 35.90
sm is 1.13
F1 score: 34.68, precision: 31.55, recall: 38.50
sm is 1.39
F1 score: 34.68, precision: 31.55, recall: 38.50
sm is 1.65
F1 score: 34.68, precision: 31.55, recall: 38.50
sm is 1.91
F1 score: 34.68, precision: 31.55, recall: 38.50
sm is 2.16
F1 score: 34.68, precision: 31.55, recall: 38.50
sm is 2.42
F1 score: 34.68, precision: 31.55, recall: 38.50
sm is 2.68
F1 score: 34.68, precision: 31.55, recall: 38.50
sm is 2.94
F1 score: 34.68, precision: 31.55, recall: 38.50
sm is 3.19
F1 score: 34.68, precision: 31.55, recall: 38.50
sm is 3.45
F1 score: 34.68, precision: 31.55, recall: 38.50
sm is 3.71
F1 score: 34.68, precision: 31.55, recall: 38.50
sm is 3.97
F1 score: 34.68, precision: 31.55, recall: 38.50
sm is 4.23
F1 score: 34.68, precision: 3

In [65]:
with open('/Users/zijiaoyang/Documents/data/mscoco/test_caps.txt', 'r') as f:
    sents = []
    for line in f:
        sents.append(line.strip().lower())

In [66]:
sents[1]

'man riding a motor bike on a dirt road on the countryside .'