In [53]:
# pip install "notebook>=7.0" "anywidget>=0.9.13"
# pip install plotly


<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=1414b58c-d851-4542-a845-1dccbe508e47' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>

In [54]:
# import necessary libraries
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots


In [55]:
import pandas as pd
import numpy as np

# Parameters
countries = ["Mexico", "China", "India", "Germany", "Brazil"]
years = list(range(2018, 2023))
categories = ["Parts", "Light", "Heavy"]

# Build dummy dataset
data = []
np.random.seed(42)  # reproducibility

for country in countries:
    for year in years:
        for cat in categories:
            imports = np.random.randint(50, 500) * 1e6   # in millions
            exports = np.random.randint(40, 450) * 1e6
            data.append([country, year, cat, imports, exports])

df = pd.DataFrame(data, columns=["Country", "Year", "Category", "Imports", "Exports"])

df["Balance"] = df["Exports"] - df["Imports"]
df["Year"] = df["Year"].astype(str)  # Convert Year to string for better plotting
df

Unnamed: 0,Country,Year,Category,Imports,Exports,Balance
0,Mexico,2018,Parts,152000000.0,388000000.0,236000000.0
1,Mexico,2018,Light,320000000.0,146000000.0,-174000000.0
2,Mexico,2018,Heavy,121000000.0,228000000.0,107000000.0
3,Mexico,2019,Parts,70000000.0,142000000.0,72000000.0
4,Mexico,2019,Light,171000000.0,254000000.0,83000000.0
...,...,...,...,...,...,...
70,Brazil,2021,Light,221000000.0,399000000.0,178000000.0
71,Brazil,2021,Heavy,263000000.0,74000000.0,-189000000.0
72,Brazil,2022,Parts,498000000.0,266000000.0,-232000000.0
73,Brazil,2022,Light,150000000.0,170000000.0,20000000.0


Helper fucntions to plot

In [56]:
def plot_line(df, x, y, x_label, y_label,legend_label,title, **kwargs):
    """
    Plot line chart using Plotly.
    Parameters:
    - df: DataFrame containing the data to plot.
    - x: Column name for x-axis.
    - y: List of column names for y-axis.
    - x_label: Label for x-axis.
    - y_label: Label for y-axis.
    - legend_label: Label for the legend.
    - **kwargs: Additional keyword arguments for customization.
    Returns:
    - fig: Plotly figure object.
    Raises:
    - ValueError: If specified columns are not in the DataFrame.
    - TypeError: If input types are incorrect.
    Example:
    plot_line(df, x="Year", y=["Imports", "Exports", "Balance"], x_label="Year", y_label="Value", legend_label="Measure", title="Trade Data Over Years")
    
    """

    fig = px.line(df, x=x, y=y, **kwargs)
    fig.update_layout(
        xaxis_title=x_label,
        yaxis_title=y_label,
        legend_title=legend_label,
        title={
            'text': title,
            'x':0.5, # center title
            'xanchor': 'center',
            'yanchor': 'top',
            'font': {'size':20}
        }
    )
    return fig

In [57]:
agg_df = df.groupby("Year")[["Imports", "Exports", "Balance"]].sum().reset_index()


fig = plot_line(agg_df, 
          x="Year",
          y=["Imports", "Exports", "Balance"],
          x_label="Year",
          y_label="Value",
          legend_label="Measure",
          title="Trade Data Over Years")

fig.show()

In [58]:
def plot_scatter(df, x, y, x_label, y_label,legend_label,title, **kwargs):
    """
    Plot line chart using Plotly.
    Parameters:
    - df: DataFrame containing the data to plot.
    - x: Column name for x-axis.
    - y: List of column names for y-axis.
    - x_label: Label for x-axis.
    - y_label: Label for y-axis.
    - legend_label: Label for the legend.
    - **kwargs: Additional keyword arguments for customization.
    Returns:
    - fig: Plotly figure object.
    Raises:
    - ValueError: If specified columns are not in the DataFrame.
    - TypeError: If input types are incorrect.
    Example:
    plot_line(df, x="Year", y=["Imports", "Exports", "Balance"], x_label="Year", y_label="Value", legend_label="Measure", title="Trade Data Over Years")
    
    """

    fig = px.scatter(df, x=x, y=y, **kwargs)
    fig.update_layout(
        xaxis_title=x_label,
        yaxis_title=y_label,
        legend_title=legend_label,
        title={
        'text': title,
        'x':0.5, # center title
        'xanchor': 'center',
        'yanchor': 'top',
        'font': {'size':20}
    }
    )
    return fig

In [59]:
agg_df = df.groupby("Year")[["Imports", "Exports", "Balance"]].sum().reset_index()


fig = plot_scatter(agg_df, 
          x="Exports",
          y="Imports",
          x_label="Year",
          y_label="Value",
          legend_label="Measure",
          title="Trade Data Over Years")

fig.show()

In [60]:
def plot_dumbell(df, x1, x2, y, x_label, y_label, legend_label, title, **kwargs):
    """
    Plot line chart using Plotly.
    Parameters:
        - df: DataFrame containing the data to plot.
        - x: Column name for x-axis.
        - y: List of column names for y-axis.
        - x_label: Label for x-axis.
        - y_label: Label for y-axis.
        - legend_label: Label for the legend.
        - **kwargs: Additional keyword arguments for customization.
    Returns:
        - fig: Plotly figure object.
    Raises:
        - ValueError: If specified columns are not in the DataFrame.
        - TypeError: If input types are incorrect.
    Example:
        plot_line(df, x="Year", y=["Imports", "Exports", "Balance"], x_label="Year", y_label="Value", legend_label="Measure", title="Trade Data Over Years")
    
    """
    

    # Create dumbbell chart
    fig = go.Figure()
    for _, row in df.iterrows():
        # Line connecting Exports and Imports
        fig.add_trace(go.Scatter(
            x=[row[x1], row[x2]],
            y=[row[y], row[y]],
            mode="lines",
            marker=dict(size=10),
            line=dict(color="grey", width=2),
            showlegend=False
        ))
        
    # Add separate markers for Exports and Imports
    
    # Add markers for Exports
    fig.add_trace(go.Scatter(
        x=df[x1],
        y=df[y],
        mode="markers",
        name=x1,
        marker=dict(color="blue", size=12)
    ))

    # Add markers for Imports
    fig.add_trace(go.Scatter(
        x=df[x2],
        y=df[y],
        mode="markers",
        name="Imports",
        marker=dict(color="pink", size=12)
    ))

    # updates labels, title, legend
    fig.update_layout(
        xaxis_title=x_label,
        yaxis_title=y_label,
        legend_title=legend_label,
        title={
        'text': title,
        'x':0.5, # center title
        'xanchor': 'center',
        'yanchor': 'top',
        'font': {'size':20}
    }
    )

    return fig


# 
agg_df = df.groupby("Country")[["Imports", "Exports", "Balance"]].sum().reset_index()
fig = plot_dumbell(agg_df, "Exports", "Imports", "Country",  x_label="Value", y_label="Year", legend_label="", title="Trade Data Over Years")
fig.show()
   

In [61]:
import plotly.express as px
df = px.data.gapminder()
df

Unnamed: 0,country,continent,year,lifeExp,pop,gdpPercap,iso_alpha,iso_num
0,Afghanistan,Asia,1952,28.801,8425333,779.445314,AFG,4
1,Afghanistan,Asia,1957,30.332,9240934,820.853030,AFG,4
2,Afghanistan,Asia,1962,31.997,10267083,853.100710,AFG,4
3,Afghanistan,Asia,1967,34.020,11537966,836.197138,AFG,4
4,Afghanistan,Asia,1972,36.088,13079460,739.981106,AFG,4
...,...,...,...,...,...,...,...,...
1699,Zimbabwe,Africa,1987,62.351,9216418,706.157306,ZWE,716
1700,Zimbabwe,Africa,1992,60.377,10704340,693.420786,ZWE,716
1701,Zimbabwe,Africa,1997,46.809,11404948,792.449960,ZWE,716
1702,Zimbabwe,Africa,2002,39.989,11926563,672.038623,ZWE,716


In [62]:
fig = px.scatter(df, x="gdpPercap", y="lifeExp", animation_frame="year", animation_group="country",
           size="pop", color="continent", hover_name="country", facet_col="continent",
           log_x=True, size_max=45, range_x=[100,100000], range_y=[25,90])
fig.show()

In [72]:
fig = px.scatter(df, x="gdpPercap", y="lifeExp",
           facet_col="continent",
           log_x=True, size_max=45)
fig.show()