In [1]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms


from tqdm import tqdm
from PIL import Image
from typing import List, Dict
from torchmetrics import CharErrorRate

# Dataloader

In [2]:
class OCRDataset(torch.utils.data.Dataset):
    def __init__(self, path, transform=None):
        self.paths = []
        self.transform = transform
        
        labels = []
        for file in os.listdir(path):
            full_path = os.path.join(path, file)
            if os.path.isdir(full_path):
                continue
            self.paths.append(full_path)
            labels.append(file.split(".")[0])
            
        self.vocab = self.__create_vocab(labels)
        self.inv_vocab = {item: key for key, item in self.vocab.items()}
        
        self.labels = []
        for label in labels:
            self.labels.append(self.encode(label))
        
        
    def __create_vocab(self, labels: List[str]) -> Dict[str, int]:
        vocab = {}
        current = 0
        for label in labels:
            for char in label:
                if char not in vocab:
                    vocab[char] = current
                    current += 1
        return vocab
    
    def encode(self, label: str) -> List[int]:
        result = []
        for char in label:
            if char in self.vocab:
                result.append(self.vocab[char])
        return result
    
    def decode(self, encoded: List[int]) -> str:
        result = ""
        for item in encoded:
            if item in self.inv_vocab:
                result += self.inv_vocab[item]
        return result
    
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert("L")
        target = torch.tensor(self.labels[index]).long()
        if self.transform is not None:
            img = self.transform(img)
        return img, target

In [3]:
transform = transforms.Compose([
    transforms.Resize((48, 80)),
    transforms.ToTensor(),
])

In [4]:
data = OCRDataset("data/samples", transform=transform)

In [5]:
len(data.vocab)

19

In [6]:
len(data)

1070

In [7]:
train_data, test_data = torch.utils.data.random_split(data, 
                                                      [int(0.8 * len(data)), int(0.2 * len(data))], 
                                                      generator=torch.Generator().manual_seed(42)
                                                     )

In [8]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=16, shuffle=True)

# Model

In [42]:
class DeepWiseBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=(1, 1), expand_ratio=1):
        super(DeepWiseBlock, self).__init__()
        
        hidden_dim = round(in_channels * expand_ratio)
        self.identity = stride[0] == 1 and stride[1] == 1 and in_channels == out_channels
        
        if expand_ratio == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(in_channels, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, out_channels, 1, 1, 0, bias=False),
                nn.BatchNorm2d(out_channels),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(in_channels, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, out_channels, 1, 1, 0, bias=False),
                nn.BatchNorm2d(out_channels),
            )
                
    def forward(self, x):
        if self.identity:
            return x + self.conv(x)
        else:
            return self.conv(x)

class OCREncoder(nn.Module):
    def __init__(self):
        super(OCREncoder, self).__init__()
        
        self.init_conv = DeepWiseBlock(1, 16)
        self.layer1 = nn.Sequential(
            DeepWiseBlock(16, 32),
            DeepWiseBlock(32, 32, stride=(2, 2), expand_ratio=1),
        )
        self.layer2 = nn.Sequential(
            DeepWiseBlock(32, 64),
            DeepWiseBlock(64, 64, expand_ratio=2),
            DeepWiseBlock(64, 64, stride=(2, 2), expand_ratio=1),
        )
        self.layer3 = nn.Sequential(
            DeepWiseBlock(64, 128),
            DeepWiseBlock(128, 128, expand_ratio=2),
            DeepWiseBlock(128, 128, stride=(2, 2), expand_ratio=1),
        )
        self.layer4 = nn.Sequential(
            DeepWiseBlock(128, 128, expand_ratio=2),
            DeepWiseBlock(128, 128, expand_ratio=2),
            DeepWiseBlock(128, 128, stride=(2, 2), expand_ratio=1),
        )
        self.maxpool = nn.AdaptiveAvgPool2d((5, 1))
        self._initialize_weights()
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
        
    def forward(self, x):
        x = self.init_conv(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.maxpool(x)
        
        x = x.squeeze(3)
        x = x.transpose(1, 2)
        return x
    
class OCRDecoder(nn.Module):
    def __init__(self, output_dim=19):
        super(OCRDecoder, self).__init__()
        
        self.lstm = nn.LSTM(input_size=128, 
                            hidden_size=128, 
                            num_layers=2, 
                            bias=True, 
                            batch_first=True, 
                            dropout=0.2, 
                            bidirectional=True
                           )
        self.output = nn.Linear(256, output_dim)
        
    def forward(self, x):
        h0 = torch.zeros(2 * 2, x.size(0), 128, device=x.device)
        c0 = torch.randn(2 * 2, x.size(0), 128, device=x.device)
        
        output, (hn, cn) = self.lstm(x, (h0, c0))
        output = self.output(output)
        return output
        
class OCRModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(OCRModel, self).__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [43]:
encoder = OCREncoder()
decoder = OCRDecoder(output_dim=19)

In [44]:
x = torch.rand(1, 1, 80, 48)

In [45]:
decoder(encoder(x)).size()

torch.Size([1, 5, 19])

In [46]:
net = OCRModel(encoder, decoder)

In [47]:
y = net(x)

In [48]:
data.decode(y[0].argmax(dim=-1).tolist())

'55yyy'

# Train one batch

In [18]:
encoder = OCREncoder()
decoder = OCRDecoder(output_dim=19)
net = OCRModel(encoder, decoder)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.01, betas=(0.9, 0.999), weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,30])
cer = CharErrorRate()

In [None]:
img, target = next(iter(train_loader))

In [None]:
img = img.cuda()
target = target.cuda()
net = net.cuda()

for i in range(1000):
    optimizer.zero_grad()
    output = net(img)
    loss = criterion(output.view(-1, output.size(2)), target.view(-1))
    loss.backward()
    optimizer.step()
    print(loss.item())

# Train

In [59]:
encoder = OCREncoder()
decoder = OCRDecoder(output_dim=len(data.vocab))
net = OCRModel(encoder, decoder)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.01, betas=(0.9, 0.999), weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 70, 80])
cer = CharErrorRate()

In [60]:
def train(net, dataloader, criterion, optimizer, device="cuda"):
    net.train()
    net = net.to(device)
    
    mean_loss = []
    for img, target in dataloader:
        img = img.to(device)
        target = target.to(device)
        
        optimizer.zero_grad()
        output = net(img)
        loss = criterion(output.view(-1, output.size(2)), target.view(-1))
        loss.backward()
        
        nn.utils.clip_grad_norm_(net.parameters(), 1e-1)
        optimizer.step()
        
        mean_loss.append(loss.item())
    
    mean_loss = sum(mean_loss)/len(mean_loss)
    print("train loss:", round(mean_loss, 3))
    return mean_loss

In [61]:
def val(net, dataloader, criterion, metric, device="cuda"):
    net.eval()
    net = net.to(device)
    
    mean_loss = []
    mean_metric = []
    for img, target in dataloader:
        img = img.to(device)
        target = target.to(device)
        
        output = net(img)
        loss = criterion(output.view(-1, output.size(2)), target.view(-1))
        
        mean_loss.append(loss.item())
        
        preds = []
        real = []
        for i in range(output.size(0)):
            p = output[i].argmax(dim=-1).tolist()
            t = target[i].tolist()
            preds.append(p)
            real.append(t)
        
        mean_metric.append(metric(preds, real).item())
    
    mean_metric = sum(mean_metric)/len(mean_metric)
    mean_loss = sum(mean_loss)/len(mean_loss)
    
    print("test metric:", round(mean_metric, 3))
    print("test loss:", round(mean_loss, 3))
    return mean_loss, mean_metric

In [62]:
for epoch in range(100):
    print("Epoch", epoch)
    train(net, train_loader, criterion, optimizer)
    val(net, test_loader, criterion, cer)
    scheduler.step()
    print("==="*20)

Epoch 0
train loss: 2.881
test metric: 0.893
test loss: 3.006
Epoch 1
train loss: 2.612
test metric: 0.834
test loss: 2.585
Epoch 2
train loss: 2.306
test metric: 0.819
test loss: 2.567
Epoch 3
train loss: 2.126
test metric: 0.809
test loss: 2.574
Epoch 4
train loss: 1.948
test metric: 0.732
test loss: 2.15
Epoch 5
train loss: 1.807
test metric: 0.655
test loss: 1.934
Epoch 6
train loss: 1.68
test metric: 0.674
test loss: 1.929
Epoch 7
train loss: 1.568
test metric: 0.715
test loss: 2.538
Epoch 8
train loss: 1.462
test metric: 0.603
test loss: 1.899
Epoch 9
train loss: 1.414
test metric: 0.769
test loss: 3.346
Epoch 10
train loss: 1.305
test metric: 0.569
test loss: 1.664
Epoch 11
train loss: 1.266
test metric: 0.692
test loss: 2.339
Epoch 12
train loss: 1.238
test metric: 0.545
test loss: 1.584
Epoch 13
train loss: 1.211
test metric: 0.532
test loss: 1.552
Epoch 14
train loss: 1.17
test metric: 0.485
test loss: 1.421
Epoch 15
train loss: 1.119
test metric: 0.532
test loss: 1.65
Epoch 

In [63]:
val(net, train_loader, criterion, cer)

test metric: 0.0
test loss: 0.0


(0.00032756491769018965, 0.0)