In [2]:
import torch
import torch.nn as nn

# 加载 .pth 文件
file_path = './weight/mkbl_weights_ep24.pth' 
checkpoint = torch.load(file_path, map_location='cpu')

# for k, v in checkpoint['state_dict']['ll_agent'].items():
#     print(k)
#     # print(v)

print(checkpoint['state_dict']['hl_agent']['policy.net.p.0.dc_leaves'])
print(checkpoint['state_dict']['hl_agent']['policy.net.p.0.dc_inner_nodes.weight'])


tensor([[ -0.4930,   0.1683,  -0.0398,  ...,   0.1388,  -0.2933,   1.4138],
        [ -0.3318,   0.5863,   0.9878,  ...,  -0.0278,   1.0599,   0.4682],
        [  0.3920,   1.8801,   1.3467,  ...,   1.2708,   0.0658,   0.0992],
        ...,
        [-15.0019, -13.9263,  -0.5682,  ..., -13.7954,  -4.4415, -16.2678],
        [  0.4264,   0.7387, -16.2612,  ..., -12.0022,  -0.7318,   1.7525],
        [-13.7561,   0.2019,   0.8381,  ...,   0.5595,   1.0081, -17.2911]])
tensor([[-1.6636e-01,  4.4661e+00,  1.0705e+00,  ...,  5.1678e-02,
         -1.0906e-01,  2.3444e-01],
        [ 7.6785e-01, -2.4005e+00, -2.0293e+00,  ..., -6.2359e-02,
          2.2932e-02, -9.6292e-01],
        [-1.2131e-01, -2.2562e-01,  4.9384e-01,  ..., -5.1090e-02,
         -4.4676e-02,  2.3678e-01],
        ...,
        [-1.7346e+00, -3.5562e+00, -2.3660e+00,  ..., -2.9540e-02,
         -3.3692e-03,  1.8367e+00],
        [ 7.0553e-01, -5.9258e+00,  9.9596e-01,  ...,  5.2658e-02,
          2.7574e-02, -7.7244e-01],
  

In [3]:
class VQCDTPredictor(nn.Module):
    def __init__(self, hp, input_dim, output_dim):
        super(VQCDTPredictor, self).__init__()

        # 设置属性
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=1)
        if hp.feature_learning_depth >= 0:
            self.num_intermediate_variables = hp.num_intermediate_variables
        else:
            self.num_intermediate_variables = input_dim
        self.feature_learning_depth = hp.feature_learning_depth
        self.decision_depth = hp.decision_depth
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.greatest_path_probability = hp.greatest_path_probability

        self.beta_fl = hp.beta_fl
        self.beta_dc = hp.beta_dc

        self.device = hp.device

        # 初始化特征学习和决策模块
        self.feature_learning_init()
        self.decision_init()

        if self.greatest_path_probability:
            print('use best path')

        # 最大叶节点索引
        self.max_leaf_idx = None

        self.if_smooth = hp.if_smooth

        self.tree_name = hp.tree_name
        self.if_save = hp.if_save
        self.forward_num = 0 # 用来保存模型
        # self.model_name = os.path.join(os.environ["EXP_DIR"], f"cdt_model/{self.tree_name}.pth")

        self.if_discrete = getattr(hp, 'if_discrete', False)

    def feature_learning_init(self):
        if self.feature_learning_depth < 0:  # 特征树深度小于0时不需要特征树
            print('use SDT')
            return
        else:
            print('use CDT')
            self.num_fl_inner_nodes = 2 ** self.feature_learning_depth - 1
            self.num_fl_leaves = self.num_fl_inner_nodes + 1
            self.fl_inner_nodes = nn.Linear(self.input_dim + 1, self.num_fl_inner_nodes, bias=False)
            # coefficients of feature combinations
            fl_leaf_weights = torch.randn(self.num_fl_leaves * self.num_intermediate_variables, self.input_dim)
            self.fl_leaf_weights = nn.Parameter(fl_leaf_weights)

            # temperature term
            if self.beta_fl is True or self.beta_fl == 1:  # learnable
                beta_fl = torch.randn(self.num_fl_inner_nodes)  # use different beta_fl for each node
                # beta_fl = torch.randn(1)     # or use one beta_fl across all nodes
                self.beta_fl = nn.Parameter(beta_fl)
            elif self.beta_fl is False or self.beta_fl == 0:
                self.beta_fl = torch.ones(1).to(self.device)  # or use one beta_fl across all nodes
            else:  # pass in value for beta_fl
                self.beta_fl = torch.tensor(self.beta_fl).to(self.device)

    def feature_learning_forward(self):
        """ 
        Forward the tree for feature learning.
        Return the probabilities for reaching each leaf.
        """
        if self.feature_learning_depth < 0:
            return None
        else:
            path_prob = self.sigmoid(self.beta_fl * self.fl_inner_nodes(self.aug_data))

            path_prob = torch.unsqueeze(path_prob, dim=2)
            path_prob = torch.cat((path_prob, 1 - path_prob), dim=2)
            _mu = self.aug_data.data.new(self.batch_size, 1, 1).fill_(1.)

            begin_idx = 0
            end_idx = 1
            for layer_idx in range(0, self.feature_learning_depth):
                _path_prob = path_prob[:, begin_idx:end_idx, :]

                _mu = _mu.view(self.batch_size, -1, 1).repeat(1, 1, 2)
                _mu = _mu * _path_prob
                begin_idx = end_idx  # index for each layer
                end_idx = begin_idx + 2 ** (layer_idx + 1)
            mu = _mu.view(self.batch_size, self.num_fl_leaves)

            return mu

    def decision_init(self):
        self.num_dc_inner_nodes = 2 ** self.decision_depth - 1
        self.num_dc_leaves = self.num_dc_inner_nodes + 1
        self.dc_inner_nodes = nn.Linear(self.num_intermediate_variables + 1, self.num_dc_inner_nodes, bias=False)

        dc_leaves = torch.randn(self.num_dc_leaves, self.output_dim)
        self.dc_leaves = nn.Parameter(dc_leaves)  # 可训练的二维张量

        # temperature term
        if self.beta_dc is True or self.beta_dc == 1:  # learnable
            beta_dc = torch.randn(self.num_dc_inner_nodes)  # use different beta_dc for each node
            # beta_dc = torch.randn(1)     # or use one beta_dc across all nodes
            self.beta_dc = nn.Parameter(beta_dc)
        elif self.beta_dc is False or self.beta_dc == 0:
            self.beta_dc = torch.ones(1).to(self.device)  # or use one beta_dc across all nodes
        else:  # pass in value for beta_dc
            self.beta_dc = torch.tensor(self.beta_dc).to(self.device)

    def decision_forward(self):
        """
        Forward the differentiable decision tree
        """
        if self.feature_learning_depth >= 0:
            self.intermediate_features_construct()  # 计算中间特征self.features: (batch_size*num_fl_leaves, num_intermediate_variables)
        else:
            self.features = self.data  # (batch_size, input_dim)

        aug_features = self._data_augment_(self.features)
        path_prob = self.sigmoid(self.beta_dc * self.dc_inner_nodes(aug_features))
        feature_batch_size = self.features.shape[0]

        path_prob = torch.unsqueeze(path_prob, dim=2)
        path_prob = torch.cat((path_prob, 1 - path_prob), dim=2)
        _mu = aug_features.data.new(feature_batch_size, 1, 1).fill_(1.)

        begin_idx = 0
        end_idx = 1
        for layer_idx in range(0, self.decision_depth):
            _path_prob = path_prob[:, begin_idx:end_idx, :]

            _mu = _mu.view(feature_batch_size, -1, 1).repeat(1, 1, 2)
            _mu = _mu * _path_prob
            begin_idx = end_idx  # index for each layer
            end_idx = begin_idx + 2 ** (layer_idx + 1)
        mu = _mu.view(feature_batch_size, self.num_dc_leaves)  # (batch_size*num_fl_leaves, num_dc_leaves)

        return mu

    def discrete_decision_forward(self):
        if self.feature_learning_depth >= 0:
            self.intermediate_features_construct()  # 计算中间特征self.features: (batch_size*num_fl_leaves, num_intermediate_variables)
        else:
            self.features = self.data  # (batch_size, input_dim)

        aug_features = self._data_augment_(self.features)
        path_prob = self.sigmoid(self.beta_dc * self.dc_inner_nodes(aug_features))
        feature_batch_size = self.features.shape[0]

        path_prob = torch.unsqueeze(path_prob, dim=2)
        path_prob = torch.cat((path_prob, 1 - path_prob), dim=2)
        path_prob = torch.where(path_prob > 0.5, torch.tensor(1.0, device=self.device), torch.tensor(0.0, device=self.device)) # 大于0.5设置为1
        _mu = aug_features.data.new(feature_batch_size, 1, 1).fill_(1.)

        begin_idx = 0
        end_idx = 1
        for layer_idx in range(0, self.decision_depth):
            _path_prob = path_prob[:, begin_idx:end_idx, :]

            _mu = _mu.view(feature_batch_size, -1, 1).repeat(1, 1, 2)
            _mu = _mu * _path_prob
            begin_idx = end_idx  # index for each layer
            end_idx = begin_idx + 2 ** (layer_idx + 1)
        mu = _mu.view(feature_batch_size, self.num_dc_leaves)  # (batch_size*num_fl_leaves, num_dc_leaves)

        return mu

    def intermediate_features_construct(self):
        """
        Construct the intermediate features for decision making, with learned feature combinations from feature learning module.
        """
        features = self.fl_leaf_weights.view(-1, self.input_dim) @ self.data.transpose(0,
                                                                                       1)  # data: (batch_size, feature_dim); return: (num_fl_leaves*num_intermediate_variables, batch)
        self.features = features.contiguous().view(self.num_fl_leaves, self.num_intermediate_variables, -1).permute(2,
                                                                                                                    0,
                                                                                                                    1).contiguous().view(
            -1,
            self.num_intermediate_variables)  # return: (N, num_intermediate_variables) where N=batch_size*num_fl_leaves

    def decision_leaves(self, p):  # p：到达每个叶节点的概率
        if self.if_smooth:
            distribution_per_leaf = self.softmax(self.dc_leaves / (self.output_dim)**0.5)
        else:
            distribution_per_leaf = self.softmax(self.dc_leaves)   # distribution_per_leaf：不同动作叶子输出不同动作的概率
        average_distribution = torch.mm(p, distribution_per_leaf)  # sum(probability of each leaf * leaf distribution)
        return average_distribution  # (batch_size, output_dim) # 各动作的概率

    def forward(self, data):
        if self.if_save:
            self.forward_num = self.forward_num + 1
            if self.forward_num >= 100000:
                self.forward_num = 0
                self.save_model(self.model_name)
        LogProb = False
        self.data = data
        self.batch_size = data.size()[0]

        if self.feature_learning_depth >= 0:
            self.aug_data = self._data_augment_(data)
            fl_probs = self.feature_learning_forward()  # (batch_size, num_fl_leaves), 在该特征下到达不同中间特征叶子的概率
            if self.if_discrete:
                dc_probs = self.discrete_decision_forward()  # (batch_size*num_fl_leaves, num_dc_leaves)
            else:
                dc_probs = self.decision_forward()
            dc_probs = dc_probs.view(self.batch_size, self.num_fl_leaves,
                                     -1)  # (batch_size, num_fl_leaves, num_dc_leaves), 在不同中间特征叶子的特征下到达不同动作叶子的概率

            _mu = torch.bmm(fl_probs.unsqueeze(1), dc_probs).squeeze(1)  # (batch_size, num_dc_leaves), 在该特征下到达不同动作叶子的概率
            output = self.decision_leaves(_mu)

            if self.greatest_path_probability:
                vs, ids = torch.max(fl_probs,
                                    1)  # ids is the leaf index with maximal path probability: 在特征下最有可能的中间特征叶子的索引
                # get the path with greatest probability, get index of it, feature vector and feature value on that leaf
                self.max_leaf_idx_fl = ids
                self.max_feature_vector = \
                self.fl_leaf_weights.view(self.num_fl_leaves, self.num_intermediate_variables, self.input_dim)[ids]
                self.max_feature_value = self.features.view(-1, self.num_fl_leaves, self.num_intermediate_variables)[:,
                                         ids, :]

                one_dc_probs = dc_probs[torch.arange(dc_probs.shape[0]), ids,
                               :]  # select decision path probabilities of learned features with largest probability
                one_hot_path_probability_dc = torch.zeros(one_dc_probs.shape).to(self.device)
                vs_dc, ids_dc = torch.max(one_dc_probs,
                                          1)  # ids is the leaf index with maximal path probability: 在中间特征下最有可能的动作叶子的索引
                self.max_leaf_idx_dc = ids_dc
                one_hot_path_probability_dc.scatter_(1, ids_dc.view(-1, 1), 1.)
                prediction = self.decision_leaves(one_hot_path_probability_dc)

            else:  # prediction value equals to the average distribution
                prediction = output

            if LogProb:
                output = torch.log(output)  # 根据所有叶节点得到的输出
                prediction = torch.log(prediction)  # 根据最优路径得到的输出

        else:
            if self.if_discrete:
                dc_probs = self.discrete_decision_forward()  # (batch_size, num_dc_leaves)
            else:
                dc_probs = self.decision_forward()
            _mu = dc_probs
            output = self.decision_leaves(_mu)

            if self.greatest_path_probability:
                one_dc_probs = dc_probs
                one_hot_path_probability_dc = torch.zeros(one_dc_probs.shape).to(self.device)
                vs_dc, ids_dc = torch.max(one_dc_probs, 1)
                self.max_leaf_idx_dc = ids_dc
                one_hot_path_probability_dc.scatter_(1, ids_dc.view(-1, 1), 1.)
                prediction = self.decision_leaves(one_hot_path_probability_dc)

            else:  # prediction value equals to the average distribution
                prediction = output

            if LogProb:
                output = torch.log(output)  # 根据所有叶节点得到的输出
                prediction = torch.log(prediction)  # 根据最优路径得到的输出

        return prediction

    def _data_augment_(self, input):    # 在前边加上偏置项
        batch_size = input.size()[0]
        input = input.view(batch_size, -1)
        bias = torch.ones(batch_size, 1).to(self.device)
        input = torch.cat((bias, input), 1)
        return input

    def save_model(self, model_path):
        torch.save(self.state_dict(), model_path)

    def load_model(self, model_path):
        self.load_state_dict(torch.load(model_path, map_location='cpu'))
        self.eval()

    def load_checkpoint(self, file_path):
        # 加载检查点文件
        checkpoint = torch.load(file_path, map_location='cpu')
        self.dc_leaves.data = checkpoint['state_dict']['hl_agent']['policy.net.p.0.dc_leaves']
        self.dc_inner_nodes.weight.data = checkpoint['state_dict']['hl_agent']['policy.net.p.0.dc_inner_nodes.weight']

In [38]:
# 绘制代码
import torch
import torch.nn as nn
from torch.utils import data
import numpy as np
import copy
import matplotlib as mpl
# from spirl.configs import local



def get_binary_index(tree):
    """
    Get binary index for tree nodes:
    From

    0
    1 2
    3 4 5 6 

    to 

    '0'
    '00' '01' 
    '000' '001' '010' '011'

    """
    index_list = []
    for layer_idx in range(0, tree.max_depth+1):
        index_list.append([bin(i)[2:].zfill(layer_idx+1) for i in range(0, np.power(2, layer_idx))])
    return np.concatenate(index_list)

def path_from_prediction(tree, idx):
    """
    Generate list of nodes as decision path, 
    with each node represented by a binary string and an int index
    """
    binary_idx_list = []
    int_idx_list=[]
    idx = int(idx)
    for layer_idx in range(tree.max_depth+1, 0, -1):
        binary_idx_list.append(bin(idx)[2:].zfill(layer_idx))
        int_idx_list.append(2**(layer_idx-1)-1+idx)
        idx = int(idx/2)
    binary_idx_list.reverse()  # from top to bottom
    int_idx_list.reverse() 
    return binary_idx_list, int_idx_list

def draw_tree(original_tree, input_img=None, show_correlation=False, DrawTree=None, savepath=''):
    '''
    Need to carefully select several configurations for well displaying trees for different environments, e.g. CartPole and LunarLander-v2
    '''
    # 整体
    aspect_inners = 0.6
    aspect_leaves = 0.6
    arrow_color = '#262626'

    # # 局部
    # aspect_inners = 1
    # aspect_leaves = 1
    # arrow_color = '#262626'

    import itertools
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.gridspec import GridSpec
    from matplotlib.patches import ConnectionPatch

    tree = copy.copy(original_tree)
    if DrawTree=='FL': # draw the feature learning tree
        tree.inner_node_num = tree.num_fl_inner_nodes
        tree.max_depth = tree.feature_learning_depth
        tree.leaf_num = tree.num_fl_leaves
        inner_nodes_name='fl_inner_nodes.weight'
        leaf_nodes_name='fl_leaf_weights'
        input_shape=(tree.input_dim,)

    elif DrawTree == 'DM':  # draw the decision making tree
        tree.inner_node_num = tree.num_dc_inner_nodes
        tree.max_depth = tree.decision_depth
        tree.leaf_num = tree.num_dc_leaves
        inner_nodes_name='dc_inner_nodes.weight'
        leaf_nodes_name='dc_leaves'
        input_shape=(tree.num_intermediate_variables,)
        # input_img=tree.max_feature_value.squeeze().detach().cpu().numpy()  # replace the original input image to be intermediate feature value

    def _add_arrow(ax_parent, ax_child, xyA, xyB, color='black', linestyle=None):
        '''Private utility function for drawing arrows between two axes.'''
        con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA='data', coordsB='data',
                              axesA=ax_child, axesB=ax_parent, arrowstyle='<-,head_length=0.8,head_width=0.4',
                              color=color, linewidth=tree.max_depth-2, linestyle=linestyle)
        ax_child.add_artist(con)

    inner_nodes = tree.state_dict()[inner_nodes_name]
    leaf_nodes = tree.state_dict()[leaf_nodes_name]
    binary_indices = get_binary_index(tree)
    inner_indices = binary_indices[:tree.inner_node_num]
    leaf_indices = binary_indices[tree.inner_node_num:]
    
    if len(input_shape) == 3:
        img_rows, img_cols, img_chans = input_shape
    elif len(input_shape) == 1:
        img_rows, img_cols = input_shape[0], input_shape[0]

    if DrawTree == 'FL':  # each leaf contains vectors of number: tree.args['num_intermediate_variables'] 
        leaf_nodes = leaf_nodes.view(tree.leaf_num, tree.num_intermediate_variables, tree.input_dim)

    kernels = dict([(node_idx, node_value.cpu().numpy().reshape(input_shape)) for node_idx, node_value in zip (inner_indices, inner_nodes[:, 1:]) ])
    biases = dict([(node_idx, node_value.cpu().numpy().squeeze()) for node_idx, node_value in zip (inner_indices, inner_nodes[:, :1]) ])
    leaves = dict([(leaf_idx, np.array([leaf_dist.cpu().numpy()])) for leaf_idx, leaf_dist in zip (leaf_indices, leaf_nodes) ])
    n_leaves = tree.leaf_num
    assert len(leaves) == n_leaves

    fig = plt.figure(figsize=(2*n_leaves, n_leaves/2), facecolor='white')  # for cartpole
    gs = GridSpec(tree.max_depth+1, n_leaves*2, height_ratios=[1]*tree.max_depth+[0.5]) # 高度调整

    # Grid Coordinate X (horizontal)
    gcx = [list(np.arange(1, 2**(i+1), 2) * (2**(tree.max_depth+1) // 2**(i+1)))
           for i in range(tree.max_depth+1)]
    gcx = list(itertools.chain.from_iterable(gcx))
    axes = {}
    path = ['0']

    import matplotlib.pyplot as plt
    import matplotlib.colors as mcolors

    # 权重颜色
    cmap = plt.get_cmap('coolwarm')
    new_cmap = cmap(np.linspace(0.08, 0.92, 256))  # 调整对比度，避免最强的红色和蓝色
    # 创建新的颜色映射
    new_cmap = mcolors.LinearSegmentedColormap.from_list('less_intense_coolwarm', new_cmap)

    imshow_args = {'origin': 'upper', 'interpolation': 'None', 'cmap': plt.get_cmap(new_cmap)}
    imshow_args = {'origin': 'upper', 'interpolation': 'None', 'cmap': 'Oranges'} # OrRd、Reds、YlOrBr

    # 权重参数
    kernel_min_val = np.min([np.min(kernel) for kernel in kernels.values()])
    kernel_max_val = np.max([np.max(kernel) for kernel in kernels.values()])
    leaf_min_val = np.min([np.min(leaf) for leaf in leaves.values()])
    leaf_max_val = np.max([np.max(leaf) for leaf in leaves.values()])

    # mkbl、mlsh参数
    kernel_min_val = 0.0
    kernel_max_val = 7.
    leaf_min_val = 0.0 # 小红多
    leaf_max_val = 0.4 # 大蓝多

    # # calvin参数
    # kernel_min_val = 0.0
    # kernel_max_val = 18.
    # leaf_min_val = 0.0 # 小红多
    # leaf_max_val = 0.8 # 大蓝多

    # plot color bar for kernels and leaves separately
    norm = mpl.colors.Normalize(vmin=kernel_min_val,vmax=kernel_max_val)
    sm = plt.cm.ScalarMappable(cmap=imshow_args['cmap'], norm=norm)
    sm.set_array([])
    cbaxes = fig.add_axes([0.01, 0.4, 0.03, 0.2])  # This is the position for the colorbar
    plt.colorbar(sm, ticks=np.linspace(kernel_min_val,kernel_max_val,5), cax = cbaxes)
        
    # draw tree nodes
    for pos, key in enumerate(sorted(kernels.keys(), key=lambda x:(len(x), x))):
        ax = plt.subplot(gs[len(key)-1, gcx[pos]-2:gcx[pos]+2])
        axes[key] = ax
        kernel_image = kernels[key]

        if len(kernel_image.shape)==3: # 2D image (H, W, C)
            ax.imshow(kernel_image.squeeze(), vmin=kernel_min_val, vmax=kernel_max_val, **imshow_args)
        elif len(kernel_image.shape)==1:
            vector_image = np.ones((kernel_image.shape[0], 1)) @ [kernel_image]
            ax.imshow(vector_image, vmin=kernel_min_val, vmax=kernel_max_val, **imshow_args)
        # 保留坐标轴，但去除刻度线和标签
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        # ax.axis('off')

        ax.set_aspect(aspect=aspect_inners) # imp：节点宽高比

        # # 绘制子任务序号
        # if DrawTree!='FL':  # feature learning tree do not have titile indicating the classification 
        #     digits = set([np.argmax(leaves[k]) for k in leaves.keys()
        #                 if k.startswith(key)])
        #     title = ','.join(str(digit) for digit in digits)
        #     plt.title('{}'.format(title))

    imshow_args = {'origin': 'upper', 'interpolation': 'None', 'cmap': 'Greens'} # YlGn

    # draw tree leaves
    for pos, key in enumerate(sorted(leaves.keys(), key=lambda x:(len(x), x))):
        ax = plt.subplot(gs[len(key)-1,
                            gcx[len(kernels)+pos]-1:gcx[len(kernels)+pos]+1])
        axes[key] = ax
        if len(leaves[key].shape)>2:  # output multi-dimension, e.g. intermediate features for feature learning tree
            leaf_image = leaves[key].squeeze(0)
        else:
            leaf_image = np.ones((tree.output_dim, 1)) @ leaves[key]

        def softmax_for_rows(matrix):
            # 对每一行的元素应用 exp 函数，同时减去该行的最大值以避免数值溢出
            exps = np.exp(matrix - np.amax(matrix, axis=1, keepdims=True))
            # 计算每一行的指数之和
            sum_exps = np.sum(exps, axis=1, keepdims=True)
            # 进行归一化，使每一行的和为 1
            probabilities = exps / sum_exps
            return probabilities
        
        leaf_image = softmax_for_rows(leaf_image)

        ax.imshow(leaf_image, vmin=leaf_min_val, vmax=leaf_max_val, **imshow_args)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        ax.set_aspect(aspect=aspect_leaves) # imp：节点宽高比

        # 绘制子任务序号
        if DrawTree!='FL':  # feature learning tree do not have titile indicating the classification 
            plt.title('{}'.format(np.argmax(leaves[key])), y=-.5)
        
    # add arrows indicating flow
    for pos, key in enumerate(sorted(axes.keys(), key=lambda x:(len(x), x))):
        children_keys = [k for k in axes.keys()
                         if len(k) == len(key) + 1 and k.startswith(key)]
        for child_key in children_keys:
            p_rows, p_cols = axes[key].get_images()[0].get_array().shape
            c_rows, c_cols = axes[child_key].get_images()[0].get_array().shape

            # 调整箭头的起点和终点位置
            parent_position = (p_cols // 2, p_rows - 1)  # 上移1个单位
            child_position = (c_cols // 2, 1 - 1)  # 上移1个单位

            linestyle = None
            _add_arrow(axes[key], axes[child_key], child_position, parent_position, color=arrow_color, linestyle=linestyle)


    # draw input image with arrow indicating flow into the root node
    if input_img is not None:
        ax = plt.subplot(gs[0, 0:4])  # for lunarlander
        img_min_val = np.min(input_img)
        img_max_val = np.max(input_img)
        if len(input_img.shape)==3: # 2D image (H, W, C)
            ax.imshow(input_img.squeeze(), clim=(0.0, 1.0), vmin=img_min_val, vmax=img_max_val, **imshow_args)
        elif len(input_img.shape)==1:
            vector_image = np.ones((input_img.shape[0], 1)) @ [input_img]
            ax.imshow(vector_image, vmin=img_min_val, vmax=img_max_val, **imshow_args)
        ax.axis('off')
        plt.title('input')
        norm = mpl.colors.Normalize(vmin=img_min_val,vmax=img_max_val)
        sm = plt.cm.ScalarMappable(cmap=imshow_args['cmap'], norm=norm)
        sm.set_array([])
        cbaxes = fig.add_axes([0.01, 0.7, 0.03, 0.2])  # This is the position for the colorbar
        plt.colorbar(sm, ticks=np.linspace(img_min_val,img_max_val,5), cax = cbaxes)



    norm = mpl.colors.Normalize(vmin=leaf_min_val,vmax=leaf_max_val)
    sm = plt.cm.ScalarMappable(cmap=imshow_args['cmap'], norm=norm)
    sm.set_array([])
    cbaxes = fig.add_axes([0.01, 0.1, 0.03, 0.2])  # This is the position for the colorbar, second dim is y, from bottom to top in img: 0->1
    plt.colorbar(sm, ticks=np.linspace(leaf_min_val,leaf_max_val,5), cax = cbaxes)


    if savepath:
        plt.savefig(savepath, facecolor=fig.get_facecolor())
        plt.close()
    else:
        plt.show()

def get_path(tree, input, Probs=False):
    tree.forward(torch.Tensor(input).unsqueeze(0))
    max_leaf_idx = tree.max_leaf_idx
    _, path_idx_int = path_from_prediction(tree, max_leaf_idx)
    if Probs:
        return path_idx_int, tree.inner_probs.squeeze().detach().cpu().numpy()
    else:
        return path_idx_int

import torch

# 模型配置
codebook = 16
class HyperParameters:
    def __init__(self):
        self.codebook_K = codebook
        self.feature_learning_depth = -1
        self.decision_depth = 6
        self.num_intermediate_variables = 20
        self.greatest_path_probability = False
        self.beta_fl = False
        self.beta_dc = False
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.if_smooth = False
        self.tree_name = ''
        self.if_save = False
hp = HyperParameters()

tree = VQCDTPredictor(hp, 60, codebook)
# tree.load_model('/home/zuo/project/xrl/spirl/experiments/old_results1/cdt_model/16_-1+60+6+0_s6_2copy.pth')
tree.load_checkpoint('./weight/mlsh_weights_ep24.pth')


use SDT


In [39]:

# 参数打印
num_params = 0
for key, v in tree.state_dict().items():
    print(key, v.reshape(-1).shape[0])
    num_params+=v.reshape(-1).shape[0]
print('Total number of parameters in model: ', num_params)

draw_tree(tree, input_img=None, DrawTree='DM', savepath='./临时/mlsh_tree.pdf')
# draw_tree_with_params(tree, savepath='/home/zuo/project/xrl/spirl/experiments/cdt_model/16+0+5+20+test.png')

dc_leaves 1024
dc_inner_nodes.weight 3843
Total number of parameters in model:  4867
