In [8]:
import numpy as np
import random
import math
import dgl
from dgl import function as fn
from dgl._ffi.base import DGLError
from dgl.nn.pytorch import edge_softmax
from dgl.nn.pytorch.utils import Identity
from dgl.utils import expand_as_pair
from dgl.nn.pytorch import GraphConv
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

#获取节点序列

class HOConv(nn.Module):
    def __init__(self, dim_in, dim_out, norm='both', num_type=6, weight=True, bias=True):
        super(HOConv, self).__init__()
        if norm not in ('none', 'both', 'right', 'left'):
            raise DGLError('Invalid norm value. Must be either "none", "both", "right" or "left".'
                           ' But got "{}".'.format(norm))
        
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.norm = norm
        
        #权重、偏置初始化
        if weight:
            self.weight = nn.Parameter(torch.Tensor(dim_in, dim_out))
        else:
            self.register_parameter('weight', None)
        if bias:
            self.bias = nn.Parameter(torch.Tensor(dim_out))
        else:
            self.register_parameter('bias', None)
        self.weight_type = nn.Parameter(torch.ones(num_type))
        self.reset_parameters()

        self.activation = nn.ReLU()

    #重置参数（xavier初始化权重）
    def reset_parameters(self):
        if self.weight is not None:
            init.xavier_uniform_(self.weight)
        if self.bias is not None:
            init.zeros_(self.bias)

    def forward(self, G: dgl.DGLGraph):
        feat = G.srcdata['feat']        
        if self.norm in ['left', 'both']:
            degs = G.out_degrees().float().clamp(min=1)     #计算节点出度
            if self.norm == 'both':
                norm = torch.pow(degs, -0.5)
            else:
                norm = 1.0 / degs
            shp = norm.shape + (1,) * (feat.dim() - 1)
            norm = torch.reshape(norm, shp)
            feat = feat * norm

        G.srcdata['feat'] = feat 
        G.update_all(fn.copy_src('feat', 'm'), fn.sum(msg='m', out='feat'))
        rst = G.dstdata['feat']
        rst = torch.matmul(rst, self.weight)

        if self.norm in ['right', 'both']:
            degs = G.in_degrees().float().clamp(min=1)      #计算节点入度
            if self.norm == 'both':
                norm = torch.pow(degs, -0.5)
            else:
                norm = 1.0 / degs
            shp = norm.shape + (1,) * (feat.dim() - 1)
            norm = torch.reshape(norm, shp)
            rst = rst * norm

        if self.bias is not None:
            rst = rst + self.bias
        if self.activation is not None:
            rst = self.activation(rst)
        
        G.dstdata['feat'] = rst
        return G

#transformer部分！！！(先不动)
class AGTLayer(nn.Module):
    def __init__(self, dim_in, nheads=2, att_dropout=0.1, emb_dropout=0.1, temper=1.0):
        super(AGTLayer, self).__init__()

        self.nheads = nheads
        self.dim_in = dim_in    #输入维度：dim_in
        self.head_dim = self.dim_in // self.nheads

        #激活函数
        self.leaky = nn.LeakyReLU(0.01)

        self.temper = temper
        self.linear_q = nn.Linear(
            self.dim_in, self.head_dim * self.nheads, bias=False)
        self.linear_k = nn.Linear(
            self.dim_in, self.head_dim * self.nheads, bias=False)
        self.linear_v = nn.Linear(
            self.dim_in, self.head_dim * self.nheads, bias=False)

        self.linear_final = nn.Linear(self.head_dim * self.nheads, self.dim_in, bias=False)
        
        self.dropout1 = nn.Dropout(att_dropout) #注意力dropout
        self.dropout2 = nn.Dropout(emb_dropout) #嵌入dropout
        self.LN = nn.LayerNorm(dim_in)

    def forward(self, h):
        ''' transpose：交换两个维度，适用于二维或更高维的张量。
            permute：用于重排多个维度顺序，更灵活适应复杂的维度变换。
        '''
        batch_size = h.size()[0]
        #张量后两位是头数和头维度
        k = self.linear_k(h).reshape(batch_size, -1, self.nheads, self.head_dim).transpose(1,2)
        q = self.linear_q(h).reshape(batch_size, -1, self.nheads, self.head_dim).permute(0, 2, 3, 1)
        v = self.linear_v(h).reshape(batch_size, -1, self.nheads, self.head_dim).transpose(1, 2)

        score = k @ q
        score = score / self.head_dim

        score = score / self.temper
        score = F.softmax(score, dim=-1)    #行和为1
        score = self.dropout1(score)
        print(v.shape)
        context = score @ v

        h_sa = context.transpose(1,2).reshape(batch_size, -1, self.head_dim * self.nheads)
        fh = self.linear_final(h_sa)
        fh = self.dropout2(fh)

        h = self.LN(h + fh)

        return h, score

class HoGT(nn.Module):
    def __init__(self, dim_in=768, dim_hidden=512, num_layers=1, num_gnns=3, nheads=2, dropout=0.1,  temper=1.0, num_type=6):
        super(HoGT, self).__init__()
        self.dim_hidden = dim_hidden
        self.num_layers = num_layers
        self.num_gnns = num_gnns      
        self.nheads = nheads
        self.dropout = dropout
        self.num_type = num_type

        self.fc = nn.Linear(dim_in, dim_hidden)

        self.GCNLayers = torch.nn.ModuleList()
        self.GTLayers = torch.nn.ModuleList()
        for layer in range(self.num_gnns):
            self.GCNLayers.append(HOConv(self.dim_hidden, self.dim_hidden))
        for layer in range(self.num_layers):
            self.GTLayers.append(AGTLayer(self.dim_hidden, self.nheads, self.dropout, self.dropout, temper=temper))
        self.Drop = nn.Dropout(dropout)     
        
        self.fc1 = nn.Sequential(nn.Linear(dim_hidden, dim_hidden), nn.LeakyReLU())
    
    def forward(self, G:dgl.DGLGraph):
        device = G.device
        if G.ndata['feat'].shape[-1] != 512:
            G.ndata['feat'] = self.fc(G.ndata['feat'])
       
        for layer in range(self.num_gnns):
            G = self.GCNLayers[layer](G)

        h = G.ndata['feat'].unsqueeze(0)
        h = self.Drop(h)

        for layer in range(self.num_layers):
            h , global_attn= self.GTLayers[layer](h)
            
        # print("wrong", h.shape)
        G.ndata['feat'] = h.squeeze(0)
        global_attn = global_attn.squeeze(0)
        global_attn = global_attn.mean(dim = 0)

        return G, global_attn
    

In [9]:
G_homo = torch.load('/data115_3/TG/Graphdata/graph/STAD/homogeneous/TCGA-3M-AB46-01Z-00-DX1.70F638A0-BDCB-4BDE-BBFE-6D78A1A08C5B.pt')['g']
model = HoGT()
G, global_attn = model(G_homo)
print(global_attn.shape)

torch.Size([1, 2, 741, 256])
torch.Size([741, 741])


In [1]:
import pickle

# 打开 .pkl 文件（读取模式为二进制读取 'rb'）
with open('/home/ouyangqi/MICCAI_2025/result_baseline/5foldcv/MSGT_nll_surv_a0.0_5foldcv_gc32/tcga_kirc_MSGT_nll_surv_a0.0_5foldcv_gc32_s1/split_latest_val_0_results.pkl', 'rb') as file:
    data = pickle.load(file)

# 查看读取的数据
print(data)


{'TCGA-A3-3319': {'slide_id': array('TCGA-A3-3319', dtype='<U12'), 'risk': -3.8478031158447266, 'disc_label': 3.0, 'survival': 81.14, 'censorship': 1}, 'TCGA-A3-3323': {'slide_id': array('TCGA-A3-3323', dtype='<U12'), 'risk': -1.1038873195648193, 'disc_label': 3.0, 'survival': 54.43, 'censorship': 0}, 'TCGA-A3-3325': {'slide_id': array('TCGA-A3-3325', dtype='<U12'), 'risk': -3.941983699798584, 'disc_label': 3.0, 'survival': 61.93, 'censorship': 1}, 'TCGA-A3-3329': {'slide_id': array('TCGA-A3-3329', dtype='<U12'), 'risk': -2.0405383110046387, 'disc_label': 2.0, 'survival': 31.27, 'censorship': 0}, 'TCGA-A3-3349': {'slide_id': array('TCGA-A3-3349', dtype='<U12'), 'risk': -1.8341548442840576, 'disc_label': 0.0, 'survival': 5.32, 'censorship': 0}, 'TCGA-A3-3362': {'slide_id': array('TCGA-A3-3362', dtype='<U12'), 'risk': -3.5370707511901855, 'disc_label': 2.0, 'survival': 46.12, 'censorship': 0}, 'TCGA-A3-3372': {'slide_id': array('TCGA-A3-3372', dtype='<U12'), 'risk': -3.9562864303588867, 

In [3]:
import torch
import torch.nn as nn

# 定义一个 2x2 的平均池化层
class AvgPoolingModel(nn.Module):
    def __init__(self):
        super(AvgPoolingModel, self).__init__()
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)  # 2x2 池化

    def forward(self, x):
        return self.avg_pool(x)

N = 8  # N 必须是 2 的倍数
input_tensor = torch.randn(1, 1, N, N)  # 增加 batch 和 channel 维度

# 实例化模型
model = AvgPoolingModel()

# 前向传播
output_tensor = model(input_tensor)

print("输入张量的形状:", input_tensor.shape)  # torch.Size([1, 1, 8, 8])
print("输出张量的形状:", output_tensor.shape)  # torch.Size([1, 1, 4, 4])

# 去掉 batch 和 channel 维度，只保留 (N/2, N/2)
output_tensor = output_tensor.squeeze(0).squeeze(0)
print("最终输出形状:", output_tensor.shape)  # torch.Size([4, 4])


输入张量的形状: torch.Size([1, 1, 8, 8])
输出张量的形状: torch.Size([1, 1, 4, 4])
最终输出形状: torch.Size([4, 4])
