In [15]:
import requests
from itertools import chain
from collections import namedtuple
import logging

from lxml import etree
from pprint import pprint
from IPython.core.display import display, HTML

from graphviz import Digraph

In [26]:
with open('src/battlescribe/catelogue_schema.xsd', 'rb') as fh:
    catelogue_schema = etree.fromstring(fh.read())

avoids = ('roster', 'gameSystem', 'comment')
types_lookup = {}
for x in catelogue_schema:
    if x.tag in (
        '{http://www.w3.org/2001/XMLSchema}complexType', 
        '{http://www.w3.org/2001/XMLSchema}simpleType',
        '{http://www.w3.org/2001/XMLSchema}attributeGroup',
        '{http://www.w3.org/2001/XMLSchema}group'):
        assert x.attrib['name'] not in types_lookup.keys()
        types_lookup[x.attrib['name']] = x

        
def find_type(typename, default):
    return types_lookup.get(typename.split(':')[1], default)


def sequence_search(dot, parent, sequence, level, found_types):
    for z in sequence:
        if z.tag in ('{http://www.w3.org/2001/XMLSchema}element'):
            element_search(dot, parent, z, level+1, found_types)
        if z.tag in ('{http://www.w3.org/2001/XMLSchema}group'):
            complex_element_search(dot, parent, level+1, found_types, z.attrib['ref'])
            

def complex_element_search(dot, element, level, found_types, typename):
    for y in find_type(typename, []):
            if y.tag in ('{http://www.w3.org/2001/XMLSchema}sequence'):
                sequence_search(dot, element, y, level, found_types)
            if y.tag in ('{http://www.w3.org/2001/XMLSchema}complexContent'):
                for z in y:
                    if z.tag in ('{http://www.w3.org/2001/XMLSchema}extension'):
                        complex_element_search(dot, element, level, found_types, z.attrib['base'])
                        for a in z:
                            if a.tag in ('{http://www.w3.org/2001/XMLSchema}sequence'):
                                sequence_search(dot, element, a, level, found_types)
            

def element_search(dot, parent, element, level, found_types):
    if element.attrib['name'] not in avoids:
        dot.node(element.attrib['name'], node_label(element))
        if parent is not None:
            dot.edge(
                parent.attrib.get('name'), 
                element.attrib.get('name'),
                label='{}..{}'.format(
                    element.attrib.get('minOccurs', '1'), 
                    '*' if element.attrib.get('maxOccurs', '1') == 'unbounded'
                    else element.attrib.get('maxOccurs', '1')))
        if element.attrib.get('type') and (element.attrib['name'], element.attrib['type']) not in found_types:
            found_types.add((element.attrib['name'], element.attrib['type']))
            #print("  " * level + element.attrib.get('type'))
            complex_element_search(dot, element, level, found_types, element.attrib['type'])

def get_attributes(element_name):
    """Handles attributeGroup"""
    find_ref = find_type(element_name, None)
    if find_ref:
        extension = find_ref.find(
            './/xs:complexContent/xs:extension[@base]',
            namespaces={'xs': 'http://www.w3.org/2001/XMLSchema'})
        attributes = [
            x for x in find_ref.findall(
                './/xs:attribute[@name]',
                namespaces={'xs': 'http://www.w3.org/2001/XMLSchema'})] + \
        list(chain(
            *[get_attributes(x.attrib['ref']) 
              for x in find_ref.findall(
                  './/xs:attributeGroup[@ref]',
                  namespaces={'xs': 'http://www.w3.org/2001/XMLSchema'})]))
        if extension:
            attributes.extend(get_attributes(extension.attrib['base']))
        return attributes
    return []
        
def node_label(element):
    attributes = [x.attrib for x in get_attributes(element.attrib['type'])]
    return "<<TABLE>" + \
    "<TR><TD><b>{}</b></TD></TR><TR><TD>".format(element.attrib['name']) + \
    "<br/>".join(
        ["{} {} {}".format(
            a['type'].split(':')[1],
            "<b>{}</b>".format(a['name'])
            if a.get('use') == 'required'
            else a['name'], '=' + a.get('default', '')) for a in attributes]) + \
    "</TD></TR></TABLE>>"

def make_gv():                                
    dot = Digraph(comment='BS Schema', node_attr={'shape': 'plaintext'})
    found_types = set()
    for x in catelogue_schema:
        if x.tag == '{http://www.w3.org/2001/XMLSchema}element':
            element_search(dot, None, x, 0, found_types)
    return dot


# try:
dot = make_gv()
# except:
#     import pdb
#     pdb.post_mortem()



In [27]:
dot.format = 'png'
dot.render('schema.gv', view=True) 

'schema.gv.png'