In [5]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()
loss_fn = nn.CrossEntropyLoss()

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor())

val_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor())
train_dataloader = DataLoader(training_data, batch_size=64)
val_dataloader = DataLoader(val_data, batch_size=1)
train_length = len(train_dataloader.dataset)
val_length = len(val_dataloader.dataset)

model.train()
for iter, (inputs, labels) in enumerate(train_dataloader):
    inputs, labels = inputs.to(device), labels.to(device)
    pred = model(inputs)
    loss = loss_fn(pred, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if iter % 100 == 0:
        loss, current = loss.item(), iter * len(inputs)
        print(f"loss: {loss:>7f}  [{current:>5d}/{train_length:>5d}]")
# 验证
model.eval()
test_loss, correct = 0, 0
for inputs, labels in val_dataloader:
    inputs, labels = inputs.to(device), labels.to(device)
    pred = model(inputs)
    pred_label = pred.argmax(dim=1)
    correct += (pred_label == labels).type(torch.float).sum().item()
correct /= val_length
print(f"Test: \n Accuracy: {(100*correct):>0.1f}% \n")

loss: 2.307098  [    0/60000]
loss: 2.291324  [ 6400/60000]
loss: 2.290518  [12800/60000]
loss: 2.294219  [19200/60000]
loss: 2.281699  [25600/60000]
loss: 2.275330  [32000/60000]
loss: 2.276254  [38400/60000]
loss: 2.278672  [44800/60000]


: 

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets.mnist import read_image_file, read_label_file

from mmengine.dataset import BaseDataset
from mmengine.dataset.utils import pseudo_collate
from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel, BaseModule
from mmengine.runner import Runner
from mmengine.structures import LabelData

from mmocr.datasets.transforms import (LoadImageFromNDArray,
                                       LoadOCRAnnotations, PackTextRecogInputs)
from mmocr.models.textrecog.data_preprocessors import TextRecogDataPreprocessor

class DemoDecoder(BaseModule):

    def __init__(self, in_channels, out_channels, init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        self.cls = nn.Linear(in_channels, out_channels)
        self.module_loss = DemoLoss()
        self.postprocessor = DemoPostprocessor()

    def forward(self, x):
        return self.cls(x)

    def loss(self, x, data_samples):
        outs = self(x)
        losses = dict(loss_ce=self.module_loss(outs, data_samples))
        return losses

    def predict(self, x, data_samples):
        outs = self(x)
        predictions = self.postprocessor(outs, data_samples)
        return predictions

class DemoLoss(nn.Module):

    def __init__(self):
        super().__init__()
        self.loss_ce = nn.CrossEntropyLoss()

    def get_target(self, data_samples):
        targets = list()
        for data_sample in data_samples:
            targets.append(data_sample.gt_text.item)
        targets = torch.stack(targets, dim=0)
        return targets

    def forward(self, outputs, data_samples):
        return self.loss_ce(outputs, self.get_target(data_samples))

class DemoPostprocessor:

    def __call__(self, x, data_samples):
        pred = torch.argmax(x, dim=1)
        data_samples.pred_text = LabelData()
        data_samples.pred_text.item = pred
        return data_samples

class DemoRecognizer(BaseModel):

    def __init__(self, data_preprocessor=None, init_cfg=None):
        super().__init__(data_preprocessor, init_cfg)
        self.backbone = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 512),
                                      nn.ReLU(), nn.Linear(512, 512),
                                      nn.ReLU())
        self.decoder = DemoDecoder(512, 10)

    def loss(self, inputs, data_samples):
        logits = self.backbone(inputs)
        loss = self.decoder.loss(logits, data_samples)
        return loss

    def predict(self, inputs, data_samples):
        logits = self.backbone(inputs)
        preditions = self.decoder.predict(logits, data_samples)
        return preditions

    def forward(self, inputs, data_samples, mode):
        if mode == 'loss':
            return self.loss(inputs, data_samples)
        elif mode == 'pred':
            return self.predict(inputs, data_samples)


class DemoMetric(BaseMetric):

    def process(self, data_batch, data_samples):
        for data_sample in data_samples:
            self.results.append(
                (data_sample.pred_text.item == data_sample.gt_text.item).type(
                    torch.float).item())

    def compute_metrics(self, results):
        return dict(accuracy=sum(results) / len(results))


class MNISTDatasets(BaseDataset):

    def load_data_list(self):
        images = read_image_file(self.data_prefix['img_path'])
        targets = read_label_file(self.ann_file)

        # load and parse data_infos.
        data_list = []
        for img, target in zip(images, targets):
            instances = [dict(text=target)]
            data_list.append(dict(img=img, instances=instances))
        return data_list

pipeline = [
    LoadImageFromNDArray(),
    LoadOCRAnnotations(with_text=True),
    PackTextRecogInputs(meta_keys=('ori_shape', 'img_shape'))
]
train_dataset = MNISTDatasets(
    ann_file='train-labels-idx1-ubyte',
    data_root='data/MNIST/raw',
    data_prefix=dict(img_path='train-images-idx3-ubyte'),
    pipeline=pipeline,
    serialize_data=False,
    test_mode=False)

val_dataset = MNISTDatasets(
    ann_file='t10k-labels-idx1-ubyte',
    data_root='data/MNIST/raw',
    data_prefix=dict(img_path='t10k-images-idx3-ubyte'),
    pipeline=pipeline,
    serialize_data=False,
    test_mode=True)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=64,
    # shuffle=True,
    num_workers=0,
    collate_fn=pseudo_collate)
val_dataloader = DataLoader(
    val_dataset, batch_size=1, collate_fn=pseudo_collate)

runner = Runner(
    model=DemoRecognizer(
        data_preprocessor=TextRecogDataPreprocessor(mean=[0], std=[255])),
    work_dir='./work_dirs/demo',
    train_dataloader=train_dataloader,
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=1e-3)),
    train_cfg=dict(by_epoch=True, max_epochs=1, val_interval=1),
    val_dataloader=val_dataloader,
    val_cfg=dict(),
    val_evaluator=dict(type=DemoMetric))
runner.train()


在 MMOCR1.0 中对训练、测试、推理任务过程中涉及到的组件进行新的抽象和划分，分为了数据集，模型，评价指标，同时这些组件用统一的接口(DataSample)进行数据传递.
这里以 Pytorch 官网的 MNIST 数据集为例，展示各个

In [None]:
from statistics import mode
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets.mnist import read_image_file, read_label_file

from mmengine.dataset import BaseDataset
from mmengine.dataset.utils import pseudo_collate
from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel, BaseModule
from mmengine.runner import Runner
from mmengine.structures import LabelData

from mmocr.datasets.transforms import (LoadImageFromNDArray,
                                       LoadOCRAnnotations, PackTextRecogInputs)
from mmocr.models.textrecog.data_preprocessors import TextRecogDataPreprocessor



class NeuralNetwork(nn.Module):

    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )
        self.loss_ce = nn.CrossEntropyLoss()

    def forward(self, inputs, data_samples, mode):
        outputs = self.linear_relu_stack(inputs)
        if mode == 'loss':
            targets = torch.stack([ds.gt_text.item for ds in data_samples])
            return self.loss_ce(outputs, targets)
        elif mode == 'pred':
            predictions = torch.argmax(outputs, dim=1)
            for ds, pred in zip(data_samples, predictions):
                ds.pred_text = LabelData()
                ds.pred_text.item = pred
            return data_samples


class DemoMetric(BaseMetric):

    def process(self, data_batch=None, data_samples=None):
        for data_sample in data_samples:
            self.results.append(
                (data_sample.pred_text.item == data_sample.gt_text.item).type(
                    torch.float).item())

    def compute_metrics(self, results):
        return sum(results) / len(results)


class MNISTDatasets(BaseDataset):

    def load_data_list(self):
        images = read_image_file(self.data_prefix['img_path'])
        targets = read_label_file(self.ann_file)

        # load and parse data_infos.
        data_list = []
        for img, target in zip(images, targets):
            instances = [dict(text=target)]
            data_list.append(dict(img=img, instances=instances))
        return data_list

pipeline = [
    LoadImageFromNDArray(),
    LoadOCRAnnotations(with_text=True),
    PackTextRecogInputs(meta_keys=('ori_shape', 'img_shape'))
]

train_dataset = MNISTDatasets(
    ann_file='train-labels-idx1-ubyte',
    data_root='data/MNIST/raw',
    data_prefix=dict(img_path='train-images-idx3-ubyte'),
    pipeline=pipeline,
    serialize_data=False,
    test_mode=False)

val_dataset = MNISTDatasets(
    ann_file='t10k-labels-idx1-ubyte',
    data_root='data/MNIST/raw',
    data_prefix=dict(img_path='t10k-images-idx3-ubyte'),
    pipeline=pipeline,
    serialize_data=False,
    test_mode=True)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=64,
    num_workers=0,
    collate_fn=pseudo_collate)
val_dataloader = DataLoader(
    val_dataset, batch_size=1, collate_fn=pseudo_collate)
train_length = len(train_dataloader.dataset)
val_length = len(val_dataloader.dataset)

model = NeuralNetwork()
data_preprocessor = TextRecogDataPreprocessor(mean=[0], std=[255])
metric = DemoMetric()


device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
data_preprocessor = data_preprocessor.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

model.train()
for iter, data in enumerate(train_dataloader):
    data = data_preprocessor(data)
    loss = model(data['inputs'], data['data_samples'], mode='loss')
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if iter % 100 == 0:
        loss, current = loss.item(), iter * len(data['inputs'])
        print(f"loss: {loss:>7f}  [{current:>5d}/{train_length:>5d}]")
# 验证
model.eval()
for data in enumerate(val_dataloader):
    data = data_preprocessor(data)
    preds = model(**data, mode='pred')
    metric.process(data_samples=preds)
correct = metric.evaluate(val_length)
print(f"Test: \n Accuracy: {(100*correct):>0.1f}% \n")
