<a href="https://colab.research.google.com/github/JaeHeee/Pytorch_Tutorial/blob/main/code/PYTORCH_CONTROL_FLOW_%2B_WEIGHT_SHARING.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## PYTORCH: CONTROL FLOW + WEIGHT SHARING

PyTorch 동적 그래프의 강력함을 보여주기 위해, 매우 이상한 모델을 구현.  

a fully-connected ReLU network that on each forward pass randomly chooses a number between 1 and 4 and has that many hidden layers, reusing the same weights multiple times to compute the innermost hidden layers.

In [None]:
import random
import torch


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)
    
    def forward(self, x):
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred

# N : batch size
# D_in : input dimension
# H : hidden dimension
# D_out : output dimension
N, D_in, H, D_out = 64, 1000, 100, 10

# random Tensors to hold input and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct out model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass
    y_pred = model(x)

    # loss
    loss = criterion(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())

    # backward pass 단계 전에 optimizer 객체를 사용하여 update할 변수에 대한 gradient를 0 으로 만든다.
    # 이렇게 하는 이유는 기본적으로 .backward()를 호출할 때마다 gradient가 버퍼(buffer)에 (덮어쓰지 않고) 누적되기 때문
    optimizer.zero_grad()

    # Backward pass : model의 parameter에 대한 loss의 gradient를 계산
    loss.backward()

    # Optimizer의 step 함수를 호출하면 Update the weights
    optimizer.step()

99 44.13999938964844
199 2.320537805557251
299 0.7980716824531555
399 0.22657804191112518
499 0.41260507702827454
