In [30]:
import pandas as pd, re
import plotly.graph_objects as go, plotly.colors as pc
import json

# -----------------------------------------------------------------------------
# 0. CONFIGURATION — reorder stages and rename as needed
# -----------------------------------------------------------------------------
STAGES = ['Final','Year','Direction','Sub','Super',  'currency_region', 'Currency'
         # , 'sender_region'
         # , 'receiver_region'
         ]
STAGE_LABELS = {
    'Super':     'Message Type',
    'Direction': 'Flow Direction',
    'Year':      'Settlement Year',
    'Sub':       'Category Code',
    'Final':     'Total Counts',
    'currency_region':  'Currency Region',
    'Currency':         'Currency Code',
    #'sender_region':    'Sender Region',
    #'receiver_region':  'Receiver Region'

}
NODE_LABELS = {
    '1xx':      'Category 1xx',
    '2xx':      'Category 2xx',
    'NFM':      'Not Financial Msg',
    'reporting':'Reporting Msgs',
    'GPI-Alert':'GPI Alerts',
    'MT':       'MT Super',
    'MX':       'MX Super',
    'MT count': 'Total MT Rows',
    'MX count': 'Total MX Rows',
    'Incoming': 'Incoming Flow',
    'Outgoing': 'Outgoing Flow',
    # add any others you like…
}
PALETTES = {
    'Year':      pc.qualitative.Plotly,
    'Direction': pc.qualitative.Set1,
    'Super':     ["#A6CEE3","#1F78B4"],
    'Sub':       pc.qualitative.Set2,
    'Final':     ["#33A02C","#FB9A99"],
    'currency_region':  pc.qualitative.Pastel1,
    'Currency':         pc.qualitative.Pastel2,
    #'sender_region':    pc.qualitative.Vivid,
    #'receiver_region':  pc.qualitative.Dark24

}

df = (
    pd.read_csv('C:/Users/ArunPriyanka/Downloads/data_iso.csv')
      .rename(columns=str.strip)
      .assign(
          Year=lambda d: pd.to_numeric(d['Year'], errors='coerce')
                             .apply(lambda x: str(int(x)) if pd.notna(x) and x != 0 else "null"),
          Direction=lambda d: d['direction'].fillna("null").str.strip(),
          Amount=lambda d: d['Amount']
                            .apply(lambda v: 0 if isinstance(v, str) and v.strip().upper()=="NFM"
                                           else pd.to_numeric(v, errors='coerce'))
                            .fillna(0),
          count=lambda d: pd.to_numeric(d['count'], errors='coerce').fillna(0)  # ← use d['count']
      )
)

for c in ('MT','sender','Receiver'):
    df[c] = df[c].fillna("null").astype(str).str.strip()

df = (
    df.query("not (Year=='null' and Direction=='null' and MT=='null' and count==0)")
      .assign(Flow=1)
)
# -----------------------------------------------------------------------------
# 2. LOAD COUNTRIES JSON & BUILD LOOKUPS
# -----------------------------------------------------------------------------
with open('C:/Users/ArunPriyanka/Downloads/restcountries.json', 'r', encoding='utf-8') as f:
    countries = json.load(f)

# currency → continent
currency_to_continent = {
    ccy: c.get('subregion','Unknown')
    for c in countries
    for ccy in (c.get('currencies') or {}).keys()
}

# country code → subregion
cca2_to_region = {
    c.get('cca2','').upper(): c.get('subregion','Unknown')
    for c in countries if 'cca2' in c
}

# -----------------------------------------------------------------------------
# 3. EXTRACT REGION FIELDS
# -----------------------------------------------------------------------------
# Currency region:
df['currency_region'] = df['Currency'].map(
    lambda c: currency_to_continent.get(c, 'Unknown')
).fillna('Unknown')

# Sender/receiver BIC → 2-letter country code → region
df['sender_region'] = df['sender'].str[4:6].str.upper().map(
    lambda cc: cca2_to_region.get(cc, 'Unknown')
).fillna('Unknown')

df['receiver_region'] = df['Receiver'].str[4:6].str.upper().map(
    lambda cc: cca2_to_region.get(cc, 'Unknown')
).fillna('Unknown')

# -----------------------------------------------------------------------------
# 4. DERIVE SUPER, SUB, FINAL
# -----------------------------------------------------------------------------
df['Super'] = df['MT'].str.match(r'^[A-Za-z]+\.', na=False) \
                    .map({True:'MX', False:'MT'})

def calc_sub(r):
    mt, sup, snd = r['MT'].strip(), r['Super'], r['sender'].upper()
    if snd.startswith("TRCKCHZZ"): return "GPI-Alert"
    if sup=="MT":
        if mt.lower()=="103plus": return "1xx"
        if mt.lower()=="202cov":  return "2xx"
        try:
            n=int(mt)
            if n in {900,910,940,942,950,991}:    return "reporting"
            if n in {192,195,196,198,199,292,295,296,298,299,935,992,995,996,998,999}:
                                                  return "NFM"
            if n in {101,103,110,111,190,191}:    return "1xx"
            if n in {200,202,204,210,290,291}:    return "2xx"
            return "1xx" if str(n).startswith("1") else "2xx"
        except:
            return "other"
    else:
        l=mt.lower()
        if   l.startswith("camt."):    return "reporting"
        elif l.startswith("pacs.009"): return "2xx"
        elif any(l.startswith(x) for x in ("pacs.002","pacs.004","pacs.008")):
                                       return "1xx"
        else:
            return "MX-other"

df['Sub']   = df.apply(calc_sub, axis=1)
df['Final'] = df['Super'].map({'MT':'MT count','MX':'MX count'})

# -----------------------------------------------------------------------------
# 5. BUILD NODES & INDEX
# -----------------------------------------------------------------------------
def get_vals(stage):
    vals = df[stage].dropna().unique().tolist()
    if stage=='Year':
        nums = sorted(int(v) for v in vals if str(v).isdigit())
        out  = [str(n) for n in nums]
        if "null" in vals: out.append("null")
        return out
    if stage=='Final':
        return [x for x in ("MT count","MX count") if x in vals]
    return sorted(vals)

nodes = sum((get_vals(s) for s in STAGES), [])
idx   = {n:i for i,n in enumerate(nodes)}
node_to_stage = {n:s for s in STAGES for n in get_vals(s)}

# -----------------------------------------------------------------------------
# 6. AGGREGATE LINKS
# -----------------------------------------------------------------------------
links = pd.concat([
    df.groupby([a,b])['Flow'].sum().reset_index()
      .assign(source=lambda g: g[a].map(idx),
              target=lambda g: g[b].map(idx))
      [['source','target','Flow']]
    for a,b in zip(STAGES, STAGES[1:])
], ignore_index=True)

src, tgt, vals = links.source, links.target, links.Flow

# -----------------------------------------------------------------------------
# 7. HOVER TEXT
# -----------------------------------------------------------------------------
def mk(a,b):
    g = df.groupby([a,b]).size().reset_index(name='flow')
    return {
        n: (
            f"{n}\nTotal: {g[g[a]==n].flow.sum()}\n" +
            "\n".join(
                f"{r[b]}: {r['flow']} ({r['flow']/g[g[a]==n].flow.sum()*100:.1f}%)"
                for _,r in g[g[a]==n].iterrows()
            )
        )
        for n in df[a].unique()
    }

hover_map = {}
for a,b in zip(STAGES, STAGES[1:]):
    hover_map.update(mk(a,b))

agg = df.groupby('Super').agg(count=('Flow','sum'), Amount=('Amount','sum'))
hover_map.update({
    "MT count": f"MT count\nTotal Count: {agg.loc['MT','count']}, Total Amount: {agg.loc['MT','Amount']}",
    "MX count": f"MX count\nTotal Count: {agg.loc['MX','count']}, Total Amount: {agg.loc['MX','Amount']}"
})

node_display = [NODE_LABELS.get(n,n) for n in nodes]
node_hover   = [
    hover_map.get(n, f"{node_display[i]}\nTotal: {len(df[df[node_to_stage[n]]==n])}")
    for i,n in enumerate(nodes)
]

# -----------------------------------------------------------------------------
# 8. COLORS & DRAW
# -----------------------------------------------------------------------------
node_color = {
    n: PALETTES[node_to_stage[n]][ get_vals(node_to_stage[n]).index(n) % len(PALETTES[node_to_stage[n]]) ]
    for n in nodes
}


fig = go.Figure(go.Sankey(
    
    node=dict(
        label=node_display,
        color=[node_color[n] for n in nodes],
        pad=15, thickness=10,
        customdata=node_hover,
        hovertemplate='%{customdata}<extra></extra>'
    ),
    link=dict(
        source=src, target=tgt, value=vals,
        color=[node_color[nodes[i]] for i in src],
        hovertemplate="Flow: %{value}<extra></extra>"
    )
))

# bottom labels
xs = [i/(len(STAGES)-1) for i in range(len(STAGES))]

# 2) Now update the overall layout
fig.update_layout(
    title="Sankey with MT/MX, Year, Direction, Category, Currency & Regions",
    font_size=10,
    width=1500,
    height=1000,
    margin=dict(l=10, r=10, t=30, b=10),
    annotations=[
        dict(
            x=xs[i], y=-0.1, xref='paper', yref='paper',
            text=STAGE_LABELS.get(s, s),
            showarrow=False, font_size=12
        )
        for i, s in enumerate(STAGES)
    ]
)

#fig.show()

from IPython.display import HTML
# 1) Turn your Plotly figure into an HTML snippet (no full-page wrapper)
html_frag = fig.to_html(full_html=False, include_plotlyjs='cdn')

# 2) Wrap it in a scrollable DIV
scrollable = f"""
<div style="width:100%; overflow-x:auto; border:1px solid #ddd; padding:5px;">
  {html_frag}
</div>
"""

# 3) Display it
display(HTML(scrollable))