# PyTorch Tensor

In [1]:
import torch

## Tensor Allocation

In [2]:
ft = torch.FloatTensor([[1, 2],
                        [3, 4]])
ft

tensor([[1., 2.],
        [3., 4.]])

In [3]:
lt = torch.LongTensor([[1, 2],
                       [3, 4]])
lt
# LongTensor : 정수 (int보다 큼)
# >> index 담을 때 자주 쓰임 

tensor([[1, 2],
        [3, 4]])

In [4]:
bt = torch.ByteTensor([[1, 0],
                       [0, 1]])
bt
# ByteTensor >> boolean(T/F)

tensor([[1, 0],
        [0, 1]], dtype=torch.uint8)

In [5]:
x = torch.FloatTensor(3, 2)
x

tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])

## NumPy Compatibility

In [6]:
import numpy as np

# Define numpy array.
x = np.array([[1, 2],
              [3, 4]])
print(x, type(x))

[[1 2]
 [3 4]] <class 'numpy.ndarray'>


In [7]:
x = torch.from_numpy(x)
print(x, type(x))

tensor([[1, 2],
        [3, 4]], dtype=torch.int32) <class 'torch.Tensor'>


In [8]:
x = x.numpy()
print(x, type(x))

[[1 2]
 [3 4]] <class 'numpy.ndarray'>


## Tensor Type-casting

In [9]:
ft

tensor([[1., 2.],
        [3., 4.]])

In [10]:
ft.long()

tensor([[1, 2],
        [3, 4]])

In [11]:
lt.float()

tensor([[1., 2.],
        [3., 4.]])

In [12]:
torch.FloatTensor([1, 0]).byte()

tensor([1, 0], dtype=torch.uint8)

## Get Shape

In [13]:
x = torch.FloatTensor([[[1, 2],
                        [3, 4]],
                       [[5, 6],
                        [7, 8]],
                       [[9, 10],
                        [11, 12]]])

In [14]:
print(x)

tensor([[[ 1.,  2.],
         [ 3.,  4.]],

        [[ 5.,  6.],
         [ 7.,  8.]],

        [[ 9., 10.],
         [11., 12.]]])


Get tensor shape.

In [15]:
print(x.size())
print(x.shape)

torch.Size([3, 2, 2])
torch.Size([3, 2, 2])


Get number of dimensions in the tensor.

In [16]:
print(x.dim())
print(len(x.size()))

3
3


Get number of elements in certain dimension of the tensor.

In [17]:
print(x.size(0)) # 1번째 차원의 값 
print(x.shape[0])

3
3


In [15]:
print(x.size(1)) # 2번째 차원의 값 
print(x.shape[1])

2
2


Get number of elements in the last dimension.

In [18]:
print(x.size(-1))
print(x.shape[-1])

2
2
