forked from MorvanZhou/PyTorch-Tutorial
-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_rnn.py
147 lines (123 loc) · 5.28 KB
/
train_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# python -m pdb tutorial-contents/train_rnn.py train -mnist_dir /Users/Natsume/Downloads/morvan_new_pytorch/mnist -batch_size 32 -test_size 100 -num_epochs 10 -num_batches 50 -net /Users/Natsume/Downloads/temp_folders/402_train_rnn/net.pkl -log /Users/Natsume/Downloads/temp_folders/402_train_rnn/log.pkl
import argparse
import sys
import torch
import torchvision
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.utils.data as Data
import matplotlib.pyplot as plt
import numpy as np
from prepare_mnist import prepareData
from build_rnn import build_net
def train(args):
""" Trains a model.
"""
## hyper-parameters
train_again = True
plot_loss = True
# prepare dataset
train_loader, test_images, test_labels = prepareData(args)
# total number of batches of an epoch
total_train_samples = train_loader.dataset.train_data.__len__()
total_num_batches = int(total_train_samples / args.batch_size)
# it should add 1 or do ceil()
## load rnn or create rnn from scratch
if train_again:
rnn = torch.load(args.net_path)
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.01)
loss_func = nn.CrossEntropyLoss()
steps, losses_train, losses_val, accuracies = torch.load(args.log_path)
previous_step = steps[-1]
else:
rnn, optimizer, loss_func = build_net(args)
# store every 10 number of batches trainning
losses_val = []
losses_train = []
accuracies = []
steps = []
previous_step = 0
# for every epoch of training
for epoch_idx in range(args.num_epochs):
# store every batch loss
loss_list = []
# loop every batch of an epoch
for batch_idx, (batch_img, batch_lab) in enumerate(train_loader):
# rnn model only access input with 3d shape
# batch_img size (batch, 1, 28, 28)
# shrink to (batch, 28, 28)
b_img = Variable(batch_img.view(-1, 28, 28))
b_lab = Variable(batch_lab)
# actual batch training
r_out, h_n, h_c, out = rnn(b_img)
loss = loss_func(out, b_lab)
optimizer.zero_grad()
loss.backward()
## todo: b_img is required_grad=True right? yes, proved
optimizer.step()
# store every batch loss
loss_list.append(loss.data.numpy()[0])
# every 10 batches, print log; use args.num_batches_log
if batch_idx % 10 == 0:
# store steps index for every 10 batches
steps.append(args.num_batches * epoch_idx + batch_idx + previous_step)
# store avg_loss every 10 batches training
avg_loss_batch_train = np.array(loss_list).mean()
loss_list = []
losses_train.append(avg_loss_batch_train)
# rnn takes (batch, step_size, feature_size), cnn takes (batch, 1, img_width, img_height); so need to shrink tensor.size
test_images = torch.squeeze(test_images) # already variable
_, _, _, test_output = rnn(test_images)
# validation loss
loss_val = loss_func(test_output, test_labels)
losses_val.append(loss_val.data.numpy()[0])
# use test|validation set, get accuracy
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
accuracy = sum(pred_y == test_labels.data.numpy()) / test_labels.data.numpy().shape[0]
accuracies.append(accuracy)
# print log
print('Epoch: '+ str(epoch_idx+1) + ' Batches: %04d' % (batch_idx+1) + ' | avg_train loss: %.4f' % avg_loss_batch_train, 'loss_val: %.4f'%loss_val.data.numpy()[0], ' | test accuracy: %.2f' % accuracy)
# If don't want to train for a full epoch, stop early on specific number of batches
if args.num_batches-1 == batch_idx:
break
# save net and log
torch.save(rnn, args.net_path)
torch.save((steps, losses_train, losses_val, accuracies), args.log_path)
if plot_loss:
plt.plot(steps, losses_train, c='blue', label='train')
plt.plot(steps, losses_val, c='red', label='val')
plt.legend(loc='best')
plt.show()
#########################################################
def build_parser():
""" Constructs an argument parser and returns the parsed arguments.
"""
# start: description
parser = argparse.ArgumentParser(description='my argparse tool')
# create a command line function
subparsers = parser.add_subparsers(dest='cmd', help='Sub-command help.')
#########################################################
subparser = subparsers.add_parser('train', help='Trains a model for the first time.')
# add args to train function
subparser.add_argument('-mnist_dir', required=True, help="Path where mnist stored")
subparser.add_argument('-batch_size', type=int, default=32, help="Number of samples in each batch")
subparser.add_argument('-num_batches', type=int, default=100, help="Number of batches to train in each epoch")
subparser.add_argument('-test_size', type=int, default=1000, help="Number of samples to test during testing")
subparser.add_argument('-net', '--net_path', required=True, help="Path to save neuralnet model")
subparser.add_argument('-log', '--log_path', required=True, help="Path to save log information: losses, steps")
subparser.add_argument('-num_epochs', type=int, default=1, help="Number of epochs to train this time")
subparser.set_defaults(func=train)
return parser, subparsers
def parse_args(parser):
""" Parses command-line arguments.
"""
return parser.parse_args()
def main():
parser, _ = build_parser()
args = parse_args(parser)
sys.exit(args.func(args) or 0)
if __name__ == '__main__':
main()
# run the line below in terminal
# python -m pdb 401_build_rnn.py build_net