- This notebook contains the implementation of Sankey plot on brand segmentation data
- `num_segments` should be set the to desired number of segments one wish to see in the Sankey plot. Maximum value is 8 which is equal to `tot_num_segments`. Default is 6 (all the (E)-cig segments).
- Below is mapping of different segments to their definition:

| brand/type      | brandcat      | brand_nonbrand_cat   | nonbrand_cat        |
|:----------------|:----------|:-----------|:-----------|
|only brand        |  segment1 |  segment4  |  segment0  |
|brand+Other brands|  segment2 |  segment5  |  segment0  |
|Non-brand         |  segment3 |  segment6  |  segment0  |

**NOTE:** In the data, `segment0` here is shown as `segment7` and unobserved events are mapped to `segment8`, mostly for visualization considerations.

In [1]:
import numpy as np
import pandas as pd
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import cufflinks as cf
cf.go_offline
init_notebook_mode(connected=True)


plotly.graph_objs.YAxis is deprecated.
Please replace it with one of the following more specific types
  - plotly.graph_objs.layout.YAxis
  - plotly.graph_objs.layout.scene.YAxis



plotly.graph_objs.XAxis is deprecated.
Please replace it with one of the following more specific types
  - plotly.graph_objs.layout.XAxis
  - plotly.graph_objs.layout.scene.XAxis




In [2]:
df = pd.read_csv("customer_segment_monthly.csv")

In [17]:
# data downloaded from spark DF, fixing the column names order to those of spark DF 
df.columns = ['memberid', '2018-10','2018-11','2018-12', '2018-5','2018-6','2018-7','2018-8','2018-9','2019-1','2019-2']

In [54]:
# total number of segments in data
tot_num_segments = 8

In [55]:
df.columns = [df.columns[0]] + list(str(k)[:7] for k in list(pd.to_datetime(i) for i in df.columns[1:]))
months = list(str(k)[:7] for k in np.sort(list(pd.to_datetime(i) for i in df.columns[1:])))
df_month = {}
for month in months:
    df_month[month] = df[month].astype(str)

In [56]:
transition_matrix = {}
for i in range(0, len(df_month.keys())-1):
    t_1 = list(df_month.keys())[i]
    t = list(df_month.keys())[i+1]
    trans_mat = pd.crosstab(df_month[t_1], df_month[t]) 
    trans_mat.columns = ['segment{}'.format(i) for i in range(1,1+tot_num_segments)]
    trans_mat.index = trans_mat.columns
    transition_matrix[i] = trans_mat 

In [57]:
num_steps = len(transition_matrix.keys())

In [58]:
# number of segments for sankey
num_segments = 6

In [59]:
label_lst = ['segment{}'.format(i) for i in range(1, 1+num_segments)]*(num_steps+1)

In [60]:
source_lst = []
func = lambda x: [x]*num_segments
num_transitions = num_segments*num_steps
for i in range(num_transitions):
    source_lst += func(i)

In [61]:
target_iterlst = [num_segments*i for i in range(1, num_steps+1)]
target_lst = []
for k in target_iterlst:
    target_lst += [i for i in range(k,k+num_segments)]*num_segments

In [62]:
values_lst = []
for k in range(num_steps):
    values_lst += [transition_matrix[k].iloc[i,j] for j in range(num_segments) for i in range(num_segments)]

In [63]:
data = dict(
    type='sankey',
    node = dict(
               pad = 15,
               thickness = 20,
               line = dict(color = "black", width = 0.5),
               label = label_lst,
               ),
    link = dict(
               source = source_lst,
               target = target_lst,
               value = values_lst
               )
             )

layout =  dict(title = "brand Segments", font = dict(size = 10))
fig = dict(data=[data], layout=layout)
iplot(fig, validate=False)

In [None]:
from plotly import offline
offline.plot(fig, filename='sankey_8segments.html')