In [1]:
import torch
import torch.nn as nn
import unittest

In [2]:
class LastElementExtractor(nn.Module): 
    def __init__(self): 
        super(LastElementExtractor, self).__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.cpu = torch.device('cpu')
    
    def forward(self, packed, lengths): 
        lengths = torch.tensor(lengths, device=self.device)
        sum_batch_sizes = torch.cat((
            torch.zeros(2, dtype=torch.int64, device=self.device),
            torch.cumsum(packed.batch_sizes, 0).to(self.device)
        ))
        sorted_lengths = lengths[packed.sorted_indices]
        last_seq_idxs = sum_batch_sizes[sorted_lengths] + torch.arange(lengths.size(0), device=self.device)
        last_seq_items = packed.data[last_seq_idxs]
        last_seq_items = last_seq_items[packed.unsorted_indices]
        return last_seq_items


In [3]:
extractor = LastElementExtractor()

In [4]:
tensor1 = torch.rand((15, 3)).cuda()
tensor2 = torch.rand((5, 3)).cuda()
tensor3 = torch.rand((7, 3)).cuda()

In [5]:
x_lens = [len(x) for x in [tensor1, tensor2, tensor3]]

In [7]:
padded = torch.nn.utils.rnn.pad_sequence([tensor1, tensor2, tensor3], batch_first=True, padding_value=0).cuda()
lengths = x_lens

In [9]:
# Example input data (replace this with your actual data)
# Assuming you have a batch with sequences of varying lengths
packed_data = torch.nn.utils.rnn.pack_padded_sequence(
    input=padded, 
    lengths=x_lens,
    batch_first=True, 
    enforce_sorted=False
)

In [10]:
packed_data.to(extractor.device)

PackedSequence(data=tensor([[0.3184, 0.2381, 0.4290],
        [0.1733, 0.7734, 0.2084],
        [0.8003, 0.6720, 0.9785],
        [0.1066, 0.5522, 0.3547],
        [0.8433, 0.2275, 0.9411],
        [0.4051, 0.9581, 0.3979],
        [0.7188, 0.9364, 0.6651],
        [0.4706, 0.4933, 0.1621],
        [0.2045, 0.0809, 0.8412],
        [0.4068, 0.3141, 0.4388],
        [0.9332, 0.2752, 0.3174],
        [0.1401, 0.6576, 0.2207],
        [0.9246, 0.0241, 0.4466],
        [0.2646, 0.1211, 0.1426],
        [0.9581, 0.1979, 0.9731],
        [0.6771, 0.6680, 0.4588],
        [0.8474, 0.4225, 0.5481],
        [0.0614, 0.9151, 0.1420],
        [0.6538, 0.7743, 0.3885],
        [0.9349, 0.2641, 0.5217],
        [0.4619, 0.1913, 0.5321],
        [0.4055, 0.4580, 0.7845],
        [0.5721, 0.8156, 0.7532],
        [0.1942, 0.5352, 0.5089],
        [0.2887, 0.1266, 0.5590],
        [0.1021, 0.9943, 0.8693],
        [0.5918, 0.6741, 0.4941]], device='cuda:0'), batch_sizes=tensor([3, 3, 3, 3, 3, 2, 2, 1,

In [11]:
last_elements = extractor(packed_data, lengths)

In [16]:
last_elements.shape

torch.Size([3, 3])

# Conclusion: this last element extractor is correct. 