## 修改内容

### 函数调试

In [3]:
from btree import *

In [7]:
def get_nodes_cost(tree) -> (dict,dict):
    """获取节点的开销信息，只记录非0节点"""
    nonleaf_cost = {} # 非叶开销
    leaf_cost = {} # 叶节点开销
    for node in tree:
        if node.left is None: # 叶节点
            if node.value: 
                leaf_cost[node.position] = node.value
        else: # 非叶节点
            v = sum(leaf.value for leaf in node.leaves)
            if v: nonleaf_cost[node.position] = v
    return nonleaf_cost,leaf_cost

# 测试函数
print(old)
print(old.position_tree)
get_nodes_cost(old)


                        ______________________________0
                       /                               \
               _______0__                               0
              /          \
          ___0___         0______________________
         /       \       /                       \
     ___0        _0     0                 ________0
    /    \      /  \                     /         \
  _0      35   40   0                ___0___        0
 /  \                               /       \
10   20                          __0        _0
                                /   \      /  \
                              _0     50   20   10
                             /  \
                            10   5


                          ________________________________________1
                         /                                         \
                ________2___                                        3
               /            \
           ___4___          _5____________

({1: 200,
  2: 200,
  4: 105,
  5: 95,
  8: 65,
  9: 40,
  11: 95,
  16: 30,
  22: 95,
  44: 65,
  45: 30,
  88: 15},
 {17: 35, 18: 40, 32: 10, 33: 20, 89: 50, 90: 20, 91: 10, 176: 10, 177: 5})

In [8]:
sum(node.value for node in old.leaves)

200

In [3]:
class BTree(BTree):
    def new_tree_by_positions(self,positions):
        """最后一层按位置展开 -> 新树"""
        tree = deepcopy(self) # 复制对象
        for node in tree.last_layer: # 最后一层
            if node.position in positions:
                self.add_left_right_to_node(node) # 展开
        tree._max_depth += 1 # 总深度+1
        return tree

In [15]:
n = 8
positions = [None for i in range(2**n)] # 旧树
pos = [1, 
       2,3, 
       4,5, 
       8,9,10,11, 
       16,17,18,19,22,23,
       32,33,44,45,
       88,89,90,91,
       176,177]
value = [0,
        0,0,
        0,0,
        0,0,0,0,
        0,35,40,0,0,0,
        10,20,0,0,
        0,50,20,10,
        10,5]
for p,v in zip(pos,value):
    positions[p-1] = v
old = BTree.list_to_tree(positions)
nonleaf_cost,leaf_cost = get_nodes_cost(old) # 节点开销信息

In [16]:
leaves = [0,0,2,0,3,2,3,26]
nonleaves = leaves_to_nonleaves(leaves)

In [17]:
root = BTree(0)
root.cost = 0 # 初始开销
min_cost = sum(leaf_cost.values()) # 最小开销
optimal = None # 最优解
tmp_trees = [root] # 待遍历集合

In [18]:
# 递归开始
tree = tmp_trees[0] # 取最小开销树
tmp_trees = tmp_trees[1:]
k = tree.max_depth # 总层数
ak = nonleaves[k] # 要展开的节点数（不能为0）
assert ak,"展开数不能为0"

# 三种类型节点
sep,not_sep,whatever = [],[],[]
for node in tree.last_layer:
    if node.position in nonleaf_cost:
        sep.append(node.position)
    elif node.position in leaf_cost:
        not_sep.append(node.position)
    else:
        whatever.append(node.position)

# 产生新树
news = []
if ak < len(sep): # 展开少，取 sep 子集，增加未取部分开销
    for choice in choose(sep,ak):
        new = tree.new_tree_by_positions(choice)
        new.cost += sum(nonleaf_cost[i] for i in sep if i not in choice)
        news.append(new)
elif len(sep) <= ak <= len(sep)+len(whatever): # 展开适中，不增加开销
    choice = sep + whatever[:ak-len(sep)]
    news = [tree.new_tree_by_positions(choice)]
else: # 展开多，取 not_sep 子集，增加选取部分开销
    for choice in choose(not_sep,ak-len(sep)-len(whatever)):
        new = tree.new_tree_by_positions(sep+whatever+choice)
        new.cost += sum(leaf_cost[i] for i in choice)
        news.append(new)
# 处理新树
if nonleaves[k+1]==0: # 树已完全展开
    if k+2<n: # 剩下节点合并
        for new in news:
            new.cost += sum(nonleaf_cost.get(node.position,0) for node in new.last_layer)
    for new in news:
        if new.cost < min_cost:
            min_cost = new.cost
            optimal = new
            tmp_trees = [tree for tree in tmp_trees if tree.cost<= min_cost]
else: # 还有非叶节点
    tmp_trees.extend(news) # 加入新结果
    tmp_trees.sort(key=lambda x:x.cost-x.depth) # 排序

In [13]:
a = {1:2,2:3}
a.get(4,0)

0

### 测试

In [19]:
n = 8
positions = [None for i in range(2**n)] # 旧树
pos = [1, 
       2,3, 
       4,5, 
       8,9,10,11, 
       16,17,18,19,22,23,
       32,33,44,45,
       88,89,90,91,
       176,177]
value = [0,
        0,0,
        0,0,
        0,0,0,0,
        0,35,40,0,0,0,
        10,20,0,0,
        0,50,20,10,
        10,5]
for p,v in zip(pos,value):
    positions[p-1] = v
old = BTree.list_to_tree(positions)
nonleaf_cost,leaf_cost = get_nodes_cost(old) # 节点开销信息
leaves = [0,0,2,0,3,2,3,26]
nonleaves = leaves_to_nonleaves(leaves)
root = BTree(0)
root.cost = 0 # 初始开销
min_cost = sum(leaf_cost.values()) # 最小开销
optimal = None # 最优解
tmp_trees = [root] # 待遍历集合
while len(tmp_trees):
    tree = tmp_trees[0] # 取最小开销树
    tmp_trees = tmp_trees[1:]
    k = tree.max_depth # 总层数
    ak = nonleaves[k] # 要展开的节点数（不能为0）
    assert ak,"展开数不能为0"
    # 三种类型节点
    sep,not_sep,whatever = [],[],[]
    for node in tree.last_layer:
        if node.position in nonleaf_cost:
            sep.append(node.position)
        elif node.position in leaf_cost:
            not_sep.append(node.position)
        else:
            whatever.append(node.position)
    # 产生新树
    news = []
    if ak < len(sep): # 展开少，取 sep 子集，增加未取部分开销
        for choice in choose(sep,ak):
            new = tree.new_tree_by_positions(choice)
            new.cost += sum(nonleaf_cost[i] for i in sep if i not in choice)
            news.append(new)
    elif len(sep) <= ak <= len(sep)+len(whatever): # 展开适中，不增加开销
        choice = sep + whatever[:ak-len(sep)]
        news = [tree.new_tree_by_positions(choice)]
    else: # 展开多，取 not_sep 子集，增加选取部分开销
        for choice in choose(not_sep,ak-len(sep)-len(whatever)):
            new = tree.new_tree_by_positions(sep+whatever+choice)
            new.cost += sum(leaf_cost[i] for i in choice)
            news.append(new)
    # 处理新树
    if nonleaves[k+1]==0: # 树已完全展开
        if k+2<n: # 剩下节点合并
            for new in news:
                new.cost += sum(nonleaf_cost.get(node.position,0) for node in new.last_layer)
        for new in news:
            if new.cost < min_cost:
                min_cost = new.cost
                optimal = new
                tmp_trees = [tree for tree in tmp_trees if tree.cost<= min_cost]
    else: # 树不完整
        tmp_trees.extend(news) # 加入新结果
        tmp_trees.sort(key=lambda x:x.cost-x.depth) # 排序

### 函数汇总

In [None]:
from btree import *

In [None]:
class BTree(BTree):
    def new_tree_by_positions(self,positions):
        """最后一层按位置展开 -> 新树"""
        tree = deepcopy(self) # 复制对象
        for node in tree.last_layer: # 最后一层
            if node.position in positions:
                self.add_left_right_to_node(node) # 展开
        tree._max_depth += 1 # 总深度+1
        return tree

In [None]:
def get_nodes_cost(tree) -> (dict,dict):
    """获取节点的开销信息，只记录非0节点"""
    nonleaf_cost = {} # 非叶开销
    leaf_cost = {} # 叶节点开销
    for node in tree:
        if node.left is None: # 叶节点
            if node.value: 
                leaf_cost[node.position] = node.value
        else: # 非叶节点
            v = sum(leaf.value for leaf in node.leaves)
            if v: nonleaf_cost[node.position] = v
    return nonleaf_cost,leaf_cost

In [None]:
def get_operations(old,new) -> ("sep","com"):
    """获取变动信息：(拆分, 合并)"""
    f = lambda a,b,c:a.difference(b).intersection(c)
    old_leaves = set(node.position for node in old.leaves)
    new_leaves = set(node.position for node in new.leaves)
    old_nodes = set(node.position for node in old)
    new_nodes = set(node.position for node in new)
    sep = f(new_nodes,new_leaves,old_leaves)
    com = f(old_nodes,old_leaves,new_leaves)
    return sep,com

In [22]:
n = 8
positions = [None for i in range(2**n)] # 旧树
pos = [1, 
       2,3, 
       4,5, 
       8,9,10,11, 
       16,17,18,19,22,23,
       32,33,44,45,
       88,89,90,91,
       176,177]
value = [0,
        0,0,
        0,0,
        0,0,0,0,
        0,35,40,0,0,0,
        10,20,0,0,
        0,50,20,10,
        10,5]
for p,v in zip(pos,value):
    positions[p-1] = v
leaves = [0,0,2,0,3,2,3,26]

def main(positions,leaves):
    old = BTree.list_to_tree(positions)
    nonleaf_cost,leaf_cost = get_nodes_cost(old) # 节点开销信息
    nonleaves = leaves_to_nonleaves(leaves)
    root = BTree(0)
    root.cost = 0 # 初始开销
    min_cost = sum(leaf_cost.values()) # 最小开销
    optimal = None # 最优解
    tmp_trees = [root] # 待遍历集合
    while len(tmp_trees):
        tree = tmp_trees[0] # 取最小开销树
        tmp_trees = tmp_trees[1:]
        news,is_end = next_level(tree,nonleaves,nonleaf_cost,leaf_cost)
        # 处理新树
        if is_end: # 树已完全展开
            for new in news:
                if new.cost < min_cost:
                    min_cost,optimal = new.cost,new
                    tmp_trees = [tree for tree in tmp_trees if tree.cost<= min_cost]
        else: # 树不完整
            tmp_trees.extend(news) # 加入新结果
            tmp_trees.sort(key=lambda x:x.cost-x.depth) # 排序
    return optimal,min_cost

In [None]:
def next_level(tree,nonleaves,nonleaf_cost,leaf_cost):
    """生成下一层树"""
    news,is_end = [],False
    k,n = tree.max_depth,len(nonleaves)
    ak = nonleaves[k]
    assert ak,"展开数不能为0"
    # 三类节点
    sep,not_sep,whatever = [],[],[]
    for node in tree.last_layer:
        if node.position in nonleaf_cost:
            sep.append(node.position)
        elif node.position in leaf_cost:
            not_sep.append(node.position)
        else:
            whatever.append(node.position)
    # 产生新树
    if ak < len(sep): # 展开少，取 sep 子集，增加未取部分开销
        for choice in choose(sep,ak):
            new = tree.new_tree_by_positions(choice)
            new.cost += sum(nonleaf_cost[i] for i in sep if i not in choice)
            news.append(new)
    elif len(sep) <= ak <= len(sep)+len(whatever): # 展开适中，不增加开销
        choice = sep + whatever[:ak-len(sep)]
        news = [tree.new_tree_by_positions(choice)]
    else: # 展开多，取 not_sep 子集，增加选取部分开销
        for choice in choose(not_sep,ak-len(sep)-len(whatever)):
            new = tree.new_tree_by_positions(sep+whatever+choice)
            new.cost += sum(leaf_cost[i] for i in choice)
            news.append(new)
    if nonleaves[k+1]==0:
        is_end = True
        if k+2<n: # 剩下节点合并
            for new in news:
                new.cost += sum(nonleaf_cost.get(node.position,0) for node in new.last_layer)
    return news,is_end

### 调试

In [48]:
n = 8
positions = [None for i in range(2**n-1)]
pos = [1, 
       2,3, 
       4,5, 
       8,9,10,11, 
       16,17,22,23, 
       32,33,34,35,46,47,
       68,69,70,71,
       136,137,138,139,142,143]
value = [0,
        0,10,
        0,0,
        0,20,5,0,
        0,0,7,0,
        10,20,0,0,0,6,
        0,0,5,0,
        10,6,3,10,7,2]
for p,v in zip(pos,value):
    positions[p-1] = v
old = BTree.list_to_tree(positions)

In [49]:
nonleaves = [1,2,4,2,3,3,2,0]
leaves = nonleaves2leaves(nonleaves)
optimal,cost = main(positions,leaves)
print(optimal.position_tree)
print(old.position_tree)
print(get_operations(old,optimal),cost)
print(old)


                                                      __________________________1________
                                                     /                                   \
                                                  __2___                               ___3___
                                                 /      \                             /       \
           _____________________________________4       _5___                       _6        _7
          /                                      \     /     \                     /  \      /  \
     ____8_________________________               9   10     _11_________         12   13   14   15
    /                              \                        /            \
  _16                   ____________17___                  22         ____23
 /   \                 /                 \                           /      \
32    33         _____34____             _35                       _46       47
                /     

In [50]:
nonleaves = [1,2,3,2,1,2,1,0]
leaves = nonleaves2leaves(nonleaves)
optimal,cost = main(positions,leaves)
print(optimal.position_tree)
print(old.position_tree)
print(get_operations(old,optimal),cost)
print(old)


                                        ______________1________
                                       /                       \
                                    __2___                   ___3
                                   /      \                 /    \
     _____________________________4       _5___           _6      7
    /                              \     /     \         /  \
  _8_________________               9   10     _11      12   13
 /                   \                        /   \
16                ____17___                  22    23
                 /         \
           _____34         _35
          /       \       /   \
        _68_       69    70    71
       /    \
     136    137


                                                              ____________________1
                                                             /                     \
                                                          __2___                    3
                        