In [18]:
import os
DATA_PATH = '../data'
import matplotlib.pyplot as plt
import numpy as np

datasets = ['ICEWS14', 'ICEWS05-15', 'ICEWS18']
sets = ['train', 'valid', 'test', 'relation2id', 'entity2id']

In [19]:
def count_lines(f):
    with open(f, 'r') as f:
        return len(f.readlines())
    
def get_num_lines(dataset):
    num_lines = {}
    for s in sets:
        num_lines[s] = count_lines(os.path.join(DATA_PATH, dataset, s+'.txt'))
    
    return num_lines

In [20]:
for i, dataset in enumerate(datasets):
    num_lines = get_num_lines(dataset)
    print(dataset, num_lines)

ICEWS14 {'train': 74845, 'valid': 8514, 'test': 7371, 'relation2id': 230, 'entity2id': 7128}
ICEWS05-15 {'train': 368868, 'valid': 46302, 'test': 46159, 'relation2id': 251, 'entity2id': 10488}
ICEWS18 {'train': 373018, 'valid': 45995, 'test': 49545, 'relation2id': 256, 'entity2id': 23033}


In [21]:
from icews_utils import ICEWSDataset
from collections import defaultdict

test = ICEWSDataset(dir_path=DATA_PATH, dataset_name='ICEWS14', filename='test', idx=[0, 7370]).data

In [22]:
templates = {
    "Make_statement": [
        "Will {subject} make a statement about {object} on {date}?",
        "Is {subject} expected to make a statement about {object} on {date}?",
        "Do you think {subject} will make a statement about {object} on {date}?"
    ],
    "Make_an_appeal_or_request": [
        "Will {subject} make an appeal or request to {object} on {date}?",
        "Is {subject} expected to make an appeal or request to {object} on {date}?",
        "Do you think {subject} will make an appeal or request to {object} on {date}?"
    ],
    "Consult": [
        "Will {subject} consult {object} on {date}?",
        "Is {subject} expected to consult {object} on {date}?",
        "Do you think {subject} will consult {object} on {date}?"
    ],
    "Arrest,_detain,_or_charge_with_legal_action": [
        "Will {subject} arrest, detain, or charge {object} with legal action on {date}?",
        "Is {subject} expected to take any legal action against {object} on {date}?",
        "Do you think {subject} will arrest, detain, or charge {object} with legal action on {date}?"
    ]
}

In [11]:
entities = []
with open('../data/ICEWS14/entity2id.txt', 'r') as f:
    for l in f.readlines():
        entities.append(l.split()[0])

print(entities[:10])

['China', 'Iran', 'Citizen_(Nigeria)', 'Citizen_(India)', 'Barack_Obama', 'Japan', 'John_Kerry', 'South_Korea', 'Iraq', 'Government_(Nigeria)']


In [23]:
import random
import json
import os

qa_pairs = []

if os.path.exists('qa_pairs.json'):
    qa_pairs = json.load(open('qa_pairs.json', 'r'))
else:
    for r in test:
        if r[3] == '2014-12-31' and r[1] in templates:
            rand_idx = random.randint(0, len(templates[r[1]]) - 1)
            yes_qa_pair = {
                'question': templates[r[1]][rand_idx].format(subject=r[0], object=r[2], date=r[3]),
                'answer': 'Yes'
            }

            rand_entity = random.choice(entities)
            while rand_entity == r[2]:
                rand_entity = random.choice(entities)
            
            no_qa_pair = {
                'question': templates[r[1]][rand_idx].format(subject=r[0], object=rand_entity, date=r[3]),
                'answer': 'No'
            }
            
            qa_pairs.append(yes_qa_pair)
            qa_pairs.append(no_qa_pair)

    json.dump(qa_pairs, open('qa_pairs1.json', 'w'), indent=4, ensure_ascii=False)


In [7]:
import jsonlines

def get_answers(f):
    answers = []
    with jsonlines.open(f) as reader:
        for obj in reader:
            if obj['response'].lower().startswith('yes'):
                answers.append('Yes')
            elif obj['response'].lower().startswith('no'):
                answers.append('No')
            else:
                print(obj['response'])
    return answers

In [8]:
def get_accuracy(answers):
    accuracy = 0
    for i, qa_pair in enumerate(qa_pairs):
        if qa_pair['answer'] == answers[i]:
            accuracy += 1

    return accuracy / len(qa_pairs)

In [14]:
answers = get_answers('test_result/output_2023-04-26-16-28-53.txt')
print('acc:', get_accuracy(answers))
print('y/n ratio', answers.count('Yes') / answers.count('No'))

acc: 0.5416666666666666
y/n ratio 0.3584905660377358


In [15]:
answers = get_answers('test_result/output_2023-04-26-19-34-28.txt')
print('acc:', get_accuracy(answers))
print('y/n ratio', answers.count('Yes') / answers.count('No'))

acc: 0.5069444444444444
y/n ratio 7.470588235294118
