In [1]:
import torch
import torch.nn as nn
from torch import optim
from torch.autograd import Variable

import model
import dataloader
import dictionary

In [2]:
net = model.CRNN(37, 128)
ctc_loss = nn.CTCLoss()

In [3]:
def train(epochs, lr, device='cuda'):
    net.train().double().to(device), ctc_loss.double().to(device)
    optimizer = optim.Adam(net.parameters(), lr=lr)
    for epoch in range(epochs):
        running_loss = 0.
        loader = dataloader.dataloader()
        while True:
            try:
                imgs, targets = loader.__next__()
            except Exception as e:
                break
            b, l_targets = targets.shape
            imgs, targets = Variable(torch.from_numpy(imgs)).double().to(device), Variable(torch.from_numpy(targets)).int().to(device)
            optimizer.zero_grad()
            outputs = net(imgs)
            l_outputs, b, p = outputs.size()
            loss = ctc_loss(outputs, targets, torch.IntTensor([l_outputs]*b), torch.IntTensor([l_targets]*b))
            loss.backward()
            running_loss += loss.item()
            optimizer.step()
        if epoch%10 == 0:
            print(epoch, running_loss)

In [4]:
train(100, 0.01)
train(100, 0.001)

0 19.56829888204109
10 28.51731262108885
20 28.432716117989404
30 28.40175810073319
40 28.363011510441847
50 28.326295183179013
60 28.29474666250559
70 28.268908712662338
80 28.248336588443184
90 28.23224202195791
0 27.184383712609318
10 27.151468814219278
20 27.092606689752827
30 27.202075628975436
40 27.440667623263423
50 27.871181466220904
60 25.514324665007123
70 24.264049073534164
80 23.135564549832814
90 21.428659161932302


In [5]:
train(100, 0.001)
train(100, 0.0001)

0 20.88055322039014
10 18.53818881067844
20 17.566887394750236
30 18.05347306884932
40 16.878291167536553
50 16.186994505686293
60 15.353458867153849
70 15.386133299067906
80 14.837307205455396
90 12.79016388546691
0 11.819876263906117
10 11.423398110225866
20 11.29483609906391
30 11.17220044114774
40 11.028084812762751
50 10.97398414647012
60 10.87583885081818
70 10.805586277897483
80 10.672831263012274
90 10.60511974240843


In [23]:
train(100, 0.001)
train(100, 0.0001)

0 1.6964632715868389
10 0.827119262212258
20 0.4309869432104593
30 2.0067856181550634
40 3.0955418439081295
50 0.6380744556414274
60 0.342501797091384
70 0.39936343818976916
80 2.062026730768948
90 0.21245933028397204
0 0.23000829728267402
10 0.23113880804283657
20 0.12658621828663746
30 0.21646170207073398
40 0.18225153363915178
50 0.11774005785998043
60 0.04585517360569485
70 0.062342100223253746
80 0.10216076462027479
90 0.06831715109992811


In [24]:
net.cpu()
torch.save(net.state_dict(), 'checkpoint.pth')

In [12]:
import cv2
import matplotlib.pyplot as plt

In [19]:
def predict(path, device="cuda"):
    net.eval().to(device)
    img = cv2.imread(path, 0)
    img = dataloader.normal(img)
    img = img[None,None,:,:]
    img = Variable(torch.from_numpy(img)).double().to(device)
    output = net(img).squeeze(1)
    prob, idx = output.topk(1)
    idx = idx.view(-1).detach().cpu().numpy()
    words = list(dictionary.encode(idx))
#     print(words)
    i = 1
    while i<len(words):
        if words[i] == words[i-1]:
            del words[i]
        else:
            i += 1
    return ''.join(words).replace(' ','')

In [25]:
import os
root_dir = 'test_img/2/'
paths = os.listdir(root_dir)
for path in paths:
    print('target',path.split('_')[0])
    path = root_dir+path
    words = predict(path)
    print('predict',words)
    print('-----------')

target 06
predict 0m
-----------
target 0r
predict 0w
-----------
target 1q
predict 1d
-----------
target 2d
predict 2f
-----------
target 2p
predict 2w
-----------
target 3n
predict 3h
-----------
target 6x
predict 6x
-----------
target 9f
predict 9f
-----------
target a4
predict a4
-----------
target d5
predict d5
-----------
target eu
predict eu
-----------
target jh
predict jh
-----------
target k9
predict k9
-----------
target kp
predict kp
-----------
target m6
predict m6
-----------
target me
predict me
-----------
target mw
predict mw
-----------
target pf
predict pf
-----------
target qu
predict qu
-----------
target r5
predict 5
-----------
target rf
predict rf
-----------
target tr
predict trw
-----------
target wm
predict wm
-----------
target xp
predict xp
-----------


In [27]:
root_dir = 'test_img/3/'
paths = os.listdir(root_dir)
for path in paths:
    print('target',path.split('_')[0])
    path = root_dir+path
    words = predict(path)
    print('predict',words)
    print('-----------')

target 0i6
predict 0i6
-----------
target 0uk
predict 0uk
-----------
target 32c
predict 32c
-----------
target 72z
predict 72z
-----------
target 9u2
predict 9u2
-----------
target a1e
predict a1e
-----------
target c1m
predict c1m
-----------
target ckw
predict ckw
-----------
target dsf
predict c6
-----------
target ib4
predict pf
-----------
target mgs
predict my
-----------
target n8b
predict api
-----------
target njk
predict 0ukx
-----------
target rkc
predict tpf
-----------


In [28]:
root_dir = 'test_img/4/'
paths = os.listdir(root_dir)
for path in paths:
    print('target',path.split('_')[0])
    path = root_dir+path
    words = predict(path)
    print('predict',words)
    print('-----------')

target 1oev
predict 1oev
-----------
target 2kvf
predict 2kvf
-----------
target 4pfo
predict 4pfo
-----------
target 52ie
predict 52ie
-----------
target 5816
predict 5816
-----------
target 63lu
predict 63lu
-----------
target 6jz9
predict 6jz9
-----------
target 6zp7
predict 6zp7
-----------
target 9egm
predict 9em
-----------
target a2yv
predict a2y
-----------
target anth
predict cfh
-----------
target b8p2
predict b8pf
-----------
target bjka
predict bjkpa
-----------
target e5ia
predict e3m
-----------
target jqx2
predict jqxp2
-----------
target kjpi
predict kjpw
-----------
target kxpf
predict kxpf
-----------
target m3gs
predict ms
-----------
target p9oy
predict p9oy
-----------
target tl0a
predict t0a
-----------
target u3gy
predict u3g
-----------
target vcem
predict vizm
-----------
target vh94
predict vo4
-----------
target vhmd
predict vmd
-----------
target x1yf
predict xpaf
-----------
target yum5
predict a8xpev
-----------
