# Prepare Data

In [3]:
# Import libraries
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

# Load data
export_clean = pd.read_csv("../data/Processed/Export_Trade_Cleaned.csv")
import_clean = pd.read_csv("../data/Processed/Import_Trade_Cleaned.csv")

# Clean and reshape
def clean_and_melt(df, value_name):
    df = df.copy()
    df.columns = df.columns.str.strip()
    
    # Remove $ and commas and convert to float
    for col in df.columns[1:]:
        df[col] = df[col].replace('[\$,]', '', regex=True).astype(float)
    
    # Reshape
    df_long = df.melt(id_vars=["Product"], var_name="Year", value_name=value_name)
    df_long["Year"] = df_long["Year"].astype(int)
    return df_long

export_long = clean_and_melt(export_clean, "Export")
import_long = clean_and_melt(import_clean, "Import")

# U.S. Export vs Import Trade (Top 10 Exported Products, 2009–2024)

In [4]:
# Merge & compute ranks
merged = pd.merge(export_long, import_long, on=["Product", "Year"])
merged["Is_All_Merchandise"] = merged["Product"].str.contains("All Merchandise")

# Select top 10 export products per year (excluding All Merchandise)
ranked = merged[~merged["Is_All_Merchandise"]].copy()
ranked["Rank"] = ranked.groupby("Year")["Export"].rank(method="first", ascending=False)

# Select top 10 per year
top_products = ranked.groupby("Year").apply(lambda g: g.nsmallest(10, "Rank")).reset_index(drop=True)

# Bubble size based on export value (scale down if needed)
top_products["Size"] = top_products["Export"] / 1000  # You can adjust divisor to scale

#  Animated Scatter Plot
fig = px.scatter(
    top_products,
    x="Import",
    y="Export",
    color="Product",
    animation_frame="Year",
    size="Size",
    hover_name="Product",
    hover_data={"Rank": True, "Export": True, "Import": True},
    title="U.S. Export vs Import Trade (Top 10 Exported Products, 2009–2024)"
)

# Add 45° trade balance line
max_val = max(top_products["Export"].max(), top_products["Import"].max())
fig.add_trace(go.Scatter(
    x=[0, max_val],
    y=[0, max_val],
    mode='lines',
    line=dict(color='black', dash='dash'),
    name='Trade Balance = 0'
))

# Layout
fig.update_layout(
    xaxis_title="Import Value (Million USD)",
    yaxis_title="Export Value (Million USD)",
    legend_title="Product Category",
    xaxis=dict(range=[0, max_val * 1.05]),
    yaxis=dict(range=[0, max_val * 1.05]),
    height=700
)

fig.show()
fig.write_html("../img/top10_trade_bubble.html")

# U.S. Export vs Import: All Merchandise (2009–2024)

In [5]:
merged = pd.merge(export_long, import_long, on=["Product", "Year"])
merged["Is_All_Merchandise"] = merged["Product"].str.contains("All Merchandise")
all_merch = merged[merged["Is_All_Merchandise"]].copy()
all_merch["Size"] = 40  # fixed bubble size

# Animated Plot for All Merchandise only
fig = px.scatter(
    all_merch,
    x="Import",
    y="Export",
    animation_frame="Year",
    size="Size",
    hover_name="Product",
    color_discrete_sequence=["gray"],
    text=["All Merchandise"] * len(all_merch),
    title="U.S. Export vs Import: All Merchandise (2009–2024)"
)

# Add 45-degree trade balance line
max_val = max(all_merch["Export"].max(), all_merch["Import"].max())
fig.add_trace(go.Scatter(
    x=[0, max_val],
    y=[0, max_val],
    mode='lines',
    line=dict(color='black', dash='dash'),
    name='Trade Balance = 0'
))

# Layout settings
fig.update_layout(
    xaxis_title="Import Value (Million USD)",
    yaxis_title="Export Value (Million USD)",
    xaxis=dict(range=[0, max_val * 1.05]),
    yaxis=dict(range=[0, max_val * 1.05]),
    height=650,
    showlegend=False
)

fig.show()
fig.write_html("../img/All_Merchandise_bubble.html")

In [6]:
# Filter for All Merchandise
all_merch = merged[merged["Is_All_Merchandise"]].copy()

# Line plot for All Merchandise export and import over time
fig_all = go.Figure()

# Add export line
fig_all.add_trace(go.Scatter(
    x=all_merch["Year"],
    y=all_merch["Export"],
    mode='lines+markers',
    name="Exports",
    line=dict(color="green"),
    marker=dict(size=6),
    hovertemplate="Year: %{x}<br>Export: %{y:,.0f}M<extra></extra>"
))

# Add import line
fig_all.add_trace(go.Scatter(
    x=all_merch["Year"],
    y=all_merch["Import"],
    mode='lines+markers',
    name="Imports",
    line=dict(color="red"),
    marker=dict(size=6),
    hovertemplate="Year: %{x}<br>Import: %{y:,.0f}M<extra></extra>"
))

# Customize layout
fig_all.update_layout(
    title="U.S. Trade Summary for All Merchandise (2009–2024)",
    xaxis_title="Year",
    yaxis_title="Trade Value (Million USD)",
    height=500,
    legend_title="Flow",
    template="plotly_white"
)

fig_all.show()
fig_all.write_html("../img/Import_export_all_trend.html")

In [7]:
# Filter for All Merchandise
all_merch = merged[merged["Is_All_Merchandise"]].copy()

# Line plot for All Merchandise export and import over time
fig_all = go.Figure()

# Add export line
fig_all.add_trace(go.Scatter(
    x=all_merch["Year"],
    y=all_merch["Export"],
    mode='lines+markers',
    name="Exports",
    line=dict(color="green"),
    marker=dict(size=6),
    hovertemplate="Year: %{x}<br>Export: %{y:,.0f}M<extra></extra>"
))

# Add import line
fig_all.add_trace(go.Scatter(
    x=all_merch["Year"],
    y=all_merch["Import"],
    mode='lines+markers',
    name="Imports",
    line=dict(color="red"),
    marker=dict(size=6),
    hovertemplate="Year: %{x}<br>Import: %{y:,.0f}M<extra></extra>"
))

# Customize layout
fig_all.update_layout(
    title="U.S. Trade Summary for All Merchandise (2009–2024)",
    xaxis_title="Year",
    yaxis_title="Trade Value (Million USD)",
    height=500,
    legend_title="Flow",
    template="plotly_white",
    xaxis=dict(showgrid=False),
    yaxis=dict(showgrid=False)
)



# Add vertical dashed lines
fig_all.update_layout(
    shapes=[
        dict(
            type="line",
            x0=2010, x1=2010,
            y0=0, y1=all_merch["Import"].max(),
            line=dict(color="lightgray", dash="dash", width=1)
        ),
        dict(
            type="line",
            x0=2018, x1=2018,
            y0=0, y1=all_merch["Import"].max(),
            line=dict(color="lightgray", dash="dash", width=1)
        ),
        dict(
            type="line",
            x0=2020, x1=2020,
            y0=0, y1=all_merch["Import"].max(),
            line=dict(color="lightgray", dash="dash", width=1)
        ),
        dict(
            type="line",
            x0=2022, x1=2022,
            y0=0, y1=all_merch["Import"].max(),
            line=dict(color="lightgray", dash="dash", width=1)
        )
    ]
)


# Add top annotations with monospace font
annotations = [
    dict(
        x=2010,
        y=all_merch["Import"].max() * 0.85,  # top
        text="Post-recession recovery",
        showarrow=False,
        font=dict(family="Courier New, monospace", size=12, color="black")
    ),
    dict(
        x=2018,
        y=all_merch["Import"].max() * 0.85,  # lower
        text="U.S.–China trade war begins",
        showarrow=False,
        font=dict(family="Courier New, monospace", size=12, color="black")
    ),
    dict(
        x=2020,
        y=all_merch["Import"].max() * 0.65,  
        text="COVID-19 disrupts global trade",
        showarrow=False,
        font=dict(family="Courier New, monospace", size=12, color="black")
    ),
    dict(
        x=2022,
        y=all_merch["Import"].max() * 1.05,  # same as before
        text="Supply chain crisis & rebound",
        showarrow=False,
        font=dict(family="Courier New, monospace", size=12, color="black")
    )
]

fig_all.update_layout(annotations=annotations)


fig_all.show()
fig_all.write_html("../img/Import_export_all_trend_annotated.html")


In [11]:
import plotly.graph_objects as go

# Filter for All Merchandise
all_merch = merged[merged["Is_All_Merchandise"]].copy()

# Base figure
fig_all = go.Figure()

# Initial data (starting frame)
initial = all_merch[all_merch["Year"] <= all_merch["Year"].min()]
fig_all.add_trace(go.Scatter(
    x=initial["Year"],
    y=initial["Export"],
    mode='lines+markers',
    name="Exports",
    line=dict(color="green"),
    marker=dict(size=6),
    hovertemplate="Year: %{x}<br>Export: %{y:,.0f}M<extra></extra>"
))
fig_all.add_trace(go.Scatter(
    x=initial["Year"],
    y=initial["Import"],
    mode='lines+markers',
    name="Imports",
    line=dict(color="red"),
    marker=dict(size=6),
    hovertemplate="Year: %{x}<br>Import: %{y:,.0f}M<extra></extra>"
))

# Layout customization
fig_all.update_layout(
    title="U.S. Trade Summary for All Merchandise (2009–2024)",
    xaxis_title="Year",
    yaxis_title="Trade Value (Million USD)",
    height=500,
    legend_title="Flow",
    template="plotly_white",
    xaxis=dict(
        showgrid=False,
        range=[2009, 2024]
    ),
    yaxis=dict(showgrid=False),
    shapes=[
        dict(type="line", x0=2010, x1=2010, y0=0, y1=all_merch["Import"].max(),
             line=dict(color="lightgray", dash="dash", width=1)),
        dict(type="line", x0=2018, x1=2018, y0=0, y1=all_merch["Import"].max(),
             line=dict(color="lightgray", dash="dash", width=1)),
        dict(type="line", x0=2020, x1=2020, y0=0, y1=all_merch["Import"].max(),
             line=dict(color="lightgray", dash="dash", width=1)),
        dict(type="line", x0=2022, x1=2022, y0=0, y1=all_merch["Import"].max(),
             line=dict(color="lightgray", dash="dash", width=1))
    ]
)

# Define annotations to show in specific years
annotations_by_year = {
    2010: [dict(
        x=2010,
        y=all_merch["Import"].max() * 0.85,
        text="Post-recession recovery",
        showarrow=False,
        font=dict(family="Courier New, monospace", size=12, color="black")
    )],
    2018: [dict(
        x=2018,
        y=all_merch["Import"].max() * 0.85,
        text="U.S.–China trade war begins",
        showarrow=False,
        font=dict(family="Courier New, monospace", size=12, color="black")
    )],
    2020: [dict(
        x=2020,
        y=all_merch["Import"].max() * 0.65,
        text="COVID-19 disrupts global trade",
        showarrow=False,
        font=dict(family="Courier New, monospace", size=12, color="black")
    )],
    2022: [dict(
        x=2022,
        y=all_merch["Import"].max() * 1.05,
        text="Supply chain crisis & rebound",
        showarrow=False,
        font=dict(family="Courier New, monospace", size=12, color="black")
    )]
}

# Create frames for animation
years = sorted(all_merch["Year"].unique())
frames = []

for yr in years:
    current = all_merch[all_merch["Year"] <= yr]
    
    # Cumulative annotations: show all annotations up to current year
    annotations = []
    for k in sorted(annotations_by_year):
        if k <= yr:
            annotations.extend(annotations_by_year[k])

    frames.append(go.Frame(
        name=str(yr),
        data=[
            go.Scatter(x=current["Year"], y=current["Export"],
                       mode="lines+markers", name="Exports",
                       line=dict(color="green"), marker=dict(size=6)),
            go.Scatter(x=current["Year"], y=current["Import"],
                       mode="lines+markers", name="Imports",
                       line=dict(color="red"), marker=dict(size=6))
        ],
        layout=go.Layout(annotations=annotations)
    ))


# Add slider and play button
fig_all.update_layout(
    updatemenus=[dict(
        type="buttons",
        showactive=False,
        buttons=[dict(label="Play",
                      method="animate",
                      args=[None, {
                          "frame": {"duration": 500, "redraw": True},
                          "fromcurrent": True,
                          "transition": {"duration": 200}
                      }])]
    )],
    sliders=[dict(
        steps=[dict(method="animate",
                    args=[[str(yr)], {"frame": {"duration": 0, "redraw": True},
                                      "mode": "immediate"}],
                    label=str(yr)) for yr in years],
        transition=dict(duration=0),
        x=0, y=0, len=1.0
    )]
)

# Attach frames to figure
fig_all.frames = frames

# Show and export
fig_all.show()
fig_all.write_html("../img/Import_export_all_trend_animated.html")
