In [11]:
from data_reader import Vocabulary, HWDBDatasetHelper, LMDBReader

# your path to data
train_path = r'/DATA/ichuviliaeva/ocr_data/train.lmdb'
test_path = r'/DATA/ichuviliaeva/ocr_data/test.lmdb'
gt_path = './gt.txt'

# Simple CNN baseline

pytorch and lmdb are required for this baseline implementation

## Baseline method

- Naively resize to 32x32 (DON'T DO THIS IN YOUR WORK, try to save geometry somehow, it is important)
- Train LeNet-like CNN
- Enjoy :)

In [12]:
import cv2
import numpy as np

### Data tools

In [13]:
train_reader = LMDBReader(train_path)
train_reader.open()
train_helper = HWDBDatasetHelper(train_reader)

In [14]:
train_helper, val_helper = train_helper.train_val_split()

In [15]:
train_helper.size(), val_helper.size()

(2578433, 644609)

In [16]:
import torch

from torch.utils.data import Dataset, DataLoader
from torch import nn

class HWDBDataset(Dataset):
    def __init__(self, helper: HWDBDatasetHelper):
        self.helper = helper
    
    def __len__(self):
        return self.helper.size()
    
    def __getitem__(self, idx):
        img, label = self.helper.get_item(idx)
        return (cv2.resize(img, (32, 32)) - 127.5) / 255., label

  from .autonotebook import tqdm as notebook_tqdm


In [17]:
train_dataset = HWDBDataset(train_helper)
val_dataset = HWDBDataset(val_helper)

### Model & training

In [18]:
class LeNet(nn.Module):
    def __init__(self, n_classes):
        super(LeNet, self).__init__()
        self.nn = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)), # [16, 16, 8] 
            
            nn.Conv2d(8, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.MaxPool2d((2, 2)), # [8, 8, 16]
            
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)), # [4, 4, 32]
            
            nn.Flatten(),
            nn.Linear(4 * 4 * 32, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            
            nn.Linear(128, n_classes, bias=False)
        )
    
    def forward(self, x):
        return self.nn(x)

In [19]:
model = LeNet(train_helper.vocabulary.num_classes())
model.eval()

LeNet(
  (nn): Sequential(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (12): Flatten(start_dim=1, end_dim=-1)
    (13): Linear(in_features=512, out_features=128, bias=True)
    (14): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=T

In [20]:
model(torch.tensor(train_dataset[0][0], dtype=torch.float32).view(1, 1, 32, 32))

tensor([[ 0.0035, -0.0090, -0.0059,  ..., -0.0022, -0.0024,  0.0095]],
       grad_fn=<MmBackward0>)

In [21]:
device = 'cuda:2'

In [22]:
model = model.to(device)

In [23]:
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, drop_last=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=False, num_workers=8)

In [24]:
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()

In [25]:
from tqdm import tqdm


def run_validation(val_loader: DataLoader, model: nn.Module, n_steps=None):
    model.eval()
    n_good = 0
    n_all = 0
    wrapper = lambda x: x
    if n_steps is None:
        n_steps = len(val_loader)
        wrapper = tqdm
    
    with torch.no_grad():
        for batch, (X, y) in enumerate(wrapper(val_loader)):
            if batch == n_steps:
                break
            logits = model(X.unsqueeze(1).to(torch.float32).to(device))
            classes = torch.argmax(logits, dim=1).cpu().numpy()
            n_good += sum(classes == y.cpu().numpy())
            n_all += len(classes)
    
    return n_good / n_all


def train_epoch(train_loader: DataLoader, val_loader: DataLoader, model: nn.Module, optim, loss_fn):
    for batch, (X, y) in enumerate(tqdm(train_loader)):
        model.train()
        logits = model(X.unsqueeze(1).to(torch.float32).to(device))
        loss = loss_fn(logits, y.to(torch.long).to(device))
        
        optim.zero_grad()
        loss.backward()
        optim.step()

In [26]:
torch.save(model.state_dict(), 'baseline.pth')

In [27]:
for epoch in range(10):
    print(f'Epoch {epoch}:')
    train_epoch(train_loader, val_loader, model, optim, loss_fn)
    accuracy = run_validation(val_loader, model)
    print(f'accuracy: {accuracy}')
    torch.save(model.state_dict(), f'baseline_epoch{epoch}.pth')

Epoch 0:


100%|███████████████████████████████████████████████████████████████████████████████| 5036/5036 [01:10<00:00, 71.29it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 315/315 [00:16<00:00, 18.72it/s]


accuracy: 0.7737512197316513
Epoch 1:


100%|███████████████████████████████████████████████████████████████████████████████| 5036/5036 [01:08<00:00, 73.35it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 315/315 [00:17<00:00, 18.40it/s]


accuracy: 0.8132992248013913
Epoch 2:


100%|███████████████████████████████████████████████████████████████████████████████| 5036/5036 [01:07<00:00, 74.06it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 315/315 [00:17<00:00, 18.29it/s]


accuracy: 0.8297169291772222
Epoch 3:


100%|███████████████████████████████████████████████████████████████████████████████| 5036/5036 [01:08<00:00, 73.27it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 315/315 [00:16<00:00, 18.90it/s]


accuracy: 0.838419879337707
Epoch 4:


100%|███████████████████████████████████████████████████████████████████████████████| 5036/5036 [01:06<00:00, 75.83it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 315/315 [00:16<00:00, 19.10it/s]


accuracy: 0.8372331134067318
Epoch 5:


100%|███████████████████████████████████████████████████████████████████████████████| 5036/5036 [01:08<00:00, 73.39it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 315/315 [00:17<00:00, 18.38it/s]


accuracy: 0.8425107313115392
Epoch 6:


100%|███████████████████████████████████████████████████████████████████████████████| 5036/5036 [01:09<00:00, 72.31it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 315/315 [00:17<00:00, 18.41it/s]


accuracy: 0.8480458696667282
Epoch 7:


100%|███████████████████████████████████████████████████████████████████████████████| 5036/5036 [01:09<00:00, 72.90it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 315/315 [00:17<00:00, 18.45it/s]


accuracy: 0.8452115933845168
Epoch 8:


100%|███████████████████████████████████████████████████████████████████████████████| 5036/5036 [01:08<00:00, 73.06it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 315/315 [00:17<00:00, 18.47it/s]


accuracy: 0.8492481488778468
Epoch 9:


100%|███████████████████████████████████████████████████████████████████████████████| 5036/5036 [01:09<00:00, 72.63it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 315/315 [00:16<00:00, 18.79it/s]

accuracy: 0.8548236217614088





### Evaluation

In [28]:
test_path = r'/DATA/ichuviliaeva/ocr_data/test.lmdb'
pred_path = './pred.txt'

test_reader = LMDBReader(test_path)
test_reader.open()
test_helper = HWDBDatasetHelper(test_reader, prefix='Test')

In [29]:
test_dataset = HWDBDataset(test_helper)
test_loader = DataLoader(test_dataset, batch_size=2048, shuffle=False, num_workers=8)

In [30]:
preds = []
model.eval()
with torch.no_grad():
    for X, _ in tqdm(test_loader):
        logits = model(X.unsqueeze(1).to(torch.float32).to(device))
        classes = torch.argmax(logits, dim=1).cpu().numpy()
        preds.extend(classes)

100%|█████████████████████████████████████████████████████████████████████████████████| 380/380 [00:19<00:00, 19.31it/s]


In [31]:
with open(pred_path, 'w') as f_pred:
    for idx, pred in enumerate(preds):
        name = test_helper.namelist[idx]
        cls = train_helper.vocabulary.class_by_index(pred)
        print(name, cls, file=f_pred)

In [41]:
from course_ocr_t2.evaluate import evaluate
# Accuracy = 0.7978

In [42]:
evaluate('./gt.txt', './pred.txt')

0.7980111342484383