In [1]:
import os
import random

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
from torch.nn import CTCLoss

import crnn.utils as utils
import crnn.dataset as dataset
import crnn.models.crnn as crnn

In [2]:
manualSeed = 5213

In [3]:
random.seed(manualSeed)
np.random.seed(manualSeed)
torch.manual_seed(manualSeed)

<torch._C.Generator at 0x109cc4590>

In [4]:
cudnn.benchmark = True

In [5]:
trainroot = '/Users/chienan/job/asr/competition/lmdb/train/'
valroot = '/Users/chienan/job/asr/competition/lmdb/test/'

In [6]:
train_dataset = dataset.lmdbDataset(root=trainroot)

In [7]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=2,
    shuffle=True,
    sampler=None,
    collate_fn=dataset.alignCollate(H=50, W=200))

In [8]:
test_dataset = dataset.lmdbDataset(root=valroot)

In [9]:
alphabet = '23456789ABCDEFGHJKNPQRSTUVXYZ'
char_dict = {'2':0, '3':1, '4':2, '5':3, '6':4, '7':5, '8':6, '9':7, 
         'A':8, 'B':9, 'C':10, 'D':11, 'E':12, 'F':13, 'G':14, 'H':15, 'J':16, 'K':17, 'N':18, 'P':19, 'Q':20,
         'R':21, 'S':22, 'T':23, 'U':24, 'V':25, 'X':26, 'Y':27, 'Z':28}

In [10]:
nclass = len(alphabet) + 1
nc = 1
batchSize = 2

In [11]:
converter = utils.strLabelConverter(alphabet)
criterion = CTCLoss(blank=nclass-1, reduction='mean')

In [12]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [13]:
crnn = crnn.CRNN(nc=nc, nclass=nclass, rnn_node=64, n_rnn=2, leakyRelu=False)

In [14]:
crnn.apply(weights_init)

CRNN(
  (cnn): Sequential(
    (conv0): Conv2d(1, 32, kernel_size=(3, 7), stride=(1, 1))
    (relu0): ReLU(inplace=True)
    (conv1): Conv2d(32, 32, kernel_size=(3, 7), stride=(1, 1))
    (relu1): ReLU(inplace=True)
    (pooling1): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(32, 32, kernel_size=(3, 7), stride=(1, 1))
    (relu2): ReLU(inplace=True)
    (conv3): Conv2d(32, 32, kernel_size=(3, 7), stride=(1, 1))
    (relu3): ReLU(inplace=True)
    (pooling2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
    (conv4): Conv2d(32, 32, kernel_size=(3, 7), stride=(1, 1))
    (relu4): ReLU(inplace=True)
    (conv5): Conv2d(32, 32, kernel_size=(3, 7), stride=(1, 1))
    (relu5): ReLU(inplace=True)
    (pooling3): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  )
  (rnn): Sequential(
    (0): BidirectionalLSTM(
      (rnn): LSTM(32, 64, bidirectional=True)
      (embedding): Linear(in_fea

In [15]:
image = torch.FloatTensor(batchSize, 1, 50, 200)
text = torch.IntTensor(batchSize * 5)
length = torch.IntTensor(batchSize)

In [16]:
image = Variable(image)
text = Variable(text)
length = Variable(length)

In [17]:
# loss averager
loss_avg = utils.averager()

In [18]:
optimizer = optim.Adam(crnn.parameters(), lr=0.01, betas=(0.5, 0.999))

In [19]:
def trainBatch(net, criterion, optimizer):
    data = train_iter.next()
    cpu_images, cpu_texts = data
    batch_size = cpu_images.size(0)
    utils.loadData(image, cpu_images)
    t, l = converter.encode(cpu_texts)
    utils.loadData(text, t)
    utils.loadData(length, l)

    preds = crnn(image)
    preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
    cost = criterion(preds, text, preds_size, length) / batch_size
    crnn.zero_grad()
    cost.backward()
    optimizer.step()
    return cost

In [20]:
train_iter = iter(train_loader)

In [21]:
crnn.training

True

In [None]:
for epoch in range(10):
    train_iter = iter(train_loader)
    i = 0
    while i < len(train_loader):
        cost = trainBatch(crnn, criterion, optimizer)
        loss_avg.add(cost)
        i += 1

        if i % 2 == 0:
            print('[%d/%d][%d/%d] Loss: %f' %
                  (epoch, 10, i, len(train_loader), loss_avg.val()))
            loss_avg.reset()
        print(i)

1
[0/10][2/4651] Loss: -35.303837
2
3
[0/10][4/4651] Loss: 355.850464
4
5
[0/10][6/4651] Loss: -107.520813
6
7
[0/10][8/4651] Loss: 89.589409
8
9
[0/10][10/4651] Loss: 28.127769
10
11
[0/10][12/4651] Loss: -18.494911
12
13
[0/10][14/4651] Loss: -21.281744
14
15
[0/10][16/4651] Loss: -4.464666
16
17
[0/10][18/4651] Loss: -12.296000
18
19
[0/10][20/4651] Loss: 28.673447
20
21
[0/10][22/4651] Loss: 68.519493
22
23
[0/10][24/4651] Loss: 41.993011
24
25
[0/10][26/4651] Loss: 14.069054
26
27
[0/10][28/4651] Loss: -45.508568
28
29
[0/10][30/4651] Loss: 18.811386
30
31
[0/10][32/4651] Loss: 18.662685
32
33
[0/10][34/4651] Loss: 4.627020
34
35
[0/10][36/4651] Loss: 4.330544
36
37
[0/10][38/4651] Loss: -9.569729
38
39
[0/10][40/4651] Loss: -1.719812
40
41
[0/10][42/4651] Loss: -3.211915
42
43
[0/10][44/4651] Loss: -1.390604
44
45
[0/10][46/4651] Loss: -1.529884
46
47
[0/10][48/4651] Loss: 1.454750
48
49
[0/10][50/4651] Loss: -6.452713
50
51
[0/10][52/4651] Loss: -11.824207
52
53
[0/10][54/4651] 