In [1]:
from collections import defaultdict, Counter
from collections import defaultdict as ddict
import itertools
import json
import pickle
import random

import dill
import networkx as nx
import numpy as np
import penman
from torch_geometric.utils import to_networkx
from transformers import AutoTokenizer

from amr.inspection import get_type_paths_for_relation

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
datasets = ["risec", "japflow", "mscorpus"]

In [3]:
amr_content = {}
preprocessed_data = {}

for dataset in datasets:
    with open(f"data/{dataset}/amr_train.pkl", "rb") as f:
        amr_content[dataset] = pickle.load(f)

    with open(f"data/{dataset}/data_amr.dill", "rb") as f:
        preprocessed_data[dataset] = dill.load(f)
    

with open("data/amr_rel2id.json") as f:
    amr_rel2id = json.load(f)

amr_id2rel = {v:k for k, v in amr_rel2id.items()}

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [7]:
print(amr_content["risec"][0])

[{'graph': <Graph object (top=z1) at 4384311328>, 'text': '1) In a saucepan over low heat, stir together the half-and-half and sugar.'}, {'graph': <Graph object (top=z1) at 4502514560>, 'text': ' 2) Whisk in egg yolks and cook until light custard forms; do not boil.'}, {'graph': <Graph object (top=z1) at 5729410736>, 'text': ' 3) Remove from heat and cool, then place in refrigerator and chill overnight.'}, {'graph': <Graph object (top=z1) at 5729506736>, 'text': ' 4) Into the chilled mixture, stir in vanilla, cream, dark rum, scotch, and vanilla ice cream. 5) Serve cold.'}, {'graph': None, 'text': ' '}]


In [8]:


for dataset in datasets:
    total = 0
    missing_amrs = 0
    for instance in amr_content[dataset]:
        for sentence in instance:
            total += 1
            if sentence["graph"] is None:
                missing_amrs += 1
    print(f"{dataset}: {missing_amrs}/{total} sentences missing AMRs")
    

risec: 122/1006 sentences missing AMRs
japflow: 0/2628 sentences missing AMRs
mscorpus: 66/1308 sentences missing AMRs


In [11]:
print(tokenizer.decode(np.array(instance["tokens"])[np.array(instance["arg1_ids"]).astype(bool)]))
print(tokenizer.decode(np.array(instance["tokens"])[np.array(instance["arg2_ids"]).astype(bool)]))


stir
low heat


In [35]:
label_counts = {}

for dataset in datasets:
    labels = [instance["label"] for instance in preprocessed_data[dataset]["train"]["rels"]]
    label_counts[dataset] = Counter(labels)

print(json.dumps(label_counts, indent=4))

{
    "risec": {
        "ArgM_LOC": 231,
        "ArgM_MNR": 227,
        "Arg_PPT": 2219,
        "ArgM_TMP": 440,
        "Arg_DIR": 42,
        "Arg_GOL": 349,
        "ArgM_INT": 26,
        "ArgM_PRP": 73,
        "Arg_PAG": 10,
        "Arg_PRD": 39,
        "ArgM_SIM": 33
    },
    "japflow": {
        "targ": 5458,
        "dest": 1759,
        "other-mod": 2426,
        "f-eq": 980,
        "agent": 600,
        "t-eq": 267,
        "v-tm": 553,
        "t-comp": 557,
        "a-eq": 222,
        "f-part-of": 660,
        "f-comp": 267,
        "t-part-of": 190,
        "f-set": 19
    },
    "mscorpus": {
        "Information_Of": 3138,
        "Coref_Of": 177,
        "Recipe_Precursor": 674,
        "Number_Of": 2127,
        "Next_Operation": 2251,
        "Recipe_Target": 271,
        "Apparatus_Of": 358,
        "Solvent_Material": 355,
        "Participant_Material": 1462,
        "Condition_Of": 1386,
        "Type_Of": 131
    }
}


In [23]:
dataset_to_paths_info = {}

for dataset in datasets:
    label_to_amr_paths = defaultdict(list)
    for instance in preprocessed_data[dataset]["train"]["rels"]:
        label = instance["label"]
        graph_data = instance["amr_data"]
        type_paths = get_type_paths_for_relation(graph_data)
        label_to_amr_paths[label].extend(type_paths)
        label_to_amr_paths["ALL"].extend(type_paths)
    dataset_to_paths_info[dataset] = dict(label_to_amr_paths)


In [27]:
atomic_counters = {}

for dataset in datasets:
    dataset_counters = {}
    atomic_counters[dataset] = dataset_counters
    paths_info = dataset_to_paths_info[dataset]
    for label, paths in paths_info.items():
        path_tuples = [tuple(path) for path in paths]
        dataset_counters[label] = Counter(path_tuples)

In [28]:
label = "ALL"

print(len(dataset_to_paths_info["risec"][label]))
atomic_counters["risec"][label].most_common()

5204


[((':arg1',), 531),
 ((':arg2',), 297),
 ((':op2', ':op1'), 155),
 ((':location',), 132),
 ((':arg1', ':op1'), 132),
 ((':arg1', ':op2'), 132),
 (('STAR', 'STAR'), 118),
 ((':li', 'STAR', 'STAR'), 115),
 ((':time',), 105),
 ((':time', ':op1'), 96),
 ((':arg1', ':mod'), 87),
 ((':arg1', ':op3'), 83),
 ((':li',), 75),
 ((':location', ':mod'), 69),
 ((':arg2', ':mod'), 63),
 ((':arg2', ':arg1'), 60),
 ((':arg1', ':op4'), 58),
 ((':arg1', ':arg1'), 58),
 ((':duration', ':op1', ':unit'), 56),
 ((':purpose',), 54),
 ((':duration', ':op1', ':quant'), 52),
 ((':li', 'STAR', 'STAR', ':op2'), 47),
 ((':manner',), 46),
 ((':op1', 'STAR', 'STAR'), 44),
 ((':arg2', ':quant'), 43),
 ((':arg2', ':op1'), 42),
 ((':duration', ':unit'), 42),
 ((':li', ':arg1'), 41),
 (('STAR', 'STAR', ':op2'), 40),
 ((':duration', ':quant'), 39),
 ((':arg1', ':consist'), 38),
 ((':arg2', ':op2'), 34),
 ((':arg1', ':op5'), 31),
 ((':duration', ':op2'), 30),
 ((':li', ':arg2', ':quant'), 29),
 ((':time', ':op1', ':arg1'),

In [29]:
unigram_counters = {}

for dataset in datasets:
    dataset_counters = {}
    unigram_counters[dataset] = dataset_counters
    paths_info = dataset_to_paths_info[dataset]
    for label, paths in paths_info.items():
        path_units = list(itertools.chain(*paths))
        dataset_counters[label] = Counter(path_units)

In [31]:
unigram_counters["risec"]["ALL"].most_common()

[(':arg1', 2426),
 (':op1', 1517),
 (':op2', 1051),
 (':arg2', 1007),
 ('STAR', 968),
 (':mod', 562),
 (':li', 544),
 (':quant', 517),
 (':duration', 511),
 (':location', 466),
 (':time', 464),
 (':unit', 288),
 (':op3', 252),
 (':manner', 157),
 (':purpose', 151),
 (':consist', 129),
 (':op4', 123),
 (':arg0', 96),
 (':part', 67),
 (':arg3', 58),
 (':op5', 55),
 (':instrument', 50),
 (':op6', 36),
 (':destination', 35),
 (':arg4', 27),
 (':source', 26),
 (':accompanier', 22),
 (':op7', 19),
 (':degree', 19),
 (':mode', 18),
 (':name', 13),
 (':direction', 10),
 (':frequency', 10),
 (':domain', 9),
 (':condition', 9),
 (':polarity', 8),
 (':op8', 6),
 (':op9', 4),
 (':ord', 4),
 (':extent', 4),
 (':op11', 3),
 (':op10', 2)]

In [18]:
unigram_counters["risec"].keys()

dict_keys(['ArgM_LOC', 'ArgM_MNR', 'Arg_PPT', 'ArgM_TMP', 'Arg_DIR', 'Arg_GOL', 'ArgM_INT', 'ArgM_PRP', 'Arg_PAG', 'Arg_PRD', 'ArgM_SIM'])

-