In [1]:
import numpy as np
import json

# load and parse dataset

In [2]:
!file countries.json -I

countries.json: text/plain; charset=utf-8


In [3]:
with open('countries.json', encoding='utf-8') as f:
    raw_data = json.load(f)

In [4]:
countries = []
cca3_codes_to_country = dict()
for country in raw_data:
    name = country['name']['official']
    cca3_codes_to_country[country['cca3']] = name

for country in raw_data:
    name = country['name']['official']
    capital = country['capital']
    region = country['region']
    subregion = country['subregion']
    neighbors = [cca3_codes_to_country[cca3_code] for cca3_code in country['borders']]

    if len(name) == 0 or len(capital) == 0 or len(region) == 0 or len(subregion) == 0:
        print("skipping", name)
        continue
    
    countries += [{'name': name, 'capital': capital, 'subregion': subregion, 'region': region, 'neighbors': neighbors}]

skipping Antarctica
skipping Territory of the French Southern and Antarctic Lands
skipping Bouvet Island
skipping Heard Island and McDonald Islands
skipping Macao Special Administrative Region of the People's Republic of China
skipping United States Minor Outlying Islands


In [5]:
countries[5]

{'name': 'Republic of Albania',
 'capital': 'Tirana',
 'subregion': 'Southern Europe',
 'region': 'Europe',
 'neighbors': ['Montenegro',
  'Hellenic Republic',
  'Republic of Macedonia',
  'Republic of Kosovo']}

# generate splits

In [6]:
train = set()
valid = set()
test = set()
neighbor_rule = set()
located_in_rule = set()

In [7]:
import random
random.Random(42).shuffle(countries)

In [8]:
n_countries = len(countries)
processed_countries = set()

for i, country in enumerate(countries):
    train.add((country['name'], 'located_in', country['subregion']))
    train.add((country['subregion'], 'located_in', country['region']))

    processed_countries.add(country['name'])

    for neighbor in country['neighbors']:
        if neighbor not in processed_countries:

            train.add((country['name'], 'is_neighbor_of', neighbor))

            if i < n_countries / 2:
                valid.add((neighbor, 'is_neighbor_of', country['name']))
            else:
                test.add((neighbor, 'is_neighbor_of', country['name']))

            neighbor_rule.add((neighbor, 'is_neighbor_of', country['name']))

    if i < n_countries / 2:
        valid.add((country['name'], 'located_in', country['region']))
    else:
        test.add((country['name'], 'located_in', country['region']))

    located_in_rule.add((country['name'], 'located_in', country['region']))

In [9]:
list(train)[-50:]

[('Republic of Chad', 'located_in', 'Middle Africa'),
 ('Central African Republic', 'is_neighbor_of', 'Republic of the Sudan'),
 ('Swiss Confederation', 'is_neighbor_of', 'Federal Republic of Germany'),
 ('Kingdom of Thailand', 'located_in', 'South-Eastern Asia'),
 ('Republic of Cuba', 'located_in', 'Caribbean'),
 ('Republic of Angola', 'is_neighbor_of', 'Republic of Namibia'),
 ('Republic of Guatemala', 'is_neighbor_of', 'Republic of El Salvador'),
 ('Sahrawi Arab Democratic Republic',
  'is_neighbor_of',
  'Islamic Republic of Mauritania'),
 ('Swiss Confederation', 'is_neighbor_of', 'French Republic'),
 ('Italian Republic', 'is_neighbor_of', 'French Republic'),
 ('Polynesia', 'located_in', 'Oceania'),
 ('State of Eritrea', 'located_in', 'Eastern Africa'),
 ('Republic of Nicaragua', 'is_neighbor_of', 'Republic of Honduras'),
 ('Republic of Mali', 'is_neighbor_of', 'Republic of Niger'),
 ('Guiana', 'is_neighbor_of', 'Republic of Suriname'),
 ('Territory of the Wallis and Futuna Islands

In [10]:
print(len(train))
print(len(valid))
print(len(test))

589
351
215


# check that splits are mutually exclusive

In [11]:
for triple in train:
    if triple in valid:
        print("valid", triple)
    if triple in test:
        print("valid", triple)

In [12]:
for triple in valid:
    if triple in train:
        print("train", triple)
    if triple in test:
        print("test", triple)

In [13]:
for triple in test:
    if triple in train:
        print("train", triple)
    if triple in valid:
        print("valid", triple)

# save splits as .tsv

In [14]:
with open("train.tsv", "w", encoding='utf-8') as f:
    for triple in train:
        f.write("{}\t{}\t{}\n".format(*triple))

In [15]:
with open("valid.tsv", "w", encoding='utf-8') as f:
    for triple in valid:
        f.write("{}\t{}\t{}\n".format(*triple))

In [16]:
with open("test.tsv", "w", encoding='utf-8') as f:
    for triple in test:
        f.write("{}\t{}\t{}\n".format(*triple))

In [17]:
with open("neighbor_rule.tsv", "w", encoding='utf-8') as f:
    for triple in neighbor_rule:
        f.write("{}\t{}\t{}\n".format(*triple))

In [18]:
with open("located_in_rule.tsv", "w", encoding='utf-8') as f:
    for triple in located_in_rule:
        f.write("{}\t{}\t{}\n".format(*triple))