In [10]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'
print('Pytorch version :', torch.__version__)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'device : [{device}]')

Pytorch version : 1.7.1
device : [cuda:0]


In [11]:
from torchvision import datasets, transforms
train_data = datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=False)
test_data = datasets.MNIST(root='./data',train=False,transform=transforms.ToTensor(),download=False)

In [12]:
BATCH_SIZE = 256
train_iter = DataLoader(train_data,shuffle=True,batch_size=BATCH_SIZE,num_workers=1)
test_iter = DataLoader(test_data,shuffle=True,batch_size=BATCH_SIZE,num_workers=1)

In [23]:
class LSTM(nn.Module):
    def __init__(self,x_dim=28,h_dim=256,y_dim=10,n_layer=3 ):
        super(LSTM,self).__init__()
        self.x_dim = x_dim
        self.h_dim = h_dim
        self.y_dim = y_dim
        self.n_layer = n_layer
        # forget gate
        self.lin_xh_f = nn.Linear(self.x_dim,self.h_dim)
        self.lin_hh_f = nn.Linear(self.h_dim,self.h_dim)
        # input gate
        self.lin_xh_in1 = nn.Linear(self.x_dim,self.h_dim)
        self.lin_hh_in1 = nn.Linear(self.h_dim,self.h_dim)
        self.lin_xh_in2 = nn.Linear(self.x_dim,self.h_dim)
        self.lin_hh_in2 = nn.Linear(self.h_dim,self.h_dim)
        # output gate
        self.lin_xh_out = nn.Linear(self.x_dim,self.h_dim)
        self.lin_hh_out = nn.Linear(self.h_dim,self.h_dim)        
        
        self.lin = nn.Linear(self.h_dim,self.y_dim)
        
        self.init_param()
    
    def forward(self,input):
        x = input
        h_0 = torch.zeros(self.n_layer,x.size(0),self.h_dim).to(device)
        c_0 = torch.zeros(self.n_layer,x.size(0),self.h_dim).to(device)
        forget_gate = (self.lin_xh_f(x)+self.lin_hh_f)
        out = self.lin(rnn_out[:,-1,:]).view(-1,self.y_dim)
        return out
    
    def init_param(self):
        for name, param in self.named_parameters():
            if 'Linear' in name:
                nn.init.kaiming_normal_(param.weight)
                nn.init.zeros_(param.bias)

In [24]:
model = LSTM(x_dim=28,h_dim=256,y_dim=10,n_layer=3).to(device)
optimizer = optim.Adam(model.parameters(),lr=1e-3)
loss = nn.CrossEntropyLoss()


In [21]:
list(model.modules())

[RNN(
   (lin_xh_f): Linear(in_features=28, out_features=256, bias=True)
   (lin_hh_f): Linear(in_features=256, out_features=256, bias=True)
   (lin_xh_in1): Linear(in_features=28, out_features=256, bias=True)
   (lin_hh_in1): Linear(in_features=256, out_features=256, bias=True)
   (lin_xh_in2): Linear(in_features=28, out_features=256, bias=True)
   (lin_hh_in2): Linear(in_features=256, out_features=256, bias=True)
   (lin_xh_out): Linear(in_features=28, out_features=256, bias=True)
   (lin_hh_out): Linear(in_features=256, out_features=256, bias=True)
   (lin): Linear(in_features=256, out_features=10, bias=True)
 ),
 Linear(in_features=28, out_features=256, bias=True),
 Linear(in_features=256, out_features=256, bias=True),
 Linear(in_features=28, out_features=256, bias=True),
 Linear(in_features=256, out_features=256, bias=True),
 Linear(in_features=28, out_features=256, bias=True),
 Linear(in_features=256, out_features=256, bias=True),
 Linear(in_features=28, out_features=256, bias=Tr

In [22]:
list(model.named_modules())

[('',
  RNN(
    (lin_xh_f): Linear(in_features=28, out_features=256, bias=True)
    (lin_hh_f): Linear(in_features=256, out_features=256, bias=True)
    (lin_xh_in1): Linear(in_features=28, out_features=256, bias=True)
    (lin_hh_in1): Linear(in_features=256, out_features=256, bias=True)
    (lin_xh_in2): Linear(in_features=28, out_features=256, bias=True)
    (lin_hh_in2): Linear(in_features=256, out_features=256, bias=True)
    (lin_xh_out): Linear(in_features=28, out_features=256, bias=True)
    (lin_hh_out): Linear(in_features=256, out_features=256, bias=True)
    (lin): Linear(in_features=256, out_features=10, bias=True)
  )),
 ('lin_xh_f', Linear(in_features=28, out_features=256, bias=True)),
 ('lin_hh_f', Linear(in_features=256, out_features=256, bias=True)),
 ('lin_xh_in1', Linear(in_features=28, out_features=256, bias=True)),
 ('lin_hh_in1', Linear(in_features=256, out_features=256, bias=True)),
 ('lin_xh_in2', Linear(in_features=28, out_features=256, bias=True)),
 ('lin_hh_