In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go

%matplotlib inline

### Const

In [2]:
# Data conditions from problem
min_cat = 15
max_cat = 40

# FIG constant
# Choose colors from https://colorscheme.ru/html-colors.html
colorscheme = ['rgb(0,0,0)','rgb(128,128,128)','rgb(139, 69, 19)','rgb(255,0,255)',
                'rgb(128, 0, 128)','rgb(255, 0, 0)','rgb(128, 0, 0)','rgb(255, 215, 0)',
                'rgb(128, 128, 0)','rgb(0, 255, 0)','rgb(0, 128, 0)','rgb(123, 104, 238)',
                'rgb(0, 128, 128)','rgb(0, 0, 255)','rgb(0, 0, 128)','rgb(255, 165, 0)',
                'rgb(255, 20, 147)']
min_block_size = 300

# Init values for squarify
x = 0.
y = 0.
width = 100.
height = 100.

## Data of Superstore Sales

Take data from example. Find feature with more than 15 categories

In [3]:
data = pd.read_excel(open('Sample - Superstore.xls', 'rb'),
                     sheet_name='Orders'
                    )
data.columns

Index(['Row ID', 'Order ID', 'Order Date', 'Ship Date', 'Ship Mode',
       'Customer ID', 'Customer Name', 'Segment', 'Country', 'City', 'State',
       'Postal Code', 'Region', 'Product ID', 'Category', 'Sub-Category',
       'Product Name', 'Sales', 'Quantity', 'Discount', 'Profit'],
      dtype='object')

In [4]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 9994 entries, 0 to 9993
Data columns (total 21 columns):
Row ID           9994 non-null int64
Order ID         9994 non-null object
Order Date       9994 non-null datetime64[ns]
Ship Date        9994 non-null datetime64[ns]
Ship Mode        9994 non-null object
Customer ID      9994 non-null object
Customer Name    9994 non-null object
Segment          9994 non-null object
Country          9994 non-null object
City             9994 non-null object
State            9994 non-null object
Postal Code      9994 non-null int64
Region           9994 non-null object
Product ID       9994 non-null object
Category         9994 non-null object
Sub-Category     9994 non-null object
Product Name     9994 non-null object
Sales            9994 non-null float64
Quantity         9994 non-null int64
Discount         9994 non-null float64
Profit           9994 non-null float64
dtypes: datetime64[ns](2), float64(3), int64(3), object(13)
memory usage: 1.6+ 

In [5]:
for c in data.columns:
    c_unique_list = data[c].unique()
    if len(c_unique_list) > min_cat and len(c_unique_list) < max_cat:
        print(c, 'unique values are:\n', data[c].unique())

Sub-Category unique values are:
 ['Bookcases' 'Chairs' 'Labels' 'Tables' 'Storage' 'Furnishings' 'Art'
 'Phones' 'Binders' 'Appliances' 'Paper' 'Accessories' 'Envelopes'
 'Fasteners' 'Supplies' 'Machines' 'Copiers']


Ok! Then need to prepare data for Squarified Treemap Layout

In [6]:
sub_cat_cnt = data.groupby(data['Sub-Category']).count()['Category']
values_int = sub_cat_cnt.values
values_int.sort()
values = [v for v in map(float, values_int[::-1])]

## Squarified Treemap Layout

In [25]:
# Implements algorithm from Bruls, Huizing, van Wijk, "Squarified Treemaps" and based on
# https://github.com/laserson/squarify and
# https://plot.ly/python/treemaps/


def layoutrow(sizes, x, y, dx, dy):
    covered_area = sum(sizes)
    width = covered_area / dy
    rects = []
    for size in sizes:
        rects.append({"x": x, "y": y, "dx": width, "dy": size / width})
        y += size / width
    return rects


def layoutcol(sizes, x, y, dx, dy):
    covered_area = sum(sizes)
    height = covered_area / dx
    rects = []
    for size in sizes:
        rects.append({"x": x, "y": y, "dx": size / height, "dy": height})
        x += size / height
    return rects


def layout(sizes, x, y, dx, dy):
    if dx >= dy:
        return layoutrow(sizes, x, y, dx, dy)
    else:
        return layoutcol(sizes, x, y, dx, dy)


def left_layout(sizes, x, y, dx, dy):
    covered_area = sum(sizes)
    width = covered_area / dy
    height = covered_area / dx
    if dx >= dy:
        return (x + width, y, dx - width, dy)
    else:
        return (x, y + height, dx, dy - height)


def worst_ratio(sizes, x, y, dx, dy):
    return max(
        [
            max(rect["dx"] / rect["dy"], 
                rect["dy"] / rect["dx"])
            for rect in layout(sizes, x, y, dx, dy)
        ]
    )


def squarify(sizes, x, y, dx, dy):
    if len(sizes) == 0:
        return []
    if len(sizes) == 1:
        return layout(sizes, x, y, dx, dy)
#   splitting
    i = 1
    while i < len(sizes):
        if worst_ratio(sizes[:i], x, y, dx, dy) >= worst_ratio(sizes[:(i + 1)], x, y, dx, dy):
            i += 1
        else:
            break
            
    remaining = sizes[i:]
    current = sizes[:i]
    leftover_x, leftover_y, leftover_dx, leftover_dy = left_layout(current, x, y, dx, dy)
    return layout(current, x, y, dx, dy) + squarify(
        remaining, leftover_x, leftover_y, leftover_dx, leftover_dy
        )


def normalize_sizes(sizes, dx, dy):
    total_sum = sum(sizes)
    total_area = dx * dy
    sizes = map(lambda size: 
                size * total_area / total_sum, 
                sizes)
    return list(sizes)

In [26]:
normed = normalize_sizes(values, width, height)
rects = squarify(normed, x, y, width, height)


shapes = []
annotations = []
counter = 0

for r, val, color in zip(rects, values, colorscheme):
    shapes.append(
        dict(
            type = 'rect',
            x0 = r['x'],
            y0 = r['y'],
            x1 = r['x']+r['dx'],
            y1 = r['y']+r['dy'],
            line = dict(
                        color=color,
                        width = 5
                    ),
            fillcolor = color
        )
    )
    annotations.append(
        dict(
            x = r['x']+(r['dx']/2),
            y = r['y']+(r['dy']/2),
            text = int(val),
            showarrow = False
        )
    )


fig = go.Figure()

# For hover text
fig.add_trace(
        go.Scatter(
            x = [ r['x']+(r['dx']/2) for r in rects ],
            y = [ r['y']+(r['dy']/2) for r in rects ],
            text = [str(int(v)) for v in values],
            mode = 'text',
    )
)

font = dict(
            color="white",
            size=15
)

fig.update_layout(
    height=np.log(len(rects))*min_block_size,
    width=np.log(len(rects))*min_block_size,
    xaxis=dict(showgrid=False,zeroline=False),
    yaxis=dict(showgrid=False,zeroline=False),
    shapes=shapes,
    font=font,
    annotations=annotations
)

fig.show()