## CheXpert

In [1]:
class TreeNode:
    def __init__(self, label, children=None):
        self.label = label
        self.children = children or []

def build_tree():
    # no_finding = TreeNode("No Finding")
    heart_related = TreeNode("Heart Related Issues", 
                             [
                              TreeNode("Enlarged Cardiomediastinum"),
                              TreeNode("Cardiomegaly")
                              ]
                             )
    lung_issues = TreeNode("Lung Issues", 
                           [
                            TreeNode("Lung Lesion"),
                            TreeNode("Lung Opacity", 
                                     [
                                      TreeNode("Edema"),
                                      TreeNode("Consolidation"),
                                      TreeNode("Pneumonia")
                                      ]),
                              TreeNode("Atelectasis"),
                              TreeNode("Pneumothorax")
                            ])
    pleural_issues = TreeNode("Pleural Issues", [
                                                  TreeNode("Pleural Effusion"),
                                                  TreeNode("Pleural Other")
                                              ]
                              )
    other_issues = TreeNode("Other Issues", [
                                              TreeNode("Fracture"),
                                              TreeNode("Support Devices")
                                              ])
    # root = TreeNode("Chest X-ray Findings", [no_finding, heart_related, lung_issues, pleural_issues, other_issues])
    root = TreeNode("Chest X-ray Findings", [heart_related, lung_issues, pleural_issues, other_issues])
    return root
  
  


### structure vis

#####  Functions

In [5]:
import plotly.graph_objects as go
import networkx as nx

class TreeNode:
    def __init__(self, label, children=None):
        self.label = label
        self.children = children or []

def build_tree():
    heart_related = TreeNode("Heart Related Issues", 
                             [
                              TreeNode("Enlarged Cardiomediastinum"),
                              TreeNode("Cardiomegaly")
                              ]
                             )
    lung_issues = TreeNode("Lung Issues", 
                           [
                            TreeNode("Lung Lesion"),
                            TreeNode("Lung Opacity", 
                                     [
                                      TreeNode("Edema"),
                                      TreeNode("Consolidation"),
                                      TreeNode("Pneumonia")
                                      ]),
                              TreeNode("Atelectasis"),
                              TreeNode("Pneumothorax")
                            ])
    pleural_issues = TreeNode("Pleural Issues", [
                                                  TreeNode("Pleural Effusion"),
                                                  TreeNode("Pleural Other")
                                              ]
                              )
    other_issues = TreeNode("Other Issues", [
                                              TreeNode("Fracture"),
                                              TreeNode("Support Devices")
                                              ])
    root = TreeNode("Chest X-ray Findings", [heart_related, lung_issues, pleural_issues, other_issues])
    return root

def tree_to_networkx(node, graph=None, parent=None):
    if graph is None:
        graph = nx.Graph()
    
    graph.add_node(node.label)
    
    if parent:
        graph.add_edge(parent.label, node.label)
    
    for child in node.children:
        tree_to_networkx(child, graph, node)
    
    return graph

def create_plotly_tree(root):
    G = tree_to_networkx(root)
    pos = nx.spring_layout(G)

    edge_x = []
    edge_y = []
    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])

    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='#888'),
        hoverinfo='none',
        mode='lines')

    node_x = []
    node_y = []
    for node in G.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers+text',
        hoverinfo='text',
        text=[node for node in G.nodes()],
        textposition="top center",
        marker=dict(
            showscale=True,
            colorscale='YlGnBu',
            size=10,
            color=[],
            colorbar=dict(
                thickness=15,
                title='Node Connections',
                xanchor='left',
                titleside='right'
            )
        )
    )

    node_adjacencies = []
    for node, adjacencies in G.adjacency():
        node_adjacencies.append(len(adjacencies))

    node_trace.marker.color = node_adjacencies

    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        title='Chest X-ray Findings Tree -- CheXpert',
                        titlefont_size=16,
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=20,l=5,r=5,t=40),
                        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                    )

    return fig

##### vis

In [6]:
# 创建树并生成图形
root = build_tree()
fig = create_plotly_tree(root)
fig.show()

### Distance functions

In [2]:
# 寻找最低公共祖先
def find_lowest_common_ancestor(root, node1, node2):
    if root is None:
        return None

    if root.label == node1 or root.label == node2:
        return root

    common_ancestor = None
    for child in root.children:
        child_ancestor = find_lowest_common_ancestor(child, node1, node2)
        if child_ancestor:
            if common_ancestor:
                # 如果已经有一个公共祖先，说明当前节点为最低公共祖先
                return root
            else:
                # 否则更新为当前子树的公共祖先
                common_ancestor = child_ancestor

    return common_ancestor

# 计算两个节点之间的距离
def calculate_distance(root, node1, node2):
    # 寻找最低公共祖先
    lca = find_lowest_common_ancestor(root, node1, node2)

    # 计算两个节点到最低公共祖先的距离
    distance1 = find_distance_to_node(lca, node1, 0)
    distance2 = find_distance_to_node(lca, node2, 0)

    # 总距离为两者之和
    total_distance = distance1 + distance2
    return total_distance

# 计算节点到祖先的距离
def find_distance_to_node(current, target, distance):
    if current is None:
        return float('inf')

    if current.label == target:
        return distance

    for child in current.children:
        distance_to_child = find_distance_to_node(child, target, distance + 1)
        if distance_to_child != float('inf'):
            return distance_to_child

    return float('inf')


In [44]:
# 构建树
tree_root = build_tree()

# 选择两个节点
node1_label = "Enlarged Cardiomediastinum"
node2_label = "Fracture"

# 计算两个节点之间的距离
distance = calculate_distance(tree_root, node1_label, node2_label)

print(f"Distance between Node {node1_label} and Node {node2_label}: {distance}")


CHEXPERT_LABELS = [
 'Atelectasis',
 'Cardiomegaly',
 'Consolidation',
 'Edema',
 'Enlarged Cardiomediastinum',
 'Fracture',
 'Lung Lesion',
 'Lung Opacity',
#  'No Finding',
 'Pleural Effusion',
 'Pleural Other',
 'Pneumonia',
 'Pneumothorax',
 "Support Devices"
]

import torch
import torch.nn.functional as F
n = len(CHEXPERT_LABELS)

import numpy as np

def create_negative_one_matrix(n):
    return np.full((n, n), -1)

negative_one_matrix = create_negative_one_matrix(n)
matrix = negative_one_matrix.copy()

# Print the created matrix
for i, j in enumerate(CHEXPERT_LABELS):
  for x,y in enumerate(CHEXPERT_LABELS):
    if j == y:
      matrix[i][x] = 1 
      continue
    node1_label = j
    node2_label = y
    distance = calculate_distance(tree_root, node1_label, node2_label)
    matrix[i][x] = distance
matrix = torch.tensor(matrix).float()
corr_matrix = 1/matrix


def safe_divide(a, b):
    # 使用 torch.where 处理除零情况
    result = torch.where(b != 0, a / b, torch.tensor(0.0))
    return result
  
# matrix = safe_divide(1, matrix)
print(matrix)
normalized_distance_matrix = F.normalize(matrix, p=2, dim=1)
normalized_corr_matrix = F.normalize(corr_matrix, p=2, dim=1)
# torch.save(normalized_matrix, './constants/normalized_distance_matrix.pt')
print("distance:\n",normalized_distance_matrix)
print("correlation:\n",normalized_corr_matrix)
torch.save(normalized_corr_matrix, '../constants/CHEXPERT/normalized_corr_matrix.pt')

Distance between Node Enlarged Cardiomediastinum and Node Fracture: 4
tensor([[1., 4., 3., 3., 4., 4., 2., 2., 4., 4., 3., 2., 4.],
        [4., 1., 5., 5., 2., 4., 4., 4., 4., 4., 5., 4., 4.],
        [3., 5., 1., 2., 5., 5., 3., 1., 5., 5., 2., 3., 5.],
        [3., 5., 2., 1., 5., 5., 3., 1., 5., 5., 2., 3., 5.],
        [4., 2., 5., 5., 1., 4., 4., 4., 4., 4., 5., 4., 4.],
        [4., 4., 5., 5., 4., 1., 4., 4., 4., 4., 5., 4., 2.],
        [2., 4., 3., 3., 4., 4., 1., 2., 4., 4., 3., 2., 4.],
        [2., 4., 1., 1., 4., 4., 2., 1., 4., 4., 1., 2., 4.],
        [4., 4., 5., 5., 4., 4., 4., 4., 1., 2., 5., 4., 4.],
        [4., 4., 5., 5., 4., 4., 4., 4., 2., 1., 5., 4., 4.],
        [3., 5., 2., 2., 5., 5., 3., 1., 5., 5., 1., 3., 5.],
        [2., 4., 3., 3., 4., 4., 2., 2., 4., 4., 3., 1., 4.],
        [4., 4., 5., 5., 4., 2., 4., 4., 4., 4., 5., 4., 1.]])
distance:
 tensor([[0.0857, 0.3430, 0.2572, 0.2572, 0.3430, 0.3430, 0.1715, 0.1715, 0.3430,
         0.3430, 0.2572, 0.1715

基于graph中各个节点之间的距离，使用节点之间的距离来反应不同节点之间的相关性。 假设graph上每条边的权重都是1   
借鉴GCN中在信息传递的时候，在每个节点上加上自环，代表说一个社区可以

## NIH

In [54]:
class TreeNode:
    def __init__(self, label, children=None):
        self.label = label
        self.children = children or []

# 创建树状结构
root = TreeNode('Medical Conditions', [
    TreeNode('Respiratory Conditions', [
        TreeNode('Atelectasis', [
            TreeNode('Mass'),
            TreeNode('Pneumonia')
        ]),
        TreeNode('Consolidation', [
            TreeNode('Pneumonia')
        ]),
        TreeNode('Edema', [
            TreeNode('Effusion')
        ]),
        TreeNode('Emphysema'),
        TreeNode('Fibrosis'),
        TreeNode('Infiltration'),
        TreeNode('Mass', [
            TreeNode('Atelectasis')
        ]),
        TreeNode('Nodule'),
        TreeNode('Pleural_Thickening'),
        TreeNode('Pneumonia', [
            TreeNode('Consolidation')
        ]),
        TreeNode('Pneumothorax')
    ]),
    TreeNode('Cardiac Conditions', [
        TreeNode('Cardiomegaly', [
            TreeNode('Edema'),
            TreeNode('Effusion')
        ])
    ]),
    TreeNode('Other Conditions', [
        TreeNode('Hernia')
    ])
])

# 打印树状结构
def print_tree(node, level=0):
    print(' ' * level * 4 + node.label)
    for child in node.children:
        print_tree(child, level + 1)

print_tree(root)




Medical Conditions
    Respiratory Conditions
        Atelectasis
            Mass
            Pneumonia
        Consolidation
            Pneumonia
        Edema
            Effusion
        Emphysema
        Fibrosis
        Infiltration
        Mass
            Atelectasis
        Nodule
        Pleural_Thickening
        Pneumonia
            Consolidation
        Pneumothorax
    Cardiac Conditions
        Cardiomegaly
            Edema
            Effusion
    Other Conditions
        Hernia


#### Structure Vis.

In [12]:
import plotly.graph_objects as go
import networkx as nx

class TreeNode:
    def __init__(self, label, children=None):
        self.label = label
        self.children = children or []

def build_tree():
    return TreeNode('Medical Conditions', [
        TreeNode('Respiratory Conditions', [
            TreeNode('Atelectasis', [
                TreeNode('Mass'),
                TreeNode('Pneumonia')
            ]),
            TreeNode('Consolidation', [
                TreeNode('Pneumonia')
            ]),
            TreeNode('Edema', [
                TreeNode('Effusion')
            ]),
            TreeNode('Emphysema'),
            TreeNode('Fibrosis'),
            TreeNode('Infiltration'),
            TreeNode('Mass', [
                TreeNode('Atelectasis')
            ]),
            TreeNode('Nodule'),
            TreeNode('Pleural_Thickening'),
            TreeNode('Pneumonia', [
                TreeNode('Consolidation')
            ]),
            TreeNode('Pneumothorax')
        ]),
        TreeNode('Cardiac Conditions', [
            TreeNode('Cardiomegaly', [
                TreeNode('Edema'),
                TreeNode('Effusion')
            ])
        ]),
        TreeNode('Other Conditions', [
            TreeNode('Hernia')
        ])
    ])

def tree_to_networkx(node, graph=None, parent=None):
    if graph is None:
        graph = nx.Graph()
    
    graph.add_node(node.label)
    
    if parent:
        graph.add_edge(parent.label, node.label)
    
    for child in node.children:
        tree_to_networkx(child, graph, node)
    
    return graph

def create_interactive_tree(root):
    G = tree_to_networkx(root)
    pos = nx.spring_layout(G, k=0.9, iterations=50)

    edge_x = []
    edge_y = []
    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])

    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='#888'),
        hoverinfo='none',
        mode='lines')

    node_x = [pos[node][0] for node in G.nodes()]
    node_y = [pos[node][1] for node in G.nodes()]

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers+text',
        hoverinfo='text',
        text=[node for node in G.nodes()],
        textposition="top center",
        marker=dict(
            showscale=True,
            colorscale='Viridis',
            size=10,
            color=[],
            colorbar=dict(
                thickness=15,
                title='Node Depth',
                xanchor='left',
                titleside='right'
            )
        )
    )

    node_depths = nx.shortest_path_length(G, source=root.label)
    node_trace.marker.color = [node_depths[node] for node in G.nodes()]

    fig = go.FigureWidget(data=[edge_trace, node_trace],
                          layout=go.Layout(
                              title='Medical Conditions Tree (Drag nodes to reposition)',
                              titlefont_size=16,
                              showlegend=False,
                              hovermode='closest',
                              margin=dict(b=20,l=5,r=5,t=40),
                              xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                              yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                              dragmode='pan'
                          ))

    def update_point(trace, points, selector):
        if len(points.point_inds) > 0:
            ind = points.point_inds[0]
            node_name = trace.text[ind]
            x, y = points.xs[0], points.ys[0]
            pos[node_name] = (x, y)
            
            with fig.batch_update():
                # 更新节点位置
                trace.x = [pos[node][0] for node in G.nodes()]
                trace.y = [pos[node][1] for node in G.nodes()]
                
                # 更新边的位置
                edge_x = []
                edge_y = []
                for edge in G.edges():
                    x0, y0 = pos[edge[0]]
                    x1, y1 = pos[edge[1]]
                    edge_x.extend([x0, x1, None])
                    edge_y.extend([y0, y1, None])
                fig.data[0].x = edge_x
                fig.data[0].y = edge_y

    fig.data[1].on_click(update_point)

    return fig

# 创建树并生成图形
root = build_tree()
fig = create_interactive_tree(root)
display(fig)

FigureWidget({
    'data': [{'hoverinfo': 'none',
              'line': {'color': '#888', 'width': 0.5},
              'mode': 'lines',
              'type': 'scatter',
              'uid': '201c2574-a579-4e47-837f-4d540285fefb',
              'x': [0.5087912555878398, -0.17264195540217742, None,
                    0.5087912555878398, 0.2173753401240641, None,
                    0.5087912555878398, 1.0, None, -0.17264195540217742,
                    -0.4904634598516633, None, -0.17264195540217742,
                    -0.2600346940083654, None, -0.17264195540217742,
                    0.12524757949813514, None, -0.17264195540217742,
                    -0.6721569049018644, None, -0.17264195540217742,
                    -0.49384776720521656, None, -0.17264195540217742,
                    -0.379681057356912, None, -0.17264195540217742,
                    -0.1356664413746154, None, -0.17264195540217742,
                    -0.5437455253891325, None, -0.17264195540217742,
           

In [56]:
from collections import deque

class TreeNode:
    def __init__(self, label, children=None):
        self.label = label
        self.children = children or []

def find_node(root, label):
    """在树中找到特定标签的节点"""
    if root.label == label:
        return root
    for child in root.children:
        found = find_node(child, label)
        if found:
            return found
    return None

def find_path_to_root(node, target_label):
    """找到从给定节点到包含目标标签的祖先节点的路径"""
    path = []
    current = node
    while current:
        path.append(current.label)
        if current.label == target_label:
            return path
        parent = None
        for child in current.children:
            if find_node(child, target_label):
                parent = child
                break
        current = parent
    return None

def find_lowest_common_ancestor(root, label1, label2):
    """找到两个标签的最低共同祖先"""
    path1 = find_path_to_root(root, label1)
    path2 = find_path_to_root(root, label2)
    if not path1 or not path2:
        return None
    i = 0
    while i < len(path1) and i < len(path2) and path1[i] == path2[i]:
        i += 1
    return path1[i-1] if i > 0 else None

def calculate_distance(root, label1, label2):
    """计算两个标签之间的距离"""
    lca = find_lowest_common_ancestor(root, label1, label2)
    if not lca:
        return None
    
    path1 = find_path_to_root(root, label1)
    path2 = find_path_to_root(root, label2)
    
    lca_index1 = path1.index(lca)
    lca_index2 = path2.index(lca)
    
    return len(path1) - lca_index1 + len(path2) - lca_index2 - 2

# 示例使用
root = TreeNode('Medical Conditions', [
    TreeNode('Respiratory Conditions', [
        TreeNode('Atelectasis', [
            TreeNode('Mass'),
            TreeNode('Pneumonia')
        ]),
        TreeNode('Consolidation', [
            TreeNode('Pneumonia')
        ]),
        TreeNode('Edema', [
            TreeNode('Effusion')
        ]),
        TreeNode('Emphysema'),
        TreeNode('Fibrosis'),
        TreeNode('Infiltration'),
        TreeNode('Mass', [
            TreeNode('Atelectasis')
        ]),
        TreeNode('Nodule'),
        TreeNode('Pleural_Thickening'),
        TreeNode('Pneumonia', [
            TreeNode('Consolidation')
        ]),
        TreeNode('Pneumothorax')
    ]),
    TreeNode('Cardiac Conditions', [
        TreeNode('Cardiomegaly', [
            TreeNode('Edema'),
            TreeNode('Effusion')
        ])
    ]),
    TreeNode('Other Conditions', [
        TreeNode('Hernia')
    ])
])

# 测试函数
print(calculate_distance(root, 'Atelectasis', 'Pneumonia'))  # 应该返回 1
print(calculate_distance(root, 'Atelectasis', 'Cardiomegaly'))  # 应该返回 3
print(calculate_distance(root, 'Hernia', 'Pneumothorax'))  # 应该返回 3
print(calculate_distance(root, 'Edema', 'Effusion'))  # 应该返回 1

1
4
4
1


In [57]:
import numpy as np

class TreeNode:
    def __init__(self, label, children=None):
        self.label = label
        self.children = children or []

def find_node(root, label):
    if root.label == label:
        return root
    for child in root.children:
        found = find_node(child, label)
        if found:
            return found
    return None

def find_path_to_root(node, target_label):
    path = []
    current = node
    while current:
        path.append(current.label)
        if current.label == target_label:
            return path
        parent = None
        for child in current.children:
            if find_node(child, target_label):
                parent = child
                break
        current = parent
    return None

def find_lowest_common_ancestor(root, label1, label2):
    path1 = find_path_to_root(root, label1)
    path2 = find_path_to_root(root, label2)
    if not path1 or not path2:
        return None
    i = 0
    while i < len(path1) and i < len(path2) and path1[i] == path2[i]:
        i += 1
    return path1[i-1] if i > 0 else None

def calculate_distance(root, label1, label2):
    if label1 == label2:
        return 0
    lca = find_lowest_common_ancestor(root, label1, label2)
    if not lca:
        return None
    
    path1 = find_path_to_root(root, label1)
    path2 = find_path_to_root(root, label2)
    
    lca_index1 = path1.index(lca)
    lca_index2 = path2.index(lca)
    
    return len(path1) - lca_index1 + len(path2) - lca_index2 - 2

def create_distance_matrix(root, labels):
    n = len(labels)
    matrix = np.zeros((n, n), dtype=int)
    
    for i in range(n):
        for j in range(i+1, n):
            distance = calculate_distance(root, labels[i], labels[j])
            matrix[i, j] = distance
            matrix[j, i] = distance
    
    return matrix

# 创建树状结构
root = TreeNode('Medical Conditions', [
    TreeNode('Respiratory Conditions', [
        TreeNode('Atelectasis', [
            TreeNode('Mass'),
            TreeNode('Pneumonia')
        ]),
        TreeNode('Consolidation', [
            TreeNode('Pneumonia')
        ]),
        TreeNode('Edema', [
            TreeNode('Effusion')
        ]),
        TreeNode('Emphysema'),
        TreeNode('Fibrosis'),
        TreeNode('Infiltration'),
        TreeNode('Mass', [
            TreeNode('Atelectasis')
        ]),
        TreeNode('Nodule'),
        TreeNode('Pleural_Thickening'),
        TreeNode('Pneumonia', [
            TreeNode('Consolidation')
        ]),
        TreeNode('Pneumothorax')
    ]),
    TreeNode('Cardiac Conditions', [
        TreeNode('Cardiomegaly', [
            TreeNode('Edema'),
            TreeNode('Effusion')
        ])
    ]),
    TreeNode('Other Conditions', [
        TreeNode('Hernia')
    ])
])

# 标签列表
labels = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema',
          'Effusion', 'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration', 'Mass',
          'Nodule', 'Pleural_Thickening', 'Pneumonia', 'Pneumothorax']

# 计算距离矩阵
distance_matrix = create_distance_matrix(root, labels)

# 打印距离矩阵
print("Distance Matrix:")
print(distance_matrix)

# 如果需要更好的可读性，可以使用pandas来展示结果
import pandas as pd

df = pd.DataFrame(distance_matrix, index=labels, columns=labels)
print("\nDistance Matrix (Pandas DataFrame):")
print(df)

Distance Matrix:
[[0 4 2 2 3 2 2 4 2 1 2 2 1 2]
 [4 0 4 4 5 4 4 4 4 5 4 4 5 4]
 [2 4 0 2 3 2 2 4 2 3 2 2 3 2]
 [2 4 2 0 1 2 2 4 2 3 2 2 3 2]
 [3 5 3 1 0 3 3 5 3 4 3 3 4 3]
 [2 4 2 2 3 0 2 4 2 3 2 2 3 2]
 [2 4 2 2 3 2 0 4 2 3 2 2 3 2]
 [4 4 4 4 5 4 4 0 4 5 4 4 5 4]
 [2 4 2 2 3 2 2 4 0 3 2 2 3 2]
 [1 5 3 3 4 3 3 5 3 0 3 3 2 3]
 [2 4 2 2 3 2 2 4 2 3 0 2 3 2]
 [2 4 2 2 3 2 2 4 2 3 2 0 3 2]
 [1 5 3 3 4 3 3 5 3 2 3 3 0 3]
 [2 4 2 2 3 2 2 4 2 3 2 2 3 0]]

Distance Matrix (Pandas DataFrame):
                    Atelectasis  Cardiomegaly  Consolidation  Edema  Effusion  \
Atelectasis                   0             4              2      2         3   
Cardiomegaly                  4             0              4      4         5   
Consolidation                 2             4              0      2         3   
Edema                         2             4              2      0         1   
Effusion                      3             5              3      1         0   
Emphysema                