In [6]:
import pandas as pd
import numpy as np
from probe_lm import data_utils
from probe_lm import plotting_utils
import pickle

### 1. Load country names and iso3 codes from the kaggle/wiki table

We then shorten the list by keeping only those synonyms that will be useful for us.

In [7]:
country_aliases = pd.read_csv('../data/country_aliases.csv')

In [8]:
country_aliases.head()

Unnamed: 0,iso3,Alias,AliasDescription
0,,Abkhazia,"common, English"
1,,Republic of Abkhazia,"official, English"
2,,Aphsny Axwynthkharra,"official, Abkhaz"
3,,Respublika Abkhaziya,"official, Russian"
4,,Autonomous Republic of Abkhazia,"Internationally recognized, English"


In [9]:
# keepterms = ['English', 'official', 'alternative', 'initialism', 'common']
keep = country_aliases['AliasDescription'].str.contains('English') | (country_aliases['AliasDescription'] == 'official') | (country_aliases['AliasDescription'] == 'alternative') | (country_aliases['AliasDescription'] == 'initialism') | (country_aliases['AliasDescription'] == 'common') 
country_aliases_new = country_aliases[keep]

In [10]:
country_aliases_new = country_aliases_new.drop(index=[0, 1, 4, 5])

In [11]:
exclude = ['archaic', 'former', 'Esperanto', 'geographical']

In [12]:
i = 400
country_aliases_new.iloc[i :i+10, :]

Unnamed: 0,iso3,Alias,AliasDescription
906,ZWE,Zimbabwe,common
907,ZWE,Rhodesia or Republic of Rhodesia,"former names, English"
908,ZWE,Republic of Zimbabwe,"official, English"
909,ZWE,Southern Rhodesia,"former, English"


In [13]:
drop = country_aliases['AliasDescription'].str.contains('archaic') | (country_aliases['AliasDescription'].str.contains('former')) | (country_aliases['AliasDescription'].str.contains('Esperanto')) | (country_aliases['AliasDescription'].str.contains('geographical'))

In [14]:
country_aliases_new = country_aliases_new[~ drop]

  country_aliases_new = country_aliases_new[~ drop]


In [15]:
country_aliases_new.iloc[:-20, :]

Unnamed: 0,iso3,Alias,AliasDescription
6,AFG,Islamic Republic of Afghanistan,"official, English"
9,ALB,Albania,"common, English"
10,ALB,Republic of Albania,"official, English"
14,DZA,Algeria,"common, English"
15,DZA,People's Democratic Republic of Algeria,"official, English"
...,...,...,...
849,ARE,The Emirates,"colloquial, English"
851,GBR,United Kingdom of Great Britain and Northern I...,"official, English"
852,GBR,Britain,alternative
853,GBR,Great Britain,alternative


In [16]:
country_aliases_new.loc[874, 'Alias'] = 'U.S.A.'

In [17]:
country_aliases_new.iloc[:325]

Unnamed: 0,iso3,Alias,AliasDescription
6,AFG,Islamic Republic of Afghanistan,"official, English"
9,ALB,Albania,"common, English"
10,ALB,Republic of Albania,"official, English"
14,DZA,Algeria,"common, English"
15,DZA,People's Democratic Republic of Algeria,"official, English"
...,...,...,...
870,USA,US,initialism
871,USA,U.S.,initialism
872,USA,USA,initialism
873,USA,Usa,initialism


### 2. Load ground truth and fill in iso3 codes by matvhing with table

In [18]:
ground_truth_file= "../data/english-speaking-population-data.txt"
ground_truth = data_utils.read_population_data(ground_truth_file)

In [19]:
ground_truth_df = pd.DataFrame(ground_truth.items(), columns=['country names', 'population'])
ground_truth_df[['iso3']] = None

In [20]:
ground_truth_df.head()

Unnamed: 0,country names,population,iso3
0,the United States,316107532,
1,India,128539090,
2,Pakistan,115044691,
3,Nigeria,103198040,
4,the Philippines,64025890,


In [21]:
for i, row1 in ground_truth_df.iterrows():
    for j, row2 in country_aliases_new.iterrows():
        if row2['Alias'] in row1['country names']:
            ground_truth_df.loc[i, 'iso3'] = row2['iso3']

In [22]:
unmatched = ground_truth_df[ground_truth_df['iso3'].isna()]
unmatched_ix = unmatched.index

In [23]:
manual_match = {'Mexico': 'MEX', 'Bahamas': 'BHS', 'the United States':'USA', 'the United Kingdom': 'GBR', 'Kenya':'KEN'}

In [24]:
for i, row1 in ground_truth_df.iterrows():
    country = row1['country names']
    if country in manual_match.keys():
        ground_truth_df.loc[i, 'iso3'] = manual_match[country]

### 3. Propagate alternative name list by matching on iso3

In [25]:
def add_article(word):
    requires_article = ['Republic', 'United', 'Principality', 'US', 'U.S.A.', 'USA', 'Usa', 'U.S.'
                        'Federation', 'Dominion', 'Islands', 'UK', 'U.K.', 'FRG', 'GDR', 'Commonwealth', 
                        'Coast', 'Federation' ,'RF', 'Kingdom', 'Netherlands', 'State', 'Confederation', 
                        'Realm', 'CSR', 'SR', 'SSR', 'RS', 'Grand Duchy of Luxembourg', 'Nation']
    for candidate in requires_article:
        if (candidate in word and not 'the' in word.lower()) or word == 'United Kingdom of Great Britain and Northern Ireland' or word == 'Kingdom of the Netherlands' or word == 'U.S.': 
            article = 'the '
            break
        else:
            article = ''
    return article
    

In [26]:
ground_truth_df_new = ground_truth_df.__deepcopy__()

for i, row1 in ground_truth_df.iterrows():
    # initialise empty list
    ground_truth_df_new.loc[i, 'country names'] = [ground_truth_df_new.loc[i, 'country names']]
    for j, row2 in country_aliases_new.iterrows():
        if row1['iso3'] == row2['iso3'] and not row2['Alias'] in row1['country names']:
            article = add_article(row2['Alias'])
            ground_truth_df_new.loc[i, 'country names'].append(article + row2['Alias'])

In [27]:
ground_truth_df_new

Unnamed: 0,country names,population,iso3
0,"[the United States, the United States of Ameri...",316107532,USA
1,"[India, the Republic of India]",128539090,IND
2,"[Pakistan, the Islamic Republic of Pakistan, t...",115044691,PAK
3,[Nigeria],103198040,
4,"[the Philippines, Republic of the Philippines,...",64025890,PHL
...,...,...,...
122,"[Andorra, the Principality of Andorra, Princip...",17869,AND
123,[Anguilla],12000,
124,[Nauru],11600,
125,[the Cook Islands],4000,


In [28]:
try:
    ground_truth_df_new.iloc[6]['country names'].remove('GDR') 
except:
    print('already removed!')
try:
    ground_truth_df_new.iloc[6]['country names'].remove('the German Democratic Republic')
except:
    print('already removed!')
try:
    ground_truth_df_new.iloc[100]['country names'].remove('the United States of America')
    ground_truth_df_new.iloc[100]['country names'].remove('America')
    ground_truth_df_new.iloc[100]['country names'].remove('the States')
    ground_truth_df_new.iloc[100]['country names'].remove('the US')
    ground_truth_df_new.iloc[100]['country names'].remove('the USA')
    ground_truth_df_new.iloc[100]['country names'].remove('the Usa')
    ground_truth_df_new.iloc[100]['country names'].remove('the U.S.A.')
except:
    print('already removed!')
try:
    ground_truth_df_new.iloc[111]['country names'].remove('the U.S.')
    ground_truth_df_new.iloc[111]['country names'].remove('the United States of America')
    ground_truth_df_new.iloc[111]['country names'].remove('the States')
    ground_truth_df_new.iloc[111]['country names'].remove('the US')
    ground_truth_df_new.iloc[111]['country names'].remove('the USA')
    ground_truth_df_new.iloc[111]['country names'].remove('the Usa')
    ground_truth_df_new.iloc[111]['country names'].remove('the U.S.A.')
    ground_truth_df_new.iloc[111]['country names'].remove('the U.S.')

except:
    print('already removed!')

already removed!
already removed!


In [29]:
ground_truth_df_new.iloc[111]['country names']


['American Samoa']

In [30]:
for i in range(120, 127):
    print(ground_truth_df_new.iloc[i]['country names'])

['the British Virgin Islands']
['Palau', 'the Republic of Palau']
['Andorra', 'the Principality of Andorra', 'Principality of the Valleys of Andorra']
['Anguilla']
['Nauru']
['the Cook Islands']
['Montserrat']


In [31]:
ground_truth_df_new.drop(columns='iso3',inplace=True)

In [32]:
alternative_names_dict = {}
for i, row in ground_truth_df_new.iterrows():
    alternative_names_dict[row['country names'][0]] = row['country names'][1:]


In [33]:
alternative_names_dict

{'the United States': ['the United States of America',
  'America',
  'the States',
  'the US',
  'the U.S.',
  'the USA',
  'the Usa',
  'the U.S.A.'],
 'India': ['the Republic of India'],
 'Pakistan': ['the Islamic Republic of Pakistan',
  'the Federation of Pakistan',
  'the Dominion of Pakistan'],
 'Nigeria': [],
 'the Philippines': ['Republic of the Philippines', 'the Philippine Islands'],
 'the United Kingdom': ['the United Kingdom of Great Britain and Northern Ireland',
  'Britain',
  'Great Britain',
  'the UK',
  'the U.K.'],
 'Germany': ['the Federal Republic of Germany',
  'the FRG',
  'Former East Germany',
  'the GDR'],
 'Uganda': [],
 'Canada': ['the Dominion of Canada'],
 'Egypt': ['the Arab Republic of Egypt'],
 'France': ['the French Republic'],
 'Australia': ['the Commonwealth of Australia'],
 'Bangladesh': ["the People's Republic of Bangladesh"],
 'Ghana': ['the Gold Coast'],
 'Russia': ['the Russian Federation', 'the RF'],
 'Thailand': ['the Kingdom of Thailand'],
 

In [148]:
with open('../data/country_synonyms.p', 'wb') as handle:
    pickle.dump(alternative_names_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [1]:
%load_ext autoreload
%autoreload 2

from probe_lm import get_aggregated_likelihoods
from probe_lm.data_utils import read_population_data
from numpy import testing
from transformers import GPT2Tokenizer
from transformers import GPT2LMHeadModel
import probe_lm
from probe_lm import data_utils, compute_population_probs
from tqdm import tqdm

In [34]:
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2")
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
prompt = 'I live in'

population_data, synonym_dict = read_population_data(filename="../data/english-speaking-population-data.txt", return_synonym_dict=True, synonym_dict_path='../data/country_synonyms.p')

predictions = compute_population_probs(gpt2_model, gpt2_tokenizer, prompt, population_data)

# when there are no synonyns nothing should change through this
dummy_synonyms_dict = {country: [] for country in population_data.keys()}
predictions_disambiguated = compute_population_probs(gpt2_model, gpt2_tokenizer, prompt, population_data, synonym_dict=synonym_dict)


100%|██████████| 127/127 [00:06<00:00, 20.75it/s]
100%|██████████| 251/251 [00:20<00:00, 12.05it/s]


AssertionError: 

In [40]:
assert (np.array([*predictions_disambiguated.values()]) >= np.array([*predictions.values()])).all()

In [35]:
predictions_disambiguated

{'the United States': 0.009669359615675957,
 'India': 0.0017266347763369991,
 'Pakistan': 0.0006112819298197305,
 'Nigeria': 0.0002793741589962357,
 'the Philippines': 0.00037076029226258277,
 'the United Kingdom': 0.0020306195341048424,
 'Germany': 0.0022493334212433185,
 'Uganda': 0.0001534984179904594,
 'Canada': 0.005756016694246861,
 'Egypt': 0.00024447195556501076,
 'France': 0.002898739036330509,
 'Australia': 0.002847974747801562,
 'Bangladesh': 0.0001512151862773792,
 'Ghana': 0.00015000855738998224,
 'Russia': 0.0007050073398280817,
 'Thailand': 0.0005335020412936055,
 'Italy': 0.0009410956229972008,
 'South Africa': 0.0008327632376641836,
 'Mexico': 0.001243778263396684,
 'Malaysia': 8.349579978406532e-05,
 'Netherlands': 0.0003066289678144462,
 'Poland': 0.0006925579031625839,
 'Sri Lanka': 4.275190805909585e-05,
 'Turkey': 0.0004469416145755299,
 'Zimbabwe': 7.326093982771543e-05,
 'Iraq': 0.00023130205191145556,
 'Brazil': 0.0005863760358695959,
 'Spain': 0.00060516597192

In [4]:
predictions

{'the United States': 0.002329109645529914,
 'India': 0.0017265477344167,
 'Pakistan': 0.0006112716485496883,
 'Nigeria': 0.0002793741589962357,
 'the Philippines': 0.00037059945390708894,
 'the United Kingdom': 0.0001415507714281924,
 'Germany': 0.0022489926527092673,
 'Uganda': 0.0001534984179904594,
 'Canada': 0.005755988014946759,
 'Egypt': 0.000244471462892844,
 'France': 0.00289840264804514,
 'Australia': 0.00284666112876884,
 'Bangladesh': 0.00015121206290489127,
 'Ghana': 0.00013454415971797377,
 'Russia': 0.0007024162220219028,
 'Thailand': 0.0005334901928809319,
 'Italy': 0.0009410956229972008,
 'South Africa': 0.0008327632376641836,
 'Mexico': 0.0012422580980241446,
 'Malaysia': 8.349576504766157e-05,
 'Netherlands': 2.966499486299454e-05,
 'Poland': 0.0006923867788828625,
 'Sri Lanka': 4.27519076270264e-05,
 'Turkey': 0.0004468236852482321,
 'Zimbabwe': 7.321906850335309e-05,
 'Iraq': 0.00023130205191145556,
 'Brazil': 0.0005863760203920109,
 'Spain': 0.0006051419307729444,