# Tensor的算术运算

注意：所有带了`_`的函数都是原地操作，会直接修改原张量的值。

原地操作的函数通常会返回`None`，而不是新的张量。

不推荐使用原地操作，因为它们可能会导致后续计算出错。

原地操作更多地用于内存优化。

对于原地操作，需要提前预见张量的类型，如果操作前后的类型不一致，会导致错误。

In [1]:
import torch

## 四则运算

### 加法运算——`add`

In [2]:
a = torch.randn(2, 3)
b = torch.randn(2, 3)
print("a:", a)
print("b:", b)
print("a + b:", a + b)
print("a.add(b):", a.add(b))
print("torch.add(a, b):", torch.add(a, b))
print("a.add_(b):", a.add_(b))
print("a after a.add_(b):", a)

a: tensor([[ 0.0242,  0.0666, -0.1596],
        [-0.8096,  0.1257, -0.0701]])
b: tensor([[-0.7666, -1.6691, -0.3594],
        [-0.6778, -0.6803,  0.1596]])
a + b: tensor([[-0.7423, -1.6025, -0.5190],
        [-1.4874, -0.5546,  0.0895]])
a.add(b): tensor([[-0.7423, -1.6025, -0.5190],
        [-1.4874, -0.5546,  0.0895]])
torch.add(a, b): tensor([[-0.7423, -1.6025, -0.5190],
        [-1.4874, -0.5546,  0.0895]])
a.add_(b): tensor([[-0.7423, -1.6025, -0.5190],
        [-1.4874, -0.5546,  0.0895]])
a after a.add_(b): tensor([[-0.7423, -1.6025, -0.5190],
        [-1.4874, -0.5546,  0.0895]])


### 减法运算——`sub`

In [3]:
print("a - b:", a - b)
print("a.sub(b):", a.sub(b))
print("torch.sub(a, b):", torch.sub(a, b))
print("a.sub_(b):", a.sub_(b))
print("a after a.sub_(b):", a)

a - b: tensor([[ 0.0242,  0.0666, -0.1596],
        [-0.8096,  0.1257, -0.0701]])
a.sub(b): tensor([[ 0.0242,  0.0666, -0.1596],
        [-0.8096,  0.1257, -0.0701]])
torch.sub(a, b): tensor([[ 0.0242,  0.0666, -0.1596],
        [-0.8096,  0.1257, -0.0701]])
a.sub_(b): tensor([[ 0.0242,  0.0666, -0.1596],
        [-0.8096,  0.1257, -0.0701]])
a after a.sub_(b): tensor([[ 0.0242,  0.0666, -0.1596],
        [-0.8096,  0.1257, -0.0701]])


### 乘法运算——`mul`

In [4]:
print("a * b:", a * b)
print("a.mul(b):", a.mul(b))
print("torch.mul(a, b):", torch.mul(a, b))
print("a.mul_(b):", a.mul_(b))
print("a after a.mul_(b):", a)

a * b: tensor([[-0.0186, -0.1112,  0.0574],
        [ 0.5487, -0.0855, -0.0112]])
a.mul(b): tensor([[-0.0186, -0.1112,  0.0574],
        [ 0.5487, -0.0855, -0.0112]])
torch.mul(a, b): tensor([[-0.0186, -0.1112,  0.0574],
        [ 0.5487, -0.0855, -0.0112]])
a.mul_(b): tensor([[-0.0186, -0.1112,  0.0574],
        [ 0.5487, -0.0855, -0.0112]])
a after a.mul_(b): tensor([[-0.0186, -0.1112,  0.0574],
        [ 0.5487, -0.0855, -0.0112]])


### 除法运算——`div`

In [5]:
print("a.div(b):", a.div(b))
print("torch.div(a, b):", torch.div(a, b))
print("a.div_(b):", a.div_(b))
print("a after a.div_(b):", a)

a.div(b): tensor([[ 0.0242,  0.0666, -0.1596],
        [-0.8096,  0.1257, -0.0701]])
torch.div(a, b): tensor([[ 0.0242,  0.0666, -0.1596],
        [-0.8096,  0.1257, -0.0701]])
a.div_(b): tensor([[ 0.0242,  0.0666, -0.1596],
        [-0.8096,  0.1257, -0.0701]])
a after a.div_(b): tensor([[ 0.0242,  0.0666, -0.1596],
        [-0.8096,  0.1257, -0.0701]])


## 矩阵、高维Tensor乘法

矩阵、高维Tensor乘法运算不存在`_`的原地操作。由于无法确定计算得到的张量形状和原本张量形状一致，所有的矩阵运算都是创建一个新的张量。

### 矩阵乘法——`matmul`
- `torch.matmul`是矩阵乘法的函数，等价于`@`运算符、`torch.mm`。
- 如果`a`是一个mxn的矩阵，则`b`是一个nxp的矩阵，那么`c`将是一个mxp的矩阵

In [6]:
a = torch.randn(2, 3)
b = torch.randn(3, 2)
print("a:", a)
print("b:", b)
print("a @ b:", a @ b)
print("torch.matmul(a, b):", torch.matmul(a, b))
print("torch.mm(a, b):", torch.mm(a, b))
print("a.matmul(b):", a.matmul(b))
print("a.mm(b):", a.mm(b))

a: tensor([[ 0.6085, -0.9408, -1.3502],
        [ 1.4057, -0.8519,  1.1425]])
b: tensor([[ 0.3712,  1.0338],
        [-0.0474,  0.2137],
        [-1.0411, -0.1239]])
a @ b: tensor([[ 1.6762,  0.5953],
        [-0.6274,  1.1296]])
torch.matmul(a, b): tensor([[ 1.6762,  0.5953],
        [-0.6274,  1.1296]])
torch.mm(a, b): tensor([[ 1.6762,  0.5953],
        [-0.6274,  1.1296]])
a.matmul(b): tensor([[ 1.6762,  0.5953],
        [-0.6274,  1.1296]])
a.mm(b): tensor([[ 1.6762,  0.5953],
        [-0.6274,  1.1296]])


### 高维Tensor乘法——`matmul`
- 不同于矩阵乘法，高维Tensor的乘法`matmul`不能被`@`、`mm`代替
- 对于高维的Tensor(dim>2)，定义其矩阵乘法仅在最后的两个维度，要求前面的维度必须保持一致，就像矩阵的索引一样

In [7]:
a = torch.randn(2, 3, 4)
b = torch.randn(2, 4, 5)
print("torch.matmul(a, b).shape:", torch.matmul(a, b).shape)
print("a.matmul(b).shape:", a.matmul(b).shape)

torch.matmul(a, b).shape: torch.Size([2, 3, 5])
a.matmul(b).shape: torch.Size([2, 3, 5])


## 其他算术运算

### 幂运算——`pow`
- `pow`函数是对Tensor的每个元素进行幂运算。

In [8]:
a = torch.tensor([1, 2])
print("torch.pow(a, 3):", torch.pow(a, 3))
print("a.pow(3):", a.pow(3))
print("a ** 3:", a ** 3)
print("a.pow_(3):", a.pow_(3))
print("a after a.pow_(3):", a)

torch.pow(a, 3): tensor([1, 8])
a.pow(3): tensor([1, 8])
a ** 3: tensor([1, 8])
a.pow_(3): tensor([1, 8])
a after a.pow_(3): tensor([1, 8])


#### 自然指数运算——`exp`

In [9]:
a = torch.tensor([1, 2], dtype=torch.float32)
print("a.type():", a.type())
print("torch.exp(a):", torch.exp(a))
print("torch.exp_(a):", torch.exp_(a))
print("a after torch.exp_(a):", a)
print("a.exp():", a.exp())
print("a.exp_():", a.exp_())
print("a after a.exp_():", a)

a.type(): torch.FloatTensor
torch.exp(a): tensor([2.7183, 7.3891])
torch.exp_(a): tensor([2.7183, 7.3891])
a after torch.exp_(a): tensor([2.7183, 7.3891])
a.exp(): tensor([  15.1543, 1618.1781])
a.exp_(): tensor([  15.1543, 1618.1781])
a after a.exp_(): tensor([  15.1543, 1618.1781])


### 对数运算——`log`、`log2`、`log10`
- `log`、`log2`、`log10`函数是对Tensor的每个元素进行底数为e、2、10的对数运算。
- 张量类型最好是浮点数类型，否则会报错。

In [10]:
a = torch.tensor([10, 2], dtype=torch.float32)  # 此处仅展示对数运算
print("torch.log(a):", torch.log(a))
print("torch.log_(a):", torch.log_(a))
print("a after torch.log_(a):", a)
print("a.log():", a.log())
print("a.log_():", a.log_())
print("a after a.log_():", a)

torch.log(a): tensor([2.3026, 0.6931])
torch.log_(a): tensor([2.3026, 0.6931])
a after torch.log_(a): tensor([2.3026, 0.6931])
a.log(): tensor([ 0.8340, -0.3665])
a.log_(): tensor([ 0.8340, -0.3665])
a after a.log_(): tensor([ 0.8340, -0.3665])


### 开方运算——`sqrt`

In [11]:
a = torch.tensor([4, 9], dtype=torch.float32)  # 对于平方根运算，输入需要是浮点数类型
print("torch.sqrt(a):", torch.sqrt(a))
print("torch.sqrt_(a):", torch.sqrt_(a))  # 若输入非浮点数类型，此处会报错
print("a after torch.sqrt_(a):", a)
print("a.sqrt():", a.sqrt())
print("a.sqrt_():", a.sqrt_())
print("a after a.sqrt_():", a)

torch.sqrt(a): tensor([2., 3.])
torch.sqrt_(a): tensor([2., 3.])
a after torch.sqrt_(a): tensor([2., 3.])
a.sqrt(): tensor([1.4142, 1.7321])
a.sqrt_(): tensor([1.4142, 1.7321])
a after a.sqrt_(): tensor([1.4142, 1.7321])
