这个记事本是用来搭建CRNN网络的，用来做recognition

In [None]:
import torch
from torch import nn
from torchvision import models
from utils import *

dataset = RecDataset("IAM", "train")
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=4)

In [None]:
# test
model = models.resnet18(pretrained=True)
# model_cnn = nn.Sequential(*list(model.children())[:-1])
for chi in model.children():
    print(chi)
    print('-------------------')

这里用resnet18除去最后的fc作为cnn的部分，lstm作为rnn的部分。<br>
输入1x128x128图片<br>
经过cnn部分，先是卷到了512x4x4，然后经过平均池化层变成512x1x1<br>
然后展平，经过线性变换放入lstm的hidden和cell

In [41]:
class CRNN(nn.Module):
    def __init__(self, num_classes=128, hidden_dim=256, io_dim=512, device='cuda:0'):
        super(CRNN, self).__init__()
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        self.io_dim = io_dim
        self.device = device
        self.max_len = 64  # max num of characters of the generated text
        self.conv1 = nn.Conv2d(1, 3, 1)
        self.cnn = models.resnet18(pretrained=True)
        self.cnn = nn.Sequential(*list(self.cnn.children())[:-1])  # output dim is 512
        self.rnn = nn.LSTM(io_dim, hidden_dim, 1, batch_first=True)
        self.h0_fc = nn.Linear(512, hidden_dim)
        self.c0_fc = nn.Linear(512, hidden_dim)
        self.out_fc = nn.Linear(hidden_dim, num_classes)
        self.embedding = nn.Embedding(num_classes, io_dim)
        self.to(device)
    
    def init_state(self, img):
        # 通过CNN卷出 lstm 的 hidden state 和 cell state
        x = self.conv1(img)         # batch_size, 3, 64, 64
        x = self.cnn(x)           # batch_size, 512, 1, 1
        x = x.view(x.size(0), -1)   # batch_size, 512
        x = x.unsqueeze(0)          # 1, batch_size, 512
        h0 = self.h0_fc(x)          # 1, batch_size, hidden_dim
        c0 = self.c0_fc(x)          # 1, batch_size, hidden_dim
        return h0, c0
    
    def next_char(self, x, h_c_n):
        # print("next char x shape: ", x.shape)
        h_n, c_n = h_c_n
        # x: the embedding of the last character
        # h_n: the hidden state of the last character
        # c_n: the cell state of the last character
        x, (h_n, c_n) = self.rnn(x, (h_n, c_n))
        # print("next char rnn output x shape: ", x.shape)
        x = self.out_fc(x)
        # print("next char output x shape: ", x.shape)
        return x, (h_n, c_n)
    
    def forward(self, img):
        batch_size = img.size(0)
        h0, c0 = self.init_state(img)
        x = 2  # the index of the start token
        x = torch.tensor([x] * batch_size, dtype=torch.long).view(batch_size, 1).to(self.device)
        x = self.embedding(x)
        # print("after embedding x shape: ", x.shape)
        h_c_n = (h0, c0)
        output = []
        for i in range(self.max_len):
            x, h_c_n = self.next_char(x, h_c_n)
            output.append(x)
            x = x.argmax(dim=-1)
            x = self.embedding(x)
        output = torch.cat(output, dim=1)
        return output

In [44]:
# 测试CRNN进行预测时的形状是否符合要求

crnn = CRNN()
for i, (img, label) in enumerate(dataloader):
    img, target = img.to(crnn.device), target.to(crnn.device)
    print(img.shape)
    output = crnn(img)
    print(output.shape)
    break

torch.Size([10, 1, 128, 128])
torch.Size([10, 64, 128])


In [46]:
# 测试CRNN进行训练时的形状是否符合要求

crnn = CRNN()
for i, (img, label) in enumerate(dataloader):
    img, target = img.to(crnn.device), target.to(crnn.device)
    print(img.shape)
    h0, c0 = crnn.init_state(img)
    print(h0.shape, c0.shape)
    break



torch.Size([10, 1, 128, 128])
torch.Size([1, 10, 256]) torch.Size([1, 10, 256])


In [48]:
# 可能要调的超参数有：hidden_dim, io_dim, lr, batch_size, num_epochs, dataset_name
def get_model_name(hidden_dim, io_dim, lr, batch_size, num_epochs, dataset_name):
    return f"crnn_{hidden_dim}_{io_dim}_{lr}_{batch_size}_{num_epochs}_{dataset_name}.pth"


# 下面是训练的代码，使用教师强制训练
def train_crnn(model, dataloader, learning_rate, epochs, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
    model.train()
    model.to(device)
    for epoch in range(epochs):
        for i, (img, target) in enumerate(dataloader):
            img, target = img.to(device), target.to(device)
            optimizer.zero_grad()
            h0, c0 = model.init_state(img)
            output = model(img)
            loss = criterion(output.view(-1, model.num_classes), target.view(-1))
            loss.backward()
            optimizer.step()
            if i % 10 == 0:
                print(f"Epoch {epoch}, Iter {i}, Loss {loss.item()}")
        model_name = get_model_name(model.hidden_dim, model.io_dim, learning_rate, dataloader.batch_size, epochs, dataset.name)
        torch.save(model.state_dict(), model_name)
        print(f"Model saved as {model_name}")


model = CRNN()
train_crnn(model, dataloader, 0.001, 10, "IAM")



RuntimeError: Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: IAM