In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.nn.init as init
%matplotlib inline
%config InlineBackend.figure_format='retina'
print ("PyTorch version:[%s]."%(torch.__version__))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print ("device:[%s]."%(device))

In [None]:
from torchvision import datasets,transforms
mnist_train = datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=True)
mnist_test = datasets.MNIST(root='./data/',train=False,transform=transforms.ToTensor(),download=True)
BATCH_SIZE = 64
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=BATCH_SIZE,shuffle=True,num_workers=1)
test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=BATCH_SIZE,shuffle=True,num_workers=1)
print ("Done.")

In [None]:
class RecurrentNeuralNetworkClass(nn.Module):
    def __init__(self, name='rnn', xdim=28, hdim=256, ydim=10, n_layer=3):
        super(RecurrentNeuralNetworkClass, self).__init__()
        self.name = name
        self.xdim = xdim
        self.hdim = hdim
        self.ydim = ydim
        self.n_layer = n_layer

        self.cells = nn.ModuleList(
            [nn.LSTMCell(input_size=self.xdim if i == 0 else self.hdim, hidden_size=self.hdim) for i in range(n_layer)])
        self.cell_weights = nn.ParameterList([nn.Parameter(torch.randn(3 * self.hdim)) for _ in range(n_layer)])
        self.lin = nn.Linear(self.hdim, self.ydim)

        # 가중치 초기화
        for i, cell in enumerate(self.cells):
            for name, param in cell.named_parameters():
                if 'weight_ih' in name:
                    nn.init.xavier_uniform_(param.data)
                elif 'weight_hh' in name:
                    nn.init.orthogonal_(param.data)
                elif 'bias' in name:
                    param.data.fill_(0)        

    def forward(self, x):
        batch_size = x.size(0)
        hidden_seq = []

        h = torch.zeros(self.n_layer, batch_size, self.hdim).to(x.device)
        #print("Here is h", h)
        c = torch.zeros(self.n_layer, batch_size, self.hdim).to(x.device)
        #print("Here is c",c)

        for t in range(x.size(1)):
            for i in range(self.n_layer):
                cell = self.cells[i]
                cell_weight = self.cell_weights[i]
                #print(cell_weight)

                if i == 0:
                    #print("This is x,{},{}".format(i,t), x)
                    input_chunk = x[:, t, :]
                    #print("Here is input_chunk,{},{}".format(i,t),input_chunk)
                else:
                    #print("Here is i > 0")
                    input_chunk = h[i - 1]
                    #print("Here is input_chunk,{},{}".format(i,t),input_chunk)

                h_prev, c_prev = h[i], c[i]
                #print("h_prev, c_prev,{},{}".format(i,t), h_prev, c_prev)
                h_next, c_next = cell(input_chunk, (h_prev, c_prev))
                #print("h_next, c_next,{},{}".format(i,t), h_next, c_next)

                ci, cf, co = cell_weight[:self.hdim], cell_weight[self.hdim:2 * self.hdim], cell_weight[2 * self.hdim:]
                #print("ci, cf, co,{},{}".format(i,t), ci, cf, co)
                
                EPSILON = 1e-8

                input_gate = (c_next - c_prev * torch.sigmoid(cf.clone())) / (1 - torch.sigmoid((cf + ci).clone() + EPSILON))
                # print("input_gate,{},{}".format(i,t), input_gate)
                forget_gate = (c_next - input_gate * torch.sigmoid(ci.clone())) / (c_prev.clone() + EPSILON)
                # print("forget_gate,{},{}".format(i,t), forget_gate)
                output_gate = (h_next / (torch.tanh(c_next.clone()) + EPSILON) - co.clone())
                # print("output_gate,{},{}".format(i,t), output_gate)
                
                

                c_new = input_gate * torch.sigmoid(ci.detach().clone()) + forget_gate * torch.sigmoid(cf.detach().clone()) * c_prev.detach().clone()
                c[i] = c_new
                #c[i] = input_gate * torch.sigmoid(ci) + forget_gate * torch.sigmoid(cf) * c_prev
                #print("c[i],{},{}".format(i,t), c[i])
                h[i] = torch.sigmoid((output_gate + co).detach().clone()) * torch.tanh(c[i].detach().clone())
                #h[i] = torch.sigmoid(output_gate + co) * torch.tanh(c[i])
                #print("h[i],{},{}".format(i,t), h[i])

            hidden_seq.append(h[-1])

        hidden_seq = torch.stack(hidden_seq, dim=1)
        out = self.lin(hidden_seq[:, -1, :]).view([-1, self.ydim])
        return out

    def get_rnn_outputs(self, x):
        batch_size = x.size(0)
        hidden_seq = []

        h = torch.zeros(self.n_layer, batch_size, self.hdim).to(x.device)
        c = torch.zeros(self.n_layer, batch_size, self.hdim).to(x.device)

        for t in range(x.size(1)):
            for i in range(self.n_layer):
                cell = self.cells[i]
                cell_weight = self.cell_weights[i]

                if i == 0:
                    input_chunk = x[:, t, :]
                else:
                    input_chunk = h[i - 1]

                h_prev, c_prev = h[i], c[i]
                h_next, c_next = cell(input_chunk, (h_prev, c_prev))

                ci, cf, co = cell_weight[:self.hdim], cell_weight[self.hdim:2 * self.hdim], cell_weight[2 * self.hdim:]
                input_gate = (c_next - c_prev * torch.sigmoid(cf)) / (1 - torch.sigmoid(cf + ci))
                forget_gate = (c_next - input_gate * torch.sigmoid(ci)) / c_prev
                output_gate = (h_next / torch.tanh(c_next)) - co

                c[i] = input_gate * torch.sigmoid(ci) + forget_gate * torch.sigmoid(cf) * c_prev
                h[i] = torch.sigmoid(output_gate + co) * torch.tanh(c[i])

            hidden_seq.append(h[-1])

        hidden_seq = torch.stack(hidden_seq, dim=1)
        return hidden_seq, (h, c)

R = RecurrentNeuralNetworkClass(
    name='rnn',xdim=28,hdim=128,ydim=10,n_layer=3).to(device)
loss = nn.CrossEntropyLoss()
optm = optim.Adam(R.parameters(),lr=1e-2)
print ("Done.")

In [None]:
np.set_printoptions(precision=3)
torch.set_printoptions(precision=3)
x_numpy = np.random.rand(2,20,28) # [N x L x Q]
x_torch = torch.from_numpy(x_numpy).float().to(device)
rnn_out,(hn,cn) = R.get_rnn_outputs(x_torch) # forward path

print ("rnn_out:",rnn_out.shape) # [N x L x D]
print ("Hidden State hn:",hn.shape) # [K x N x D]
print ("Cell States cn:",cn.shape) # [K x N x D]

In [None]:
np.set_printoptions(precision=3)
n_param = 0
for p_idx,(param_name,param) in enumerate(R.named_parameters()):
    if param.requires_grad:
        param_numpy = param.detach().cpu().numpy() # to numpy array 
        n_param += len(param_numpy.reshape(-1))
        print ("[%d] name:[%s] shape:[%s]."%(p_idx,param_name,param_numpy.shape))
        print ("    val:%s"%(param_numpy.reshape(-1)[:5]))
print ("Total number of parameters:[%s]."%(format(n_param,',d')))

In [None]:
np.set_printoptions(precision=3)
torch.set_printoptions(precision=3)
x_numpy = np.random.rand(2,20,28) # [N x L x Q]
x_torch = torch.from_numpy(x_numpy).float().to(device)
y_torch = R.forward(x_torch) # [N x 1 x R] where R is the output dim.
y_numpy = y_torch.detach().cpu().numpy() # torch tensor to numpy array
# print ("x_torch:\n",x_torch)
# print ("y_torch:\n",y_torch)
print ("x_numpy %s"%(x_numpy.shape,))
print ("y_numpy %s"%(y_numpy.shape,))

In [None]:
def func_eval(model,data_iter,device):
    with torch.no_grad():
        n_total,n_correct = 0,0
        model.eval() # evaluate (affects DropOut and BN)
        for batch_in,batch_out in data_iter:
            y_trgt = batch_out.to(device)
            #print("y_trgt", y_trgt)
            model_pred = model.forward(batch_in.view(-1,28,28).to(device))
            #print("model_pred", model_pred)
            _,y_pred = torch.max(model_pred,1)
            n_correct += (y_pred==y_trgt).sum().item()
            n_total += batch_in.size(0)
        val_accr = (n_correct/n_total)
        model.train() # back to train mode 
    return val_accr
print ("Done")

In [None]:
train_accr = func_eval(R,train_iter,device)
test_accr = func_eval(R,test_iter,device)
print ("train_accr:[%.3f] test_accr:[%.3f]."%(train_accr,test_accr))

In [None]:
print ("Start training.")
R.train() # to train mode 
EPOCHS,print_every = 30,1
for epoch in range(EPOCHS):
    loss_val_sum = 0
    for batch_in,batch_out in train_iter:
        # Forward path
        y_pred = R.forward(batch_in.view(-1,28,28).to(device))
        loss_out = loss(y_pred,batch_out.to(device))
        # Update
        optm.zero_grad() # reset gradient 
        loss_out.backward() # backpropagate
        optm.step() # optimizer update
        loss_val_sum += loss_out
    loss_val_avg = loss_val_sum/len(train_iter)
    # Print
    if ((epoch%print_every)==0) or (epoch==(EPOCHS-1)):
        train_accr = func_eval(R,train_iter,device)
        test_accr = func_eval(R,test_iter,device)
        print ("epoch:[%d] loss:[%.3f] train_accr:[%.3f] test_accr:[%.3f]."%
               (epoch,loss_val_avg,train_accr,test_accr))
print ("Done")

In [None]:
n_sample = 25
sample_indices = np.random.choice(len(mnist_test.targets),n_sample,replace=False)
test_x = mnist_test.data[sample_indices]
test_y = mnist_test.targets[sample_indices]
with torch.no_grad():
    R.eval() # to evaluation mode 
    y_pred = R.forward(test_x.view(-1,28,28).type(torch.float).to(device)/255.)
y_pred = y_pred.argmax(axis=1)
plt.figure(figsize=(10,10))
for idx in range(n_sample):
    plt.subplot(5, 5, idx+1)
    plt.imshow(test_x[idx], cmap='gray')
    plt.axis('off')
    plt.title("Pred:%d, Label:%d"%(y_pred[idx],test_y[idx]))
plt.show()
print ("Done")