https://yjs-program.tistory.com/165

In [21]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
#import torch.nn.functional as F
#from torchtext import data, datasets


In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [19]:
# hyper parameter
input_size = 28
sequence_length = 28
num_layers = 2

# 
hidden_size = 128
num_classes = 10
num_epochs = 2
batch_size = 100
lr = 0.001

In [20]:
train_dataset = torchvision.datasets.MNIST(root = "./data",
                                          train = True,
                                          transform = transforms.ToTensor(),
                                          download = True)

test_dataset = torchvision.datasets.MNIST(root = "./data",
                                          train = False,
                                          transform = transforms.ToTensor(),
                                          download = True)

In [22]:
train_loader = DataLoader(train_dataset,
                          batch_size = batch_size,
                          shuffle = True)

test_loader = DataLoader(test_dataset,
                        batch_size = batch_size,
                        shuffle = False)

In [107]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size,num_layers, num_classes, model_type:str):
        super(RNN, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.model_type = model_type
        
        if model_type == "RNN":
            self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first = True)
        elif model_type == "GRU":
            self.rnn = nn.GRU(input_size, hidden_size, num_layers, batch_first = True)
        elif model_type == "LSTM":
            self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first = True)

        
        # x -> (Batch, sequence, feature_size(input_size))
        self.fc = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        ho = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # 2*100*28
        co = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # 2*100*28
        
        if self.model_type == "LSTM":
            out,_ = self.rnn(x,(ho,co)) # batch*seq_length*hidden_size
        else:
            out,_ = self.rnn(x,ho)
        
        out = out[:,-1,:]
        out = self.fc(out)
        return out
        
    

In [108]:
# input_size = 28, hidden_size = 128, num_classes = 10
model = RNN(input_size, hidden_size,num_layers, num_classes, "RNN").to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = lr)

In [165]:
model_type = ["RNN","GRU","LSTM"]

def make_model(model_type):
    return RNN(input_size, hidden_size,num_layers, num_classes, model_type).to(device)


x,y = next(iter(train_loader))
x = x.reshape(x.size(0), sequence_length, input_size)
models = [make_model(rnn_type) for rnn_type in model_type]
for idx,model in enumerate(models):
    print(f"model_type = {model_type[idx]}, shape:{model(x).shape}")

model_type = RNN, shape:torch.Size([100, 10])
model_type = GRU, shape:torch.Size([100, 10])
model_type = LSTM, shape:torch.Size([100, 10])


In [168]:
for idx, model in enumerate(models):
    optimizer = torch.optim.Adam(model.parameters(), lr = lr)
    for epoch in range(num_epochs):
        n_total_steps = len(train_loader)
        for i, (images, labels) in enumerate(train_loader):
            images = images.reshape(-1,sequence_length, input_size).to(device)
            labels = labels.to(device)

            # Forward
            outputs = model(images)
            loss = criterion(outputs, labels)

            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i+1)%100 == 0:
                print(f"[{model_type[idx]}] Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss:{loss.item():.4f}")
    print("\n")            

[RNN] Epoch [1/2], Step [100/600], Loss:1.0091
[RNN] Epoch [1/2], Step [200/600], Loss:0.8586
[RNN] Epoch [1/2], Step [300/600], Loss:0.5610
[RNN] Epoch [1/2], Step [400/600], Loss:0.4879
[RNN] Epoch [1/2], Step [500/600], Loss:0.5852
[RNN] Epoch [1/2], Step [600/600], Loss:0.5508
[RNN] Epoch [2/2], Step [100/600], Loss:0.3775
[RNN] Epoch [2/2], Step [200/600], Loss:0.1994
[RNN] Epoch [2/2], Step [300/600], Loss:0.4172
[RNN] Epoch [2/2], Step [400/600], Loss:0.1360
[RNN] Epoch [2/2], Step [500/600], Loss:0.2592
[RNN] Epoch [2/2], Step [600/600], Loss:0.2185
[GRU] Epoch [1/2], Step [100/600], Loss:0.6300
[GRU] Epoch [1/2], Step [200/600], Loss:0.6417
[GRU] Epoch [1/2], Step [300/600], Loss:0.2719
[GRU] Epoch [1/2], Step [400/600], Loss:0.2289
[GRU] Epoch [1/2], Step [500/600], Loss:0.1510
[GRU] Epoch [1/2], Step [600/600], Loss:0.1237
[GRU] Epoch [2/2], Step [100/600], Loss:0.2230
[GRU] Epoch [2/2], Step [200/600], Loss:0.1653
[GRU] Epoch [2/2], Step [300/600], Loss:0.1623
[GRU] Epoch [