In [7]:
import torch 
from copy import deepcopy

class Vector:
    def __init__(self, 
        data_shape: list[int], 
        dtype: torch.dtype, 
        device: torch.device | str | int, 
        dim: int, 
        cap: int | None = None 
    ):
        self.data_shape = data_shape # mutable
        self.dtype = dtype
        self.device = device 

        # push and pop dim
        if 0 <= dim <= len(self.data_shape) - 1:
            self.dim = dim
        elif -len(self.data_shape) <= dim <= -1:
            self.dim = len(self.data_shape) + dim
        else:
            raise ValueError('dim error')

        # capacity of storage
        if cap is None:
            self.cap = data_shape[dim] * 3 // 2
        else:
            self.cap = cap  
        
        # init storage
        self.storage_shape = [s if d != self.dim else self.cap for d, s in enumerate(self.data_shape)] # mutable
        self.storage = torch.zeros(self.storage_shape, dtype=self.dtype, device=self.device)
        
    @property 
    def rear(self):
        return self.data_shape[self.dim]  
    
    def empty(self):
        return self.rear == 0
    
    @classmethod
    def from_tensor(cls, tensor: torch.Tensor, dim: int):
        vec = cls(list(tensor.shape), tensor.dtype, tensor.device, dim)
        indices = [slice(0, s) for s in tensor.shape]
        vec.storage[*indices].copy_(tensor[*indices]) 
        return vec

    @property
    def shape(self):
        return tuple(self.data_shape)
    
    def size(self):
        return tuple(self.data_shape)

    def check_copyable(self, x: torch.Tensor):
        assert len(x.shape) == len(self.storage_shape), "dimension number mismatch"

        for d, (x, s) in enumerate(zip(x.shape, self.storage_shape)):
            if d != self.dim and x != s:
                return False 
        return True

    def increase_storage(self, push_len):
        # change storage_shape, reallocate & copy storage
        self.cap = (self.rear + push_len) * 3 // 2
        self.storage_shape[self.dim] = self.cap
        tmp = torch.zeros(self.storage_shape, dtype=self.dtype, device=self.device) 
        data_indices = [slice(0, s) for s in self.data_shape]
        tmp[*data_indices].copy_(self.storage[*data_indices])
        self.storage = tmp

    def shrink_storage(self):
        # change storage_shape, reallocate & copy storage(data)
        self.cap = self.rear * 3 // 2
        self.storage_shape[self.dim] = self.cap
        tmp = torch.zeros(self.storage_shape, dtype=self.dtype, device=self.device) 
        data_indices = [slice(0, s) for s in self.data_shape]
        tmp[*data_indices].copy_(self.storage[*data_indices])
        self.storage = tmp

    def push_back(self, x: torch.Tensor):
        # change data_shape (& self.rear simultaneously), copy x to storage, optional change storage_shape
        assert self.check_copyable(x)

        push_len = x.shape[self.dim]
        if self.rear + push_len > self.cap:
            self.increase_storage(push_len)
            
        push_slice = [None for _ in range(len(self.storage_shape))]
        push_slice[self.dim] = slice(self.rear, self.rear + push_len)
        self.storage[*push_slice].copy_(x)

        self.data_shape[self.dim] += push_len 

    def pop_back(self, pop_len: int = 1):
        # rear -= 1 (by modifying self.data_shape)
        pop_slice = [None for _ in range(len(self.storage_shape))]
        assert self.rear - pop_len >= 0
        pop_slice[self.dim] = slice(self.rear - pop_len, self.rear)
        ret = deepcopy(self.storage[*pop_slice])
        self.data_shape[self.dim] -= pop_len  

        if self.rear < self.cap // 2:
            self.shrink_storage() 

        return ret

    def __repr__(self) -> str:
        return f"{self.storage.tolist()}"

dev = 0

t = torch.tensor([1,2,3]).to(dev)
v = Vector.from_tensor(t, -1)
for i in range(4, 30):
    v.push_back(torch.tensor([i]).to(dev))
    print(len(v.storage), v.shape, v)


while v.rear > 20:
    p = v.pop_back()
    print(len(v.storage), v.shape, v, p)

for i in range(1, 21, 2):
    v.push_back(torch.tensor([i, i + 1]).to(dev))
    print(len(v.storage), v.shape, v)

pop_len = 5
while not v.empty() and v.rear >= pop_len:
    p = v.pop_back(pop_len)
    print(len(v.storage), v.shape, v, p)


for i in range(1, 11):
    v.push_back(torch.tensor([i]).to(dev))
    print(len(v.storage), v.shape, v)



4 (4,) [1, 2, 3, 4]
7 (5,) [1, 2, 3, 4, 5, 0, 0]
7 (6,) [1, 2, 3, 4, 5, 6, 0]
7 (7,) [1, 2, 3, 4, 5, 6, 7]
12 (8,) [1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0]
12 (9,) [1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0]
12 (10,) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0]
12 (11,) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0]
12 (12,) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
19 (13,) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0, 0, 0, 0, 0, 0]
19 (14,) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 0, 0, 0, 0]
19 (15,) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0]
19 (16,) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 0, 0, 0]
19 (17,) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 0, 0]
19 (18,) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 0]
19 (19,) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
30 (20,) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
30 (21,) [1, 2, 3, 4, 5, 6, 7,