<a href="https://colab.research.google.com/github/KennyThinh/Algorithm/blob/master/05_LoadSaveModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
!mkdir /dataset
!wget www.di.ens.fr/~lelarge/MNIST.tar.gz -P /dataset
!tar -zxvf /dataset/MNIST.tar.gz -C /dataset/

--2021-03-24 06:35:21--  http://www.di.ens.fr/~lelarge/MNIST.tar.gz
Resolving www.di.ens.fr (www.di.ens.fr)... 129.199.99.14
Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://www.di.ens.fr/~lelarge/MNIST.tar.gz [following]
--2021-03-24 06:35:22--  https://www.di.ens.fr/~lelarge/MNIST.tar.gz
Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/x-gzip]
Saving to: ‘/dataset/MNIST.tar.gz’

MNIST.tar.gz            [            <=>     ]  33.20M  6.47MB/s    in 5.4s    

2021-03-24 06:35:28 (6.17 MB/s) - ‘/dataset/MNIST.tar.gz’ saved [34813078]

MNIST/
MNIST/raw/
MNIST/raw/train-labels-idx1-ubyte
MNIST/raw/t10k-labels-idx1-ubyte.gz
MNIST/raw/t10k-labels-idx1-ubyte
MNIST/raw/t10k-images-idx3-ubyte.gz
MNIST/raw/train-images-idx3-ubyte
MNIST/raw/train-labels-idx1-ubyte.gz
MNIST/raw/t10k-images-idx3-u

In [4]:
#input N*1*28*28 --> 28 time series, each series has 28 features
class RNN(nn.Module):
  def __init__(self, input_size, hidden_size, num_layers, num_classes):
    super(RNN, self).__init__()
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.rnn = nn.RNN(input_size, hidden_size = hidden_size, num_layers= num_layers, batch_first = True)
    self.fc = nn.Linear(hidden_size*sequence_length,num_classes)
    

  def forward(self, x):# x shape N*sequence_length*input_size
    h0 = torch.zeros(self.num_layers,x.shape[0], self.hidden_size).to(device)
    out, _ = self.rnn(x, h0) #because we set batch_first = True, then output will be (batch, seq, hidden_size)
    out = out.reshape(out.shape[0],-1)
    out = self.fc(out)
    return out

class GRU(nn.Module):
  def __init__(self, input_size, hidden_size, num_layers, num_classes):
    super(GRU, self).__init__()
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.gru = nn.GRU(input_size, hidden_size = hidden_size, num_layers= num_layers, batch_first = True)
    self.fc = nn.Linear(hidden_size,num_classes)
    

  def forward(self, x):# x shape N*sequence_length*input_size
    h0 = torch.zeros(self.num_layers,x.shape[0], self.hidden_size).to(device)
    out, _ = self.gru(x, h0) 
    out = out[:,-1,:] #we want out is (N,hidden_size) so that we can feed to FC
    out = self.fc(out)
    return out

class LSTM(nn.Module):
  def __init__(self, input_size, hidden_size, num_layers, num_classes):
    super(LSTM, self).__init__()
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.lstm = nn.LSTM(input_size, hidden_size = hidden_size, num_layers= num_layers, batch_first = True)
    self.fc = nn.Linear(hidden_size,num_classes)    

  def forward(self, x):# x shape N*sequence_length*input_size
    h0 = torch.zeros(self.num_layers,x.shape[0], self.hidden_size).to(device)
    c0 = torch.zeros(self.num_layers,x.shape[0], self.hidden_size).to(device)
    out, _ = self.lstm(x, (h0,c0))
    out = out[:,-1,:] #we want out is (N,hidden_size) so that we can feed to FC
    out = self.fc(out)
    return out


In [20]:
sequence_length= 28 
input_size = 28 #features
num_layers = 2 #no of RNN cell
hidden_size = 128 #size of hidden state
num_classes = 10
num_epochs = 6
batch_size = 64
learning_rate = 0.001
is_load_model = True

In [21]:
def save_model(dict_state, filename="checkpoint.pth.tar"):
  print("--> Saving checkpoint")
  torch.save(dict_state, filename)
def load_model(filename="checkpoint.pth.tar"):
  print("<-- Load checkpoint")
  checkpoint = torch.load(filename)
  model.load_state_dict(checkpoint["state_dict"])
  optimizer.load_state_dict(checkpoint["optimizer"])


In [22]:
train_dataset = datasets.MNIST(root="/dataset/",train=True, transform=transforms.ToTensor(), download=False)
train_loader = DataLoader(train_dataset,batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST(root="/dataset/", train=True, transform=transforms.ToTensor(), download=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [23]:
# model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)
# model = GRU(input_size, hidden_size, num_layers, num_classes).to(device)
model = LSTM(input_size, hidden_size, num_layers, num_classes).to(device)

In [24]:
#loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
#train
if is_load_model:
  load_model()

for epoch in range(num_epochs):

  if epoch % 2 == 0:
    checkpoint = {"state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
    save_model(checkpoint)

  for batch_idx, (data, target) in enumerate(train_loader):
    data = data.to(device).squeeze(1) #shape (64, 28, 28)
    
    target = target.to(device)

    #forwad
    scores = model(data)
    loss = criterion(scores, target)

    #backward
    optimizer.zero_grad()
    loss.backward()

    #gradient
    optimizer.step()
  print(f"loss at {epoch} is {loss}")

<-- Load checkpoint
--> Saving checkpoint
loss at 0 is 0.030677832663059235


KeyboardInterrupt: ignored

In [None]:
def check_accuracy(loader, model):
  num_correct = 0
  num_samples = 0
  model.eval()
  with torch.no_grad():
    for x, y in loader:
      x = x.to(device).squeeze(1)
      y = y.to(device)
      scores = model(x)
      _,predictions = torch.max(scores, dim= 1)
      num_correct += (predictions==y).sum()
      num_samples += predictions.size(0)
    print(f'Got {num_correct}/{num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}')
check_accuracy(train_loader, model)
check_accuracy(test_loader, model)

Got 58767/60000 with accuracy 97.95
Got 58767/60000 with accuracy 97.95


In [None]:
# RNN: 
# Got 58541/60000 with accuracy 97.57
# Got 58541/60000 with accuracy 97.57

# GRU
# Got 58597/60000 with accuracy 97.66
# Got 58597/60000 with accuracy 97.66

#LSTM
# Got 58767/60000 with accuracy 97.95
# Got 58767/60000 with accuracy 97.95