This repository has been archived by the owner on Jan 13, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tree_plt.py
executable file
·62 lines (50 loc) · 1.74 KB
/
tree_plt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
import numpy as np
def get_height(node):
if not node:
return 0
return 1 + max([get_height(node.left), get_height(node.right)])
def get_node_count(node):
if not node:
return 0
return 1 + get_node_count(node.left) + get_node_count(node.right)
def get_fontsize(count):
if count < 10:
return 30
if count < 20:
return 20
return 16
def show_node(node, ax, height, index, font_size=12):
if not node:
return
x1, y1 = None, None
if node.left:
x1, y1, index = show_node(node.left, ax, height-1, index, font_size)
x = 100 * index - 50
y = 100 * height - 50
if x1:
plt.plot((x1, x), (y1, y), linewidth=2.0, color='b')
circle_color = "black" if node.is_black_node() else 'r'
text_color = "beige" if node.is_black_node() else 'black'
ax.add_artist(plt.Circle((x, y), 50, color=circle_color))
ax.add_artist(plt.Text(x, y, node.val, color=text_color, fontsize=font_size,
horizontalalignment="center", verticalalignment="center"))
index += 1
if node.right:
x1, y1, index = show_node(node.right, ax, height-1, index, font_size)
plt.plot((x1, x), (y1, y), linewidth=2.0, color='b')
return x, y, index
def save_tree(tree, index):
fig, ax = plt.subplots()
fig.set_facecolor('gray')
height = get_height(tree)
h = height*100+100
w = 100 * get_node_count(tree) + 100
plt.ylim(0, h)
plt.xlim(0, w)
plt.axis('off')
show_node(tree, ax, height, 1, get_fontsize(get_node_count(tree)))
fig.set_size_inches(10, h/(w/10))
plt.savefig("pics/output{}.png".format(index))
print("pics/output{}.png".format(index))