In [778]:
import torch

# Raw linear layer

In [779]:
W = torch.FloatTensor([
    [1, 2],
    [3, 4],
    [5, 6],
])
b = torch.FloatTensor([
    2, 2
])
b_prime = torch.FloatTensor([
    [1],
    [2],
    [3]
])

f = lambda x: print(x, x.size())

f(W)
f(b)
f(b_prime)

tensor([[1., 2.],
        [3., 4.],
        [5., 6.]]) torch.Size([3, 2])
tensor([2., 2.]) torch.Size([2])
tensor([[1.],
        [2.],
        [3.]]) torch.Size([3, 1])


In [780]:
def linear(x, W, b):
    y = torch.matmul(x, W) + b
    return y

In [781]:
x = torch.FloatTensor([
    [1, 1, 1],
    [2, 2, 2],
    [3, 3, 3],
    [4, 4, 4],
])
f(x)

tensor([[1., 1., 1.],
        [2., 2., 2.],
        [3., 3., 3.],
        [4., 4., 4.]]) torch.Size([4, 3])


In [782]:
y = linear(x, W, b)
f(y)

tensor([[11., 14.],
        [20., 26.],
        [29., 38.],
        [38., 50.]]) torch.Size([4, 2])


# nn.Module

In [783]:
import torch.nn as nn

In [784]:
class MyLinear(nn.Module):
    def __init__(self, input_dim=3, output_dim=2):
        self.input_dim = input_dim
        self.output_dim = output_dim

        super().__init__()

        # Non-registered
        # self.W = torch.FloatTensor(input_dim, output_dim) 
        # self.b = torch.FloatTensor(output_dim)

        # Registering
        self.W = nn.Parameter(torch.FloatTensor(input_dim, output_dim))
        self.b = nn.Parameter(torch.FloatTensor(output_dim))

    def forward(self, x):
        y = torch.matmul(x, self.W) + self.b
        print('forward() processed!')
        return y

In [785]:
linear = MyLinear(3, 2)
y = linear(x)
f(y)

forward() processed!
tensor([[-0.1760,  2.2103],
        [-0.3520,  2.1706],
        [-0.5280,  2.1309],
        [-0.7040,  2.0913]], grad_fn=<AddBackward0>) torch.Size([4, 2])


In [786]:
for param in linear.parameters():
    print(param)

Parameter containing:
tensor([[ 0.3477, -0.5286],
        [ 0.0493,  0.0533],
        [-0.5729,  0.4356]], requires_grad=True)
Parameter containing:
tensor([0.0000, 2.2500], requires_grad=True)


# nn.Linear

In [787]:
linear = nn.Linear(3, 2)
y = linear(x)

In [788]:
f(y)

tensor([[-0.0140,  0.3910],
        [-0.1345,  0.3138],
        [-0.2549,  0.2366],
        [-0.3754,  0.1594]], grad_fn=<AddmmBackward0>) torch.Size([4, 2])


In [789]:
for param in linear.parameters():
    print(param)

Parameter containing:
tensor([[ 0.1408,  0.0559, -0.3172],
        [ 0.2370, -0.0269, -0.2873]], requires_grad=True)
Parameter containing:
tensor([0.1064, 0.4682], requires_grad=True)


# nn.Module can contain other nn.Module's child classes

In [790]:
class MyLinear(nn.Module):
    def __init__(self, input_dim=3, output_dim=2):
        self.input_dim = input_dim
        self.output_dim = output_dim

        super().__init__()

        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        # print('forward() called!')
        y = self.linear(x)
        return y

In [791]:
linear = MyLinear() # (3, 2)
# y = linear.forward(x)
y = linear(x)
f(y)

tensor([[0.6431, 0.2304],
        [0.9490, 0.4604],
        [1.2549, 0.6904],
        [1.5607, 0.9204]], grad_fn=<AddmmBackward0>) torch.Size([4, 2])


In [792]:
for p in linear.parameters():
    print(p)

Parameter containing:
tensor([[ 0.0727,  0.0206,  0.2126],
        [-0.0845,  0.4962, -0.1817]], requires_grad=True)
Parameter containing:
tensor([0.3373, 0.0004], requires_grad=True)
