In [28]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

In [29]:
class Mulitiheadattention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(Mulitiheadattention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads
        self.qkv = nn.Linear(d_model, d_model * 3)
        self.out = nn.Linear(d_model, d_model)
    
    def scaled_dot_product_attention(self, q, k, v, mask=None):
        d_k = q.size(-1)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        return torch.matmul(attn, v), attn

    def split_heads(self, x):
        # x: (batch_size, seq_length, num_heads, head_dim)
        return x.transpose(1, 2)

    def combine_heads(self, x):
        batch_size, num_heads, seq_length, head_dim = x.size()
        x = x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        return x

    def forward(self, x, mask=None, kv=None):
        # x: query, kv: key/value（可选）
        batch_size, seq_length, d_model = x.size()
        if kv is None:
            # 自注意力
            qkv = self.qkv(x).view(batch_size, seq_length, 3, self.num_heads, self.head_dim)
            q, k, v = qkv.unbind(dim=2)
        else:
            # 交叉注意力
            q = self.qkv(x).view(batch_size, seq_length, 3, self.num_heads, self.head_dim)[:, :, 0]
            k = self.qkv(kv).view(batch_size, kv.size(1), 3, self.num_heads, self.head_dim)[:, :, 1]
            v = self.qkv(kv).view(batch_size, kv.size(1), 3, self.num_heads, self.head_dim)[:, :, 2]
            # 恢复维度 (batch_size, seq_length, num_heads, head_dim)
            q = q.view(batch_size, seq_length, self.num_heads, self.head_dim)
            k = k.view(batch_size, kv.size(1), self.num_heads, self.head_dim)
            v = v.view(batch_size, kv.size(1), self.num_heads, self.head_dim)
        # 转换为 (batch_size, num_heads, seq_length, head_dim)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        attn_output, attn_weights = self.scaled_dot_product_attention(q, k, v, mask)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        return self.out(attn_output)

In [30]:
class positionwisefeedforward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(positionwisefeedforward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
    
    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x
class positionencoder(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(positionencoder, self).__init__()
        self.d_model = d_model
        self.pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        self.pe[:, 0::2] = torch.sin(position * div_term)
        self.pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = self.pe.unsqueeze(0)
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

In [31]:
class encoderlayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(encoderlayer, self).__init__()
        self.self_attn = Mulitiheadattention(d_model, num_heads)
        self.ffn = positionwisefeedforward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        attn_output = self.self_attn(x, mask)
        x = self.norm1(x + self.dropout1(attn_output))
        ffn_output = self.ffn(x)
        return self.norm2(x + self.dropout2(ffn_output))

In [32]:
class decoderlayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(decoderlayer, self).__init__()
        self.self_attn = Mulitiheadattention(d_model, num_heads)
        self.cross_attn = Mulitiheadattention(d_model, num_heads)
        self.ffn = positionwisefeedforward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
    
    def forward(self, x, memory, src_mask=None, tgt_mask=None):
        attn_output = self.self_attn(x, tgt_mask)
        x = self.norm1(x + self.dropout1(attn_output))
        cross_attn_output = self.cross_attn(x, src_mask, kv=memory)
        x = self.norm2(x + self.dropout2(cross_attn_output))
        ffn_output = self.ffn(x)
        return self.norm3(x + self.dropout3(ffn_output))

In [None]:
class NodeTransformer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers, dropout=0.1):
        super(NodeTransformer, self).__init__()
        self.d_model = d_model
        self.position_encoder = positionencoder(d_model)
        self.encoder_layers = nn.ModuleList(
            [encoderlayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )
        self.decoder_layers = nn.ModuleList(
            [decoderlayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )
        self.final_layer = nn.Linear(d_model, d_model)
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        src = self.position_encoder(src)
        tgt = self.position_encoder(tgt)
        
        for layer in self.encoder_layers:
            src = layer(src, src_mask)
        
        for layer in self.decoder_layers:
            tgt = layer(tgt, src, src_mask, tgt_mask)
        
        return self.final_layer(tgt)
    
transformer = NodeTransformer(d_model=512, num_heads=8, d_ff=2048, num_layers=6)
loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(transformer.parameters(), lr=0.001)
src_data = torch.randint(0, 100, (32, 10, 512))  # Example source data
tgt_data = torch.randint(0, 100, (32, 10, 512))  # Example target data

for epoch in range(1000):
    transformer.train()
    optimizer.zero_grad()
    
    output = transformer(src_data, tgt_data)
    target = torch.randint(0, 100, (32, 10))  # Example target labels
    
    loss_value = loss(output.view(-1, 512), target.view(-1))
    loss_value.backward()
    
    optimizer.step()
    if (epoch + 1) % 100 == 0:
        print(f'Epoch {epoch + 1}, Loss: {loss_value.item()}')

Epoch 100, Loss: 4.707938194274902
Epoch 200, Loss: 4.619059085845947
Epoch 200, Loss: 4.619059085845947
Epoch 300, Loss: 4.67251443862915
Epoch 300, Loss: 4.67251443862915
Epoch 400, Loss: 4.639707088470459
Epoch 400, Loss: 4.639707088470459
Epoch 500, Loss: 4.632154941558838
Epoch 500, Loss: 4.632154941558838
Epoch 600, Loss: 4.6318254470825195
Epoch 600, Loss: 4.6318254470825195
Epoch 700, Loss: 4.63408899307251
Epoch 700, Loss: 4.63408899307251
Epoch 800, Loss: 4.641055107116699
Epoch 800, Loss: 4.641055107116699
