<a href="https://colab.research.google.com/github/Mgalvaz/license_plate-recognizer/blob/main/notebooks/train_LPD_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import json
from tqdm.notebook import tqdm
import torch
from torch import nn
from torch.functional import F
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torchvision.ops import box_iou, nms
from torchvision.transforms import v2
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.image_list import ImageList
from torchvision.models.detection._utils import Matcher, BoxCoder
from torch.utils.data import Dataset
from PIL import Image

In [4]:
def transform_img(image: Image.Image) -> torch.Tensor:
    tr = v2.Compose([
        v2.PILToTensor(),
        v2.Resize((384, 384)),
        v2.ToDtype(torch.float, True),
    ])
    return tr(image)

class CarPlateTrainDataset(Dataset):

    def __init__(self, path: str) -> None:
        super().__init__()
        self.path = path + 'train/'
        self.train = []
        with open(path+'train.txt', 'r') as f:
            self.train = [x.rstrip('\n') for x in f]

    def __len__(self) -> int:
        return len(self.train)

    def __getitem__(self, item: int) -> tuple[torch.Tensor, torch.Tensor]:
        image_path = self.path + self.train[item] + '.jpg'
        label_path = self.path + self.train[item] + '.json'
        image = Image.open(image_path)
        w, h = image.size
        scale_x = 384 / w
        scale_y = 384 / h
        image = transform_img(image)
        with open(label_path) as f:
            full_label = json.load(f)
        labels = []
        for lbl in full_label['lps']:
            lp = torch.Tensor(lbl['poly_coord'])
            x_min = lp[:, 0].min() * scale_x
            y_min = lp[:, 1].min() * scale_y
            x_max = lp[:, 0].max() * scale_x
            y_max = lp[:, 1].max() * scale_y
            labels.append(torch.tensor([x_min, y_min, x_max, y_max], dtype=torch.float32, device=device))
        labels = torch.stack(labels)
        return image.to(device), labels

class CarPlateTestDataset(Dataset):

    def __init__(self, path: str) -> None:
        super().__init__()
        self.path = path + 'test/'
        self.test = []
        with open(path+'test.txt', 'r') as f:
            self.test = [x.rstrip('\n') for x in f]

    def __len__(self) -> int:
        return len(self.test)

    def __getitem__(self, item: int) -> tuple[torch.Tensor, torch.Tensor]:
        image_path = self.path + self.test[item] + '.jpg'
        label_path = self.path + self.test[item] + '.json'
        image = Image.open(image_path)
        w, h = image.size
        scale_x = 384 / w
        scale_y = 384 / h
        image = transform_img(image)
        with open(label_path) as f:
            full_label = json.load(f)
        labels = []
        for lbl in full_label['lps']:
            lp = torch.Tensor(lbl['poly_coord'])
            x_min = lp[:, 0].min() * scale_x
            y_min = lp[:, 1].min() * scale_y
            x_max = lp[:, 0].max() * scale_x
            y_max = lp[:, 1].max() * scale_y
            labels.append(torch.tensor([x_min, y_min, x_max, y_max], dtype=torch.float32, device=device))
        labels = torch.stack(labels)
        return image.to(device), labels

In [5]:
def detection_loss(cls_preds: torch.Tensor, reg_preds: torch.Tensor, anchors: list[torch.Tensor], gt_boxes: list[torch.Tensor], matcher: Matcher, box_coder: BoxCoder):
    batch_size = cls_preds.size(0)
    cls_loss = 0.0
    reg_loss = 0.0

    for i in range(batch_size):
        # Match anchors with GT
        iou_matrix = box_iou(gt_boxes[i], anchors[i])
        matched_idxs = matcher(iou_matrix)
        matched_mask = matched_idxs >= 0

        # Classification loss
        labels = torch.zeros_like(cls_preds[i][:, 0], dtype=torch.long)
        labels[matched_mask] = 1
        cls_loss += F.cross_entropy(cls_preds[i], labels)

        # Regression loss for matched anchors
        if matched_mask.sum() > 0:
            matched_gt_boxes = gt_boxes[i][matched_idxs[matched_mask]]
            matched_anchors = anchors[i][matched_mask]
            encoded_targets = box_coder.encode([matched_anchors], [matched_gt_boxes])[0]
            reg_loss += F.smooth_l1_loss(reg_preds[i][matched_mask], encoded_targets)

    cls_loss /= batch_size
    reg_loss /= batch_size

    return cls_loss + reg_loss

def collate_fn(batch: list) -> tuple[torch.Tensor, list[torch.Tensor]]:
    imgs, labels = zip(*batch)
    imgs = torch.stack(imgs)
    labels = list(labels)
    return imgs, labels

In [6]:
class InceptionResidual(nn.Module):

    def __init__(self, inc: int, outc: int, *args: tuple[int, int] | int, include_max_pool: bool =True) -> None:
        super().__init__()

        # Inception module
        num_branches = len(args) + (1 if include_max_pool else 0)
        outc_hidden = outc//num_branches
        self.inception = nn.ModuleList()
        for kernel in args:
            if isinstance(kernel, int):
                pad = kernel // 2
            else:
                pad = (kernel[0] // 2, kernel[1] // 2)
            self.inception.append(
                nn.Sequential(
                    nn.Conv2d(inc, outc_hidden, kernel, padding=pad),
                    nn.BatchNorm2d(outc_hidden),
                    nn.ReLU()
                )
            )
        if include_max_pool:
            self.inception.append(nn.Sequential(
                nn.MaxPool2d(3, 1, padding=1),
                nn.Conv2d(inc, outc_hidden, 1),
                nn.BatchNorm2d(outc_hidden),
                nn.ReLU()
            ))
        # Conv2D for the inception module
        self.final_conv =  nn.Conv2d(outc, outc, 1, bias=False)
        # Decoder
        self.decoder = nn.ReLU()

        # Residual module
        if inc == outc:
            self.residual = nn.Identity()
        else:
            self.residual = nn.Conv2d(inc, outc, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = [module(x) for module in self.inception]
        out = torch.cat(out, dim=1)

        out = self.final_conv(out) + self.residual(x)

        out = self.decoder(out)
        return out

class Prediction(nn.Module):

    def __init__(self, inc: int) -> None:
        super().__init__()
        self.cls = nn.Conv2d(inc, 2, kernel_size=3, padding=1) # (32, 16, 16) -> (2, 16, 16) or (48, 8, 8) -> (2, 8, 8)
        self.reg = nn.Conv2d(inc, 4, kernel_size=3, padding=1) # (32, 16, 16) -> (4, 16, 16) or (48, 8, 8) -> (4, 8, 8)

    def forward(self, x:torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        cls = self.cls(x)
        reg = self.reg(x)

        cls = cls.permute(0, 2, 3, 1).reshape(cls.size(0), -1, 2) # (2, 16, 16) -> (256, 2) or (2, 8, 8) -> (64, 2)
        reg = reg.permute(0, 2, 3, 1).reshape(reg.size(0), -1, 4) # (4, 16, 16) -> (256, 4) or (4, 8, 8) -> (64, 4)

        return cls, reg

class LPD(nn.Module):

    def __init__(self) -> None:
        super().__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(3, 16, 5, stride=3, padding=2), # (3, 384, 384) -> (16, 128, 128)
            InceptionResidual(16, 16, 5, 3, 3, include_max_pool=True), #(16, 128, 128) -> (16, 128, 128)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(16, 24, 3, stride=2, padding=1),  # (16, 128, 128) -> (24, 64, 64)
            InceptionResidual(24, 24, 5, 3, 3, include_max_pool=True), # (24, 64, 64) -> (24, 64, 64)
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(24, 32, 3, stride=2, padding=1),  # (24, 128, 128) -> (32, 32, 32)
            InceptionResidual(32, 32, 5, 3, 3, include_max_pool=True) , # (32, 32, 32) -> (32, 32, 32)
        )
        self.block4 = nn.Sequential(
            nn.Conv2d(32, 32, 3, stride=2, padding=1),  # (32, 32, 32) -> (32, 16, 16)
            InceptionResidual(32, 32, 5, 3, 3, include_max_pool=True),  # (32, 16, 16) -> (32, 16, 16)
        )
        self.block5 = nn.Sequential(
            nn.Conv2d(32, 48, 3, stride=2, padding=1),  # (32, 16, 16) -> (48, 8, 8)
            InceptionResidual(48, 48, 5, 3, 3, include_max_pool=True)  # (48, 8, 8) -> (48, 8, 8)
        )

        self.FM0 = Prediction(32) #(32, 16, 16) -> ((2, 16, 16) # matricula/no matricula, (4, 16, 16) # (dx, dy, dw, dh))
        self.FM1 = Prediction(48) #(48, 8, 8) -> ((2, 8, 8)# matricula/no matricula, (4, 8, 8) # (dx, dy, dw, dh))

        self.anchors = AnchorGenerator(sizes=((39.6,), (79.2,)), aspect_ratios=((2.0,), (2.0,))
    )

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]:
        images = x
        batch, channels, height, width = images.size()
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        fm0 = self.block4(x)
        fm1 = self.block5(fm0)

        image_list = ImageList(images, [(height, width)] * batch)
        anchors = self.anchors(image_list, [fm0, fm1])

        cls0, reg0 = self.FM0(fm0)
        cls1, reg1 = self.FM1(fm1)

        cls = torch.cat([cls0, cls1], dim=1)
        reg = torch.cat([reg0, reg1], dim=1)

        return cls, reg, anchors

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
#Model loading
device = torch.device('cpu')
model = LPD().to(device)
matcher = Matcher(0.5, 0.3)
box_coder = BoxCoder((1., 1., 1., 1.))
optimizer = AdamW(model.parameters(), lr=0.0002, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
last_epoch = 0

Using device: cuda


In [None]:
checkpoint = torch.load('model.pth', weights_only=True, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
last_epoch = checkpoint['epoch']
loss = checkpoint['loss']
print(f'Checkpoint loaded. Last epoch: {last_epoch}, with loss: {loss}')

In [12]:
train_dataset = CarPlateTrainDataset('drive/MyDrive/dataset/')
test_dataset = CarPlateTestDataset('drive/MyDrive/dataset/')
train_loader = DataLoader(train_dataset, batch_size=64, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64, collate_fn=collate_fn)

model.train()
num_epochs = 1 + last_epoch
for epoch in range(last_epoch+1, num_epochs+1):
  loop = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}")
  for images, targets in loop:
    cls_preds, reg_preds, anchors = model(images)
    loss = detection_loss(cls_preds, reg_preds, anchors, targets, matcher, box_coder)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()
    loop.set_postfix(loss=loss.item())

Epoch 0/1:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch 1/1:   0%|          | 0/25 [00:00<?, ?it/s]

In [None]:
model.eval()
total_gt = 0
correct_detections = 0
with torch.no_grad():
  for images, targets in test_loader:

    cls_preds, reg_preds, anchors = model(images)
    batch_size = images.size(0)

    for i in range(batch_size):

      gt_boxes = targets[i]
      total_gt += len(gt_boxes)

      scores = cls_preds[i].softmax(dim=1)[:, 1]
      mask = scores > 0.6

      if mask.sum() == 0:
        continue

      # Decodificar cajas
      print('anchors y cajas predichas antes de codificar')
      print(anchors[i][mask])
      print(reg_preds[i][mask])
      pred_boxes = box_coder.decode(anchors[i][mask], [reg_preds[i][mask]]).reshape(1, -1, 4).squeeze(0)
      pred_scores = scores[mask]
      print('cajas predichas y probabilidades despues de codificar')
      print(pred_boxes)
      print(pred_scores)

      # NMS
      keep = nms(pred_boxes, pred_scores, iou_threshold=0.5)
      pred_boxes = pred_boxes[keep]
      print('cajas predichas tras nms')
      print(pred_boxes)

      # Comparar cada GT con las predicciones
      if len(pred_boxes) > 0:
        iou = box_iou(gt_boxes, pred_boxes)
        max_iou_per_gt, _ = iou.max(dim=1)

          # Contar GT detectadas con IoU suficiente
          correct_detections += (max_iou_per_gt >= 0.5).sum().item()
    break
recall = correct_detections / total_gt if total_gt > 0 else 0
print("Recall:", recall)
print(f"Accuracy of detection = {recall * 100:.2f}%")