## Initialization

In [None]:
# libraries
import re
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import plotly.graph_objects as go

from thefuzz import fuzz

In [None]:
# alliance = pd.read_csv("./version4.1_csv/alliance_v4.1_by_dyad_yearly.csv")
alliance = pd.read_csv("./version4.1_csv/alliance_v4.1_by_dyad.csv")

## Helper functions

In [None]:
def is_float(x: str):
    """
    Determine if the string `x` can be converted to float
    """
    try:
        float(x)
        return True
    except ValueError:
        return False

def has_char(x: str):
    """
    Check if string `x` contains any alphabetical characters
    """
    m = re.search('[a-zA-Z]', x)
    return m is not None

def str_best_match(s: str, strings):
    """
    Find the best match of string `s` in the list of strings.
    """
    assert isinstance(s, str) and isinstance(strings, list)
    s = s.lower()
    strings = [ss.lower() for ss in strings]
    candidates = [ss for ss in strings if ss.find(s) >= 0 or s.find(ss) >= 0]
    if len(candidates) == 0:
        return None
    match_score = [(fuzz.ratio(s, ss), ss) for ss in candidates]
    best_match = sorted(match_score, key=lambda x: x[0])[-1]
    idx = strings.index(best_match[1])
    return idx


## Alliance Map Class

Create a map class that is a subclass of `Basemap`. Add some helper functions.

In [None]:

class MyBasemap(object):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.fig = go.Figure()
        # load country attributes such as longitude, latitude, and BGN names
        self.countries_data = pd.read_csv("./version4.1_csv/cow.txt", sep=";", skiprows=28)
        self.countries_data.BGN_proper = self.countries_data.BGN_proper.apply(lambda x: x.strip(" "))
        self.countries_data.BGN_name = self.countries_data.BGN_name.apply(lambda x: x.strip(" "))
        self.decide_countries_on_map()
        self.BGN_name = self.countries_data.BGN_name.tolist()
        self.BGN_proper = self.countries_data.BGN_proper.tolist()
        self.countries_name = self.BGN_name + self.BGN_proper

    def decide_countries_on_map(self):
        """
        Decide which countries can be labelled on the map. 
        Only label those with a land area larger than `min_size`.
        """
        self.countries_on_map = set()
        data = self.countries_data
        idx_to_use = []

        for ix, country in data.iterrows():
            if has_char(country.BGN_proper.lower()):
                self.countries_on_map.add(country.BGN_proper.lower())
                self.countries_on_map.add(country.BGN_name.lower())
                idx_to_use.append(ix)
                # print(country.BGN_proper.lower(), pos)
        self.countries_on_map = list(self.countries_on_map)
        self.countries_data = self.countries_data.iloc[idx_to_use]

    def best_match_country(self, country: str):
        """
        Find the best match of `country` in self.countries_name
        """
        match_idx = str_best_match(country.lower(), self.countries_name)
        if match_idx is None:
            return
        name = self.countries_name[match_idx]
        if match_idx >= len(self.BGN_name):
            match_idx -= len(self.BGN_name)
        return match_idx, name

    def label_countries(self):
        """
        Label `country` with their name on the map. Only print 
        the first `max_len` char of the country names.
        """
        country_label_obj = self.scatter_points(self.countries_data, 'longitude', 'latitude', 'BGN_proper')
        self.fig.add_trace(country_label_obj)

    @staticmethod
    def scatter_points(df, lon_key, lat_key, text_key):
        graph_object = go.Scattergeo(
            lon = df[lon_key],
            lat = df[lat_key],
            hoverinfo = 'text',
            text = df[text_key],
            mode = 'markers',
            marker = dict(
                size = 2,
                color = 'rgb(255, 0, 0)',
                line = dict(
                    width = 3,
                    color = 'rgba(68, 68, 68, 0)'
                )
            )
        )
        return graph_object

    def get_country_row(self, country):
        """
        Query the attributes of `country`
        """
        ct = country.lower()
        match_i = str_best_match(ct, self.countries_on_map)
        if match_i is None:
            return None
        match_idx, name = self.best_match_country(ct)
        if name.lower() != self.countries_on_map[match_i]:
            return
        row = self.countries_data.iloc[match_idx]
        # print(ct, row)
        return row

    def get_longitude(self, country):
        r0 = self.get_country_row(country)
        if r0 is None:
            return
        coord0 = r0.longitude.item(), r0.latitude.item()
        return r0.longitude.item()

    def get_latitude(self, country):
        r0 = self.get_country_row(country)
        if r0 is None:
            return
        return r0.latitude.item()

    def connect_countries(self, year, key="dyad_st_year"):
        year_data = alliance.loc[alliance[key] == year]
        df = year_data.copy()
        df['start_lon'] = df['state_name1'].apply(self.get_longitude)
        df['end_lon'] = df['state_name2'].apply(self.get_longitude)
        df['start_lat'] = df['state_name1'].apply(self.get_latitude)
        df['end_lat'] = df['state_name2'].apply(self.get_latitude)
        df = df.loc[df['start_lon'].notnull() & df['end_lon'].notnull()]
        self._connect(df.loc[df['defense'] == 1], 'rgb(255,0,0)')
        self._connect(df.loc[(df['defense'] == 0) & (df['neutrality'] == 1)], 'rgb(0, 255, 0)')
        self._connect(df.loc[
            (df['defense'] == 0) & (df['neutrality'] == 0) & (df['nonaggression'] == 1)], 'rgb(0, 0, 255)')
        self._connect(
            df.loc[(df['defense'] == 0) & (df['neutrality'] == 0) & (df['nonaggression'] == 0) & df['entente'] == 1], 
            'rgb(255,165,0)'
        )

    def _connect(self, df, color='#0099C6'):
#         if len(df) == 0:
#             return
        lons = np.empty(3 * len(df))
        lons[::3] = df['start_lon']
        lons[1::3] = df['end_lon']
        lons[2::3] = None
        lats = np.empty(3 * len(df))
        lats[::3] = df['start_lat']
        lats[1::3] = df['end_lat']
        lats[2::3] = None

        self.fig.add_trace(
            go.Scattergeo(
                lon = lons,
                lat = lats,
                mode = 'lines',
                hoverinfo = 'none',
                line = dict(width=1,color=color),
                opacity = 0.5
            )
        )

In [None]:
year = 1949
alliance_map = MyBasemap()
alliance_map.connect_countries(year)
alliance_map.label_countries()

In [None]:
alliance_map.fig.update_geos(
    projection_type="mercator",
    showcountries=True, 
    countrycolor="Black",
    landcolor="rgb(243,243,243)",
    lataxis_range=[-30,86],
)
alliance_map.fig.update_layout(
    title=str(year),
    width=900,
    height=600,
    margin={"r":0,"t":30,"l":0,"b":0},
    showlegend=False,
)
alliance_map.fig.show()

In [None]:
def save_to_img(year):
    """
    Plot the alliance map of `year` and save it to png.
    """
    alliance_map = MyBasemap()
    alliance_map.connect_countries(year)
    alliance_map.label_countries()
    alliance_map.fig.update_geos(
        projection_type="mercator",
        showcountries=True, 
        countrycolor="Black",
        landcolor="rgb(243,243,243)",
        lataxis_range=[-30,86],
    )
    alliance_map.fig.update_layout(
        title=str(year),
        width=900,
        height=600,
        margin={"r":0,"t":30,"l":0,"b":0},
        showlegend=False,
    )
    alliance_map.fig.write_image(f"./plotly2/{year}.png", scale=1)

# for year in range(1816, 2013):
#     save_to_img(year)

In [None]:
alliance_map0 = MyBasemap()
alliance_map0.connect_countries(1816)
alliance_map0.label_countries()
initial_data = alliance_map0.fig.data

In [None]:
frames0 = []
for year in range(1816, 2013):
    map_ = MyBasemap()
    map_.connect_countries(year)
    map_.label_countries()
    frames0.append(map_.fig)

In [None]:
# frames = [{"data": list(f.data), "name": str(year)} for f, year in zip(frames0, range(1816, 2013))]
frames = [
    go.Frame(
        data=list(f.data), name=str(year)
    ) for f, year in zip(frames0, range(1816, 2013))
]

In [None]:
# create animation auguments
fig_dict = {
    "data": list(initial_data),
    "layout": {},
    "frames": frames
}

# fill in most of layout
# fig_dict["layout"]["xaxis"] = {"range": [30, 85], "title": "Life Expectancy"}
# fig_dict["layout"]["yaxis"] = {"title": "GDP per Capita", "type": "log"}
fig_dict["layout"]["hovermode"] = "closest"

# create buttons "Play" and "Pause" for the animation
fig_dict["layout"]["updatemenus"] = [
    {
        "buttons": [
            {
                "args": [
                    None, 
                    {
                        "frame": {"duration": 500, "redraw": True},
                        "fromcurrent": True, 
                        "transition": {"duration": 300, "easing": "quadratic-in-out"}
                    }
                ],
                "label": "Play",
                "method": "animate"
            },
            {
                "args": [
                    [None], 
                    {
                        "frame": {"duration": 0, "redraw": False},
                        "mode": "immediate",
                        "transition": {"duration": 0}
                    }
                ],
                "label": "Pause",
                "method": "animate"
            }
        ],
        "direction": "left",
        "pad": {"r": 10, "t": 3},
        "showactive": False,
        "type": "buttons",
        "x": 0.1,
        "xanchor": "right",
        "y": 0,
        "yanchor": "top"
    }
]

# create sliders for the animation
sliders_dict = {
    "active": 0,
    "yanchor": "top",
    "xanchor": "left",
    "currentvalue": {
        "font": {"size": 20},
        "prefix": "Year:",
        "visible": True,
        "xanchor": "right"
    },
    "transition": {"duration": 300, "easing": "cubic-in-out"},
    "pad": {"b": 10, "t": 0},
    "len": 0.9,
    "x": 0.1,
    "y": 0,
    "steps": []
}

for year in range(1816, 2013):
    slider_step = {
        "args": [
            [str(year)],
            {
                "frame": {"duration": 300, "redraw": True},
                "mode": "immediate",
                "transition": {"duration": 300}
            }
        ],
        "label": str(year),
        "method": "animate"
    }
    sliders_dict["steps"].append(slider_step)

fig_dict["layout"]["sliders"] = [sliders_dict]

In [None]:
# make animations
fig = go.Figure(
    fig_dict
)

fig.update_geos(
    projection_type="mercator",
    showcountries=True, 
    countrycolor="Black",
    landcolor="rgb(250,250,250)",
    lataxis_range=[-30,86],
)
fig.update_layout(
    width=900,
    height=700, 
    margin={"r":0,"t":0,"l":0,"b":0},
    showlegend=False,
)

fig.show()

In [None]:
fig.write_html("./plotly/all_fig.html")