### This file analyzes USA's trade volumne with differentr countries identifying top partners for export and import. 
### Then there are visualizations for exploring USA's export to top 10 countries with tariff imposed by those countries over the years.
### Similarly, there are visualizations for exploring USA's imports to top 10 countries with tariff imposed by US over the years.

In [33]:
## First read the combined primary-secondary data csv file into a dataframe

import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import plotly.figure_factory as ff
from plotly.subplots import make_subplots
import math
import nbformat

# Read the CSV file into a DataFrame
combined_df = pd.read_csv('./../../data/processed/combined_primary_secondary.csv')
# Display the first few rows of the DataFrame
print(combined_df.head(5))

       country category  year  import_value  export_value  \
0  Afghanistan    Parts  2008         10294      26954442   
1  Afghanistan    Parts  2009        161658     102667946   
2  Afghanistan    Parts  2010         26552      87621159   
3  Afghanistan    Parts  2011          5928      51232064   
4  Afghanistan    Parts  2012          2500      84963182   

   mfn_by_us_simple_avg  mfn_by_us_weighted_avg  mfn_on_us_simple_avg  \
0                  3.31                    1.64                 7.680   
1                  2.83                    1.42                 7.911   
2                  2.44                    1.15                 8.142   
3                  3.82                    1.36                 8.373   
4                  3.69                    1.52                 8.604   

   mfn_on_us_weighted_avg           gdp  gdp_2015_adj  
0                   5.480  1.010930e+10  1.106040e+10  
1                   5.715  1.241615e+10  1.342627e+10  
2                   5.950 

In [9]:
## Keep columns country, category, year, import_value, export_value, mfn_by_us_simple_avg, mfn_on_us_simple_avg
combined_df = combined_df[['country', 'category', 'year', 'import_value', 'export_value', 'mfn_by_us_simple_avg', 'mfn_on_us_simple_avg']].copy()

# find NaN values in combined_df
nan_counts = combined_df.isna().sum()

# Replace NaN values with average of the column for numeric columns only
numeric_cols = combined_df.select_dtypes(include=[np.number]).columns.tolist()
for column in numeric_cols:
    mean_value = combined_df[column].mean()
    # use .loc to avoid chained assignment warnings
    combined_df.loc[:, column] = combined_df[column].fillna(mean_value)

nan_counts_after = combined_df.isna().sum()

In [41]:
# Normalize column names to a canonical lowercase, space-separated form to avoid KeyErrors
def _normalize_cols(df):
    mapping = {}
    for col in df.columns:
        # lower, strip, replace underscores with spaces, collapse multiple spaces
        norm = ' '.join(col.lower().strip().replace('_', ' ').split())
        mapping[col] = norm
    df = df.rename(columns=mapping)
    return df

combined_df = _normalize_cols(combined_df)

# If some expected MFN columns exist with different names, map them to canonical names
col_map = {}
if 'mfn by us simple avg' in combined_df.columns:
    col_map['mfn by us simple avg'] = 'mfn_by_us'
if 'mfn on us simple avg' in combined_df.columns:
    col_map['mfn on us simple avg'] = 'mfn_on_us'
if 'import value' in combined_df.columns:
    col_map['import value'] = 'import value'  # already canonical
if 'export value' in combined_df.columns:
    col_map['export value'] = 'export value'
# Apply any mapped renames (to shorter canonical MFN names)
combined_df = combined_df.rename(columns=col_map)
combined_df.head(5)

Unnamed: 0,country,category,year,import value,export value,mfn_by_us,mfn by us weighted avg,mfn_on_us,mfn on us weighted avg,gdp,gdp 2015 adj
0,Afghanistan,Parts,2008,10294,26954442,3.31,1.64,7.68,5.48,10109300000.0,11060400000.0
1,Afghanistan,Parts,2009,161658,102667946,2.83,1.42,7.911,5.715,12416150000.0,13426270000.0
2,Afghanistan,Parts,2010,26552,87621159,2.44,1.15,8.142,5.95,15856670000.0,15354610000.0
3,Afghanistan,Parts,2011,5928,51232064,3.82,1.36,8.373,6.185,17805100000.0,15420080000.0
4,Afghanistan,Parts,2012,2500,84963182,3.69,1.52,8.604,6.42,19907330000.0,17386490000.0


##### Function to create a horizontal bar chart for countries on Y axis and their total export/import value on X axis

In [18]:
def plot_trade_value_by_country(df, value_col, country_col='country', top_n=20,
                                show_values=False, value_format=',.0f'):
    """Plot horizontal bar chart showing top N countries

    Args:
        df: DataFrame containing country and value columns.
        value_col: name of numeric value column (will be coerced).
        country_col: name of country/name column to display on the y-axis.
        top_n: how many countries to show (default 20).
        show_values: if True, display numeric values on the bars.
        value_format: d3 format string for value text (default no decimals).
    Returns:
        Plotly figure object.
    """

    # Validate columns
    if country_col not in df.columns or value_col not in df.columns:
        raise ValueError(f"DataFrame must contain columns: {value_col} and {country_col}. Available columns: {df.columns.tolist()}")

    # Coerce to numeric and drop invalid rows
    df[value_col] = df[value_col].astype(str).str.replace(',', '', regex=False)
    df[value_col] = pd.to_numeric(df[value_col], errors='coerce')
    df = df.dropna(subset=[value_col]).copy()

    # Aggregate and select top N
    agg = df.groupby(country_col)[value_col].sum().nlargest(top_n).reset_index()

    # Set figure height so labels don't overlap (approx 30 px per row + padding)
    height = max(400, 30 * len(agg) + 120)

    # Create horizontal bar chart
    fig = px.bar(agg, x=value_col, y=country_col, orientation='h',
                 title=f'Aggregate {value_col} by Country (Top {len(agg)})',
                 height=height)


    # Keep the ordering so the largest is at the top
    fig.update_layout(
        yaxis=dict(
            categoryorder='total ascending'
        ),
        xaxis=dict(
            showgrid=False
        )
    )

    return fig

In [19]:
# Plot Export Trade by Country:
fig = plot_trade_value_by_country(combined_df, value_col='export value', country_col='country', top_n=20, show_values=True)
fig.show()

In [20]:
# Plot Import Trade by Country:
fig = plot_trade_value_by_country(combined_df, value_col='import value', country_col='country', top_n=20, show_values=True)
fig.show()

#### Explore relationship between export value and tariff for top N countries

In [24]:
# Function to create a scatter plot with animation on year to see correlation between tariff and export value for top 20 countries

def plot_animated_scatter(df, x_col, y_col, size_col, color_col, animation_col, title):
    """Create animated scatter with correct hovertemplate using custom_data for the year.
    """
    # Defensive copy and type coercion
    df = df.copy()
    # Clean numeric columns
    df[x_col] = pd.to_numeric(df[x_col].astype(str).str.replace(',', ''), errors='coerce')
    df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
    df = df.dropna(subset=[x_col, y_col, color_col, animation_col])

    # Use custom_data to pass the animation frame (year) into the hovertemplate
    fig = px.scatter(
        df,
        x=x_col,
        y=y_col,
        size=size_col,
        color=color_col,
        animation_frame=animation_col,
        animation_group=color_col,
        custom_data=[animation_col],
        size_max=60,
        title=title,
        labels={x_col: "Export value", y_col: "Tariff Rate", color_col: "Country", animation_col: "Year"},
        hover_name=color_col,
        template="plotly_white"
    )

    # Define the hovertemplate once
    hover_template = (
        "<b>%{hovertext}</b><br><br>"
        "Export Value: $%{x:,.0f}<br>"
        "Tariff Rate: %{y:.2f}%<br>"
        "Year: %{customdata[0]}<extra></extra>"
    )

    # Apply to base traces
    fig.update_traces(hovertemplate=hover_template)

    # Also apply to each trace inside each animation frame (ensures it's preserved during animation)
    if hasattr(fig, 'frames') and fig.frames:
        for frame in fig.frames:
            # frame.data is a tuple/list of traces for that frame
            for trace in frame.data:
                # set hovertemplate for that trace
                trace.hovertemplate = hover_template

    # Set the scale for Y axis
    fig.update_yaxes(range=[0, 15], tick0=0, dtick=5)

    # Keep log x-axis if desired (optional) and style
    fig.update_layout(
        xaxis=dict(type='log', title='Export Value (log scale)'),
        yaxis=dict(title='Tariff Rate (%)'),
        legend_title_text='Country'
    )

    return fig

In [26]:
# build a new dataframe from combined_df with only country, year, Export_value and mfn_on_us_simple_avg
export_tariff_df = combined_df[['country', 'year', 'export value', 'mfn_on_us']].copy()

# sort new dataframe by year and Export_value in descending order
export_tariff_df = export_tariff_df.sort_values(by=['year', 'export value'], ascending=[False, False])

# Keep only top 20 countries by Export_value for each year
export_tariff_df = export_tariff_df.groupby('year').head(20).reset_index(drop=True) 

# Plot the animated scatter using the new dataframe
if __name__ == "__main__":

    fig = plot_animated_scatter(
        export_tariff_df,
        x_col='export value',
        y_col='mfn_on_us',
        size_col='export value',
        color_col='country',
        animation_col='year',
        title='Correlation between Export Value and Tariff Rate (Top 20 Countries by Export Value)'
    )

    fig.show()

In [27]:
# Function to create a scatter plot with animation on year to see correlation between tariff and import value for top 20 countries

def plot_animated_scatter(df, x_col, y_col, size_col, color_col, title, animation_col=None):
    """Create animated scatter with correct hovertemplate using custom_data for the year.
    """
    # Defensive copy and type coercion
    df = df.copy()
    # Clean numeric columns
    df[x_col] = pd.to_numeric(df[x_col].astype(str).str.replace(',', ''), errors='coerce')
    df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
    df = df.dropna(subset=[x_col, y_col, color_col, animation_col])

    # Use custom_data to pass the animation frame (year) into the hovertemplate
    fig = px.scatter(
        df,
        x=x_col,
        y=y_col,
        size=size_col,
        color=color_col,
        animation_frame=animation_col,
        animation_group=color_col,
        custom_data=[animation_col],
        size_max=60,
        title=title,
        labels={x_col: "Import Value", y_col: "Tariff Rate", color_col: "Country", animation_col: "Year"},
        hover_name=color_col,
        template="plotly_white"
    )

    # Define the hovertemplate once
    hover_template = (
        "<b>%{hovertext}</b><br><br>"
        "Import Value: $%{x:,.0f}<br>"
        "Tariff Rate: %{y:.2f}%<br>"
        "Year: %{customdata[0]}<extra></extra>"
    )

    # Apply to base traces
    fig.update_traces(hovertemplate=hover_template)

    # Also apply to each trace inside each animation frame (ensures it's preserved during animation)
    if hasattr(fig, 'frames') and fig.frames:
        for frame in fig.frames:
            # frame.data is a tuple/list of traces for that frame
            for trace in frame.data:
                # set hovertemplate for that trace
                trace.hovertemplate = hover_template

    # Keep log x-axis if desired (optional) and style
    fig.update_layout(
        xaxis=dict(type='log', title='Import Value (log scale)'),
        yaxis=dict(title='Tariff Rate (%)'),
        legend_title_text='Country'
    )

    return fig

In [28]:
import_tariff_df = combined_df[['country', 'year', 'import value', 'mfn_by_us']].copy()
import_tariff_df = import_tariff_df.sort_values(by=['year', 'import value'], ascending=[False, False])
import_tariff_df = import_tariff_df.groupby('year').head(20).reset_index(drop=True)

if __name__ == "__main__":

    fig = plot_animated_scatter(
        import_tariff_df,
        x_col='import value',
        y_col='mfn_by_us',
        size_col='import value',
        color_col='country',
        animation_col='year',
        title='Correlation between Import Value and Tariff Rate (Top 20 Countries by Import Value)'
    )

    fig.show()

#### Canada and Mexico are the top 2 partners for export as well as import. Let us filter them out and visualize the trend between tariff and import/export for next top 10 partners.

In [29]:
# build a new dataframe from combined_df with columns for country, category, year, Export value and mfn_on_us
export_tariff_df1 = combined_df[['country', 'category', 'year', 'export value', 'mfn_on_us']].copy()

# Copy export_tariff_df1 into a dataframe called filtered_df with grouping on country and category and aggregate to get top 10 countries by export value for each year
filtered_df = export_tariff_df1[export_tariff_df1['country'].isin(['Canada', 'Mexico']) == False]

# Keep only rows where category is 'Passenger Vehicles'
filtered_df = filtered_df[filtered_df['category'] == 'Passenger Vehicles']

# Aggregate the export value by country and category
filtered_df = filtered_df.groupby(['country']).agg({'export value': 'sum', 'mfn_on_us': 'mean'}).reset_index()

# Sort the dataframe by Export_value in descending order and take top 10 countries by export value for each category
filtered_df = filtered_df.sort_values('export value', ascending=False).groupby(['country']).head(10)

In [31]:
# filtered_df gives top 10 countries by export value for passenger vehicles excluding Canada and Mexico.
# Take the top 10 countries from this dataframe in a list.
top_10_countries = filtered_df['country'].unique().tolist()[:10]

# Now create a new dataframe from combined_df where the country is in top_10_countries and category is 'Passenger Vehicles'
top_countries_df = combined_df[(combined_df['country'].isin(top_10_countries)) & (combined_df['category'] == 'Passenger Vehicles')].copy()

# Keep columns country, year, Export_value and mfn_on_us_simple_avg
top_countries_df = top_countries_df[['country', 'year', 'export value', 'mfn_on_us']].copy()

top_countries_df.rename(columns={
    'mfn_on_us': 'Tariff on US'}, 
    inplace=True)

top_countries_df.head(10)

Unnamed: 0,country,year,export value,Tariff on US
222,Australia,2008,564502516,4.61
225,Australia,2009,271049394,3.85
228,Australia,2010,471495141,3.08
231,Australia,2011,735422102,3.05
234,Australia,2012,1081054868,3.01
237,Australia,2013,1388250828,3.03
240,Australia,2014,1847630380,3.02
243,Australia,2015,1649200775,2.79
246,Australia,2016,1281662454,2.79
249,Australia,2017,1608810594,2.78


#### plot 10 subplots for each country with 2 y axis - One for export value and the other for tariff rate. Each subplot should represent one country.

In [35]:
# plot 10 subplots for each country with 2 y axis - One for export value and the other for tariff rate. 
# Each subplot should represent one country.

def plot_export_tariff_subplots(df, countries, Y1, Y2, cols=3):
    
    if len(countries) == 0:
        raise ValueError('No countries to plot. Ensure countries list is populated')

    # grid layout
    rows = math.ceil(len(countries) / cols)

    # create specs with secondary_y for every subplot
    specs = [[{"secondary_y": True} for _ in range(cols)] for _ in range(rows)]

    # prepare subplot titles (pad if needed)
    subplot_titles = countries + [""] * (rows * cols - len(countries))

    fig = make_subplots(rows=rows, cols=cols, subplot_titles=subplot_titles, shared_xaxes=False, specs=specs)

    # iterate and add traces per country
    for i, country in enumerate(countries):
        row = i // cols + 1
        col = i % cols + 1

        country_data = df[df['country'] == country].sort_values('year')
        # ensure year is numeric/sortable
        try:
            country_data['year'] = pd.to_numeric(country_data['year'], errors='coerce')
        except Exception:
            pass

        # Export value on primary y-axis (hide per-country legend entries)
        fig.add_trace(
            go.Scatter(
                x=country_data['year'],
                y=country_data[Y1],   
                mode='lines+markers',
                name=f'Y1',
                line=dict(color='blue'),
                showlegend=False
            ),
            row=row, col=col, secondary_y=False
        )

        # Tariff rate on secondary y-axis (hide per-country legend entries)
        fig.add_trace(
            go.Scatter(
                x=country_data['year'],
                y=country_data[Y2],
                mode='lines+markers',
                name=f'Tariff Rate',
                line=dict(color='red'),
                showlegend=False
            ),
            row=row, col=col, secondary_y=True
        )
        # Update y-axis titles for this subplot
        fig.update_yaxes(title_text=Y1, row=row, col=col, secondary_y=False)
        fig.update_yaxes(title_text='Tariff Rate (%)', row=row, col=col, secondary_y=True)  

    # Add a single legend mapping colors to metric names (one entry each)
    # We add invisible traces with the desired legend names/colors so the legend shows only these two entries.
    fig.add_trace(
        go.Scatter(x=[None], y=[None], mode='lines', line=dict(color='blue'), name=Y1),
        row=1, col=1, secondary_y=False
    )
    fig.add_trace(
        go.Scatter(x=[None], y=[None], mode='lines', line=dict(color='red'), name='Tariff Rate (%)'),
        row=1, col=1, secondary_y=True
    )
    # Layout adjustments
    fig.update_layout(height=300 * rows, width=1200, title_text=f'{Y1} and Tariff Rate by Country (Top)', showlegend=True)
    # tighten spacing
    fig.update_layout(margin=dict(t=80, l=50, r=50,
                                    b=50))
    return fig

In [36]:
# Call the function to plot subplots for top 10 countries by export value for passenger vehicles excluding Canada and Mexico
fig = plot_export_tariff_subplots(top_countries_df, top_10_countries, Y1='export value', Y2='Tariff on US', cols=3)
fig.show()

#### Create a similar grid plot for imports into USA against tariff imposed by USA on top 10 trading partners

In [42]:
# build a new dataframe from combined_df with columns for country, category, year, import value and mfn_by_us
import_tariff_df1 = combined_df[['country', 'category', 'year', 'import value', 'mfn_by_us']].copy()

# Filter out Canada and Mexico
filtered_df1 = import_tariff_df1[~import_tariff_df1['country'].isin(['Canada', 'Mexico'])]

# Keep only rows where category is 'Passenger Vehicles'
filtered_df1 = filtered_df1[filtered_df1['category'] == 'Passenger Vehicles']

# Aggregate the import value by country and category
filtered_df1 = filtered_df1.groupby(['country']).agg({'import value': 'sum', 'mfn_by_us': 'mean'}).reset_index()

# Sort the dataframe by import value in descending order and take top 10 countries by import value
filtered_df1 = filtered_df1.sort_values('import value', ascending=False).head(10)

filtered_df1.head(10)

Unnamed: 0,country,import value,mfn_by_us
41,Japan,606181344605,3.790667
29,Germany,337932199932,3.650667
72,South Korea,266349252198,3.837333
83,United Kingdom,104492803463,3.725333
40,Italy,43824255789,3.914
70,Slovakia,41361033699,3.841333
74,Sweden,31231492582,3.132667
71,South Africa,23354453521,3.263333
5,Belgium,17834679622,3.736667
3,Austria,16992116599,3.406


In [45]:
top_10_import = filtered_df1['country'].unique().tolist()[:10]

# Now create a new dataframe from combined_df where the country is in top_10_import and category is 'Passenger Vehicles'
top_import_countries_df = combined_df[(combined_df['country'].isin(top_10_import)) & (combined_df['category'] == 'Passenger Vehicles')].copy()

# Keep columns country, year, Import_value and mfn_by_us_simple_avg
top_import_countries_df = top_import_countries_df[['country', 'year', 'import value', 'mfn_by_us']].copy()

top_import_countries_df.rename(columns={
    'mfn_by_us': 'Tariff by US',
    'import value': 'Import Value'}, 
    inplace=True)


## Now use the function created above for plotting export/import value and tariff rate subplots to plot import value and tariff rate for top 10 countries by import value for passenger vehicles excluding Canada and Mexico
fig = plot_export_tariff_subplots(top_import_countries_df, top_10_import, Y1='Import Value', Y2='Tariff by US', cols=3)
fig.show()