In [2]:
import torch
import torch.nn
import matplotlib.pyplot as plt

In [3]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')

In [4]:
device

device(type='cuda')

In [15]:
gen = torch.Generator(device=device)
x = torch.randn((1024, 1024), device=device, generator=gen)
x.shape, x.stride(), x.data_ptr(), x.dtype

(torch.Size([1024, 1024]), (1024, 1), 23449854541824, torch.float32)

In [16]:
hex(x.data_ptr())

'0x1553d8400000'

In [17]:
x[0][0].data_ptr()

23449854541824

In [18]:
x[0][1].data_ptr()

23449854541828

In [19]:
h, w = x.shape

In [20]:
x[0][w-1].data_ptr(), x[1][0].data_ptr()

(23449854545916, 23449854545920)

In [25]:
x1 = torch.randn((1024, 512, 256), device=device, generator=gen)
x1.shape, x1.stride(), x1.data_ptr()

(torch.Size([1024, 512, 256]), (131072, 256, 1), 23447837081600)

In [24]:
print(512*512)

262144


In [27]:
y = x1.transpose(-2, -1)
y.shape, y.stride(), y.data_ptr()

(torch.Size([1024, 256, 512]), (131072, 1, 256), 23447837081600)

In [28]:
x = torch.randn((1024, 1024), device=device)
y = x.transpose(0, 1)
x.shape, y.shape, x.stride(), y.stride(), x.data_ptr(), y.data_ptr()

(torch.Size([1024, 1024]),
 torch.Size([1024, 1024]),
 (1024, 1),
 (1, 1024),
 23449850347520,
 23449850347520)

In [29]:
y.is_contiguous()

False

In [30]:
x.is_contiguous()

True

In [None]:
def is_contiguous_fast(z):
    stride_values = z.stride()
    i = len(stride_values)-1
    shape_values = z.shape
    current = 1
    while i>=0:
        if stride_values[i] != current:
            return False
        current *= shape_values[i]
        i-=1
    return True
          

In [45]:
is_contiguous_fast(y)

False

In [57]:
x = torch.randn((1024, 512, 256))
x.shape, x.stride(), x.data_ptr()

(torch.Size([1024, 512, 256]), (131072, 256, 1), 23447300206656)

In [58]:
y = x.transpose(-2, -1)
y.shape, y.stride(), y.data_ptr()

(torch.Size([1024, 256, 512]), (131072, 1, 256), 23447300206656)

In [59]:
is_contiguous_fast(x)

True

In [51]:
print(256*512)

131072


In [60]:
is_contiguous_fast(y)

(131072, 1, 256) torch.Size([1024, 256, 512]) 2 1


False

In [61]:
is_contiguous_fast(y.view(-1))

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [63]:
y = y.view(-1)

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [64]:
y.view(-1)

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [65]:
y.shape

torch.Size([1024, 256, 512])

In [68]:
x.transpose(-1, -1).view(-1)

tensor([ 1.4496,  0.0308, -1.4092,  ...,  1.0776,  0.9064,  0.2648])

In [67]:
y.view(-1)

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.