In [15]:
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
import math

In [25]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, device):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model, device=device)

    def forward(self, x):
        x = self.embedding(x)
        return x
        


In [41]:
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len, device):
        super(PositionalEmbedding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.require_grad = False # 位置编码不需要梯度
        pos = torch.arange(0, max_len, dtype=torch.float, device=device)
        pos = pos.unsqueeze(1)
        _2i = torch.arange(0, d_model, 2, dtype=torch.float, device=device)
        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))

    def forward(self, x):
        batch_size, seq_len = x.size()
        return self.encoding[:seq_len, :]

In [30]:
class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, drop_prob, device):
        super(TransformerEmbedding, self).__init__()
        self.token_embedding = TokenEmbedding(vocab_size, d_model, device)
        self.positional_embedding = PositionalEmbedding(d_model, max_len, device)
        self.dropout = nn.Dropout(p = drop_prob) 

    def forward(self, x):
        token_embedding = self.token_embedding(x)
        positional_embedding = self.positional_embedding(x)
        return self.dropout(token_embedding + positional_embedding)

In [42]:
vocab_size = 10
d_model = 8
max_len = 16
drop_prob = 0.5
device = "cuda"

embedding = TransformerEmbedding(vocab_size, d_model, max_len, drop_prob, device)
x = torch.randint(0, vocab_size, (2, 12), device=device)  # batch_size=2, seq_len=12
print(x)
out = embedding(x)
print("输出shape:", out.shape)  # 期望: (2, 12, 8)

tensor([[0, 5, 0, 7, 5, 6, 3, 7, 4, 3, 0, 6],
        [9, 5, 4, 0, 6, 1, 9, 9, 6, 2, 1, 0]], device='cuda:0')
输出shape: torch.Size([2, 12, 8])
