[Reference](https://pub.towardsai.net/an-unique-way-of-visualising-confusion-matrix-sankey-chart-de8e4d09b9b)

In [1]:
# necessary imports
# !pip install plotly
import pandas as pd
import numpy as np
from plotly import graph_objects as go

# create a confusion matrix
confusion_matrix = np.array([[10, 6],
                             [2, 12]])

confusion_matrix

array([[10,  6],
       [ 2, 12]])

In [2]:
# create a dataframe 
df = pd.DataFrame(data=confusion_matrix, 
                  index=[f"True Class-{i+1}" for i in range(confusion_matrix.shape[0])],
                  columns=[f"Predicted Class-{i+1}" for i in range(confusion_matrix.shape[0])])
df

Unnamed: 0,Predicted Class-1,Predicted Class-2
True Class-1,10,6
True Class-2,2,12


In [3]:
# restructre the dataframe
df = df.stack().reset_index()
df

Unnamed: 0,level_0,level_1,0
0,True Class-1,Predicted Class-1,10
1,True Class-1,Predicted Class-2,6
2,True Class-2,Predicted Class-1,2
3,True Class-2,Predicted Class-2,12


In [4]:
# rename the default column names
df.rename(columns={'level_0':'source', 'level_1':'target', 0:'value'}, inplace=True)
df

Unnamed: 0,source,target,value
0,True Class-1,Predicted Class-1,10
1,True Class-1,Predicted Class-2,6
2,True Class-2,Predicted Class-1,2
3,True Class-2,Predicted Class-2,12


In [5]:
# add new column for colour
# here rgba(211,255,216,0.6) indicates green colour whereas rgba(245,173,168,0.6) is red colour
# green colour illustrates correct predictions and red colour is for incorrect predictions
df["colour"] = df.apply(lambda x: 
                          "rgba(211,255,216,0.6)" if x.source.split()[-1] == x.target.split()[-1] 
                           else "rgba(245,173,168,0.6)", axis=1)
df

Unnamed: 0,source,target,value,colour
0,True Class-1,Predicted Class-1,10,"rgba(211,255,216,0.6)"
1,True Class-1,Predicted Class-2,6,"rgba(245,173,168,0.6)"
2,True Class-2,Predicted Class-1,2,"rgba(245,173,168,0.6)"
3,True Class-2,Predicted Class-2,12,"rgba(211,255,216,0.6)"


In [6]:
# extract unique values from source and target columns
labels = pd.concat([df.source, df.target]).unique()
labels

array(['True Class-1', 'True Class-2', 'Predicted Class-1',
       'Predicted Class-2'], dtype=object)

In [7]:
# get indices of the above unique values
labels_indices = {label:index for index, label in enumerate(labels)}
labels_indices

{'Predicted Class-1': 2,
 'Predicted Class-2': 3,
 'True Class-1': 0,
 'True Class-2': 1}

In [8]:
# map the source and target column using the above indices
df[["source", "target"]] = df[["source", "target"]].applymap(lambda x: labels_indices[x])
df

Unnamed: 0,source,target,value,colour
0,0,2,10,"rgba(211,255,216,0.6)"
1,0,3,6,"rgba(245,173,168,0.6)"
2,1,2,2,"rgba(245,173,168,0.6)"
3,1,3,12,"rgba(211,255,216,0.6)"


In [9]:
pd.set_option("max_colwidth", 100)
df["tooltip"] = df.apply(lambda x:
                         f"{x['value']} {labels[x['target']].split()[-1]} instances correctly classified as {labels[x['target']].split()[-1]}" 
                         if x['colour']=='rgba(211,255,216,0.6)'
                         
                         else 
                         f"{x['value']} {labels[x['source']].split()[-1]} instances misclassified as {labels[x['target']].split()[-1]}", axis=1)
df

Unnamed: 0,source,target,value,colour,tooltip
0,0,2,10,"rgba(211,255,216,0.6)",10 Class-1 instances correctly classified as Class-1
1,0,3,6,"rgba(245,173,168,0.6)",6 Class-1 instances misclassified as Class-2
2,1,2,2,"rgba(245,173,168,0.6)",2 Class-2 instances misclassified as Class-1
3,1,3,12,"rgba(211,255,216,0.6)",12 Class-2 instances correctly classified as Class-2


In [10]:
# create a Sankey chart
fig = go.Figure(data=[go.Sankey(
    
    node = dict(
      pad = 20,
      thickness = 20,
      line = dict(color = "black", width = 1.0),
      label = labels,
      
      # this template will be used to display text when hovering over nodes  
      hovertemplate = "%{label} has total %{value:d} instances<extra></extra>"
    ),
    link = dict(
      source = df.source, 
      target = df.target,
      value = df.value,
      color = df.colour,
      customdata = df['tooltip'], 
        
      # this template will be used to display text when hovering over the links  
      hovertemplate = "%{customdata}<extra></extra>"  
  ))])

fig.update_layout(title_text="Confusion Matrix Visualisation Using Sankey Diagram", font_size=13,
                  width=510, height=450)
fig.show(render="jpg")