## 01-02 Tensor Operation

In [47]:
import torch

torch.manual_seed(0)
torch.__version__

'1.3.0'

使用`torch.cat`进行张量拼接

In [48]:
temp_tensor = torch.ones((2, 3))

test_tensor = torch.cat([temp_tensor, temp_tensor], dim=0)
print('test_tensor数据：\n', test_tensor)
print()

test_tensor = torch.cat([temp_tensor, temp_tensor], dim=1)
print('test_tensor数据：\n', test_tensor)

test_tensor数据：
 tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

test_tensor数据：
 tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])


使用`torch.stack`进行张量拼接（会拓展维度）

In [49]:
temp_tensor = torch.ones((2, 3))

test_tensor = torch.stack([temp_tensor, temp_tensor, temp_tensor], dim=0)
print('test_tensor数据：\n', test_tensor, test_tensor.shape)

test_tensor数据：
 tensor([[[1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.]]]) torch.Size([3, 2, 3])


使用`torch.chunk`进行张量切分

In [50]:
temp_tensor = torch.arange(14).reshape(2, 7)
print('temp_tensor数据：\n', temp_tensor)
print()

tensor_list = torch.chunk(temp_tensor, dim=1, chunks=3)
for idx, test_tensor in enumerate(tensor_list):
    print(f'第{idx + 1}个张量：\n', test_tensor, test_tensor.shape)

temp_tensor数据：
 tensor([[ 0,  1,  2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11, 12, 13]])

第1个张量：
 tensor([[0, 1, 2],
        [7, 8, 9]]) torch.Size([2, 3])
第2个张量：
 tensor([[ 3,  4,  5],
        [10, 11, 12]]) torch.Size([2, 3])
第3个张量：
 tensor([[ 6],
        [13]]) torch.Size([2, 1])


使用`torch.split`进行张量切分

In [51]:
temp_tensor = torch.arange(10).reshape(2, 5)
print('temp_tensor数据：\n', temp_tensor)
print()

tensor_list = torch.split(temp_tensor, 2, dim=1)
for i, test_tensor in enumerate(tensor_list):
    print(f'第{i + 1}个张量：\n', test_tensor, test_tensor.shape)
print()

tensor_list = torch.split(temp_tensor, [2, 1, 2], dim=1)
for i, test_tensor in enumerate(tensor_list):
    print(f'第{i + 1}个张量：\n', test_tensor, test_tensor.shape)

temp_tensor数据：
 tensor([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]])

第1个张量：
 tensor([[0, 1],
        [5, 6]]) torch.Size([2, 2])
第2个张量：
 tensor([[2, 3],
        [7, 8]]) torch.Size([2, 2])
第3个张量：
 tensor([[4],
        [9]]) torch.Size([2, 1])

第1个张量：
 tensor([[0, 1],
        [5, 6]]) torch.Size([2, 2])
第2个张量：
 tensor([[2],
        [7]]) torch.Size([2, 1])
第3个张量：
 tensor([[3, 4],
        [8, 9]]) torch.Size([2, 2])


使用`torch.index_select`进行张量索引

In [52]:
temp_tensor = torch.randint(0, 9, size=(3, 3))
print('temp_tensor数据：\n', temp_tensor)
print()

index_tensor = torch.tensor([0, 2], dtype=torch.long)
test_tensor = torch.index_select(temp_tensor, dim=0, index=index_tensor)
print('test_tensor数据：\n', test_tensor)

temp_tensor数据：
 tensor([[8, 0, 2],
        [6, 7, 6],
        [7, 1, 1]])

test_tensor数据：
 tensor([[8, 0, 2],
        [7, 1, 1]])


使用`torch.masked_select`进行张量索引

In [53]:
temp_tensor = torch.randint(0, 9, size=(3, 3))
print('temp_tensor数据：\n', temp_tensor)
print()

mask_tensor = temp_tensor.le(5)
print('mask_tensor数据：\n', mask_tensor)
print()

test_tensor = torch.masked_select(temp_tensor, mask_tensor)
print('test_tensor数据：\n', test_tensor)

temp_tensor数据：
 tensor([[0, 8, 2],
        [6, 3, 1],
        [2, 0, 0]])

mask_tensor数据：
 tensor([[ True, False,  True],
        [False,  True,  True],
        [ True,  True,  True]])

test_tensor数据：
 tensor([0, 2, 3, 1, 2, 0, 0])


使用`torch.reshape`变换张量形状

In [54]:
temp_tensor = torch.randperm(8)
test_tensor = torch.reshape(temp_tensor, (-1, 2, 2))
print('temp_tensor数据：\n', temp_tensor)
print('test_tensor数据：\n', test_tensor, test_tensor.shape)
print()

temp_tensor[0] = 10
print('修改temp_tensor后')
print('temp_tensor数据：\n', temp_tensor)
print('test_tensor数据：\n', test_tensor)
print()

print('temp_tensor.data内存地址：', id(temp_tensor.data))
print('test_tensor.data内存地址：', id(test_tensor.data))

temp_tensor数据：
 tensor([2, 3, 4, 7, 1, 0, 6, 5])
test_tensor数据：
 tensor([[[2, 3],
         [4, 7]],

        [[1, 0],
         [6, 5]]]) torch.Size([2, 2, 2])

修改temp_tensor后
temp_tensor数据：
 tensor([10,  3,  4,  7,  1,  0,  6,  5])
test_tensor数据：
 tensor([[[10,  3],
         [ 4,  7]],

        [[ 1,  0],
         [ 6,  5]]])

temp_tensor.data内存地址： 139971984023376
test_tensor.data内存地址： 139971984023376


使用`torch.transpose`交换张量的维度

In [55]:
temp_tensor = torch.rand(2, 3, 4)
print('temp_tensor数据：\n', temp_tensor, temp_tensor.shape)

test_tensor = torch.transpose(temp_tensor, dim0=1, dim1=2)
print('test_tensor数据：\n', test_tensor, test_tensor.shape)

temp_tensor数据：
 tensor([[[0.5529, 0.9527, 0.0362, 0.1852],
         [0.3734, 0.3051, 0.9320, 0.1759],
         [0.2698, 0.1507, 0.0317, 0.2081]],

        [[0.9298, 0.7231, 0.7423, 0.5263],
         [0.2437, 0.5846, 0.0332, 0.1387],
         [0.2422, 0.8155, 0.7932, 0.2783]]]) torch.Size([2, 3, 4])
test_tensor数据：
 tensor([[[0.5529, 0.3734, 0.2698],
         [0.9527, 0.3051, 0.1507],
         [0.0362, 0.9320, 0.0317],
         [0.1852, 0.1759, 0.2081]],

        [[0.9298, 0.2437, 0.2422],
         [0.7231, 0.5846, 0.8155],
         [0.7423, 0.0332, 0.7932],
         [0.5263, 0.1387, 0.2783]]]) torch.Size([2, 4, 3])


使用`torch.squeeze`压缩张量维度

In [56]:
temp_tensor = torch.rand((1, 2, 3, 1))

test_tensor = torch.squeeze(temp_tensor)
print('test_tensor形状：\n', test_tensor.shape)

test_tensor = torch.squeeze(temp_tensor, dim=0)
print('test_tensor形状：\n', test_tensor.shape)

test_tensor = torch.squeeze(temp_tensor, dim=1)
print('test_tensor形状：\n', test_tensor.shape)

test_tensor形状：
 torch.Size([2, 3])
test_tensor形状：
 torch.Size([2, 3, 1])
test_tensor形状：
 torch.Size([1, 2, 3, 1])


使用`torch.add`进行张量加法运算

In [58]:
temp_tensor_0 = torch.randn((3, 3))
temp_tensor_1 = torch.ones_like(temp_tensor_0)
test_tensor = torch.add(temp_tensor_0, temp_tensor_1, alpha=10)

print('temp_tensor_0数据：\n', temp_tensor_0)
print('temp_tensor_1数据：\n', temp_tensor_1)
print('test_tensor数据：\n', test_tensor)

temp_tensor_0数据：
 tensor([[-0.1848, -1.1938, -0.2233],
        [-1.2706,  0.0193,  0.8868],
        [ 0.0552,  0.6880,  1.2326]])
temp_tensor_1数据：
 tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
test_tensor数据：
 tensor([[ 9.8152,  8.8062,  9.7767],
        [ 8.7294, 10.0193, 10.8868],
        [10.0552, 10.6880, 11.2326]])
