<a href="https://colab.research.google.com/github/Mgalvaz/license_plate-recognizer/blob/main/notebooks/train_OCR_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [57]:
import random
from tqdm.notebook import tqdm
import torch
import torch.optim as optim
from torch.functional import F
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import v2
from PIL import Image, ImageDraw, ImageFont

In [58]:
def generate_plate_text() -> str:
  nums = ''.join(random.choices('0123456789', k=4))

  consonants = "BCDFGHJKLMNPQRSTVWXYZ"
  vowels = "AEIOU"
  alphabet = consonants + vowels
  weights = [10] * len(consonants) + [1] * len(vowels)
  letters = ''.join(random.choices(alphabet, weights=weights, k=3))

  return f"{nums}  {letters}"

def augment_image_v2(img: Image.Image) -> torch.Tensor:
  tr = v2.Compose([
    v2.PILToTensor(),
    v2.ToDtype(torch.float, True),
    v2.RandomRotation(degrees=8, fill=0.392),
    v2.RandomPerspective(distortion_scale=0.25, p=1.0, fill=0.392),
    v2.GaussianNoise(sigma=0.05),
    v2.RandomApply([v2.ColorJitter(brightness=(0.8, 1.2), contrast=(0.8, 1.2))], p=0.7),
  ])
  return tr(img)

class SyntheticPlateDataset(Dataset):
  """
  Dataset that generates fake synthetic spanish license plates and augments them to simulate perspective.
  """

  def __init__(self, num_samples: int = 10000):
    super().__init__()
    self.num_samples = num_samples
    self.font_main = ImageFont.truetype('/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf', 25)
    self.font_small = ImageFont.truetype('/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf', 8)
    self.transform = augment_image_v2
    self.translator = dict((l, n) for n, l in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', start=1))

  def __len__(self) -> int:
    return self.num_samples

  def __getitem__(self, idx: int) -> tuple:
    plate_text = generate_plate_text()
    plate = Image.new("L", (150, 32), color=230)
    draw = ImageDraw.Draw(plate)
    draw.rectangle([0, 0, 20, 32], fill=55)
    draw.text((8, 18), 'E', font=self.font_small, fill=230)
    draw.text((22, 2), plate_text, font=self.font_main, fill=50)
    plate = self.transform(plate).to(device)
    label = torch.tensor([self.translator[l] for l in plate_text if l != ' '], dtype=torch.long).to(device)
    return plate, label

In [59]:
TRANSLATOR = dict((l, n) for n, l in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', start=1))
NUM_CLASSES = len(TRANSLATOR) + 1


In [60]:
def collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor]:
  imgs, labels = zip(*batch)
  imgs = torch.stack(imgs)
  labels = pad_sequence(labels, batch_first=True, padding_value=-1)
  return imgs, labels

def ctc_decode(pred_seq: torch.Tensor, blank: int=0) -> list[int]:
  decoded = []
  prev = None
  for p in pred_seq:
    if p != blank and p != prev:
      decoded.append(p)
    prev = p
  return decoded

class CRNN(nn.Module):

  def __init__(self):
    super(CRNN, self).__init__()

    self.cnn = nn.Sequential(
      nn.Conv2d(1, 64, (3, 3), padding=1),  # (1, 32, 150) -> (32, 32, 150)
      nn.ReLU(),
      nn.MaxPool2d(2, 2), # (32, 32, 150) -> (32, 16, 75)
      nn.Conv2d(64, 128, (3, 3), padding=1),  # (32, 16, 75) -> (64, 16, 75)
      nn.ReLU(),
      nn.MaxPool2d(2, 2),  # (64, 16, 75) -> (64, 8, 37)
      nn.Conv2d(128, 256, (3, 3), padding=1),  # (64, 8, 37) -> (128, 8, 37)
      nn.ReLU(),
      nn.MaxPool2d((1, 2), (2,1)),  # (128, 8, 37) -> (128, 4, 36)
      nn.Conv2d(256, 512, (3, 3), padding=1),  # (128, 4, 36) -> (256, 4, 36)
      nn.BatchNorm2d(512),
      nn.ReLU(),
      nn.Conv2d(512, 512, (3, 3), padding=1),  # (256, 4, 37) -> (256, 4, 36)
      nn.BatchNorm2d(512),
      nn.ReLU(),
      nn.Dropout2d(0.5),
      nn.MaxPool2d((2, 1), 1),  # (256, 4, 36) -> (256, 3, 36)
      nn.Conv2d(512, 512, (3, 3), padding=0),  # (256, 3, 36) -> (512, 1, 34)
      nn.BatchNorm2d(512),
      nn.ReLU(),
    )

    self.rnn = nn.GRU(512, 256, num_layers=2, batch_first=True, bidirectional=True) # (34, 512) -> (512, 34)

    self.decoder = nn.Linear(512, NUM_CLASSES) # (512, 34) -> (37, 34)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    _, ic, ih, iw = x.size() # (batch, channels=1, height=32, width=150)
    assert (ic, ih, iw) == (1, 32, 150), f'Input size ({ic}, {ih}, {iw}) does not correspond to expected size (1, 32, 150)'
    x = self.cnn(x) # (batch, channels=512, height=1, width=34)

    x = x.squeeze(2).permute(0, 2, 1)  # (batch, width=34, channels=512)

    x, _ = self.rnn(x) # (batch, seq_len=34, channels=512)

    x = self.decoder(x) # (batch, seq_len=34, label=37)
    return x

In [61]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [62]:
total = 50000
num_test = total//7
num_train = total - num_test
num_epochs = 5
dataset = SyntheticPlateDataset(num_samples=total)
train_dataset, test_dataset = random_split(dataset, [num_train, num_test])

train_loader = DataLoader(train_dataset, batch_size=64, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64,  collate_fn=collate_fn)

#Model loading
model = CRNN().to(device)
criterion = nn.CTCLoss(blank=0, zero_infinity=True, reduction='sum')
optimizer = optim.Adam(model.parameters(), lr=0.0001)

#Training
torch.autograd.set_detect_anomaly(True)
for epoch in range(1, num_epochs+1):
  loop = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}")
  for batch_images, batch_labels in loop:
    # zero the parameter gradients
    optimizer.zero_grad()
    # forward
    batch_output = model(batch_images)
    log_probs = F.log_softmax(batch_output, dim=2).permute(1, 0 ,2)
    input_lengths = torch.full(size=(batch_output.size(0),), fill_value=batch_output.size(1), dtype=torch.long)
    targets = batch_labels
    batch_labels = [lbl[lbl != -1] for lbl in batch_labels]
    target_lengths = torch.tensor([len(l) for l in batch_labels], dtype=torch.long)
    loss = criterion(log_probs, targets, input_lengths, target_lengths)
    # backward
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    # optimize
    optimizer.step()
    loop.set_postfix(loss=loss.item())


#Testing
model.eval()
correct = 0
with torch.no_grad():
  for batch_images, batch_labels in test_loader:
    output = model(batch_images).permute(1, 0 ,2)
    preds = output.argmax(dim=2).cpu().numpy().T

    batch_labels = [lbl[lbl != -1] for lbl in batch_labels]
    decoded_preds = [ctc_decode(seq) for seq in preds]

    for pred, target in zip(decoded_preds, batch_labels):
      target = target.cpu().numpy().tolist()
      if pred == target:
        correct += 1
  print(f"Exact match accuracy: {correct / num_test:.2%}")

Epoch 1/5:   0%|          | 0/670 [00:00<?, ?it/s]

Epoch 2/5:   0%|          | 0/670 [00:00<?, ?it/s]

Epoch 3/5:   0%|          | 0/670 [00:00<?, ?it/s]

Epoch 4/5:   0%|          | 0/670 [00:00<?, ?it/s]

Epoch 5/5:   0%|          | 0/670 [00:00<?, ?it/s]

Exact match accuracy: 98.07%


In [63]:
REVERSE_DICT = {v: k for k, v in TRANSLATOR.items()}
final_dataset = SyntheticPlateDataset(num_samples=20)
final_loader = DataLoader(final_dataset, batch_size=64,  collate_fn=collate_fn)
model.eval()
with torch.no_grad():
  for batch_images, batch_labels in final_loader:
    output = model(batch_images).permute(1, 0 ,2)
    preds = output.argmax(dim=2).cpu().numpy().T

    batch_labels = [lbl[lbl != -1] for lbl in batch_labels]
    decoded_preds = [ctc_decode(seq) for seq in preds]

    for pred, target in zip(decoded_preds, batch_labels):
      target = target.cpu().numpy().tolist()
      print(pred, target, list(map(lambda x: REVERSE_DICT[x], target)))

[np.int64(29), np.int64(36), np.int64(31), np.int64(32), np.int64(7), np.int64(23), np.int64(19)] [29, 36, 31, 32, 7, 23, 19] ['2', '9', '4', '5', 'G', 'W', 'S']
[np.int64(28), np.int64(28), np.int64(36), np.int64(30), np.int64(2), np.int64(18), np.int64(4)] [28, 28, 36, 30, 2, 18, 4] ['1', '1', '9', '3', 'B', 'R', 'D']
[np.int64(32), np.int64(36), np.int64(29), np.int64(36), np.int64(10), np.int64(10), np.int64(26)] [32, 36, 29, 36, 10, 10, 26] ['5', '9', '2', '9', 'J', 'J', 'Z']
[np.int64(28), np.int64(27), np.int64(36), np.int64(32), np.int64(13), np.int64(22), np.int64(16)] [28, 27, 36, 32, 13, 22, 16] ['1', '0', '9', '5', 'M', 'V', 'P']
[np.int64(34), np.int64(28), np.int64(31), np.int64(34), np.int64(6), np.int64(4), np.int64(18)] [34, 28, 31, 34, 6, 4, 18] ['7', '1', '4', '7', 'F', 'D', 'R']
[np.int64(33), np.int64(29), np.int64(33), np.int64(36), np.int64(4), np.int64(20), np.int64(7)] [33, 29, 33, 36, 4, 20, 7] ['6', '2', '6', '9', 'D', 'T', 'G']
[np.int64(33), np.int64(36), n

In [64]:
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, 'OCR_model.pth')