这个记事本是用来搭建CRNN网络的，用来做recognition

In [1]:
import torch
from torch import nn
from torchvision import models
from utils import *

dataset = RecDataset("IAM", "train")
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=4)

In [2]:
# test
model = models.resnet18(pretrained=True)
# model_cnn = nn.Sequential(*list(model.children())[:-1])
for chi in model.children():
    print(chi)
    print('-------------------')



Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
-------------------
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
-------------------
ReLU(inplace=True)
-------------------
MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
-------------------
Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (co

这里用resnet18除去最后的fc作为cnn的部分，lstm作为rnn的部分。<br>
输入1x128x128图片<br>
经过cnn部分，先是卷到了512x4x4，然后经过平均池化层变成512x1x1<br>
然后展平，经过线性变换放入lstm的hidden和cell

In [3]:
class CRNN(nn.Module):
    def __init__(self, num_classes=128, hidden_dim=256, io_dim=512, device='cuda:0'):
        super(CRNN, self).__init__()
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        self.io_dim = io_dim
        self.device = device
        self.max_len = 64  # max num of characters of the generated text
        self.conv1 = nn.Conv2d(1, 3, 1)
        self.cnn = models.resnet18(pretrained=True)
        self.cnn = nn.Sequential(*list(self.cnn.children())[:-1])  # output dim is 512
        self.rnn = nn.LSTM(io_dim, hidden_dim, 1, batch_first=True)
        self.h0_fc = nn.Linear(512, hidden_dim)
        self.c0_fc = nn.Linear(512, hidden_dim)
        self.out_fc = nn.Linear(hidden_dim, num_classes)
        self.embedding = nn.Embedding(num_classes, io_dim)
        self.to(device)
    
    def init_state(self, img):
        # 通过CNN卷出 lstm 的 hidden state 和 cell state
        x = self.conv1(img)         # batch_size, 3, 64, 64
        x = self.cnn(x)             # batch_size, 512, 1, 1
        x = x.view(x.size(0), -1)   # batch_size, 512
        x = x.unsqueeze(0)          # 1, batch_size, 512
        h0 = self.h0_fc(x)          # 1, batch_size, hidden_dim
        c0 = self.c0_fc(x)          # 1, batch_size, hidden_dim
        return h0, c0
    
    def next_char(self, x, h_c_n):
        # print("next char x shape: ", x.shape)
        h_n, c_n = h_c_n
        # x: the embedding of the last character
        # h_n: the hidden state of the last character
        # c_n: the cell state of the last character
        x, (h_n, c_n) = self.rnn(x, (h_n, c_n))
        # print("next char rnn output x shape: ", x.shape)
        x = self.out_fc(x)
        # print("next char output x shape: ", x.shape)
        return x, (h_n, c_n)
    
    def forward(self, img):
        batch_size = img.size(0)
        h0, c0 = self.init_state(img)
        x = 2  # the index of the start token
        x = torch.tensor([x] * batch_size, dtype=torch.long).view(batch_size, 1).to(self.device)
        x = self.embedding(x)
        # print("after embedding x shape: ", x.shape)
        h_c_n = (h0, c0)
        output = []
        for i in range(self.max_len):
            x, h_c_n = self.next_char(x, h_c_n)
            output.append(x)
            x = x.argmax(dim=-1)
            x = self.embedding(x)
        output = torch.cat(output, dim=1)
        return output

In [4]:
# 测试CRNN进行预测时的形状是否符合要求

crnn = CRNN()
for i, (img, label) in enumerate(dataloader):
    img, label = img.to(crnn.device), label.to(crnn.device)
    print(img.shape)
    output = crnn(img)
    print(output.shape)
    break

torch.Size([10, 1, 128, 128])
torch.Size([10, 64, 128])


In [5]:
# 测试CRNN进行训练时的形状是否符合要求

crnn = CRNN()
for i, (img, label) in enumerate(dataloader):
    img, label = img.to(crnn.device), label.to(crnn.device)
    print(img.shape)
    h0, c0 = crnn.init_state(img)
    print(h0.shape, c0.shape)
    break

torch.Size([10, 1, 128, 128])
torch.Size([1, 10, 256]) torch.Size([1, 10, 256])


In [6]:
# 可能要调的超参数有：hidden_dim, io_dim, lr, batch_size, num_epochs, dataset_name
def get_model_name(hidden_dim, io_dim, lr, batch_size, num_epochs, dataset_name):
    return f"crnn_{hidden_dim}_{io_dim}_{lr}_{batch_size}_{num_epochs}_{dataset_name}.pth"


# 下面是训练的代码，使用教师强制训练
def train_crnn(model, dataloader, learning_rate, epochs, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
    model.train()
    model.to(device)
    for epoch in range(epochs):
        for i, (img, target) in enumerate(dataloader):
            img, target = img.to(device), target.to(device)
            output = model(img)
            optimizer.zero_grad()
            # h0, c0 = model.init_state(img)
            # h0, c0 = h0.to(device), c0.to(device)
            # output = crnn.embedding(torch.tensor([2] * img.size(0), dtype=torch.long).view(img.size(0), 1).to(device))
            # output = model(output, (h0, c0))
            loss = criterion(output.view(-1, crnn.num_classes), target.argmax(-1).view(-1))
            loss.backward()
            optimizer.step()
            if i % 10 == 0:
                print(f"Epoch {epoch}, Iter {i}, Loss {loss.item()}")
        model_name = get_model_name(model.hidden_dim, model.io_dim, learning_rate, dataloader.batch_size, epoch, dataset.name)
        torch.save(model.state_dict(), model_name)
        print(f"Model saved as {model_name}")


model = CRNN()
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)
train_crnn(model, dataloader, 0.001, 10, "cuda:0")

Epoch 0, Iter 0, Loss 4.752656936645508
Epoch 0, Iter 10, Loss 4.150956630706787
Epoch 0, Iter 20, Loss 2.4928619861602783
Epoch 0, Iter 30, Loss 0.9234460592269897
Epoch 0, Iter 40, Loss 0.6512908339500427
Epoch 0, Iter 50, Loss 0.6668704152107239
Epoch 0, Iter 60, Loss 0.6871585845947266
Epoch 0, Iter 70, Loss 0.6272826790809631
Epoch 0, Iter 80, Loss 0.772942066192627
Epoch 0, Iter 90, Loss 0.7369575500488281
Epoch 0, Iter 100, Loss 0.6620486974716187
Epoch 0, Iter 110, Loss 0.6445013880729675
Epoch 0, Iter 120, Loss 0.5646364688873291
Epoch 0, Iter 130, Loss 0.5715126395225525
Epoch 0, Iter 140, Loss 0.5998734831809998
Epoch 0, Iter 150, Loss 0.6283571124076843
Epoch 0, Iter 160, Loss 0.5739874243736267
Epoch 0, Iter 170, Loss 0.5709889531135559
Epoch 0, Iter 180, Loss 0.6036087274551392
Epoch 0, Iter 190, Loss 0.5977658033370972
Epoch 0, Iter 200, Loss 0.5659731030464172
Epoch 0, Iter 210, Loss 0.5328385233879089
Epoch 0, Iter 220, Loss 0.573197603225708
Epoch 0, Iter 230, Loss 0.