In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import xmltodict
import json
from collections import defaultdict
import plotly.express as px
 
# wget http://www.image-net.org/api/xml/structure_released.xml
with open('structure_released.xml') as f:
    xml = f.read()
    struct = xmltodict.parse(xml)
root = struct['ImageNetStructure']['synset']

In [3]:
# wget http://www.image-net.org/api/xml/ReleaseStatus.xml
with open('ReleaseStatus.xml') as f:
    xml = f.read()
    status = xmltodict.parse(xml)

In [4]:
status_list = status['ReleaseStatus']['images']['synsetInfos']['synset']
wnid2num = {} 
for i in status_list:
    if i['@version'] == 'winter11':
        wnid2num.update({i['@wnid']: int(i['@numImages'])})

In [5]:
class Node(object):
    def __init__(self, node):
        self._node = node
        self.wnid = node['@wnid']
        self.gloss = node['@gloss']
        self.name = node['@words']
        self.numImg = wnid2num.get(self.wnid, 0)
        self._set_children_node()
        
    def _set_children_node(self):
        self.children = []
        if 'synset' in self._node:
            children = self._node['synset']
            if not isinstance(children, list):
                children = [children]
            for child in children:
                child = Node(child)
                if child.numImg > 0:
                    self.children.append(child)
            
        self.num_children = len(self.children)
            
    def is_leaf(self):
        return self.num_children == 0
        
    def __repr__(self):
        return self.name

In [6]:
class Tree(object):
    def __init__(self, root):
        self.root = root
    
    def find_node(self, name, node=None):
        node = self.root if node is None else node
        
        if node and node.name == name:
            return node
        if not node.is_leaf():
            for child in node.children:
                node = self.find_node(name, child)
                if node and node.name == name:
                    return node
        
    def find_all_children(self, node=None):
        node = self.root if node is None else node
        children = []
        
        def _dfs(node):
            children.append(node)
            if not node.is_leaf():
                for child in node.children:
                    node = _dfs(child)
                   
        _dfs(node)
        return children

    def find_parents(self, children):
        
        children_to_parent = {}
        
        def _dfs(node, parent_node):
            if node in children:
                children_to_parent.update({node: parent_node})
                
            if not node.is_leaf():
                for child in node.children:
                    _dfs(child, node)
        
        _dfs(self.root, None)
        return [children_to_parent[c] for c in children]
                    
    def find_max_depth(self, node=None):
        node = self.root if node is None else node
        depth = []
        
        def _dfs(node, cur_depth):
            cur_depth += 1
            depth.append(cur_depth)
                
            if not node.is_leaf():
                for child in node.children:
                    _dfs(child, cur_depth)
        
        _dfs(node, cur_depth=0)
        return max(depth)

In [7]:
# Build tree
root_node = Node(root)
tree = Tree(root_node)

In [23]:
# http://image-net.org/explore 
node_interested = tree.find_node('beetle')
children = tree.find_all_children(node_interested)
parents = tree.find_parents(children)

# Create dict


n_children = len(children)
max_d = tree.find_max_depth(node_interested)
minImg = min([c.numImg for c in children])
maxImg = max([c.numImg for c in children])
totalImg = sum([c.numImg for c in children])
avgImg = int(totalImg / n_children)

fig = px.treemap(
    names = [str(i) for i in children],
    parents = [str(i) for i in parents],
    values = [str(c.numImg) for c in children], 
    title = "Total nodes: {}, max depth: {}, total/max/min/avgImg: {}/{}/{}/{}".format(n_children, max_d, totalImg, maxImg, minImg, avgImg)
)
fig.show()

In [9]:
import plotly.io as pio
pio.write_html(fig, file='index.html', auto_open=True)