In [1]:
import torch

## Raw Linear Layer

In [2]:
W = torch.FloatTensor([[1,2],
                     [3,4],
                     [5,6]])
b = torch.FloatTensor([2,2])

In [3]:
print(W.size())
print(b.size())

torch.Size([3, 2])
torch.Size([2])


In [4]:
def linear(x, W, b):
    y = torch.matmul(x, W) + b
    return y

In [5]:
x = torch.FloatTensor([[1,1,1],
                     [2,2,2],
                     [3,3,3],
                     [4,4,4]])
print(x.size())

torch.Size([4, 3])


In [6]:
y = linear(x,W,b)

In [7]:
print(y.size())

torch.Size([4, 2])


## nn.Module

In [8]:
import torch.nn as nn

In [9]:
class MyLinear(nn.Module):
    def __init__(self, input_dim=3, output_dim=2):
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        super().__init__()
        
        self.W = torch.FloatTensor(input_dim, output_dim)
        self.b = torch.FloatTensor(output_dim)
    
    def forward(self,x):
        # |x| = (batch_size, input_dim)
        y = torch.matmul(x, self.W) + self.b
        # |y| = (batch_size, input_dim) * (input_dim, output_dim) = (batch_size, output_dim)
        return y

In [10]:
linear = MyLinear(3,2)
y = linear(x)

In [11]:
print(y.size())

torch.Size([4, 2])


In [12]:
for p in linear.parameters():
    print(p)
# 원래는 iternable한 parrameter가 있어야 학습 모델이 된다. => 잘못된 모델!

In [13]:
# 올바른 모델 => nn.Parameter 사용
class MyLinear(nn.Module):
    def __init__(self, input_dim=3, output_dim=2):
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        super().__init__()
        
        self.W = nn.Parameter(torch.FloatTensor(input_dim, output_dim))
        self.b = nn.Parameter(torch.FloatTensor(output_dim))
    
    def forward(self,x):
        # |x| = (batch_size, input_dim)
        y = torch.matmul(x, self.W) + self.b
        # |y| = (batch_size, input_dim) * (input_dim, output_dim) = (batch_size, output_dim)
        return y

In [14]:
linear = MyLinear(3,2)
y = linear(x)

In [15]:
print(y.size())

torch.Size([4, 2])


In [16]:
for p in linear.parameters():
    print(p)

Parameter containing:
tensor([[0.0000e+00, 0.0000e+00],
        [1.8754e+28, 2.0110e+20],
        [4.3918e-05, 1.2859e-11]], requires_grad=True)
Parameter containing:
tensor([-1.5018e+32,  4.5908e-41], requires_grad=True)
