In [1]:
import torch

#计算softmax举例
x = torch.tensor([[100, 200]], dtype=torch.float32)
x_softmax = torch.softmax(x, dim=1)
print(x_softmax)

tensor([[3.7835e-44, 1.0000e+00]])


In [2]:
x.shape

torch.Size([1, 2])

In [3]:
x = torch.tensor([[1,2 ]], dtype=torch.float32)
x_softmax = torch.softmax(x, dim=1)
print(x_softmax)

tensor([[0.2689, 0.7311]])


In [7]:
import torch


def custom_embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False,
                     sparse=False):
    # max_norm限制了最大范数，norm_type指定范数计算方式，scale_grad_by_freq是否按词频缩放梯度
    """
    A simplified version of torch.nn.functional.embedding.
    """
    # 输入检查
    if not isinstance(input, torch.LongTensor):
        raise TypeError("Input must be a LongTensor")

    if padding_idx is not None:
        if padding_idx >= weight.size(0) or padding_idx < 0:
            raise ValueError("padding_idx must be within the range of weight size")

    # 获取嵌入矩阵的形状
    num_embeddings, embedding_dim = weight.size()

    # 初始化输出张量
    output = torch.zeros(input.size() + (embedding_dim,), dtype=weight.dtype, device=weight.device)   

    # 遍历输入张量的每个索引
    for idx in input.view(-1).tolist(): # 展开输入张量，并转换为列表
        if idx != padding_idx:
            output.view(-1, embedding_dim)[idx] = weight[idx]   # 填充输出张量

    # 如果指定了 max_norm，对嵌入向量进行归一化
    if max_norm is not None:
        with torch.no_grad():
            norm = weight.norm(p=norm_type, dim=1, keepdim=True)     # 计算weight的范数
            torch.clamp(norm, max=max_norm, out=norm)     # 限制范数的最大值
            weight.div_(norm)   # 对weight进行归一化

    return output   # output的形状为 (batch_size * seq_len * embedding_dim)


# 示例使用
input_indices = torch.tensor([[1, 2, 4], [4, 3, 2]], dtype=torch.long)  # shape: (batch_size, seq_len)
embedding_matrix = torch.randn(5, 10)  # 假设有 5 个词，每个词的嵌入维度为 10  shape: (vocab_size, embedding_dim)

output = custom_embedding(input_indices, embedding_matrix)
print(output)   # shape: (batch_size, seq_len, embedding_dim)
output.shape

tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000],
         [-0.1098, -0.2201, -1.2564, -0.9372, -1.8769,  1.7540, -0.6517,
           0.3467,  0.8380,  0.4739],
         [ 0.9540,  0.2696, -1.5430, -0.4435,  0.5579, -0.1922,  0.2865,
           1.8943,  0.0369,  0.8779]],

        [[-0.8891,  1.1431,  1.5063, -0.3315,  0.2334,  1.1230,  0.4451,
          -0.6990,  0.0068,  0.8159],
         [ 0.1075,  0.8889, -1.5297,  1.5845, -0.3677,  1.6821, -1.0814,
           0.3575, -0.3537, -0.2370],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]]])


torch.Size([2, 3, 10])