<a href="https://colab.research.google.com/github/ahxlzjt/MedImagingDL/blob/CH01_Deep-Learning-Basic/RNN_in_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dset

In [None]:
train_dataset = dset.MNIST(root='.',
                           train=True,
                           transform=transforms.ToTensor(),
                           download=True)

test_dataset = dset.MNIST(root='.',
                          train=False,
                          transform=transforms.ToTensor())

In [None]:
batch_size = 100

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

In [None]:
# RNN은 자연어 처리에서 많이 사용됨. (영상 처리에는 잘 사용 x)
#  Classification이 목적.

# RNN 모델의 구성
#  RNN layer 1개, 마지막으로 FC layer를 이용하여 classification하는 구조.
#  RNN layer는 hidden state를 크기를 조정하기 위하여 hidden dim이라는 인자를 받음.
#  batch first는 data를 구성할 때, 가장 첫 번째 차원을 batch 차원이라고 입력하기 위함.
#  hidden state를 계산할 때, 비정형 연산을 추가하기 위해 Relu 활성화 함수를 사용.
#  FC layer는 hidden dim의 크기를 가지는 입력 받고, output dim의 크기를 가지는 출력.
#  hidden state는 맨 처음에는 계산 불가로 처음에는 임의의 값 설정(0인 vector)
#  -> 어떤 입력이 들어오게 되면 RNN 모델을 이용하여 hidden state를 업데이트하고,
#  출력을 만들어낸 후, FC layer에 입력하여 최종적으로 classification 결과를 업데이트.

class RNNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super(RNNModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.layer_dim = layer_dim

        self.rnn = nn.RNN(input_dim, hidden_dim, layer_dim, batch_first=True, nonlinearity='relu')
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).to(x.device)

        out, hn = self.rnn(x, h0)

        out = self.fc(out[:, -1, :])

        return out

In [None]:
input_dim = 28
hidden_dim = 100
layer_dim = 1
output_dim = 10
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = RNNModel(input_dim, hidden_dim, layer_dim, output_dim).to(device)

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
learning_rate = 0.01

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [None]:
# 학습 코드
#  이미지를 input dim, input dim으로 하는 차원으로 view를 통해 바꿔줌
#  -> mnist image가 flatten 되어있는 것에 대비하여 본래대로 변환하는 역할

num_epochs = 20


for epoch in range(num_epochs):
    model.train()
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        images = images.view(-1, input_dim, input_dim)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

    model.eval()
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        images = images.view(-1, 28, input_dim)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum()

    accuracy = 100 * correct / total

    print('Epoch: {}. Loss: {}. Accuracy: {}'.format(epoch, loss.item(), accuracy))

Epoch: 0. Loss: 2.293511152267456. Accuracy: 19.809999465942383
Epoch: 1. Loss: 2.2594332695007324. Accuracy: 19.459999084472656
Epoch: 2. Loss: 1.3639010190963745. Accuracy: 55.28999710083008
Epoch: 3. Loss: 1.1228222846984863. Accuracy: 71.29000091552734
Epoch: 4. Loss: 1.0093928575515747. Accuracy: 75.12999725341797
Epoch: 5. Loss: 0.47615233063697815. Accuracy: 82.68000030517578
Epoch: 6. Loss: 0.5150405168533325. Accuracy: 83.18999481201172
Epoch: 7. Loss: 0.2995639145374298. Accuracy: 82.41999816894531
Epoch: 8. Loss: 0.36420580744743347. Accuracy: 89.5199966430664
Epoch: 9. Loss: 0.2922385334968567. Accuracy: 90.5
Epoch: 10. Loss: 0.17149147391319275. Accuracy: 92.6199951171875
Epoch: 11. Loss: 0.21500307321548462. Accuracy: 93.11000061035156
Epoch: 12. Loss: 0.18573856353759766. Accuracy: 94.11000061035156
Epoch: 13. Loss: 0.12494295835494995. Accuracy: 94.43999481201172
Epoch: 14. Loss: 0.2083677351474762. Accuracy: 95.18000030517578
Epoch: 15. Loss: 0.2874150574207306. Accura