In [1]:
import os

files = []
for d, s, f in os.walk('ccgbank'):
    for file in f:
        if file.endswith("auto"):
            files.append(f"{d}/{file}")

In [2]:
import re

def parse(text):
    tokens = []
    pattern = re.compile(r"<L (.*?)>")
    matches = pattern.findall(text)

    for match in matches:
        parts = match.split()
        if len(parts) >= 4:
            word = parts[3]
            category = parts[0]
            cat2 = parts[1]
            tokens.append((word, category, cat2))
    
    return tokens

In [3]:
import tqdm
from collections import defaultdict

tags = defaultdict(list)
words = defaultdict(lambda: defaultdict(int))

for file in tqdm.tqdm(files):
    with open(file) as f:
        for line in f.readlines():
            sentence = parse(line)
            for word, c1, c2 in sentence:
                if word not in tags[(c1, c2)]:
                    tags[(c1, c2)].append(word) # sets aren't hashable
                for word2, c12, c22 in sentence:
                    if word != word2:
                        words[word][word2] += 1
                        words[word2][word] += 1

100%|██████████████████████████████████████| 2312/2312 [00:23<00:00, 100.25it/s]


In [4]:
import pickle
pickle.dump(dict(tags), open("tags", "wb"))
pickle.dump(dict(words), open("words", "wb"))

In [5]:
all_sents = []
for file in tqdm.tqdm(files):
    with open(file) as f:
        for line in f.readlines():
            all_sents.append(parse(line))

100%|█████████████████████████████████████| 2312/2312 [00:01<00:00, 1836.41it/s]


In [6]:
pickle.dump(all_sents, open("all_sents", "wb"))

In [7]:
import random
random.shuffle(all_sents)

In [10]:
sents = []
sents = [[i, ' '.join(j[0] for j in all_sents[int(i)])] for i in sents]
for i in range(1000):
    sent = []
    while len(sent) < 5 or len(sent) > 50:
        sent = random.choice(all_sents)
    sents.append([all_sents.index(sent), ' '.join(i[0] for i in sent)])

In [11]:
import csv
with open("devset.csv", "w") as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=['Sentence ID', 'Original Sentence', 'New Sentence'])
    writer.writeheader()

    for x in tqdm.tqdm(sents):
        sent_id, sent = x
        writer.writerow({'Sentence ID': sent_id, 'Original Sentence': sent})

100%|███████████████████████████████████| 1000/1000 [00:00<00:00, 105252.30it/s]


In [12]:
def frequency(word):
    return sum(words[word].values())

In [13]:
from llr import llr_2x2
k22a = 0
for k, v in words.items():
    k22a += sum(v.values())
k22a //= 2
def get_llr(w1, w2):
    k11 = words[w1][w2]
    k12 = sum(words[w1].values()) - k11
    k21 = sum(words[w2].values()) - k11
    k22 = k22a - k12 - k21 - k11
    return llr_2x2(k11, k12, k21, k22)    

def get_sentence_llr(sent):
    total = 0
    for token1 in sent:
        for token2 in sent:
            total += get_llr(token1, token2)
    return total

In [14]:
def replace(word, c1, c2):
    if word in ["n't", "'s", "'m", "to", "%"]:
        return False
    if c1 in [r'(S[dcl]\NP)/(S[b]\NP)', 'NP[nb]/N', r'(NP\NP)/NP', 'conj', ',', '.',  r'(S[b]\NP)/(S[adj]\NP)', r'(S[dcl]\NP)/(S[ng]\NP)', r'(NP\NP)/(NP\NP)', r'(S[dcl]\NP)/(S[pss]\NP)']:
        return False
    if c2 in ['$', 'DT', 'IN', 'CC', 'PRP', 'CD', 'MD', 'NNP']:
        return False
    return True

def get_replacements(sent_id):
    sent = all_sents[sent_id]
    # print("got sent", sent)
    repl = []
    for word, c1, c2 in sent:
        if not replace(word, c1, c2):
            continue
        else:
            w = None
            while not w:# or frequency(w) < 1000 or len(w) < 3:
                w = random.choice(tags[(c1, c2)])
            # print("replacing", word, c1, c2, 'with', w)
            # print(w, end = ' ') # frequency(w), 
            repl.append(w)
    return repl

def do_replacements(sent_id, repls):
    sent = all_sents[sent_id]
    # print("got sent", sent)
    new_sent = []
    i = 0
    for word, c1, c2 in sent:
        if not replace(word, c1, c2):
            new_sent.append(word)
        else:
            new_sent.append(repls[i])
            i += 1
    return new_sent

In [35]:
import csv
def clean(i):
    return i.replace(" ,", ",").replace(" .", ".").replace(" :", ":")
with open("devset_wrong.csv", "w") as devset_wrong:
    writer = csv.DictWriter(devset_wrong, fieldnames=['S1_ID', 'S2_ID', 'Original Sentence', 'New Sentence'])
    writer.writeheader()

    for _ in tqdm.tqdm(range(25)):
        s = []
        while len(s) < 5:
            s = random.choice(all_sents)
        s1i = all_sents.index(s)
        
        s2 = random.choice([i for i in all_sents if len(i) == len(s)])
        s2i = all_sents.index(s2)
        # print(' '.join(i[0] for i in s))
        # print(' '.join(do_replacements(s2i, get_replacements(s2i))))
        writer.writerow({
            'S1_ID': s1i,
            'S2_ID': s2i,
            'Original Sentence': clean(' '.join(i[0] for i in s)),
            'New Sentence': clean(' '.join(do_replacements(s2i, get_replacements(s2i))))
        })

100%|███████████████████████████████████████████| 25/25 [00:00<00:00, 76.92it/s]


In [16]:
do_replacements(184490, get_replacements(184490))

IndexError: list index out of range

In [17]:
import csv

rows = []
with open("devset.csv") as devset:
    reader = csv.DictReader(devset)
    for row in tqdm.tqdm(reader, total=len(sents)):
        sent_id = int(row["Sentence ID"])
        orig = row["Original Sentence"]

        replacement_candidates = [get_replacements(sent_id) for _ in range(10)]
        best_candidate = max(replacement_candidates, key=get_sentence_llr)
        # for i in replacement_candidates:
            # print(' '.join(do_replacements(sent_id, i)))
            # print("llr", get_sentence_llr(i))
        new_sentence = ' '.join(do_replacements(sent_id, best_candidate))
        # input()
        rows.append({"Sentence ID": sent_id, "Original Sentence": orig, "New Sentence": new_sentence})

with open("devset.csv", "w") as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=['Sentence ID', 'Original Sentence', 'New Sentence'])
    writer.writeheader()

    for row in tqdm.tqdm(rows):
        writer.writerow(row)


100%|███████████████████████████████████████| 1000/1000 [00:23<00:00, 41.99it/s]
100%|███████████████████████████████████| 1000/1000 [00:00<00:00, 267255.26it/s]


In [19]:
with open("devset_new/orig.csv") as f:
    reader = csv.DictReader(f)
    rows = []
    i = 0
    for row in tqdm.tqdm(reader):
        rows.append(row)
        if len(rows) == 50:
            with open(f"devset_new/orig_{i}.csv", "w") as f2:
                writer = csv.DictWriter(f2, fieldnames=['Sentence ID', 'Original Sentence', 'New Sentence'])
                writer.writeheader()

                for row in rows:
                    writer.writerow(row)
            i += 1
            rows = []

1000it [00:00, 39098.98it/s]


In [51]:
def parse_individual_tikz(tikz):
    # print(tikz)
    lines = tikz.split(r"\depedge[edge style={red!60!}, edge below]")
    lines = [i for i in lines if i.startswith("{")]
    # print(lines)
    lines = [[j.strip("{") for j in i.split("\n")[0].strip().split("}")[:-2]] for i in lines]
    # print(lines, len(lines))
    # lines = {a:b for a, b in lines}
    # print(lines)
    return lines
    
def parse_full_tikz(tikz_file):
    with open(tikz_file, "r") as f:
        for tikz in f.read().split(r"\begin{dependency}")[1:]:
            yield parse_individual_tikz(tikz)

def compute_uuas(ground, test):
    n_correct = 0
    n_total = 0
    for (k, v) in test:
        n_total += 1
        if [k, v] in ground:
            n_correct += 1
    print(test, n_correct, n_total)
    return n_correct / n_total

for i in range(1, 4):
    file = f"/Users/simonchervenak/Documents/GitHub/structural-probes/example/results/test{i}/demo.tikz"
    orig, *new = parse_full_tikz(file)

    print(orig)
    x = [compute_uuas(orig, i) for i in new]
    print(f"Average UUAS for test {i}:", sum(x) / len(x))

[['10', '12'], ['2', '4'], ['7', '8'], ['3', '4'], ['4', '18'], ['8', '10'], ['14', '15'], ['5', '6'], ['11', '12'], ['16', '18'], ['4', '6'], ['10', '13'], ['9', '10'], ['15', '16'], ['12', '15'], ['1', '4'], ['16', '17']]
[['11', '12'], ['16', '18'], ['3', '4'], ['14', '15'], ['8', '10'], ['2', '4'], ['5', '6'], ['6', '7'], ['9', '10'], ['15', '18'], ['1', '4'], ['10', '12'], ['16', '17'], ['10', '13'], ['4', '6'], ['6', '8'], ['13', '14']] 13 17
[['7', '8'], ['11', '12'], ['2', '4'], ['8', '10'], ['3', '4'], ['9', '10'], ['10', '12'], ['10', '13'], ['16', '18'], ['14', '15'], ['1', '4'], ['16', '17'], ['15', '16'], ['4', '18'], ['5', '6'], ['12', '15'], ['4', '6']] 17 17
[['7', '8'], ['3', '4'], ['8', '10'], ['14', '15'], ['2', '4'], ['16', '18'], ['9', '10'], ['15', '16'], ['6', '8'], ['1', '4'], ['5', '6'], ['11', '12'], ['4', '6'], ['10', '12'], ['16', '17'], ['10', '13'], ['4', '18']] 16 17
[['14', '15'], ['8', '10'], ['7', '8'], ['16', '18'], ['4', '6'], ['5', '6'], ['3', '4'],

In [39]:
from llr import llr_2x2

In [40]:
k22a = 0
for k, v in words.items():
    k22a += sum(v.values())
k22a //= 2
def get_llr(w1, w2):
    k11 = words[w1][w2]
    k12 = sum(words[w1].values()) - k11
    k21 = sum(words[w2].values()) - k11
    k22 = k22a - k12 - k21 - k11
    return llr_2x2(k11, k12, k21, k22)

In [47]:
get_llr('created', 'doors')

4.6881918320432305

In [None]:
distance = []
for i in range(1, 4):
    file = f"/Users/simonchervenak/Documents/GitHub/structural-probes/example/results/test{i}/demo.tikz"
    orig, *new = parse_full_tikz(file)

    test = f"/Users/simonchervenak/Documents/GitHub/structural-probes/test{i}"
    orig_sent, *new_sents = open(test).read().split("\n")
    orig_sent = orig_sent.split()
    new_sents = [new_sent.split() for new_sent in new_sents]
    
    for n in new:
        

In [None]:
doc = nlp("The chef who ran to the stores is out of food.")
displacy.serve(doc, style="dep", auto_select_port=True)


Using the 'dep' visualizer
Serving on http://0.0.0.0:5001 ...



In [16]:
doc = nlp("The Eritrean rubric was indicated to match buying inactivation for skin beasts.")
displacy.serve(doc, style="dep", auto_select_port=True)


Using the 'dep' visualizer
Serving on http://0.0.0.0:5001 ...

Shutting down server on port 5001.
