In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.transforms as transforms
from torchvision.datasets import MNIST, CIFAR10, CIFAR100
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt

In [2]:
path = './datasets/'

transform = transforms.Compose([transforms.ToTensor()])

train_data = MNIST(root=path, train=True, transform=transform, download=True)
test_data = MNIST(root=path, train=False, transform=transform, download=True)

batch_size = 100

train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False, num_workers=4)

print(train_data)
print(test_data)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 74085571.15it/s]


Extracting ./datasets/MNIST/raw/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./datasets/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 78813073.41it/s]


Extracting ./datasets/MNIST/raw/train-labels-idx1-ubyte.gz to ./datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./datasets/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 28128587.39it/s]


Extracting ./datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to ./datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3245958.22it/s]


Extracting ./datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./datasets/MNIST/raw

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./datasets/
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
           )
Dataset MNIST
    Number of datapoints: 10000
    Root location: ./datasets/
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
           )




In [5]:
_, seq_len, input_size = train_data[0][0].shape # (1,28,28)
output_size = len(train_data.classes)

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
hidden_size = input_size*2
model_name = 'LSTM'

In [11]:
class RNNClassifier(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.cell = nn.RNNCell(input_size=self.input_size,
                               hidden_size=self.hidden_size)
        self.fc = nn.Linear(self.hidden_size, output_size)

    def forward(self, x):
        x = x.reshape(-1, seq_len, self.input_size).permute((1,0,2)) # transpose 28,100,28
         # img data라서 필요한 부분, 100x1x28x28 -> 100x28x28
        hidden_state = torch.zeros(batch_size, self.hidden_size).to(device) # inital hidden을 세팅해준다. h_0
        for i in range(seq_len):
            hidden_state = self.cell(x[i], hidden_state)
        out = self.fc(hidden_state)
        return out

In [12]:
class LSTMClassifier(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.cell = nn.LSTMCell(input_size=self.input_size,
                               hidden_size=self.hidden_size)
        self.fc = nn.Linear(self.hidden_size, output_size)

    def forward(self, x):
        x = x.reshape(-1, seq_len, self.input_size).permute((1,0,2))
        hidden_state = torch.zeros(batch_size, self.hidden_size).to(device)
        cell_state = torch.zeros(batch_size, self.hidden_size).to(device)
        for i in range(seq_len):
            hidden_state, cell_state = self.cell(x[i], (hidden_state, cell_state))
        out = self.fc(hidden_state)
        return out

In [13]:
class GRUClassifier(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.cell = nn.GRUCell(input_size=self.input_size,
                               hidden_size=self.hidden_size)
        self.fc = nn.Linear(self.hidden_size, output_size)

    def forward(self, x):
        x = x.reshape(-1, seq_len, self.input_size).permute((1,0,2))
        hidden_state = torch.zeros(batch_size, self.hidden_size).to(device)
        for i in range(seq_len):
            hidden_state = self.cell(x[i], hidden_state)
        out = self.fc(hidden_state)
        return out

In [14]:
if model_name == 'RNN':
    classifier = RNNClassifier
elif model_name == 'LSTM':
    classifier = LSTMClassifier
else:
    classifier = GRUClassifier

In [16]:
model = classifier(input_size, hidden_size).to(device)
loss = nn.CrossEntropyLoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3)

In [17]:
num_epoch = 10
train_loss_lst, test_loss_lst = list(), list()

for i in range(num_epoch):
    # training
    model.train()

    total_loss = 0
    cnt = 0

    for batch_idx, (x,y) in enumerate(train_loader):

        x,y = x.to(device), y.to(device)
        y_est = model.forward(x)
        cost = loss(y_est, y)

        total_loss += cost.item()

        optimizer.zero_grad()
        cost.backward()
        optimizer.step()

        pred = torch.argmax(y_est, dim=1)
        cnt += (pred == y).sum().item()

    acc = cnt / len(train_data)
    ave_loss = total_loss / len(train_data)

    train_loss_lst.append(ave_loss)

    if i % 1 == 0:
        print(f"\nEpoch {i} Train : {ave_loss:.3f} / {acc:.3f}")

    #testing
    model.eval()

    total_loss = 0
    cnt = 0

    with torch.no_grad():
        for batch, (x,y) in enumerate(test_loader):

            x, y = x.to(device), y.to(device)

            y_est = model.forward(x)
            pred = torch.argmax(y_est, dim=1)

            total_loss += cost.item()

        acc = cnt / len(test_data)
        ave_loss = total_loss / len(test_data)

        test_loss_lst.append(ave_loss)

        if i % 1 == 0:
            print(f"Epoch {i} Test : {ave_loss:.3f} / {acc:.3f}")

print()
num_parameter = 0
for parameter in model.parameters():
    print(parameter.shape)
    num_parameter += np.prod(parameter.size())
print(num_parameter)




Epoch 0 Train : 0.748 / 0.758
Epoch 0 Test : 0.345 / 0.000

Epoch 1 Train : 0.230 / 0.933
Epoch 1 Test : 0.098 / 0.000

Epoch 2 Train : 0.160 / 0.953
Epoch 2 Test : 0.129 / 0.000

Epoch 3 Train : 0.127 / 0.963
Epoch 3 Test : 0.083 / 0.000

Epoch 4 Train : 0.109 / 0.968
Epoch 4 Test : 0.049 / 0.000

Epoch 5 Train : 0.095 / 0.971
Epoch 5 Test : 0.097 / 0.000

Epoch 6 Train : 0.084 / 0.975
Epoch 6 Test : 0.047 / 0.000

Epoch 7 Train : 0.075 / 0.977
Epoch 7 Test : 0.069 / 0.000

Epoch 8 Train : 0.069 / 0.979
Epoch 8 Test : 0.039 / 0.000

Epoch 9 Train : 0.062 / 0.981
Epoch 9 Test : 0.028 / 0.000

torch.Size([224, 28])
torch.Size([224, 56])
torch.Size([224])
torch.Size([224])
torch.Size([10, 56])
torch.Size([10])
19834
