<a href="https://colab.research.google.com/github/KimBbaoro/KEB_toy/blob/master/GRU_Cell_practice.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dataset
from torch.autograd import Variable
from torch.nn import Parameter
from torch import Tensor
import torch.nn.functional as F
from torch.utils.data import DataLoader
import math

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
torch.manual_seed(125)
if torch.cuda.is_available():
  torch.cuda.manual_seed_all(125)

In [3]:
mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (1.0,))
])

In [4]:
from torchvision.datasets import MNIST
download_root = "/content/sample_data"


train_dataset = MNIST(download_root, transform = mnist_transform, train = True, download = True)
valid_dataset = MNIST(download_root, transform = mnist_transform, train = False, download = True)
test_dataset = MNIST(download_root, transform = mnist_transform, train = False, download = True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /content/sample_data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 75415455.86it/s]


Extracting /content/sample_data/MNIST/raw/train-images-idx3-ubyte.gz to /content/sample_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /content/sample_data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 33288181.87it/s]


Extracting /content/sample_data/MNIST/raw/train-labels-idx1-ubyte.gz to /content/sample_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /content/sample_data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 20894689.22it/s]


Extracting /content/sample_data/MNIST/raw/t10k-images-idx3-ubyte.gz to /content/sample_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /content/sample_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3700568.91it/s]


Extracting /content/sample_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /content/sample_data/MNIST/raw



In [5]:
batch_size = 64
train_loader = DataLoader(dataset = train_dataset,
                          batch_size = batch_size,
                          shuffle = True)
valid_loader = DataLoader(dataset = test_dataset,
                          batch_size = batch_size,
                          shuffle = True)
test_loader = DataLoader(dataset = train_dataset,
                          batch_size = batch_size,
                          shuffle = True)

In [6]:
batch_size = 100
n_iters = 6000
num_epochs = n_iters / (len(train_dataset)/batch_size)
num_epochs = int(num_epochs)

In [9]:
class GRUCell(nn.Module):
  def __init__(self, input_size, hidden_size, bias = True):
    super(GRUCell, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.bias = bias
    self.x2h = nn.Linear(input_size, 3*hidden_size, bias)
    self.h2h = nn.Linear(hidden_size, 3*hidden_size, bias = bias)

  def reset_parameters(self):
    std = 1.0/math.sqrt(self.hidden_size)
    for w in self.parameters():
      w.data.uniform_(-std, std)
  
  def forward(self,x, hidden):
    x = x.view(-1, x.size(1))

    gate_x = self.x2h(x) #lstm은 x2h + h2h, gur에서는 개별적인 상태
    gate_h = self.h2h(hidden)
    gate_x = gate_x.squeeze()
    gate_h = gate_h.squeeze()

    i_r, i_i, i_n = gate_x.chunk(3,1) #세개로 쪼갬//탄젠트 활성화 함수가 적용되는 부분을 newgate로.
    h_r, h_i, h_n = gate_h.chunk(3,1)

    resetgate = F.sigmoid(i_r + h_r)
    inputgate = F.sigmoid(i_i + h_i)
    newgate = F.tanh(i_n +(resetgate*h_n))
    hy = newgate + inputgate*(hidden-newgate)
    return hy

In [22]:
class GRUModel(nn.Module):
  def __init__(self, input_dim, hidden_dim, layer_dim, output_dim, bias = True):
    super(GRUModel, self).__init__()
    self.hidden_dim = hidden_dim
    self.layer_dim = layer_dim
    self.gru_cell = GRUCell(input_dim, hidden_dim , layer_dim)
    self.fc = nn.Linear(hidden_dim, output_dim)

  def forward(self,x):
    if torch.cuda.is_available():
      h0 = Variable(torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).cuda())
    else:
      h0 = Variable(torch.zeros(self.layer_dim, x.size(0), self.hidden_dim))

    outs = []
    hn = h0[0,:,:]
    for seq in range(x.size(1)):
      hn = self.gru_cell(x[:,seq,:], hn)
      outs.append(hn)
      out = outs[-1].squeeze()
      out = self.fc(out)
    return out

In [23]:
input_dim = 28
hidden_dim = 128
layer_dim = 1
output_dim = 10

model = GRUModel(input_dim, hidden_dim,layer_dim,output_dim)

if torch.cuda.is_available():
  model.cuda()

criterion = nn.CrossEntropyLoss()
learning_rate = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

In [None]:
seq_dim = 28
loss_list = []
iter = 0
for epoch in range(num_epochs):
  for i, (images, labels) in enumerate(train_loader):
    if torch.cuda.is_available():
      images = Variable(images.view(-1, seq_dim, input_dim).cuda())
      labels = Variable(labels.cuda())

    else:
      images = Variable(images.view(-1, seq_dim, input_dim))
      labels = Variable(labels)

    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    if torch.cuda.is_available():
      loss.cuda()
    loss.backward()
    optimizer.step()

    loss_list.append(loss.item())
    iter +=1


    if iter % 500 == 0:         
      correct = 0
      total = 0
      for images, labels in valid_loader:
          if torch.cuda.is_available():
              images = Variable(images.view(-1, seq_dim, input_dim).cuda())
          else:
              images = Variable(images.view(-1 , seq_dim, input_dim))
          
          outputs = model(images)
          _, predicted = torch.max(outputs.data, 1)
          total += labels.size(0)

          if torch.cuda.is_available():
              correct += (predicted.cpu() == labels.cpu()).sum()
          else:
              correct += (predicted == labels).sum()
        
      accuracy = 100 * correct / total
      print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.item(), accuracy))



Iteration: 500. Loss: 0.7814511060714722. Accuracy: 67.12999725341797
Iteration: 1000. Loss: 0.6065189838409424. Accuracy: 88.0
Iteration: 1500. Loss: 0.22339999675750732. Accuracy: 92.55999755859375
Iteration: 2000. Loss: 0.09184202551841736. Accuracy: 94.7300033569336
