# Train

In [1]:
import os

import cv2
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn import CTCLoss

from dataset import Synth90kDataset, synth90k_collate_fn
from model import CRNN
from evaluate import evaluate
from config import train_config as config

## Train Fuctions

### Comments

Shape

- images: (N, C, H, W) -> (32, 1, 32, 100)
- targets：(X, ),one dimenstion ，32 batch labels
- logits：(T, N, n_class) -> (24, 32, 37)
- input_lengths: (N,) -> (32, ), one dimenstion A batch with 32 sample of(input character length)
- target_lengths: (N, ) ->(32, ), one dimenstion A batch with 32 sample of(target character length)

Init signature: CTCLoss(blank:int=0, reduction:str='mean', zero_infinity:bool=False)
Docstring:     
The Connectionist Temporal Classification loss.

Calculates loss between a continuous (unsegmented) time series and a target sequence. CTCLoss sums over the
probability of possible alignments of input to target, producing a loss value which is differentiable
with respect to each input node. The alignment of input to target is assumed to be "many-to-one", which
limits the length of the target sequence such that it must be :math:`\leq` the input length.

- log_probs：shape为(T, N, C)的模型输出张量，其中，T表示CTCLoss的输入长度也即输出序列长度，N表示训练的batch size长度，C则表示包含有空白标签的所有要预测的字符集总长度，log_probs一般需要经过torch.nn.functional.log_softmax处理后再送入到CTCLoss中；

- targets：shape为(N, S) 或(sum(target_lengths))的张量，其中第一种类型，N表示训练的batch size长度，S则为标签长度，第二种类型，则为所有标签长度之和，但是需要注意的是targets不能包含有空白标签；

- input_lengths：shape为(N)的张量或元组，但每一个元素的长度必须等于T即输出序列长度，一般来说模型输出序列固定后则该张量或元组的元素值均相同；

- target_lengths：shape为(N)的张量或元组，其每一个元素指示每个训练输入序列的标签长度，但标签长度是可以变化的；

Args:
    blank (int, optional): blank label. Default :math:`0`.
    reduction (string, optional): Specifies the reduction to apply to the output:
        ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
        ``'mean'``: the output losses will be divided by the target lengths and
        then the mean over the batch is taken. Default: ``'mean'``
    zero_infinity (bool, optional):
        Whether to zero infinite losses and the associated gradients.
        Default: ``False``
        Infinite losses mainly occur when the inputs are too short
        to be aligned to the targets.

Shape:
    - Log_probs: Tensor of size :math:`(T, N, C)`,
      where :math:`T = \text{input length}`,
      :math:`N = \text{batch size}`, and
      :math:`C = \text{number of classes (including blank)}`.
      The logarithmized probabilities of the outputs (e.g. obtained with
      :func:`torch.nn.functional.log_softmax`).
    - Targets: Tensor of size :math:`(N, S)` or
      :math:`(\operatorname{sum}(\text{target\_lengths}))`,
      where :math:`N = \text{batch size}` and
      :math:`S = \text{max target length, if shape is } (N, S)`.
      It represent the target sequences. Each element in the target
      sequence is a class index. And the target index cannot be blank (default=0).
      In the :math:`(N, S)` form, targets are padded to the
      length of the longest sequence, and stacked.
      In the :math:`(\operatorname{sum}(\text{target\_lengths}))` form,
      the targets are assumed to be un-padded and
      concatenated within 1 dimension.
    - Input_lengths: Tuple or tensor of size :math:`(N)`,
      where :math:`N = \text{batch size}`. It represent the lengths of the
      inputs (must each be :math:`\leq T`). And the lengths are specified
      for each sequence to achieve masking under the assumption that sequences
      are padded to equal lengths.
    - Target_lengths: Tuple or tensor of size :math:`(N)`,
      where :math:`N = \text{batch size}`. It represent lengths of the targets.
      Lengths are specified for each sequence to achieve masking under the
      assumption that sequences are padded to equal lengths. If target shape is
      :math:`(N,S)`, target_lengths are effectively the stop index
      :math:`s_n` for each target sequence, such that ``target_n = targets[n,0:s_n]`` for
      each target in a batch. Lengths must each be :math:`\leq S`
      If the targets are given as a 1d tensor that is the concatenation of individual
      targets, the target_lengths must add up to the total length of the tensor.
    - Output: scalar. If :attr:`reduction` is ``'none'``, then
      :math:`(N)`, where :math:`N = \text{batch size}`.

Examples::

    >>> # Target are to be padded
    >>> T = 50      # Input sequence length
    >>> C = 20      # Number of classes (including blank)
    >>> N = 16      # Batch size
    >>> S = 30      # Target sequence length of longest target in batch (padding length)
    >>> S_min = 10  # Minimum target length, for demonstration purposes
    >>>
    >>> # Initialize random batch of input vectors, for *size = (T,N,C)
    >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
    >>>
    >>> # Initialize random batch of targets (0 = blank, 1:C = classes)
    >>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
    >>>
    >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
    >>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
    >>> ctc_loss = nn.CTCLoss()
    >>> loss = ctc_loss(input, target, input_lengths, target_lengths)
    >>> loss.backward()
    >>>
    >>>
    >>> # Target are to be un-padded
    >>> T = 50      # Input sequence length
    >>> C = 20      # Number of classes (including blank)
    >>> N = 16      # Batch size
    >>>
    >>> # Initialize random batch of input vectors, for *size = (T,N,C)
    >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
    >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
    >>>
    >>> # Initialize random batch of targets (0 = blank, 1:C = classes)
    >>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
    >>> target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
    >>> ctc_loss = nn.CTCLoss()
    >>> loss = ctc_loss(input, target, input_lengths, target_lengths)
    >>> loss.backward()

Reference:
    A. Graves et al.: Connectionist Temporal Classification:
    Labelling Unsegmented Sequence Data with Recurrent Neural Networks:
    https://www.cs.toronto.edu/~graves/icml_2006.pdf

Note:
    In order to use CuDNN, the following must be satisfied: :attr:`targets` must be
    in concatenated format, all :attr:`input_lengths` must be `T`.  :math:`blank=0`,
    :attr:`target_lengths` :math:`\leq 256`, the integer arguments must be of
    dtype :attr:`torch.int32`.

    The regular implementation uses the (more common in PyTorch) `torch.long` dtype.


Note:
    In some circumstances when using the CUDA backend with CuDNN, this operator
    may select a nondeterministic algorithm to increase performance. If this is
    undesirable, you can try to make the operation deterministic (potentially at
    a performance cost) by setting ``torch.backends.cudnn.deterministic =
    True``.
    Please see the notes on :doc:`/notes/randomness` for background.
Init docstring: Initializes internal Module state, shared by both nn.Module and ScriptModule.
File:           d:\anaconda3\envs\env_gpu\lib\site-packages\torch\nn\modules\loss.py
Type:           type
Subclasses:   

In [3]:
T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
N = 16      # Batch size
S = 30      # Target sequence length of longest target in batch (padding length)
S_min = 10  # Minimum target length, for demonstration purposes

In [4]:
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
input.shape

torch.Size([50, 16, 20])

In [5]:
input[:,0,:][0] ## the first sample and the first time sequence

tensor([-3.7639, -2.6277, -2.6374, -4.4793, -2.2352, -4.4302, -4.1634, -4.2208,
        -3.6696, -1.5673, -5.3828, -3.0010, -3.2726, -2.5471, -3.9482, -3.0687,
        -1.9266, -4.3272, -3.9098, -3.7618], grad_fn=<SelectBackward>)

In [6]:
target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
target.shape 

torch.Size([16, 30])

In [7]:
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)

In [8]:
target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)

In [9]:
target_lengths

tensor([28, 17, 29, 27, 28, 21, 20, 17, 15, 16, 19, 13, 27, 14, 13, 18])

In [10]:
ctc_loss = CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)

In [11]:
loss

tensor(6.1596, grad_fn=<MeanBackward0>)

In [12]:
loss.backward()

## Train

####  Patermeter

In [13]:
epochs = config['epochs']
train_batch_size = config['train_batch_size']
eval_batch_size = config['eval_batch_size']
lr = config['lr']
show_interval = config['show_interval']
valid_interval = config['valid_interval']
save_interval = config['save_interval']
cpu_workers = config['cpu_workers']
reload_checkpoint = config['reload_checkpoint']
valid_max_iter = config['valid_max_iter']

img_width = config['img_width']
img_height = config['img_height']
data_dir = config['data_dir']

num_class = len(Synth90kDataset.LABEL2CHAR) + 1

cpu_workers = 0

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')

device: cuda


### Load Data

In [15]:
train_dataset = Synth90kDataset(root_dir=data_dir, mode='train',
                                    img_height=img_height, img_width=img_width)
valid_dataset = Synth90kDataset(root_dir=data_dir, mode='dev',
                                img_height=img_height, img_width=img_width)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=cpu_workers,
    collate_fn=synth90k_collate_fn)
valid_loader = DataLoader(
    dataset=valid_dataset,
    batch_size=eval_batch_size,
    shuffle=True,
    num_workers=cpu_workers,
    collate_fn=synth90k_collate_fn)

In [16]:
crnn = CRNN(1, img_height, img_width, num_class,
            map_to_seq_hidden=config['map_to_seq_hidden'],
            rnn_hidden=config['rnn_hidden'],
            leaky_relu=config['leaky_relu'])
if reload_checkpoint:
    crnn.load_state_dict(torch.load(reload_checkpoint, map_location=device))
crnn.to(device)

CRNN(
  (cnn): Sequential(
    (conv0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu0): ReLU(inplace=True)
    (pooling0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1): ReLU(inplace=True)
    (pooling1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu2): ReLU(inplace=True)
    (conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu3): ReLU(inplace=True)
    (pooling2): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)
    (conv4): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batchnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu4): ReLU(inplace=True)
    (conv5): Conv2d(512, 512, 

In [17]:
optimizer = optim.RMSprop(crnn.parameters(), lr=lr)
criterion = CTCLoss(reduction='sum')
criterion.to(device)

CTCLoss()

In [18]:
assert save_interval % valid_interval == 0

#### Start Training

In [19]:
## testing
i = 0
for train_data in train_loader:
    sample_train_data = train_data
    i+=1
    if i==1:
        break

In [20]:
_ = crnn.train()

In [21]:
images, targets, target_lengths = [d.to(device) for d in sample_train_data]

In [22]:
logits = crnn(images)

In [23]:
logits.shape # (T, N, C)

torch.Size([24, 32, 37])

In [24]:
log_probs = torch.nn.functional.log_softmax(logits, dim=2) # softmax

In [25]:
log_probs[:,0,:][0]

tensor([-3.5903, -3.6474, -3.6094, -3.5912, -3.5892, -3.6407, -3.6404, -3.6426,
        -3.5924, -3.5884, -3.5724, -3.6011, -3.6278, -3.6061, -3.5878, -3.5853,
        -3.5867, -3.6093, -3.5839, -3.6416, -3.6299, -3.5777, -3.6182, -3.5736,
        -3.6282, -3.6013, -3.6061, -3.6388, -3.6249, -3.5818, -3.6279, -3.6625,
        -3.6213, -3.6277, -3.6130, -3.6108, -3.6363], device='cuda:0',
       grad_fn=<SelectBackward>)

In [26]:
batch_size = images.size(0)

In [27]:
input_lengths = torch.LongTensor([logits.size(0)] * batch_size)

In [28]:
input_lengths

tensor([24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
        24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24])

In [29]:
target_lengths = torch.flatten(target_lengths)

In [30]:
target_lengths 

tensor([ 6,  5, 10,  5, 14,  7,  8, 12, 11,  7, 11,  9,  5, 11,  7,  7, 13,  7,
         7,  5,  9,  6, 10,  7,  7, 11,  9,  8, 10,  7,  9,  8],
       device='cuda:0')

In [31]:
loss = criterion(log_probs, targets, input_lengths, target_lengths)

In [32]:
loss

tensor(2139.9348, device='cuda:0', grad_fn=<SumBackward0>)

In [34]:
_ = crnn.eval()

In [35]:
def train_batch(crnn, data, optimizer, criterion, device):
    crnn.train()
    images, targets, target_lengths = [d.to(device) for d in data]

    logits = crnn(images)
    log_probs = torch.nn.functional.log_softmax(logits, dim=2)

    batch_size = images.size(0)
    input_lengths = torch.LongTensor([logits.size(0)] * batch_size)
    target_lengths = torch.flatten(target_lengths)

    loss = criterion(log_probs, targets, input_lengths, target_lengths)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

In [36]:
epochs

10000

In [37]:
epochs = 10

In [38]:
for epoch in range(1, epochs + 1):
    print(f'epoch: {epoch}')
    tot_train_loss = 0.
    tot_train_count = 0
    for train_data in train_loader:
        loss = train_batch(crnn, train_data, optimizer, criterion, device)
        train_size = train_data[0].size(0)

        tot_train_loss += loss
        tot_train_count += train_size
        if i % show_interval == 0:
            print('train_batch_loss[', i, ']: ', loss / train_size)

        if i % valid_interval == 0:
            evaluation = evaluate(crnn, valid_loader, criterion,
                                  decode_method=config['decode_method'],
                                  beam_size=config['beam_size'])
            # only once evey elements of sequences is correct, this same can be deemed as correct.  
            print('valid_evaluation: loss={loss}, acc={acc}'.format(**evaluation))

        if i % save_interval == 0:
            prefix = 'crnn'
            loss = evaluation['loss']
            save_model_path = os.path.join(config['checkpoints_dir'],
                                           f'{prefix}_{i:06}_loss{loss}.pt')
            torch.save(crnn.state_dict(), save_model_path)
            print('save model at ', save_model_path)

        i += 1

    print('train_loss: ', tot_train_loss / tot_train_count)

epoch: 1
train_batch_loss[ 10 ]:  27.803953170776367
train_batch_loss[ 20 ]:  24.912765502929688
train_batch_loss[ 30 ]:  23.342370986938477
train_batch_loss[ 40 ]:  27.656431198120117
train_batch_loss[ 50 ]:  24.8463134765625
train_loss:  27.45540482855296
epoch: 2
train_batch_loss[ 60 ]:  24.50200080871582
train_batch_loss[ 70 ]:  26.64198875427246
train_batch_loss[ 80 ]:  24.154861450195312
train_batch_loss[ 90 ]:  27.972400665283203
train_batch_loss[ 100 ]:  24.27780532836914
train_batch_loss[ 110 ]:  26.269683837890625
train_loss:  25.37368762473143
epoch: 3
train_batch_loss[ 120 ]:  24.99655532836914
train_batch_loss[ 130 ]:  25.585840225219727
train_batch_loss[ 140 ]:  26.731746673583984
train_batch_loss[ 150 ]:  26.63530921936035
train_batch_loss[ 160 ]:  25.008487701416016
train_loss:  25.16807632062474
epoch: 4
train_batch_loss[ 170 ]:  25.0599422454834
train_batch_loss[ 180 ]:  24.829959869384766
train_batch_loss[ 190 ]:  24.508867263793945
train_batch_loss[ 200 ]:  23.78364

Evaluate:   0%|                                                                                  | 0/1 [00:00<?, ?it/s]

train_batch_loss[ 500 ]:  25.051483154296875


Evaluate: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:15<00:00, 15.25s/it]


valid_evaluation: loss=25.001851821836098, acc=0.0
train_loss:  24.688604603272168
epoch: 10
train_batch_loss[ 510 ]:  29.2703857421875
train_batch_loss[ 520 ]:  25.738025665283203
train_batch_loss[ 530 ]:  23.810016632080078
train_batch_loss[ 540 ]:  23.943490982055664
train_batch_loss[ 550 ]:  25.516605377197266
train_batch_loss[ 560 ]:  27.051362136314655
train_loss:  24.622436114034553
