<a href="https://colab.research.google.com/github/Mgalvaz/license_plate-recognizer/blob/main/notebooks/train_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [30]:
import random
import torch
import torch.optim as optim
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import v2
from PIL import Image, ImageDraw, ImageFont

In [31]:
def generate_plate_text() -> str:
  nums = ''.join(random.choices('0123456789', k=4))

  consonants = "BCDFGHJKLMNPQRSTVWXYZ"
  vowels = "AEIOU"
  alphabet = consonants + vowels
  weights = [10] * len(consonants) + [1] * len(vowels)
  letters = ''.join(random.choices(alphabet, weights=weights, k=3))

  return f"{nums}  {letters}"

def augment_image_v2(img: Image.Image) -> torch.Tensor:
  tr = v2.Compose([
    v2.PILToTensor(),
    v2.ToDtype(torch.float, True),
    v2.RandomRotation(degrees=8, fill=0.392),
    v2.RandomPerspective(distortion_scale=0.25, p=1.0, fill=0.392),
    v2.GaussianNoise(sigma=0.05),
    v2.RandomApply([v2.ColorJitter(brightness=(0.8, 1.2), contrast=(0.8, 1.2))], p=0.7),
  ])
  return tr(img)

class SyntheticPlateDataset(Dataset):
  """
  Dataset that generates fake synthetic spanish license plates and augments them to simulate perspective.
  """

  def __init__(self, num_samples: int = 10000):
    super().__init__()
    self.num_samples = num_samples
    self.font_main = ImageFont.truetype('/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf', 25)
    self.font_small = ImageFont.truetype('/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf', 8)
    self.transform = augment_image_v2
    self.translator = dict((l, n) for n, l in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', start=1))

  def __len__(self) -> int:
    return self.num_samples

  def __getitem__(self, idx: int) -> tuple:
    plate_text = generate_plate_text()
    plate = Image.new("L", (150, 32), color=230)
    draw = ImageDraw.Draw(plate)
    draw.rectangle([0, 0, 20, 32], fill=55)
    draw.text((8, 18), 'E', font=self.font_small, fill=230)
    draw.text((22, 2), plate_text, font=self.font_main, fill=50)
    plate = self.transform(plate).to(device)
    label = torch.tensor([self.translator[l] for l in plate_text if l != ' '], dtype=torch.long).to(device)
    return plate, label

In [32]:
def collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor]:
  imgs, labels = zip(*batch)
  imgs = torch.stack(imgs)
  labels = pad_sequence(labels, batch_first=True, padding_value=-1)  # rellena con -1
  return imgs, labels

def ctc_decode(pred_seq: torch.Tensor, blank: int=0) -> list[int]:
  decoded = []
  prev = None
  for p in pred_seq:
    if p != blank and p != prev:
      decoded.append(p)
    prev = p
  return decoded

class CRNN(nn.Module):

  def __init__(self):
    super(CRNN, self).__init__()

    self.cnn = nn.Sequential(
      nn.Conv2d(1, 32, (3, 3), padding=1),  # (1, 32, 150) -> (32, 32, 150)
      nn.ReLU(),
      nn.MaxPool2d(2, 2), # (32, 32, 150) -> (32, 16, 75)
      nn.Conv2d(32, 64, (3, 3), padding=1),  # (32, 16, 75) -> (64, 16, 75)
      nn.ReLU(),
      nn.MaxPool2d(2, 2),  # (64, 16, 75) -> (64, 8, 37)
      nn.Conv2d(64, 128, (3, 3), padding=1),  # (64, 8, 37) -> (128, 8, 37)
      nn.ReLU(),
      nn.MaxPool2d((1, 2), (2,1)),  # (128, 8, 37) -> (128, 4, 36)
      nn.Conv2d(128, 256, (3, 3), padding=1),  # (128, 4, 36) -> (256, 4, 36)
      nn.BatchNorm2d(256),
      nn.ReLU(),
      nn.Conv2d(256, 256, (3, 3), padding=1),  # (256, 4, 37) -> (256, 4, 36)
      nn.BatchNorm2d(256),
      nn.ReLU(),
      nn.MaxPool2d((2, 1), 1),  # (256, 4, 36) -> (256, 3, 36)
      nn.Conv2d(256, 512, (3, 3), padding=0),  # (256, 3, 36) -> (512, 1, 34)
      nn.BatchNorm2d(512),
      nn.ReLU(),
    )

    self.rnn = nn.GRU(512, 256, num_layers=2, batch_first=True, bidirectional=True) # (36, 512) -> (512, 36)

    self.fc = nn.Linear(512, 37) # (512, 36) -> (37, 36)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    _, ic, ih, iw = x.size() # (batch, channels=1, height=32, width=150)
    assert (ic, ih, iw) == (1, 32, 150), f'Input size ({ic}, {ih}, {iw}) does not correspond to expected size (1, 32, 150)'
    x = self.cnn(x) # (batch, channels=512, height=1, width=34)

    x = x.squeeze(2).permute(0, 2, 1)  # (batch, width=34, channels=512)

    x, _ = self.rnn(x) # (batch, seq_len=34, channels=512)
    x = self.fc(x) # (batch, seq_len=34, label=37)
    return x.log_softmax(2)

In [33]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Usando dispositivo:", device)


Usando dispositivo: cuda


In [34]:
TRANSLATOR = dict((l, n) for n, l in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', start=1))
total = 700
num_test = total//7
num_train = total - num_test
num_epochs = 1
dataset = SyntheticPlateDataset(num_samples=total)
train_dataset, test_dataset = random_split(dataset, [num_train, num_test])

train_loader = DataLoader(train_dataset, batch_size=64, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64,  collate_fn=collate_fn)

#Model loading
model = CRNN().to(device)
criterion = nn.CTCLoss(zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=0.01)

#Training
for epoch in range(num_epochs):
  print(f'Epoch {epoch+1} out of {num_epochs}')
  iteration = 1
  num_iterations = len(train_loader)
  for batch_images, batch_labels in train_loader:
    print(f'iteration: {iteration} out of {num_iterations}')
    iteration += 1
    optimizer.zero_grad()

    batch_output = model(batch_images)
    log_probs = batch_output.permute(1, 0, 2)
    input_lengths = torch.full(size=(batch_output.size(0),), fill_value=batch_output.size(1), dtype=torch.long)
    batch_labels = [lbl[lbl != -1] for lbl in batch_labels] # Remove -1 padding
    targets = torch.cat(batch_labels)
    target_lengths = torch.tensor([len(l) for l in batch_labels], dtype=torch.long)

    loss = criterion(log_probs, targets, input_lengths, target_lengths)
    loss.backward()
    optimizer.step()
  print()

#Testing
model.eval()
correct = 0
with torch.no_grad():
  for batch_images, batch_labels in test_loader:
    output = model(batch_images).permute(1, 0 ,2)
    preds = output.argmax(dim=2).cpu().numpy().T

    batch_labels = [lbl[lbl != -1] for lbl in batch_labels]
    decoded_preds = [ctc_decode(seq) for seq in preds]

    for pred, target in zip(decoded_preds, batch_labels):
      target = target.cpu().numpy().tolist()
      if pred == target:
        correct += 1
  print(f"Exact match accuracy: {correct / num_test:.2%}")

Epoch 1 out of 1
iteration: 1 out of 10
iteration: 2 out of 10
iteration: 3 out of 10
iteration: 4 out of 10
iteration: 5 out of 10
iteration: 6 out of 10
iteration: 7 out of 10
iteration: 8 out of 10
iteration: 9 out of 10
iteration: 10 out of 10

Exact match accuracy: 0.00%
