In [1]:
import numpy as np
import json

# load and parse dataset

In [2]:
!file countries.json -I

countries.json: text/plain; charset=utf-8


In [3]:
with open('countries.json', encoding='utf-8') as f:
    raw_data = json.load(f)

In [4]:
countries = []
cca3_codes_to_country = dict()
for country in raw_data:
    name = country['name']['official']
    cca3_codes_to_country[country['cca3']] = name

for country in raw_data:
    name = country['name']['official']
    capital = country['capital']
    region = country['region']
    subregion = country['subregion']
    neighbors = [cca3_codes_to_country[cca3_code] for cca3_code in country['borders']]

    if len(name) == 0 or len(capital) == 0 or len(region) == 0 or len(subregion) == 0:
        print("skipping", name)
        continue
    
    countries += [{'name': name, 'capital': capital, 'subregion': subregion, 'region': region, 'neighbors': neighbors}]

skipping Antarctica
skipping Territory of the French Southern and Antarctic Lands
skipping Bouvet Island
skipping Heard Island and McDonald Islands
skipping Macao Special Administrative Region of the People's Republic of China
skipping United States Minor Outlying Islands


# generate splits

In [5]:
train = set()
valid = set()
test = set()
neighbor_rule = set()
located_in_rule = set()

In [6]:
import random
random.Random(42).shuffle(countries)

In [7]:
n_countries = len(countries)
processed_countries = set()

for i, country in enumerate(countries):
    train.add((country['name'], 'located_in', country['subregion']))
    train.add((country['subregion'], 'located_in', country['region']))

    processed_countries.add(country['name'])

    for neighbor in country['neighbors']:
        if neighbor not in processed_countries:
            train.add((country['name'], 'is_neighbor_of', neighbor))
            valid.add((neighbor, 'is_neighbor_of', country['name']))
            neighbor_rule.add((neighbor, 'is_neighbor_of', country['name']))

    valid.add((country['name'], 'located_in', country['region']))
    located_in_rule.add((country['name'], 'located_in', country['region']))

In [8]:
train = list(train)
valid = list(valid)
random.Random(42).shuffle(valid)
valid, test = valid[:len(valid) // 2], valid[len(valid) // 2:]

In [9]:
print(len(train))
print(len(valid))
print(len(test))

589
283
283


# check that splits are mutually exclusive

In [10]:
for triple in train:
    if triple in valid:
        print("valid", triple)
    if triple in test:
        print("valid", triple)

In [11]:
for triple in valid:
    if triple in train:
        print("train", triple)
    if triple in test:
        print("test", triple)

In [12]:
for triple in test:
    if triple in train:
        print("train", triple)
    if triple in valid:
        print("valid", triple)

# save splits as .tsv

In [13]:
with open("train.tsv", "w", encoding='utf-8') as f:
    for triple in train:
        f.write("{}\t{}\t{}\n".format(*triple))

In [14]:
with open("valid.tsv", "w", encoding='utf-8') as f:
    for triple in valid:
        f.write("{}\t{}\t{}\n".format(*triple))

In [15]:
with open("test.tsv", "w", encoding='utf-8') as f:
    for triple in test:
        f.write("{}\t{}\t{}\n".format(*triple))

In [16]:
with open("neighbor_rule.tsv", "w", encoding='utf-8') as f:
    for triple in neighbor_rule:
        f.write("{}\t{}\t{}\n".format(*triple))

In [17]:
with open("located_in_rule.tsv", "w", encoding='utf-8') as f:
    for triple in located_in_rule:
        f.write("{}\t{}\t{}\n".format(*triple))