In [1]:
import sdmx
import pandas as  pd 
import ipywidgets as widgets
from IPython.display import display, clear_output
import plotly.graph_objects as go
from plotly.io import write_image
from iso3166 import countries_by_alpha3, countries


In [3]:
client = sdmx.Client()
url = (
    # Base URL
    "http://dataservices.imf.org/REST/SDMX_XML.svc/CompactData/"
    # Data flow ID and key
    "DOT/M..TMG_CIF_USD.US+CN+B0"
    # Query parameters, including format
    "?startPeriod=2000&format=sdmx-2.1"
)

def get_highest_source(row):
    sources = {'B0': row['B0'], 'CN': row['CN'], 'US': row['US']}
    # Remove NaN values
    valid_sources = {k: v for k, v in sources.items() if pd.notna(v)}
    
    if not valid_sources:
        return None, None  # All values are NaN
    
    total = sum(valid_sources.values())
    max_key = max(valid_sources, key=valid_sources.get)
    max_value = valid_sources[max_key]
    
    return max_key, round(((max_value / total)*100), 1), round(max_value, 1)



In [6]:
def imf_data(url):
    message = client.get(url=url)
    raw = sdmx.to_pandas(message.data[0])
        
    process_1 = raw.reset_index()
    
    process_2 = process_1.pivot_table(
        index=['REF_AREA', 'TIME_PERIOD'], 
        columns='COUNTERPART_AREA',  
        values='value', 
        aggfunc='first')
    
    process_2.reset_index(inplace=True)
    process_2.columns.name = None

    process_3 = process_2.rename(columns={
        'FREQ': 'Data Granularity',
        'REF_AREA': 'Importer_Code',
        'INDICATOR': 'Indicator',
        'COUNTERPART_AREA': 'Exporter_Code',
    })

    process_3[['Import Partner', 'Percent', 'Amount (USD Millions)']] = process_3.apply(lambda row: pd.Series(get_highest_source(row)), axis=1)
    
    # Split Year and Month for convenience
    process_3[['Year', 'Month']] = process_3['TIME_PERIOD'].str.split('-', expand=True)

    # Create % Change Prev. Period column - now compartmentalized by Importer_Code
    process_4 = pd.DataFrame()
    importer_codes = process_3['Importer_Code'].unique()

    for code in importer_codes:
        # Filter rows for the current importer code
        importer_subset = process_3[process_3['Importer_Code'] == code].copy()
        pct_changes = []
        
        for i in range(len(importer_subset)):
            if i == 0:
                pct_changes.append(None)
                continue
            
            curr_row = importer_subset.iloc[i]
            prev_row = importer_subset.iloc[i-1]
            
            curr_partner = curr_row['Import Partner']
            prev_partner = prev_row['Import Partner']
            
            curr_amount = curr_row['Amount (USD Millions)']
            
            if pd.isna(curr_amount):
                pct_changes.append(None)
                continue
            
            if curr_partner == prev_partner:
                prev_amount = prev_row['Amount (USD Millions)']
            else:
                prev_amount = prev_row.get(curr_partner, None)
            
            if pd.isna(prev_amount) or prev_amount == 0:
                pct_changes.append(None)
            else:
                change = ((curr_amount - prev_amount) / prev_amount) * 100
                pct_changes.append(round(change, 2))
        
        importer_subset['% Change Prev. Period'] = pct_changes
        process_4 = pd.concat([process_4, importer_subset])

    process_4.reset_index(drop=True, inplace=True)

    # Map partner codes to full names
    process_4['Import Partner'] = process_4['Import Partner'].map({
        'B0': 'European Union', 
        'US': 'United States', 
        'CN': 'China'
    }).fillna(process_4['Import Partner'])

    source_mapping = {
    'China': 1,
    'European Union': 2, 
    'United States': 3
    
    }
    process_4.loc[:, 'source_cat'] = process_4['Import Partner'].map(source_mapping)

    def bin_share(value):
        if value <= 25:
            return 0
        elif value <= 50:
            return 1
        elif value <= 75:
            return 2
        else:
            return 3

    process_4.loc[:, 'share_bin'] = process_4['Percent'].apply(bin_share).copy()

    return process_4

In [7]:
data = imf_data(url)

xml.Reader got no structure=… argument for StructureSpecificTimeSeriesData


In [8]:
data.head()

Unnamed: 0,Importer_Code,TIME_PERIOD,B0,CN,US,Import Partner,Percent,Amount (USD Millions),Year,Month,% Change Prev. Period,source_cat,share_bin
0,1C_080,2000-01,3614.345317,449.412672,1642.461369,European Union,63.3,3614.3,2000,1,,2,2
1,1C_080,2000-02,3885.422689,540.175471,1853.89488,European Union,61.9,3885.4,2000,2,7.5,2,2
2,1C_080,2000-03,4661.890084,634.667641,2186.011737,European Union,62.3,4661.9,2000,3,19.99,2,2
3,1C_080,2000-04,3616.6508,507.603242,1736.763672,European Union,61.7,3616.7,2000,4,-22.42,2,2
4,1C_080,2000-05,4310.082377,612.838257,1754.597474,European Union,64.5,4310.1,2000,5,19.17,2,2


In [9]:
# Step 3: Define color grid (3 sources × 4 bins)
color_grid = {
    (1, 0): '#ffcccc', (1, 1): '#ff9999', (1, 2): '#ff6666', (1, 3): '#cc0000',  # Red shades
    (2, 0): '#ffffcc', (2, 1): '#ffff99', (2, 2): '#ffff66', (2, 3): '#cccc00',  # Yellow shades
    (3, 0): '#ccccff', (3, 1): '#9999ff', (3, 2): '#6666ff', (3, 3): '#0000cc'   # Blue shades
}


In [None]:
# ISO3 to country name lookup dictionary
iso3_to_country = {code: country.name for code, country in countries_by_alpha3.items()}
# ISO2 to ISO3 mapping
iso2_to_iso3 = {country.alpha2: country.alpha3 for country in countries}

year_month_options = sorted(data['TIME_PERIOD'].unique())

slider = widgets.SelectionSlider(
    options=year_month_options,
    description='Period:',
    orientation='horizontal',
    layout={'width': '90%'},
    style={'description_width': 'initial'}
)

map_output = widgets.Output()

def update_map_px(change=None):
    with map_output:
        year_month = slider.value
        year = int(year_month.split('-')[0])
        month = int(year_month.split('-')[1])

        filtered_data = data[(data['Year'].astype(int) == year) & (data['Month'].astype(int) == month)]
        if len(filtered_data) == 0:
            clear_output(wait=True)
            print(f"No data available for {year_month}")
            return

        filtered_data = filtered_data.copy()
        filtered_data['color_key'] = filtered_data.apply(
            lambda row: color_grid.get((row['source_cat'], row['share_bin']), '#cccccc'), axis=1
        )

        # Convert ISO2 codes to ISO3 for choropleth
        filtered_data['ISO3_from_ISO2'] = filtered_data['Importer_Code'].map(iso2_to_iso3)

        # Map ISO3 codes to country names
        filtered_data['CountryName'] = filtered_data['ISO3_from_ISO2'].map(iso3_to_country).fillna(filtered_data['Importer_Code'])

        fig = go.Figure()
        for color in filtered_data['color_key'].unique():
            color_data = filtered_data[filtered_data['color_key'] == color]

            fig.add_trace(go.Choropleth(
                locations=color_data['ISO3_from_ISO2'],
                z=[1] * len(color_data),
                locationmode='ISO-3',
                colorscale=[[0, color], [1, color]],
                showscale=False,
                hovertemplate=(
                    '<b>%{customdata[0]}</b><br>' +
                    'Top Import Source: %{customdata[1]}<br>' +
                    'Percent: %{customdata[2]}%<br>' +
                    'Amount: $%{customdata[3]}M<br>' +
                    'Change vs. Prev Period: %{customdata[4]}%<br>' +
                    '<extra></extra>'
                )
                ,
                customdata=color_data[['CountryName', 'Import Partner', 'Percent', 'Amount (USD Millions)', '% Change Prev. Period']].values,
                marker_line_color='black',
                marker_line_width=0.5,
                name=''
            ))

        fig.update_layout(
            title=f'Import Data - {year_month}',
            geo=dict(
                projection_type='robinson',
                showframe=False,
                showcoastlines=True,
                showland=True,
                landcolor='black',
                showocean=True,
                oceancolor='lightgray',
                bgcolor='white'
            ),
            height=600,
            margin=dict(t=50, b=0, l=0, r=0),
            showlegend=False
        )

        clear_output(wait=True)
        fig.show()
        fig.write_json("plot_name.json")

slider.observe(update_map_px, names='value')
display(slider)
display(map_output)
update_map_px()


SelectionSlider(description='Period:', layout=Layout(width='90%'), options=('2000-01', '2000-02', '2000-03', '…

Output()

In [13]:
data.to_json("data.json", orient='records')
