In [1]:
import torch
from torch import nn

# nn.Linear()

In [2]:
linear = nn.Linear(5, 1, bias=True)
linear.eval()

Linear(in_features=5, out_features=1, bias=True)

## 初始化方式1

In [6]:
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, mean=0, std=0.01)

In [7]:
linear.apply(init_weights)

Linear(in_features=5, out_features=1, bias=True)

## 初始化方式2

In [9]:
# 初始化
linear.weight.data.normal_(0, 0.01)

tensor([[-0.0009, -0.0166,  0.0066, -0.0157, -0.0096]])

In [10]:
linear.weight.data

tensor([[-0.0009, -0.0166,  0.0066, -0.0157, -0.0096]])

In [11]:
# 初始化
linear.bias.data.fill_(0)

tensor([0.])

In [12]:
linear.bias.data

tensor([0.])

# 一维输入

In [13]:
x1 = torch.arange(1, 6.)
x1

tensor([1., 2., 3., 4., 5.])

In [14]:
x1.shape

torch.Size([5])

In [15]:
with torch.inference_mode():
    y1 = linear(x1)
y1

tensor([-0.1248])

In [16]:
# 相同位置点乘求和再加上bias
(x1 * linear.weight.data).sum() + linear.bias.data

tensor([-0.1248])

# 二维输入

## batch = 1

In [17]:
x2 = x1.reshape(1, 5)
x2

tensor([[1., 2., 3., 4., 5.]])

In [18]:
with torch.inference_mode():
    y2 = linear(x2)
y2

tensor([[-0.1248]])

In [19]:
# 相同位置点乘求和再加上bias
(x2 * linear.weight.data).sum(dim=-1, keepdim=True) + linear.bias.data

tensor([[-0.1248]])

## batch != 1

In [20]:
x3 = torch.arange(1, 11.).reshape(2, 5)
x3

tensor([[ 1.,  2.,  3.,  4.,  5.],
        [ 6.,  7.,  8.,  9., 10.]])

In [21]:
x3.shape

torch.Size([2, 5])

In [22]:
with torch.inference_mode():
    y3 = linear(x3)
y3

tensor([[-0.1248],
        [-0.3054]])

In [23]:
y3.shape

torch.Size([2, 1])

In [24]:
(x3 * linear.weight.data).sum(dim=-1, keepdim=True) + linear.bias.data

tensor([[-0.1248],
        [-0.3054]])

In [25]:
for x3_ in x3:
    print((x3_ * linear.weight.data).sum() + linear.bias.data)

tensor([-0.1248])
tensor([-0.3054])


# 多维输入

## 三维输入

In [26]:
x4 = x3.reshape(1, 2, 5).expand(2, -1, -1)
x4

tensor([[[ 1.,  2.,  3.,  4.,  5.],
         [ 6.,  7.,  8.,  9., 10.]],

        [[ 1.,  2.,  3.,  4.,  5.],
         [ 6.,  7.,  8.,  9., 10.]]])

In [27]:
with torch.inference_mode():
    y4 = linear(x4)
y4

tensor([[[-0.1248],
         [-0.3054]],

        [[-0.1248],
         [-0.3054]]])

In [28]:
y4.shape

torch.Size([2, 2, 1])

In [29]:
(x4 * linear.weight.data).sum(dim=-1, keepdim=True) + linear.bias.data

tensor([[[-0.1248],
         [-0.3054]],

        [[-0.1248],
         [-0.3054]]])

In [30]:
((x4 * linear.weight.data).sum(dim=-1, keepdim=True) + linear.bias.data).shape

torch.Size([2, 2, 1])