In [1]:
# Import libraries with Plotly
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
import plotly.subplots as sp
from plotly.offline import init_notebook_mode

# Initialize Plotly in notebook mode
init_notebook_mode(connected=True)

# Set up for better display
pd.set_option('display.max_columns', None)

In [2]:
# Load processed data
df = pd.read_csv('../data/processed/final_dataset.csv')

df.rename(columns={
    'GDP_current_US$': 'gdp',
    'Military_expenditure_%_of_GDP': 'mil_exp_pct',
    'Population_total': 'population'
}, inplace=True)

In [3]:
# Step 2.1: Check the basic statistics
print("BASIC STATISTICS OF RELATIONSHIPS:")
print(f"Total unique country-supplier pairs: {len(relationship_table)}")
print(f"Total TIV across all relationships: {relationship_table['tiv_total_order'].sum():,.0f}")
print(f"Average TIV per relationship: {relationship_table['tiv_total_order'].mean():,.0f}")
print(f"Median TIV per relationship: {relationship_table['tiv_total_order'].median():,.0f}")

# Step 2.2: See the range of relationship sizes
print("\nRELATIONSHIP SIZE DISTRIBUTION:")
print(relationship_table['tiv_total_order'].describe())

# Step 2.3: Check which countries have the most supplier relationships
country_relationships = relationship_table.groupby('country').agg({
    'supplier_country': 'count',
    'tiv_total_order': 'sum'
}).rename(columns={'supplier_country': 'number_of_suppliers'})

print("\nCOUNTRIES WITH MOST SUPPLIER RELATIONSHIPS:")
print(country_relationships.sort_values('number_of_suppliers', ascending=False).head(10))

# Step 2: Filter our dataset to ONLY include these top 10 importers
df_top10 = df[df['country'].isin(top_10_importers)].copy()

print(f"\nData Overview:")
print(f"Total records in dataset: {len(df)}")
print(f"Records for top 10 importers: {len(df_top10)}")
print(f"Percentage of data we're analyzing: {len(df_top10)/len(df)*100:.1f}%")


BASIC STATISTICS OF RELATIONSHIPS:


NameError: name 'relationship_table' is not defined

In [None]:
# Step 3.1: Set a reasonable threshold to filter out very small relationships
# Let's analyze the distribution to choose a good threshold
tiv_threshold = relationship_table['tiv_total_order'].quantile(0.75)  # Top 25% of relationships
print(f"75th percentile TIV value: {tiv_threshold:,.0f}")

# Alternatively, set a fixed threshold based on the data scale
if tiv_threshold < 100:
    tiv_threshold = 100  # Minimum meaningful threshold

print(f"Using TIV threshold: {tiv_threshold:,.0f}")

# Step 3.2: Filter the relationships
significant_relationships = relationship_table[relationship_table['tiv_total_order'] >= tiv_threshold].copy()

print(f"\nAfter filtering: {len(significant_relationships)} significant relationships")
print(f"Covering {significant_relationships['tiv_total_order'].sum():,.0f} TIV "
      f"({significant_relationships['tiv_total_order'].sum() / relationship_table['tiv_total_order'].sum() * 100:.1f}% of total)")



# Step 3: Now create the relationship table for ONLY top 10 importers
print("\n" + "="*50)
print("CREATING SUPPLIER-RECIPIENT TABLE FOR TOP 10 IMPORTERS")
print("="*50)

relationship_table = df_top10.groupby(['country', 'supplier_country']).agg({
    'tiv_total_order': 'sum',
    'units_ordered': 'sum',
    'weapon_designation': 'count',
    'order_year': 'min'
}).reset_index()

relationship_table = relationship_table.rename(columns={
    'weapon_designation': 'number_of_deals',
    'order_year': 'first_order_year'
})

print(f"Total supplier-recipient relationships among top 10 importers: {len(relationship_table)}")





# Step 3.3: Sort by importance
significant_relationships = significant_relationships.sort_values('tiv_total_order', ascending=False)

print("\nTOP 20 MOST SIGNIFICANT ARMS TRADE RELATIONSHIPS:")
top_20_display = significant_relationships.head(20)[['country', 'supplier_country', 'tiv_total_order', 'number_of_deals']]
print(top_20_display.to_string(index=False))


In [None]:
# Step 4.1: Add additional metrics to our relationship table
significant_relationships['avg_deal_size'] = significant_relationships['tiv_total_order'] / significant_relationships['number_of_deals']

# Step 4.2: Categorize relationship strength
def categorize_relationship_strength(tiv):
    if tiv >= 1000:
        return "Very Strong"
    elif tiv >= 500:
        return "Strong"
    elif tiv >= 100:
        return "Moderate"
    else:
        return "Weak"

significant_relationships['relationship_strength'] = significant_relationships['tiv_total_order'].apply(categorize_relationship_strength)

# Step 4.3: Display the enhanced table
print("ENHANCED RELATIONSHIP TABLE:")
enhanced_display = significant_relationships[[
    'country', 'supplier_country', 'tiv_total_order', 'number_of_deals', 
    'avg_deal_size', 'relationship_strength'
]].head(15)

print(enhanced_display.to_string(index=False, float_format='%.0f'))


# Step 4: Identify ALL unique suppliers to these top 10 importers
all_suppliers_to_top10 = relationship_table['supplier_country'].unique()
print(f"Number of unique supplier countries to top 10 importers: {len(all_suppliers_to_top10)}")
print(f"\nSupplier countries to top 10 importers:")
print(sorted(all_suppliers_to_top10))

In [None]:
# Step 5.1: See the distribution of suppliers
supplier_summary = relationship_table.groupby('supplier_country').agg({
    'country': 'count',  # How many top10 importers they supply
    'tiv_total_order': 'sum'
}).rename(columns={'country': 'number_of_top10_recipients'}).sort_values('tiv_total_order', ascending=False)

print("SUPPLIERS TO TOP 10 IMPORTERS (Ranked by Total TIV):")
print(supplier_summary.head(15))

# Step 5.2: Analyze each top 10 importer's supplier portfolio
print("\n" + "="*60)
print("SUPPLIER PORTFOLIOS OF TOP 10 IMPORTERS")
print("="*60)

for importer in top_10_importers:
    importer_relationships = relationship_table[relationship_table['country'] == importer]
    if len(importer_relationships) > 0:
        total_tiv = importer_relationships['tiv_total_order'].sum()
        num_suppliers = len(importer_relationships)
        main_supplier = importer_relationships.nlargest(1, 'tiv_total_order')['supplier_country'].iloc[0]
        
        print(f"\n{importer}:")
        print(f"  - Total suppliers: {num_suppliers}")
        print(f"  - Total TIV: {total_tiv:,.0f}")
        print(f"  - Main supplier: {main_supplier}")
        print(f"  - All suppliers: {', '.join(importer_relationships['supplier_country'].tolist())}")

In [None]:
# Step 6.1: Filter for significant relationships (adjust threshold as needed)
tiv_threshold = 50  # Lower threshold since we're only looking at top 10 importers

significant_relationships = relationship_table[relationship_table['tiv_total_order'] >= tiv_threshold].copy()
significant_relationships = significant_relationships.sort_values('tiv_total_order', ascending=False)

print(f"Significant relationships (TIV ≥ {tiv_threshold}): {len(significant_relationships)}")

# Step 6.2: Display the key relationships
print("\nKEY ARMS TRADE RELATIONSHIPS TO TOP 10 IMPORTERS:")
key_display = significant_relationships[['country', 'supplier_country', 'tiv_total_order', 'number_of_deals']]
print(key_display.to_string(index=False))

# Step 6.3: Which suppliers are most important to the top 10?
supplier_importance = significant_relationships.groupby('supplier_country').agg({
    'tiv_total_order': 'sum',
    'country': 'count',
    'number_of_deals': 'sum'
}).rename(columns={
    'country': 'top10_recipients_served',
    'tiv_total_order': 'total_tiv_to_top10'
}).sort_values('total_tiv_to_top10', ascending=False)

print("\nMOST IMPORTANT SUPPLIERS TO TOP 10 IMPORTERS:")
print(supplier_importance.head(10))

In [None]:
# ======================================================================

In [None]:
# Step 1.1: Use our filtered relationship data for top 10 importers
print("Preparing data for Sankey diagram...")

# We'll use the significant_relationships we created earlier
sankey_df = significant_relationships.copy()

print(f"Number of flows to visualize: {len(sankey_df)}")
print(f"Total TIV represented: {sankey_df['tiv_total_order'].sum():,.0f}")

# Step 1.2: Get all unique nodes (suppliers + top 10 importers)
all_suppliers = sankey_df['supplier_country'].unique()
all_importers = sankey_df['country'].unique()  # These are already top 10 importers

print(f"Unique supplier countries: {len(all_suppliers)}")
print(f"Top 10 importer countries: {len(all_importers)}")

# Create combined node list (suppliers first, then importers)
all_nodes = list(all_suppliers) + list(all_importers)
print(f"Total nodes in Sankey: {len(all_nodes)}")

# Create mapping from country name to index
node_indices = {node: i for i, node in enumerate(all_nodes)}

In [None]:
# Step 2.1: Prepare source, target, and value arrays
sources = []
targets = []
values = []
hover_info = []

print("\nCreating link data...")
for _, row in sankey_df.iterrows():
    source_idx = node_indices[row['supplier_country']]
    target_idx = node_indices[row['country']]
    
    sources.append(source_idx)
    targets.append(target_idx)
    values.append(row['tiv_total_order'])
    
    # Create informative hover text
    hover_text = (f"<b>{row['supplier_country']} → {row['country']}</b><br>"
                  f"Total TIV: {row['tiv_total_order']:,.0f}<br>"
                  f"Number of deals: {row['number_of_deals']}<br>"
                  f"Avg deal size: {row['tiv_total_order']/row['number_of_deals']:,.0f}")
    hover_info.append(hover_text)

print(f"Created {len(sources)} links")

In [None]:
# Step 3.1: Create node colors (suppliers vs importers)
def get_node_colors(nodes, suppliers_list, importers_list):
    node_colors = []
    for node in nodes:
        if node in suppliers_list:
            # Different colors for major suppliers
            if node == 'USA':
                node_colors.append('rgba(31, 119, 180, 0.8)')  # Blue for USA
            elif node == 'Russia':
                node_colors.append('rgba(214, 39, 40, 0.8)')   # Red for Russia
            elif node == 'France':
                node_colors.append('rgba(44, 160, 44, 0.8)')   # Green for France
            elif node == 'China':
                node_colors.append('rgba(255, 0, 0, 0.8)')     # Bright Red for China
            else:
                node_colors.append('rgba(148, 103, 189, 0.8)') # Purple for other suppliers
        else:
            # Importers get orange shades
            node_colors.append('rgba(255, 127, 14, 0.8)')      # Orange for importers
    return node_colors

# Step 3.2: Create link colors (based on source supplier)
def get_link_colors(sources, nodes):
    link_colors = []
    color_map = {
        'USA': 'rgba(31, 119, 180, 0.4)',
        'Russia': 'rgba(214, 39, 40, 0.4)',
        'France': 'rgba(44, 160, 44, 0.4)',
        'China': 'rgba(255, 0, 0, 0.4)',
        'United Kingdom': 'rgba(148, 103, 189, 0.4)',
        'Germany': 'rgba(140, 86, 75, 0.4)',
    }
    
    for source_idx in sources:
        source_node = nodes[source_idx]
        link_colors.append(color_map.get(source_node, 'rgba(128, 128, 128, 0.4)'))
    
    return link_colors

node_colors = get_node_colors(all_nodes, all_suppliers, all_importers)
link_colors = get_link_colors(sources, all_nodes)

In [None]:
# Step 4: Create the Sankey Diagram (CORRECTED)

print("Creating Sankey diagram...")

fig = go.Figure(data=[go.Sankey(
    # Node configuration
    node=dict(
        pad=20,
        thickness=25,
        line=dict(color="black", width=0.8),
        label=all_nodes,
        color=node_colors,
        hoverinfo='none'  # We'll handle hover via links
    ),
    
    # Link configuration - CORRECTED HOVERINFO
    link=dict(
        source=sources,
        target=targets,
        value=values,
        color=link_colors,
        customdata=hover_info,  # Store hover text here
        hovertemplate='%{customdata}<extra></extra>',  # Correct way to show custom hover
        line=dict(width=0.1)
    )
)])

# Step 4.2: Update layout
fig.update_layout(
    title=dict(
        text="Arms Trade Flow: Suppliers to Top 10 Importers (2020-2023)<br><sub>TIV Values Represented by Flow Width</sub>",
        x=0.5,
        font=dict(size=16, family="Arial")
    ),
    font=dict(size=10),
    height=800,
    width=1200,
    margin=dict(l=80, r=80, t=100, b=50),
    paper_bgcolor='white'
)

print("Displaying Sankey diagram...")
fig.show()

In [None]:
# SIMPLER WORKING VERSION - Minimum required parameters
print("Creating simplified Sankey diagram...")

fig = go.Figure(data=[go.Sankey(
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color="black", width=0.5),
        label=all_nodes,
        color=node_colors,
    ),
    link=dict(
        source=sources,
        target=targets,
        value=values,
        color=link_colors,
        # Remove complex hover settings for now
    )
)])

# Add hover effects through Plotly's built-in functionality
fig.update_traces(
    selector=dict(type='sankey'),
    hoverinfo='all'  # This should work for basic hover
)

fig.update_layout(
    title="Arms Trade Network: Suppliers to Top 10 Importers (2020-2023)",
    height=700,
    font=dict(size=10)
)

print("Displaying diagram...")
fig.show()