In [1]:
import torch
import numpy as np

In [6]:
from typing import List, Union

def print_T(*xs):
    if not isinstance(xs, list) and not isinstance(xs, tuple):
        xs = [xs,]
    for x in xs:
        if isinstance(x, torch.Tensor):
            print(f'{x}\n  shape {x.shape} offset {x.storage_offset()} stride: {x.stride()}, data_ptr={x.untyped_storage().data_ptr()}')
        else: # numpy
            offset = (x.__array_interface__['data'][0] - x.base.__array_interface__['data'][0]) // x.itemsize if x.base is not None else 0
            print(f'{x}\n  shape {x.shape} offset {offset} stride: {(np.array(x.strides) // 8).tolist()}, dtype: {x.dtype}')

In [None]:
x = torch.ones(3, 4, 5, 6)
# x_str = torch.tensor(x.shape).flip(dims=[0]).cumprod(0)
ndim = len(x.shape)
x_stride = [1] * ndim
for i in reversed(range(0, ndim)):
    x_stride[i] = x_stride[i+1] * x.shape[i]
print(len(x.shape), x_stride, x.shape, x.stride())
# print(x.stride(), x.shape, x_str)


3 1
2 6
1 30
0 120
4 [1, 1, 1, 1] torch.Size([3, 4, 5, 6]) (120, 30, 6, 1)


In [None]:
# postfix multiply

In [55]:
x = np.arange(12).reshape(3, 4)
# print_T(x, np.flip(x, axis=[0]), np.flip(x, axis=[1]))
print_T(x, np.flip(x, axis=[0, 1]))

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]
  shape (3, 4) offset 0 stride: [4, 1], dtype: int64
[[11 10  9  8]
 [ 7  6  5  4]
 [ 3  2  1  0]]
  shape (3, 4) offset 11 stride: [-4, -1], dtype: int64


In [None]:
print_T(x.flip(dims=[1]))

AttributeError: 'numpy.ndarray' object has no attribute 'flip'

In [None]:
x = torch.arange(18).reshape(2, 9, 1)

In [80]:
# slice operator
x = torch.arange(12).reshape(3, -1)
print_T(x, x[1:2, 1:4])
# yy = torch.arange(3)
# print(yy, yy.shape, yy.stride())
# print(tt.data, tt.shape, tt.stride())
# print(x.unsqueeze(0).stride())

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
  shape torch.Size([3, 4]) offset 0 stride: (4, 1), data_ptr=123878912
tensor([[5, 6, 7]])
  shape torch.Size([1, 3]) offset 5 stride: (4, 1), data_ptr=123878912


In [37]:
x, x.T, x.T.stride()
x.storage().data_ptr() == x.T.storage().data_ptr()

True

In [71]:
ex = torch.arange(16).reshape(4, -1)
# idx = torch.Tensor([0, 1, 0, 1, 0, 1]).int()
print_T(ex, ex[0:2])

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
  shape torch.Size([4, 4]) offset 0 stride: (4, 1), data_ptr=123689856
tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])
  shape torch.Size([2, 4]) offset 0 stride: (4, 1), data_ptr=123689856


In [79]:
so = torch.vstack((torch.hstack((ex[:2, :2], ex[2:, :2])), torch.hstack((ex[:2, 2:], ex[2:, 2:]))))
print_T(ex, so)

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
  shape torch.Size([4, 4]) offset 0 stride: (4, 1), data_ptr=123689856
tensor([[ 0,  1,  8,  9],
        [ 4,  5, 12, 13],
        [ 2,  3, 10, 11],
        [ 6,  7, 14, 15]])
  shape torch.Size([4, 4]) offset 0 stride: (4, 1), data_ptr=33555072


In [81]:
a = torch.arange(24).reshape(2, 3, 4)
b = a.permute(2, 0, 1)
print_T(a, b)

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
  shape torch.Size([2, 3, 4]) offset 0 stride: (12, 4, 1), data_ptr=33595840
tensor([[[ 0,  4,  8],
         [12, 16, 20]],

        [[ 1,  5,  9],
         [13, 17, 21]],

        [[ 2,  6, 10],
         [14, 18, 22]],

        [[ 3,  7, 11],
         [15, 19, 23]]])
  shape torch.Size([4, 2, 3]) offset 0 stride: (1, 12, 4), data_ptr=33595840


In [105]:
t = torch.arange(24).reshape(2, 3, 4)
x = torch.arange(4)
print_T(x, x.expand_as(t))

tensor([0, 1, 2, 3])
  shape torch.Size([4]) offset 0 stride: (1,), data_ptr=123953600
tensor([[[0, 1, 2, 3],
         [0, 1, 2, 3],
         [0, 1, 2, 3]],

        [[0, 1, 2, 3],
         [0, 1, 2, 3],
         [0, 1, 2, 3]]])
  shape torch.Size([2, 3, 4]) offset 0 stride: (0, 0, 1), data_ptr=123953600


In [None]:
def my_permutation(x: torch.Tensor, perm: torch.Tensor):
    newshape = (torch.Tensor(tuple(x.shape)).int())[perm]
    newstride = (torch.Tensor(x.stride()).int())[perm]
    tg1 = x.as_strided(tuple(newshape), tuple(newstride), x.storage_offset())
    tg2 = x.permute(tuple(perm))
    print_T(tg1, tg2)
# my_permutation(t, torch.Tensor([2, 0, 1]).int())

RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 3 is not equal to len(dims) = 2

In [99]:
print(304//8, 192//8)

38 24
