In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import dense_to_sparse

class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
    
    def forward(self, x, adj):
        # x: [batch_size, seq_len, in_features]
        # adj: [batch_size, seq_len, seq_len]
        x = torch.matmul(adj, x)  # Graph propagation
        x = self.linear(x)
        return F.relu(x)

class TextGCN(nn.Module):
    def __init__(self, vocab_size=10000, embed_dim=128, gcn_dim=64, num_classes=14):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # GCN参数
        self.gcn1 = GCNLayer(embed_dim, gcn_dim)
        self.gcn2 = GCNLayer(gcn_dim, gcn_dim)
        
        # 分类器
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim + gcn_dim, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )
    
    def build_cooccurrence_matrix(self, batch):
        """构建批量共现矩阵"""
        batch_adj = []
        for seq in batch:
            seq_len = seq.size(0)
            window_size = 3
            adj = torch.zeros((seq_len, seq_len))
            
            # 滑动窗口统计共现
            for i in range(seq_len):
                start = max(0, i - window_size)
                end = min(seq_len, i + window_size + 1)
                adj[i, start:end] = 1.0
                adj[start:end, i] = 1.0
            
            batch_adj.append(adj)
        return torch.stack(batch_adj).to(seq.device)

    def forward(self, x):
        # x: [batch_size, seq_len]
        embeddings = self.embedding(x)  # [batch, seq_len, embed_dim]
        
        # 构建共现图邻接矩阵
        adj = self.build_cooccurrence_matrix(x)  # [batch, seq_len, seq_len]
        
        # GCN处理
        gcn_out = self.gcn1(embeddings, adj)
        gcn_out = self.gcn2(gcn_out, adj)  # [batch, seq_len, gcn_dim]
        
        # 池化聚合
        gcn_pooled = torch.mean(gcn_out, dim=1)  # [batch, gcn_dim]
        text_pooled = torch.mean(embeddings, dim=1)  # [batch, embed_dim]
        
        # 特征融合
        combined = torch.cat([text_pooled, gcn_pooled], dim=1)
        return self.classifier(combined)

# 使用示例
if __name__ == "__main__":
    # 假设输入数据: batch_size=2, seq_len=10, vocab_size=10000
    inputs = torch.randint(0, 10000, (2, 10))  
    
    model = TextGCN(vocab_size=10000)
    outputs = model(inputs)
    print(outputs.shape)  # torch.Size([2, 14])

torch.Size([2, 14])


In [2]:
inputs

tensor([[8473,  661, 4243, 2415, 7734, 8163, 2561, 6944, 1088, 5565],
        [3613, 6045,   28, 8119, 6925,  560, 6913, 4765, 9400, 2793]])