In [None]:
import pandas as pd
from mlxtend.frequent_patterns import apriori
from mlxtend.frequent_patterns import association_rules

import matplotlib.pyplot as plt
%matplotlib inline


import networkx as nx
import numpy as np

def draw_graph(rules, rules_to_show=5):
    """
    draws the rules as a graph linking antecedents and consequents
    "rule nodes" are yellow, with name "R<n>", "item nodes" are green
    arrows colors are different for each rule, and go from the antecedent(s)
    to the rule node and to the consequent(s)
    the "rules_to_show" parameter limits the rules to show to the initial
    part of the "rules" dataframe
    @author: Claudio Sartori
    """
    N = 50
    np.random.seed(42)
    colors = np.random.rand(N)
    G1 = nx.DiGraph()
    color_map = []
    strs = []  # will store the names of the rules   
    for i in range(rules_to_show):
        G1.add_nodes_from(["R" + str(i)])
        strs.append("R" + str(i))  # stores in a list the "names" of the rules
        for a in rules.iloc[i]['antecedents']:
            G1.add_nodes_from([a])
            G1.add_edge(a, "R" + str(i), color=colors[i], weight=2)
        for c in rules.iloc[i]['consequents']:
            G1.add_nodes_from([c])
            G1.add_edge("R" + str(i), c, color=colors[i], weight=2)
    for node in G1:  # set the appropriate color for rule nodes and item nodes
        if node in strs:
            color_map.append('yellow')
        else:
            color_map.append('green')
    edges = G1.edges()
    colors = [G1[u][v]['color'] for u, v in edges]
    weights = [G1[u][v]['weight'] for u, v in edges]
    pos = nx.spring_layout(G1, k=16, scale=1)
    nx.draw(G1, pos, node_color=color_map, edge_color=colors, width=weights,
            font_size=16, with_labels=False)
    for p in pos:  # raise text positions
        pos[p][1] += 0.07
    nx.draw_networkx_labels(G1, pos)
    plt.show()


In [None]:
url = 'Online-Retail.csv'
df = pd.read_csv(url)
df.head(20)

In [None]:
# Actions:
# 1. filter the rows ``Country`='France'`
# 2. group by `['InvoiceNo', 'Description']` computing a sum on `['Quantity']`
# 3. use the `unstack` function to move the items from rows to columns
# 4. reset the index
# 5. fill the missing with zero (`fillna(0)`)
# 6. store the result in the new dataframe `basket` and inspect it
basket = (df[df['Country'] =="France"]
          .groupby(['InvoiceNo', 'Description'])['Quantity']
          .sum().unstack().reset_index().fillna(0)
          .set_index('InvoiceNo'))
basket

In [None]:
# There are a lot of zeros in the data but we also need to make sure any 
# positive values are converted to a 1 and anything less than 0 is set to 0.
encode_units = lambda x: 0 if x <= 0 else 1
basket_sets = basket.applymap(encode_units)

Now that the data is structured properly, we can generate frequent item sets that have a support of at least 7% (this number was chosen so that we can get enough useful examples):

- generate the `frequent_itemsets` with `apriori`, setting `min_support=0.07` and `use_colnames=True`
- generate the `rules` with `association_rules` using `metric="lift"` and `min_threshold=1`
- show the rules

In [None]:
frequent_itemsets = apriori(basket_sets, min_support=0.07, use_colnames=True)

rules = association_rules(frequent_itemsets, metric="lift", min_threshold=1)
rules.shape

In [None]:
# In order to plot the rules, it is better to sort them according to some metrics.
# We will sort on descending confidence and support and plot `'confidence'` and `'support'`

sorted_rules=rules.sort_values(by=['confidence','support'],ascending=False).reset_index(drop=True)
sorted_rules[['confidence','support']].plot(title='Association Rules');

You find below a three dimensional plot, where the dot size is proportional to the lift, obtained using `plot.scatter`.

In [None]:
# 1.8 is chosen empirically to obtain the best graphical effect
s = [1.8**n for n in rules.lift]
rules.plot.scatter(x='support', 
                   y='confidence', 
                   title='Association Rules (dot proportional to Lift)', 
                   s=s);

In [None]:
plt.figure(figsize=(10,10))
draw_graph (sorted_rules, 10)  