<a href="https://colab.research.google.com/github/Mgalvaz/license_plate-recognizer/blob/main/notebooks/train_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [32]:
import random
from tqdm.notebook import tqdm
import torch
import torch.optim as optim
from torch.functional import F
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import v2
from PIL import Image, ImageDraw, ImageFont

In [33]:
def generate_plate_text() -> str:
  nums = ''.join(random.choices('0123456789', k=4))

  consonants = "BCDFGHJKLMNPQRSTVWXYZ"
  vowels = "AEIOU"
  alphabet = consonants + vowels
  weights = [10] * len(consonants) + [1] * len(vowels)
  letters = ''.join(random.choices(alphabet, weights=weights, k=3))

  return f"{nums}  {letters}"

def augment_image_v2(img: Image.Image) -> torch.Tensor:
  tr = v2.Compose([
    v2.PILToTensor(),
    v2.ToDtype(torch.float, True),
    #v2.RandomRotation(degrees=8, fill=0.392),
    #v2.RandomPerspective(distortion_scale=0.25, p=1.0, fill=0.392),
    #v2.GaussianNoise(sigma=0.05),
    #v2.RandomApply([v2.ColorJitter(brightness=(0.8, 1.2), contrast=(0.8, 1.2))], p=0.7),
  ])
  return tr(img)

class SyntheticPlateDataset(Dataset):
  """
  Dataset that generates fake synthetic spanish license plates and augments them to simulate perspective.
  """

  def __init__(self, num_samples: int = 10000):
    super().__init__()
    self.num_samples = num_samples
    self.font_main = ImageFont.truetype('/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf', 25)
    self.font_small = ImageFont.truetype('/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf', 8)
    self.transform = augment_image_v2
    self.translator = dict((l, n) for n, l in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', start=1))

  def __len__(self) -> int:
    return self.num_samples

  def __getitem__(self, idx: int) -> tuple:
    plate_text = generate_plate_text()
    plate = Image.new("L", (150, 32), color=230)
    draw = ImageDraw.Draw(plate)
    draw.rectangle([0, 0, 20, 32], fill=55)
    draw.text((8, 18), 'E', font=self.font_small, fill=230)
    draw.text((22, 2), plate_text, font=self.font_main, fill=50)
    plate = self.transform(plate).to(device)
    label = torch.tensor([self.translator[l] for l in plate_text if l != ' '], dtype=torch.long).to(device)
    return plate, label

In [40]:
print(dict((l, n) for n, l in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', start=1)))

{'A': 1, 'B': 2, 'C': 3, 'D': 4, 'E': 5, 'F': 6, 'G': 7, 'H': 8, 'I': 9, 'J': 10, 'K': 11, 'L': 12, 'M': 13, 'N': 14, 'O': 15, 'P': 16, 'Q': 17, 'R': 18, 'S': 19, 'T': 20, 'U': 21, 'V': 22, 'W': 23, 'X': 24, 'Y': 25, 'Z': 26, '0': 27, '1': 28, '2': 29, '3': 30, '4': 31, '5': 32, '6': 33, '7': 34, '8': 35, '9': 36}


In [43]:
def collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor]:
  imgs, labels = zip(*batch)
  imgs = torch.stack(imgs)
  labels = pad_sequence(labels, batch_first=True, padding_value=-1)  # rellena con -1
  return imgs, labels

def ctc_decode(pred_seq: torch.Tensor, blank: int=0) -> list[int]:
  decoded = []
  prev = None
  for p in pred_seq:
    if p != blank and p != prev:
      decoded.append(p)
    prev = p
  return decoded

class CRNN(nn.Module):

  def __init__(self):
    super(CRNN, self).__init__()

    self.cnn = nn.Sequential(
      nn.Conv2d(1, 64, (3, 3), padding=1),  # (1, 32, 150) -> (32, 32, 150)
      nn.ReLU(),
      nn.MaxPool2d(2, 2), # (32, 32, 150) -> (32, 16, 75)
      nn.Conv2d(64, 128, (3, 3), padding=1),  # (32, 16, 75) -> (64, 16, 75)
      nn.ReLU(),
      nn.MaxPool2d(2, 2),  # (64, 16, 75) -> (64, 8, 37)
      nn.Conv2d(128, 256, (3, 3), padding=1),  # (64, 8, 37) -> (128, 8, 37)
      nn.ReLU(),
      nn.MaxPool2d((1, 2), (2,1)),  # (128, 8, 37) -> (128, 4, 36)
      nn.Conv2d(256, 512, (3, 3), padding=1),  # (128, 4, 36) -> (256, 4, 36)
      nn.BatchNorm2d(512),
      nn.ReLU(),
      nn.Conv2d(512, 512, (3, 3), padding=1),  # (256, 4, 37) -> (256, 4, 36)
      nn.BatchNorm2d(512),
      nn.ReLU(),
      nn.Dropout2d(0.5),
      nn.MaxPool2d((2, 1), 1),  # (256, 4, 36) -> (256, 3, 36)
      nn.Conv2d(512, 512, (3, 3), padding=0),  # (256, 3, 36) -> (512, 1, 34)
      nn.BatchNorm2d(512),
      nn.ReLU(),
    )

    self.rnn = nn.GRU(512, 256, num_layers=2, batch_first=True, bidirectional=True) # (36, 512) -> (512, 36)

    self.fc = nn.Linear(512, 37) # (512, 36) -> (37, 36)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    _, ic, ih, iw = x.size() # (batch, channels=1, height=32, width=150)
    assert (ic, ih, iw) == (1, 32, 150), f'Input size ({ic}, {ih}, {iw}) does not correspond to expected size (1, 32, 150)'
    x = self.cnn(x) # (batch, channels=512, height=1, width=34)

    x = x.squeeze(2).permute(0, 2, 1)  # (batch, width=34, channels=512)

    x, _ = self.rnn(x) # (batch, seq_len=34, channels=512)
    x = self.fc(x) # (batch, seq_len=34, label=37)
    return x

In [36]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [44]:
TRANSLATOR = dict((l, n) for n, l in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', start=1))
total = 10000
num_test = total//7
num_train = total - num_test
num_epochs = 3
dataset = SyntheticPlateDataset(num_samples=total)
train_dataset, test_dataset = random_split(dataset, [num_train, num_test])

train_loader = DataLoader(train_dataset, batch_size=64, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64,  collate_fn=collate_fn)

#Model loading
model = CRNN().to(device)
criterion = nn.CTCLoss(zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=0.005)

#Training
for epoch in range(1, num_epochs+1):
  loop = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}")
  for batch_images, batch_labels in loop:
    # zero the parameter gradients
    optimizer.zero_grad()
    # forward
    batch_output = model(batch_images)
    log_probs = F.log_softmax(batch_output, dim=2).permute(1, 0 ,2)
    input_lengths = torch.full(size=(batch_output.size(0),), fill_value=batch_output.size(1), dtype=torch.long)
    targets = batch_labels
    batch_labels = [lbl[lbl != -1] for lbl in batch_labels]
    target_lengths = torch.tensor([len(l) for l in batch_labels], dtype=torch.long)
    loss = criterion(log_probs, targets, input_lengths, target_lengths)
    # backward
    loss.backward()
    # optimize
    optimizer.step()
    loop.set_postfix(loss=loss.item())


#Testing
model.eval()
correct = 0
with torch.no_grad():
  for batch_images, batch_labels in test_loader:
    output = model(batch_images).permute(1, 0 ,2)
    preds = output.argmax(dim=2).cpu().numpy().T

    batch_labels = [lbl[lbl != -1] for lbl in batch_labels]
    decoded_preds = [ctc_decode(seq) for seq in preds]

    for pred, target in zip(decoded_preds, batch_labels):
      target = target.cpu().numpy().tolist()
      if pred == target:
        correct += 1
  print(f"Exact match accuracy: {correct / num_test:.2%}")

Epoch 1/3:   0%|          | 0/134 [00:00<?, ?it/s]

Epoch 2/3:   0%|          | 0/134 [00:00<?, ?it/s]

Epoch 3/3:   0%|          | 0/134 [00:00<?, ?it/s]

Exact match accuracy: 0.00%


In [45]:
REVERSE_DICT = {v: k for k, v in TRANSLATOR.items()}
final_dataset = SyntheticPlateDataset(num_samples=2)
final_loader = DataLoader(final_dataset, batch_size=64,  collate_fn=collate_fn)
model.eval()
with torch.no_grad():
  for batch_images, batch_labels in final_loader:
    output = model(batch_images)
    preds = output.argmax(dim=2).cpu().numpy().T
    print(output)

    batch_labels = [lbl[lbl != -1] for lbl in batch_labels]
    decoded_preds = [ctc_decode(seq) for seq in preds]

    for pred, target in zip(decoded_preds, batch_labels):
      target = target.cpu().numpy().tolist()
      print(pred, target, list(map(lambda x: REVERSE_DICT[x], target)))

tensor([[[ 6.2799, -3.9095, -3.3256,  ...,  0.9906,  1.7291,  1.3155],
         [ 6.1752, -3.8151, -3.2253,  ...,  0.8959,  1.5170,  1.2116],
         [ 6.1866, -3.7950, -3.2009,  ...,  0.8611,  1.4974,  1.1791],
         ...,
         [ 3.3325,  1.6309,  3.6533,  ..., -2.3319, -0.8925, -1.6921],
         [ 3.3319,  1.6384,  3.6663,  ..., -2.3395, -0.8902, -1.6960],
         [ 3.3194,  1.6501,  3.6797,  ..., -2.3490, -0.8964, -1.7029]],

        [[ 6.8638, -4.4178, -3.5406,  ...,  3.0256,  3.0736,  3.3268],
         [ 6.7867, -4.3937, -3.5195,  ...,  2.9951,  3.0113,  3.3548],
         [ 6.7900, -4.3820, -3.4978,  ...,  2.9829,  3.0075,  3.3574],
         ...,
         [ 3.3261,  1.6094,  3.6365,  ..., -2.3124, -0.8833, -1.6755],
         [ 3.3264,  1.6000,  3.6309,  ..., -2.3083, -0.8751, -1.6716],
         [ 3.3141,  1.6174,  3.6480,  ..., -2.3220, -0.8845, -1.6823]]],
       device='cuda:0')
[] [29, 33, 28, 29, 13, 20, 7] ['2', '6', '1', '2', 'M', 'T', 'G']
[] [34, 32, 27, 30, 19, 1