In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.activation import MultiheadAttention

In [None]:
class TextClassificationModel_MultiScaleAttention(nn.Module):
    def __init__(self, vocab_size, feature_size, embed_dim, gru_hidden_dim, attn_heads, num_classes, feature_classes, embeddings=None):
        super(TextClassificationModel_MultiScaleAttention, self).__init__()
        self.embedding_dim = embed_dim
        self.name = 'TextClassificationModel_MultiScaleAttention'
        self.gru_hidden_dim = gru_hidden_dim
        self.attn_heads = attn_heads
        self.num_classes = num_classes
        self.feature_classes = feature_classes

        # Word-level embedding
        self.embedding_word = nn.Embedding(vocab_size, embed_dim)
        if embeddings is not None:
            self.embedding_word.weight = nn.Parameter(embeddings)
            self.embedding_word.weight.requires_grad = False

        # Character-level embedding
        self.embedding_char = nn.Embedding(feature_size, embed_dim)


        # Bidirectional GRU
        self.gru = nn.GRU(embed_dim, gru_hidden_dim, batch_first=True, bidirectional=True)

        # Attention mechanisms for different scales
        self.attention_word = MultiheadAttention(gru_hidden_dim * 2, attn_heads)
        self.attention_char = MultiheadAttention(gru_hidden_dim * 2, attn_heads)

        # Fully connected layers
        self.fc1 = nn.Linear(gru_hidden_dim * 2 * 2  , gru_hidden_dim)
        self.fc2_emotion = nn.Linear(gru_hidden_dim, num_classes)
        self.fc2_feature = nn.Linear(gru_hidden_dim, feature_classes)

        self.dropout = nn.Dropout(0.5)

    def forward(self, text, char_features):
        # Word-level embedding and processing
        embedded_word = self.embedding_word(text)
        embedded_word = self.dropout(embedded_word)
        gru_word_output, _ = self.gru(embedded_word)

        # Character-level embedding and processing
        embedded_char = self.embedding_char(char_features).unsqueeze(1)
        embedded_char = self.dropout(embedded_char)
        gru_char_output, _ = self.gru(embedded_char)

        # Attention mechanisms for different scales
        word_attention_output, _ = self.attention_word(gru_word_output, gru_word_output, gru_word_output)
        word_attention_output = word_attention_output.mean(dim=1)  # 替换原来的mean操作，改为sum操作

        char_attention_output, _ = self.attention_char(gru_char_output, gru_char_output, gru_char_output)
        char_attention_output = char_attention_output.mean(dim=1)  # 同样替换为sum操作

        # Concatenate outputs from different scales
        concatenated_output = torch.cat((word_attention_output, char_attention_output), dim=-1)  # 对于char_attention_output，增加一个维度以匹配word_attention_output的维度

        # Fully connected layers
        fc1_output = F.relu(self.fc1(concatenated_output))
        fc1_output = self.dropout(fc1_output)

        # Predict emotion and feature classes
        emotion_logits = self.fc2_emotion(fc1_output)
        feature_logits = self.fc2_feature(fc1_output)

        return emotion_logits, feature_logits