In [1]:

import pylab
from matplotlib.patches import Circle, Ellipse
from itertools import chain
from collections import Iterable

#--------------------------------------------------------------------
alignment = {'horizontalalignment':'center', 'verticalalignment':'baseline'}

#--------------------------------------------------------------------
def get_labels(data, fill="number"):
    """
    to get a dict of labels for groups in data
    input
      data: data to get label for
      fill = ["number"|"logic"|"both"], fill with number, logic label, or both
    return
      labels: a dict of labels for different sets
    example:
    In [12]: get_labels([range(10), range(5,15), range(3,8)], fill="both")
    Out[12]:
    {'001': '001: 0',
     '010': '010: 5',
     '011': '011: 0',
     '100': '100: 3',
     '101': '101: 2',
     '110': '110: 2',
     '111': '111: 3'}
    """

    N = len(data)

    sets_data = [set(data[i]) for i in range(N)]  # sets for separate groups
    s_all = set(chain(*data))                             # union of all sets

    # bin(3) --> '0b11', so bin(3).split('0b')[-1] will remove "0b"
    set_collections = {}
    for n in range(1, 2**N):
        key = bin(n).split('0b')[-1].zfill(N)
        value = s_all
        sets_for_intersection = [sets_data[i] for i in range(N) if  key[i] == '1']
        sets_for_difference = [sets_data[i] for i in range(N) if  key[i] == '0']
        for s in sets_for_intersection:
            value = value & s
        for s in sets_for_difference:
            value = value - s
        set_collections[key] = value

    if fill == "number":
        labels = {k: len(set_collections[k]) for k in set_collections}
    elif fill == "logic":
        labels = {k: k for k in set_collections}
    elif fill == "both":
        labels = {k: ("%s: %d" % (k, len(set_collections[k]))) for k in set_collections}
    else:  # invalid value
        raise Exception("invalid value for fill")

    return labels



def venn4(data=None, names=None, fill="number", show_names=True, show_plot=True, **kwds):

    if (data is None) or len(data) != 4:
        raise Exception("length of data should be 4!")
    if (names is None) or (len(names) != 4):
        names = ("set 1", "set 2", "set 3", "set 4")

    labels = get_labels(data, fill=fill)

    # set figure size
    if 'figsize' in kwds and len(kwds['figsize']) == 2:
        # if 'figsize' is in kwds, and it is a list or tuple with length of 2
        figsize = kwds['figsize']
    else: # default figure size
        figsize = (10, 10)

    # set colors for different Circles or ellipses
    if 'colors' in kwds and isinstance(kwds['colors'], Iterable) and len(kwds['colors']) >= 4:
        colors = kwds['colors']
    else:
        colors = ['r', 'g', 'b', 'c']

    # draw ellipse, the coordinates are hard coded in the rest of the function
    fig = pylab.figure(figsize=figsize)   # set figure size
    ax = fig.gca()
    patches = []
    width, height = 170, 110  # width and height of the ellipses
    patches.append(Ellipse((170, 170), width, height, -45, color=colors[0], alpha=0.5))
    patches.append(Ellipse((200, 200), width, height, -45, color=colors[1], alpha=0.5))
    patches.append(Ellipse((200, 200), width, height, -135, color=colors[2], alpha=0.5))
    patches.append(Ellipse((230, 170), width, height, -135, color=colors[3], alpha=0.5))
    for e in patches:
        ax.add_patch(e)
    ax.set_xlim(80, 320); ax.set_ylim(80, 320)
    ax.set_xticks([]); ax.set_yticks([]);
    ax.set_aspect("equal")

    ### draw text
    # 1
    pylab.text(120, 200, labels['1000'], **alignment)
    pylab.text(280, 200, labels['0100'], **alignment)
    pylab.text(155, 250, labels['0010'], **alignment)
    pylab.text(245, 250, labels['0001'], **alignment)
    # 2
    pylab.text(200, 115, labels['1100'], **alignment)
    pylab.text(140, 225, labels['1010'], **alignment)
    pylab.text(145, 155, labels['1001'], **alignment)
    pylab.text(255, 155, labels['0110'], **alignment)
    pylab.text(260, 225, labels['0101'], **alignment)
    pylab.text(200, 240, labels['0011'], **alignment)
    # 3
    pylab.text(235, 205, labels['0111'], **alignment)
    pylab.text(165, 205, labels['1011'], **alignment)
    pylab.text(225, 135, labels['1101'], **alignment)
    pylab.text(175, 135, labels['1110'], **alignment)
    # 4
    pylab.text(200, 175, labels['1111'], **alignment)
    # names of different groups
    if show_names:
        pylab.text(110, 110, names[0], fontsize=16, **alignment)
        pylab.text(290, 110, names[1], fontsize=16, **alignment)
        pylab.text(130, 275, names[2], fontsize=16, **alignment)
        pylab.text(270, 275, names[3], fontsize=16, **alignment)

        
    leg_names = [names[0], names[2], names[3], names[1]]    ## the legend function messes up the names order so gave it manually here
    leg = ax.legend(leg_names, loc='best', fancybox=True)
    leg.get_frame().set_alpha(0.5)

    if show_plot:
        pylab.show()



## Test

In [22]:
venn4([range(10), range(5,10), range(3,10), range(4,16)], ["aaaa", "bbbb", "cccc", "dddd"], figsize=(12,12))

##Test function

In [None]:
#--------------------------------------------------------------------
def test():
    """ a test function to show basic usage of venn()"""

    # venn3()
    venn([range(10), range(5,15), range(3,8)], ["aaaa", "bbbb", "cccc"], fill="both", show_names=False)
    # venn2()
    venn([range(10), range(5,15)])
    venn([range(10), range(5,15)], ["aaaa", "bbbb"], fill="logic", show_names=False)
    # venn4()
    venn([range(10), range(5,15), range(3,8), range(4,9)], ["aaaa", "bbbb", "cccc", "dddd"], figsize=(12,12))

#--------------------------------------------------------------------
if __name__ == '__main__':
    test()

## My data


In [2]:
Sample_0 = []
Sample_1 = []
Sample_2 = []
Sample_3 = []

seqNumber = 1
for line in open("/media/dan/34D5D1CE642D7E36/2013076_Hanfling_Bernd/Stacks/Stacks_analyses_V3/Cru_gold_hyb_common_catalog_test/test3_cru_gib_comm_hyb/batch_1.catalog.tags.tsv","r").readlines():
    locfield = line.strip().split()[7].split(',') ## Isolate and split the relevant column    
    
    for loc in locfield:
        if loc.startswith('1_'): ## MY5
            Sample_0.append(seqNumber) ### Find a way to give these list names automatically
        if loc.startswith('2_'): ## POLEN
            Sample_1.append(seqNumber)
        if loc.startswith('192_'): ## RM31
            Sample_2.append(seqNumber)
        if loc.startswith('3_'): ## SWED
            Sample_3.append(seqNumber)
    seqNumber += 1  ## Increment seq number after every line (tag)
    

venn4([set(Sample_0), set(Sample_1), set(Sample_2), set(Sample_3)], ["MY5", "POLEN", "RM31", "SWED"], figsize=(12,12))
    