-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo.py
49 lines (37 loc) · 1.35 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import chainer
from chainer import Variable
import chainer.functions as F
import utils
import dataset
from PIL import Image
import models.crnn as crnn
import numpy as np
import skimage.io as skio
def main():
img_path = './data/demo.png'
alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'
gpu = -1
model = crnn.CRNN(32, 1, 37, 256)
if gpu >= 0:
model = model.cuda()
# model_path = './data/crnn.pth'
#print('loading pretrained model from %s' % model_path)
#model.load_state_dict(torch.load(model_path))
converter = utils.strLabelConverter(alphabet)
transformer = dataset.resizeNormalize((32, 100))
image = np.expand_dims(skio.imread(img_path, as_grey=True), axis=0)
image = transformer(image)
image = np.expand_dims(image, axis=0).astype(np.float32)
if gpu >= 0:
image = cuda.to_gpu(image)
image = Variable(image)
with chainer.using_config('train', False):
preds = model(image) # (26, 1, 37)
preds = F.argmax(preds, axis=2) # (26, 1)
preds = F.transpose(preds, axes=(1, 0)).reshape(-1) # (26,)
preds_size = Variable(np.array([preds.shape[0]]))
raw_pred = converter.decode(preds.data, preds_size.data, raw=True)
sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
print('%-20s => %-20s' % (raw_pred, sim_pred))
if __name__ == '__main__':
main()