In [31]:
import torch.nn as nn 
import torch

class PositionEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config 
        
        self.position_embedding = nn.Embedding(config.max_position_embeddeings, config.hidden_size)
    
    def forward(self, input_tensor, seq_lengths, shift_step):
        """
        input_tensor: with shape of [L1, D], L1 = sum(seq_lengths)
        seq_lengths: with shape of [B].
        shift_step: int.
        """
        start_indices = seq_lengths.cumsum(dim=0).roll(shifts=1, dims=0)
        start_indices[0] = 0
        position_ids = torch.repeat_interleave(start_indices, repeats=seq_lengths)
        position_ids = torch.arange(position_ids.size(0), device=position_ids.device) - position_ids
        position_ids = (position_ids + shift_step * torch.ones_like(position_ids, device=position_ids.device)).int()
        
        print(position_ids)
        
        output_tensor = input_tensor + self.position_embedding(position_ids)
        return output_tensor

class Config(object):
    pass

config = Config()
config.max_position_embeddeings = 5
config.hidden_size = 3
input_tensor = torch.randn(9,3)
seq_lenghts = torch.FloatTensor([3,2,3,1]).long()
shift_step = 0

position_embedding = PositionEmbedding(config)

In [33]:
import numpy as np
np.set_printoptions(suppress=True, precision=3)
print("position embedding:\n {}".format(position_embedding.position_embedding.weight.detach().numpy()))

print("input tensor:\n {}".format(input_tensor.numpy()))

position_embedding.to(torch.device("cuda"))
position_embedding(input_tensor.cuda(), seq_lenghts.cuda(), shift_step=1)

position embedding:
 [[-0.363 -0.274 -0.172]
 [ 1.708 -1.276 -1.036]
 [-0.919  0.369  0.947]
 [ 0.801  0.909 -0.798]
 [-0.289 -0.736 -1.239]]
input tensor:
 [[ 0.548 -1.63  -0.931]
 [ 0.214 -0.733  0.908]
 [ 0.845  0.189  0.418]
 [-0.476  1.679  0.034]
 [-0.05  -0.292  0.846]
 [-1.908 -0.647 -0.081]
 [-0.71   0.757 -0.616]
 [ 0.812  1.437 -1.466]
 [ 0.533  1.36   1.713]]
tensor([1, 2, 3, 1, 2, 1, 2, 3, 1], device='cuda:0', dtype=torch.int32)


tensor([[ 2.2554, -2.9069, -1.9668],
        [-0.7050, -0.3639,  1.8553],
        [ 1.6466,  1.0979, -0.3799],
        [ 1.2315,  0.4024, -1.0020],
        [-0.9688,  0.0770,  1.7927],
        [-0.2000, -1.9235, -1.1167],
        [-1.6287,  1.1258,  0.3314],
        [ 1.6130,  2.3463, -2.2642],
        [ 2.2410,  0.0836,  0.6770]], device='cuda:0', grad_fn=<AddBackward0>)