In [58]:
import io
from typing import List, Dict

from PIL import Image
import networkx as nx
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import community as community_louvain
from mpl_toolkits.basemap import Basemap

source_df = pd.read_csv('data/global_missile_trade_2000_2023.csv')
source_df.sample(10)

Unnamed: 0,id,trade_id,target,source,order_date,orderYrEst,quantity,delivery_date,designation,description,category,source_lat,source_lng,target_lat,target_lng
212,301453,33409,Singapore,United States,2000-01-01 00:00:00+00:00,True,60.0,2001-01-01 00:00:00+00:00,AIM-9M,SRAAM,Missiles,38.893651,-77.17063,,
1087,236081,57819,Germany,United States,2000-01-01 00:00:00+00:00,True,500.0,2004-01-01 00:00:00+00:00,Paveway,guided bomb,Missiles,38.893651,-77.17063,51.247366,10.298068
1751,241603,64660,Qatar,United States,2017-01-01 00:00:00+00:00,True,200.0,2023-01-01 00:00:00+00:00,AGM-65,ASM,Missiles,38.893651,-77.17063,25.32375,51.183576
1529,248186,62422,South Korea,United States,2017-01-01 00:00:00+00:00,False,290.0,2018-01-01 00:00:00+00:00,JDAM,guided bomb,Missiles,38.893651,-77.17063,36.447642,127.822588
2049,313682,67073,South Korea,France,2023-01-01 00:00:00+00:00,False,,,Mistral SAM,portable SAM,Missiles,44.349371,-1.328556,36.447642,127.822588
808,268310,53683,Thailand,Ukraine,2011-01-01 00:00:00+00:00,True,1500.0,2017-01-01 00:00:00+00:00,R-2,anti-tank missile,Missiles,49.215255,31.19776,15.08695,101.007552
1256,229547,59710,unknown recipient(s),Poland,2016-01-01 00:00:00+00:00,True,50.0,2017-01-01 00:00:00+00:00,Warmate,loitering munition,Missiles,52.211251,19.294992,,
1275,262183,59897,South Korea,United States,2016-01-01 00:00:00+00:00,True,17.0,2021-01-01 00:00:00+00:00,SM-2MR,SAM,Missiles,38.893651,-77.17063,36.447642,127.822588
680,262855,51430,Algeria,South Africa,2007-01-01 00:00:00+00:00,True,50.0,2008-01-01 00:00:00+00:00,Raptor-2,ASM,Missiles,-29.071408,24.993023,28.452596,2.596583
217,241074,33452,UAE,United States,2006-01-01 00:00:00+00:00,False,96.0,2015-01-01 00:00:00+00:00,ESSM,SAM,Missiles,38.893651,-77.17063,23.875788,54.210894


In [59]:
edge_df = source_df.copy(deep=True)
edge_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2121 entries, 0 to 2120
Data columns (total 15 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   id             2121 non-null   int64  
 1   trade_id       2121 non-null   int64  
 2   target         2121 non-null   object 
 3   source         2121 non-null   object 
 4   order_date     2121 non-null   object 
 5   orderYrEst     2118 non-null   object 
 6   quantity       1994 non-null   float64
 7   delivery_date  1698 non-null   object 
 8   designation    2121 non-null   object 
 9   description    2121 non-null   object 
 10  category       2121 non-null   object 
 11  source_lat     2101 non-null   float64
 12  source_lng     2101 non-null   float64
 13  target_lat     2014 non-null   float64
 14  target_lng     2014 non-null   float64
dtypes: float64(5), int64(2), object(8)
memory usage: 248.7+ KB


In [60]:
edge_df.dropna(inplace=True)
#edge_df = edge_df[edge_df['category'] == 'Missiles']

In [61]:
def set_node_positions(network: nx.DiGraph) -> dict:
    pos = {}
    for u, v, d in network.edges(data=True):
        pos[u] = (d['source_lng'], d['source_lat'])
        pos[v] = (d['target_lng'], d['target_lat'])
    return pos


def set_node_degrees(network: nx.DiGraph) -> nx.DiGraph:
    in_degree = dict(network.in_degree())
    out_degree = dict(network.out_degree())
    nx.set_node_attributes(network, in_degree, 'in_degree')
    nx.set_node_attributes(network, out_degree, 'out_degree')
    return network

def compute_louvain_communities(
    input_network: nx.Graph
) -> nx.Graph:
    community_network = input_network.copy(as_view=False)
    community_network = community_network.to_undirected()
    partition_dict = community_louvain.best_partition(
        community_network, weight="weight", resolution=1.3, random_state=1
    )
    return partition_dict


def create_network(edge_df: pd.DataFrame) -> nx.DiGraph:
    network = nx.MultiDiGraph()    
    for _, row in edge_df.iterrows():
        network.add_edge(row['source'], row['target'], **row.to_dict())
    network = set_node_degrees(network)
    
    partition_dict = compute_louvain_communities(network)
    nx.set_node_attributes(network, partition_dict, "louvain_community")
    return network

network = create_network(edge_df)
positions = set_node_positions(network)

In [62]:
def plot_network_with_basemap(network: nx.DiGraph, positions: dict, year: int) -> tuple:
    fig, ax = plt.subplots(figsize=(36, 18))
    fig.patch.set_facecolor('#292929')
    ax.set_facecolor('#292929')
    
    m = Basemap(
        projection='merc',
        llcrnrlat=-60,
        urcrnrlat=75,
        llcrnrlon=-175,
        urcrnrlon=195,
        resolution='i',
        ax=ax
    )
    waters = "#404040"
    m.drawcountries(color="#707070")
    m.drawmapboundary(fill_color='#292929', linewidth=0)
    m.drawmapboundary(fill_color=waters)
    m.fillcontinents(color='#292929', lake_color=waters)
    
    # Convert positions to basemap projection
    projected_positions = {country: m(lon, lat) for country, (lon, lat) in positions.items()}
    
    # Draw the graph with nodes at the projected positions
    nx.draw_networkx_nodes(network, projected_positions, node_size=2, node_color='darkgrey', ax=ax)
    
    edge_widths = []
    for u, v, d in network.edges(data=True):
        edge_widths.append(d['quantity'] ** 0.20)
    
    nx.draw_networkx_edges(network, projected_positions, arrowstyle='->', arrowsize=40, edge_color='#e44bb4', ax=ax, alpha=0.3, width=edge_widths, connectionstyle='arc3,rad=0.3')
    
    nx.draw_networkx_labels(network, projected_positions, font_size=10, font_family='sans-serif', ax=ax, font_color="white")
    
        # Add year text in one corner
    plt.text(0.99, 0.95, f'World Arms Trade Network: {year}', verticalalignment='bottom', horizontalalignment='right', 
             transform=ax.transAxes, color='white', fontsize=30, bbox=dict(facecolor='black', alpha=0.5, boxstyle='round,pad=0.5'))
    plt.text(0.99, 0.01, f'© github/geometrein', verticalalignment='bottom', horizontalalignment='right', 
             transform=ax.transAxes, color='white', fontsize=20)
    
    # Save the figure without padding and white border
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    plt.savefig(f'output/map_animation_frames/trade_network_{year}.png', facecolor=fig.get_facecolor(), bbox_inches='tight', pad_inches=0)
    return fig, ax


def convert_fig_to_pil(fig: plt.figure) -> Image:
    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    buf.seek(0)
    img = Image.open(buf)
    return img


def fade_images(images, fade_duration: int = 10)-> List[Image]:
    frames = []
    for i in range(len(images) - 1):
        start_img = np.array(images[i])
        end_img = np.array(images[i + 1])
        for j in range(fade_duration + 1):
            alpha = j / fade_duration
            blended = (1 - alpha) * start_img + alpha * end_img
            frames.append(Image.fromarray(np.uint8(blended)))
    frames.append(images[-1])
    return frames


def animate(edge_df: pd.DataFrame, fade_duration: int=4) -> None:
    frames = []
    for year, year_df in edge_df.groupby('order_date'):
        year_network = create_network(year_df)
        fig, ax = plot_network_with_basemap(year_network, positions, year[:4])
        pil_image = convert_fig_to_pil(fig)
        frames.append(pil_image)
        plt.close(fig)
    frames_with_fade = fade_images(frames, fade_duration)
    frames_with_fade[0].save('output/animation/animation.gif', save_all=True, append_images=frames_with_fade[1:], duration=100, loop=0)

animate(edge_df)