In [19]:
%matplotlib inline

from __future__ import print_function
from sklearn import svm
import os
import gensim
import json
import numpy as np
import matplotlib.pyplot as plt
import random
import requests
from IPython.display import HTML

In [2]:
MODEL = 'GoogleNews-vectors-negative300.bin'
if not os.path.isfile(MODEL):
  with open(MODEL, 'wb') as fout:
    path = 'https://s3.amazonaws.com/dl4j-distribution/%s.gz' % MODEL
    curl = subprocess.Popen(['curl', path],
                           stdout=subprocess.PIPE).stdout
    zcat = subprocess.Popen(['zcat'],
                          stdin=curl,
                          stdout=fout
                         )
    zcat.wait()

In [3]:
model = gensim.models.KeyedVectors.load_word2vec_format(MODEL, binary=True)
model.most_similar(positive=['Germany', 'India', 'Brazil'])

[(u'Argentina', 0.6674915552139282),
 (u'South_Africa', 0.6387202739715576),
 (u'Portugal', 0.6178219318389893),
 (u'China', 0.6163622736930847),
 (u'Poland', 0.6154980063438416),
 (u'Europe', 0.6070772409439087),
 (u'Japan', 0.604107677936554),
 (u'Uruguay', 0.6010408997535706),
 (u'South_Korea', 0.5988144874572754),
 (u'United_States', 0.5964299440383911)]

In [4]:
COUNTRIES = {'Canada': 'ca', 'Turkmenistan': 'tm', 'Vatican': 'va', 'Lithuania': 'lt', 'Cambodia': 'kh',
             'Ethiopia': 'et', 'Aruba': 'aw', 'Swaziland': 'sz', 'Argentina': 'ar', 'Bolivia': 'bo', 'Cameroon': 'cm',
             'Ghana': 'gh', 'Japan': 'jp', 'Slovenia': 'si', 'Guatemala': 'gt', 'Kuwait': 'kw', 'Jordan': 'jo',
             'UAE': 'ae', 'Spain': 'es', 'Western_Sahara': 'eh', 'Liberia': 'lr', 'Maldives': 'mv', 'East_Timor': 'tl',
             'Pakistan': 'pk', 'Oman': 'om', 'Tanzania': 'tz', 'Zambia': 'zm', 'North_Korea': 'kp', 'Albania': 'al',
             'Gabon': 'ga', 'Finland': 'fi', 'Monaco': 'mc', 'Samoa': 'ws', 'Yemen': 'ye', 'Jamaica': 'jm',
             'Greenland': 'gl', 'England': 'uk', 'Ivory_Coast': 'ci', 'Guam': 'gu', 'Uruguay': 'uy', 'India': 'in',
             'Azerbaijan': 'az', 'Solomon_Islands': 'sb', 'Kenya': 'ke', 'Tajikistan': 'tj', 'Turkey': 'tr',
             'Afghanistan': 'af', 'Bangladesh': 'bd', 'Mauritania': 'mr', 'Mongolia': 'mn',
             'France': 'fr', 'Bermuda': 'bm', 'Slovakia': 'sk', 'Somalia': 'so', 'Peru': 'pe', 'Laos': 'la',
             'Norway': 'no', 'Czech_Republic': 'cz', 'Benin': 'bj', 'Cuba': 'cu', 'South_Africa': 'za',
             'Montenegro': 'me', 'Togo': 'tg', 'China': 'cn', 'Armenia': 'am', 'Ukraine': 'ua', 'Bahrain': 'bh',
             'Tonga': 'to', 'French_Guiana': 'gf', 'Libya': 'ly', 'Indonesia': 'id', 'Mauritius': 'mu', 'Sweden': 'se',
             'Belarus': 'by', 'Equatorial_Guinea': 'gq', 'Mali': 'ml', 'Russia': 'ru', 'Bulgaria': 'bg', 'Papua': 'pg',
             'Romania': 'ro', 'Angola': 'ao', 'Chad': 'td', 'Cyprus': 'cy', 'Puerto_Rico': 'pr', 'Malaysia': 'my',
             'Austria': 'at', 'Vietnam': 'vn', 'Mozambique': 'mz', 'Uganda': 'ug', 'Hungary': 'hu', 'Niger': 'ne',
             'Brazil': 'br', 'Dominican_Republic': 'do', 'Guinea': 'gn', 'Panama': 'pa', 'Qatar': 'qa',
             'Luxembourg': 'lu', 'Bahamas': 'bs', 'Ireland': 'ie', 'Nigeria': 'ng', 'Ecuador': 'ec', 'Brunei': 'bn',
             'Australia': 'au', 'Iran': 'ir', 'Algeria': 'dz', 'Svalbard': 'sj', 'Chile': 'cl', 'Belgium': 'be',
             'Thailand': 'th', 'Haiti': 'ht', 'Belize': 'bz', 'Georgia': 'ge', 'Gambia': 'gm', 'Poland': 'pl',
             'Moldova': 'md', 'Morocco': 'ma', 'Croatia': 'hr', 'Switzerland': 'ch', 'Iraq': 'iq', 'Sierra_Leone': 'sl',
             'Portugal': 'pt', 'Estonia': 'ee', 'Kosovo': 'xk', 'Lebanon': 'lb', 'America': 'us', 'Uzbekistan': 'uz',
             'Tunisia': 'tn', 'Djibouti': 'dj', 'Rwanda': 'rw', 'Saudi_Arabia': 'sa', 'Colombia': 'co', 'Burundi': 'bi',
             'Sri_Lanka': 'lk', 'Taiwan': 'tw', 'Fiji': 'fj', 'Barbados': 'bb', 'Madagascar': 'mg', 'Italy': 'it',
             'Virgin_Islands': 'vi', 'Bhutan': 'bt', 'Sudan': 'sd', 'Nepal': 'np', 'Malta': 'mt', 'Malawi': 'mw',
             'Netherlands': 'nl', 'Suriname': 'sr', 'Lesotho': 'ls', 'Venezuela': 've', 'South_Korea': 'kr',
             'Israel': 'il', 'Iceland': 'is', 'Burkina_Faso': 'bf', 'Senegal': 'sn', 'El_Salvador': 'sv',
             'Zimbabwe': 'zw', 'Germany': 'de', 'Denmark': 'dk', 'Kazakhstan': 'kz', 'Philippines': 'ph',
             'Eritrea': 'er', 'Kyrgyzstan': 'kg', 'Bosnia': 'ba', 'New_Zealand': 'nz', 'Macedonia': 'mk',
             'Latvia': 'lv', 'Guyana': 'gy', 'Syria': 'sy', 'Gaza_Strip': 'ps', 'Honduras': 'hn', 'Myanmar': 'mm',
             'Mexico': 'mx', 'Egypt': 'eg', 'Nicaragua': 'ni', 'Singapore': 'sg', 'Serbia': 'rs', 'Botswana': 'bw',
             'Antarctica': 'aq', 'Congo': 'cd', 'Greece': 'gr', 'Paraguay': 'py', 'Namibia': 'na', 'Costa_Rica': 'cr',
             'Comoros': 'km', 'Cayman_Islands': 'ky'}

In [5]:
positive = list(COUNTRIES.keys())
random.shuffle(positive)
positive = positive[:40]

In [6]:
negative = random.sample(model.vocab.keys(), 5000)
negative[:4]

[u'Gush_Shalom_Peace', u'La_Rabassa', u'coagulase', u'RIS_LeaderBoard']

In [7]:
labelled = [(p, 1) for p in positive] + [(n, 0) for n in negative]
random.shuffle(labelled)
X = np.asarray([model[w] for w, l in labelled])
y = np.asarray([l for w, l in labelled])
X.shape, y.shape

((5040, 300), (5040,))

In [8]:
TRAINING_FRACTION = 0.3
cut_off = int(TRAINING_FRACTION * len(labelled))
clf = svm.SVC(kernel='linear')
clf.fit(X[:cut_off], y[:cut_off]) 

SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='linear',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)

In [9]:
res = clf.predict(X[cut_off:])

missed = [country for (pred, truth, country) in 
 zip(res, y[cut_off:], labelled[cut_off:]) if pred != truth]

100 - 100 * float(len(missed)) / len(res), missed

(99.91496598639456, [('Iraq', 1), ('Iran', 1), ('Gaza_Strip', 1)])

In [10]:
all_predictions = clf.predict(model.syn0)

In [11]:
res = []
for word, pred in zip(model.index2word, all_predictions):
  if pred:
    res.append(word)
    if len(res) == 150:
      break
random.sample(res, 10)

[u'British_Columbia',
 u'Argentina',
 u'Belgium',
 u'South_Dakota',
 u'Venezuela',
 u'New_Zealand',
 u'Japan',
 u'Swedish',
 u'Florida',
 u'Idaho']

In [12]:
url = 'https://docs.google.com/spreadsheets/d/1jqYEIrvgGKc_FE7R5-zqYjwLfOY2ptVFH74sdw7VbOg/pub?gid=0&single=true&output=csv'
country_to_cc = dict(x.split(',') for x in requests.get(url).text.splitlines())
country_to_cc['United_States']

u'us'

In [13]:
countries = list(country_to_cc.keys())
country_to_idx = {country: idx for idx, country in enumerate(countries)}
country_vecs = np.asarray([model[c] for c in countries])

In [14]:
dists = np.dot(country_vecs, country_vecs[country_to_idx['Canada']])
for idx in reversed(np.argsort(dists)[-10:]):
  print(countries[idx], dists[idx])

Canada 7.54402
New_Zealand 3.96197
Finland 3.93924
Puerto_Rico 3.83815
Jamaica 3.81029
Sweden 3.80428
Slovakia 3.70387
Australia 3.6711
Bahamas 3.62404
United_States 3.53743


In [15]:
def rank_countries(term, topn=10):
  if not term in model:
    return []
  vec = model[term]
  dists = np.dot(country_vecs, vec)
  return [(countries[idx], float(dists[idx])) 
          for idx in reversed(np.argsort(dists)[-topn:])]

In [16]:
rank_countries('hockey')

[(u'Canada', 2.5576062202453613),
 (u'Slovakia', 2.456810474395752),
 (u'Finland', 2.244586706161499),
 (u'Sweden', 2.1060357093811035),
 (u'Czech_Republic', 2.088719129562378),
 (u'Latvia', 2.033060312271118),
 (u'Pakistan', 1.8923059701919556),
 (u'Norway', 1.7905339002609253),
 (u'Belarus', 1.7010831832885742),
 (u'Greenland', 1.6949326992034912)]

In [17]:
rank_countries('Canada', topn=0)

[(u'Canada', 7.5440239906311035),
 (u'New_Zealand', 3.9619698524475098),
 (u'Finland', 3.9392404556274414),
 (u'Puerto_Rico', 3.8381450176239014),
 (u'Jamaica', 3.810293436050415),
 (u'Sweden', 3.80427885055542),
 (u'Slovakia', 3.703874111175537),
 (u'Australia', 3.6711010932922363),
 (u'Bahamas', 3.6240415573120117),
 (u'United_States', 3.5374338626861572),
 (u'Barbados', 3.470252275466919),
 (u'Norway', 3.4603371620178223),
 (u'Mexico', 3.426602840423584),
 (u'Argentina', 3.4216275215148926),
 (u'Bermuda', 3.381308078765869),
 (u'Guyana', 3.3389341831207275),
 (u'Colombia', 3.3358325958251953),
 (u'Dominican_Republic', 3.253561019897461),
 (u'Latvia', 3.2421295642852783),
 (u'Chile', 3.229321002960205),
 (u'Switzerland', 3.2072014808654785),
 (u'Netherlands', 3.195124864578247),
 (u'Suriname', 3.181366205215454),
 (u'Costa_Rica', 3.15801739692688),
 (u'Belize', 3.1510772705078125),
 (u'Czech_Republic', 3.143127918243408),
 (u'France', 3.128284454345703),
 (u'Iceland', 3.1238260269165

In [24]:
def GChart(term):
  data = rank_countries(term, topn=0)
  data_by_cc = [[country_to_cc[country], val] for country, val in data]
  data_js = json.dumps([('Country', term)] + data_by_cc)

  code = """
  <script type="text/javascript" src="https://www.gstatic.com/charts/loader.js"></script>
  <div id="regions_div" style="width: 640px; height: 320px;"></div>
  <script type="text/Javascript">
    google.charts.load('upcoming', {'packages':['geochart']});
    google.charts.setOnLoadCallback(drawRegionsMap);
    function drawRegionsMap() {
      var chart = new google.visualization.GeoChart(
          document.getElementById('regions_div'));
      chart.draw(google.visualization.arrayToDataTable(%s), {});
    }
  </script>
  """

  return HTML(code % data_js)

GChart('coffee')

In [None]:
HTML("""
  <div id="a_map" style="width: 640px; height: 320px;"></div>
  <script type="text/Javascript">
    google.charts.load('upcoming', {'packages':['geochart']});
    google.charts.setOnLoadCallback(drawRegionsMap);
    function drawRegionsMap() {
      var chart = new google.visualization.GeoChart(
          document.getElementById('a_map'));
      var data = [['Header', 'Value'], ['ca', 100], ['us', 50]]
      chart.draw(google.visualization.arrayToDataTable(data), {});
    }
  </script>
  """)

In [None]:
data[0]

In [163]:
data_js

'[["bd", 2.4695568084716797], ["lk", 2.2264552116394043], ["np", 2.1970415115356445], ["bt", 2.181537628173828], ["ke", 2.0634021759033203], ["kp", 1.769707202911377], ["mm", 1.7306537628173828], ["et", 1.5849881172180176], ["pk", 1.5726646184921265], ["th", 1.5576823949813843]]'