# MobileNetV3 YOLOv3 for Text Detection

A text detection model based on MobileNetV3 and YOLOv3.<br>
Pruned and quantized for deployment on edge devices.

- [x] Pretrained MobileNetV2 backbone
- [x] YOLOv3 top end
- [x] Basic Pruning, Quantization integration
- [x] Training pipeline (for ICDAR 2015)
- [x] Switch backbone to MobileNetV3
- [x] Mixed Precision Training
- [x] Advanced Pruning and quantization

- [ ] Basic Inference
- [ ] Performance Evaluation
- [ ] Deflate Jupyter Notebook into file structure
- [ ] Advanced training pipeline (COCO-Text dataset, batch augmentation, etc.)
- [ ] Live Image-Feed Inference

In [14]:
import os
import csv
import torch
import random
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.init as init
import torchvision.ops as ops
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torchvision.models as models
import torch.nn.utils.prune as prune
import torch.quantization.quantize_fx as quantize_fx

from PIL import Image, ImageDraw
from pathlib import Path
from torch.nn import functional as F
from torchvision import transforms
from torch.cuda.amp import autocast
from torch.quantization import quantize_dynamic
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

%matplotlib inline

In [15]:
torch.manual_seed(42)
torch.cuda.manual_seed_all(42);

---

## Model Definition

In [19]:
class MobileYOLOv3(nn.Module):
    """
    In:  (batch_size, 3, 448, 224)
    Out: (batch_size, 7, 7, num_anchors * (5 + num_classes))
          5 + num_classes = 4 (x,y,w,h) 
                            + 1 (objectness) 
                            + num_classes (class probabilities)
    Making for 7x7=49 grid tiles with num_anchors many (5 + num_classes) values each.
    YoloV3 originally starts 13x13, but I deviate for a smaller model.
    """

    def __init__(self, num_classes=1, num_anchors=3):
        super(MobileYOLOv3, self).__init__()
        self.num_classes = num_classes  # 1, but keep this flexible
        self.num_anchors = num_anchors  # 3, like the original YOLOv3
        self.conv1 = nn.Conv2d(3, 3, kernel_size=3, stride=2, padding=1)
        self.mobilenet = models.mobilenet_v3_small(weights='IMAGENET1K_V1').features
        self.conv2 = nn.Conv2d(576, num_anchors * (5 + num_classes), kernel_size=1, stride=1, padding=0)

    def forward(self, x):           # Input shape: (batch_size, 3, 448, 224)
        x = F.relu(self.conv1(x))   # (batch_size, 3, 224, 112)
        x = self.mobilenet(x)       # (batch_size, 576, 7, 7)
        x = F.relu(self.conv2(x))   # (batch_size, num_anchors * (5 + num_classes), 7, 7)
        return x.permute(0, 2, 3, 1).contiguous()   # (batch_size, 7, 7, num_anchors * (5 + num_classes))

---

## Pruning & Quantization Definition

In [17]:
def prune_model(model, amount=0.3):
    for _, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=amount)
            prune.ln_structured(module, name='weight', amount=amount, n=2, dim=0)
            prune.remove(module, 'weight')
            if module.bias is not None:
                prune.l1_unstructured(module, name='bias', amount=amount/2)
                prune.remove(module, 'bias')
    parameters_to_prune = [(module, 'weight') for module in model.modules() if isinstance(module, (nn.Conv2d, nn.Linear))]
    prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=amount/2)
    return model

def quantize_model(model, device):
    model = model.cpu()  # Quantization happens only on CPU
    quantized_model = torch.quantization.quantize_dynamic(
        model, {nn.Conv2d, nn.Linear}, dtype=torch.qint8, inplace=True  # Use inplace=True against deepcopy
    )
    return quantized_model.to(device)

---

## Dataset

In [None]:
class ICDAR2015(Dataset):
    """
    ICDAR2015 Dataset for YOLOv3 training.
    Required Inputs: input_path, label_path,
    Output: (batch_size, 3, 448, 224), (batch_size, 7, 7, num_anchors * (5 + num_classes))
    (5 = 4 (x,y,w,h) + 1 (objectness) + num_classes (class probabilities))
    """
    def __init__(self, input_path, label_path, num_classes=1, num_anchors=3, transform=(448, 224), grid_size=(7, 7), device=None):
        self.input_path = Path(input_path)
        self.label_path = Path(label_path)
        self.num_classes = num_classes
        self.num_anchors = num_anchors
        self.image_size = transform
        self.grid_size = grid_size
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.batch_count = 0
        
        self.transform = transforms.Compose([
            transforms.Resize(transform),
            transforms.ToTensor()
        ])
        
        self.data = [(img, label) for img, label in zip(list(self.input_path.glob('*.jpg')), list(self.label_path.glob('*.txt')))]
        self.labels = [self._parse_label(label, Image.open(img).size) for img, label in self.data]

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path, _ = self.data[idx]
        img = Image.open(img_path)
        img = self.transform(img).to(self.device)
        label = self.labels[idx]
        return img, label
    
    def __iter__(self):
        self.index = 0
        return self

    def __next__(self):
        if self.index >= len(self):
            raise StopIteration
        item = self[self.index]
        self.index += 1
        return item
    
    def _parse_label(self, label_path, img_size):
        labels = torch.zeros((self.grid_size[0], self.grid_size[1], self.num_anchors * (5 + self.num_classes)), device=self.device)
        with open(label_path, 'r') as file:
            reader = csv.reader(file, delimiter=',')
            for row in reader:
                row = torch.tensor([float(i) for i in row[:-1]], device=self.device)
                x, y = (row[0::2].sum() / 4, row[1::2].sum() / 4)
                w = row[0::2].max() - row[0::2].min()
                h = row[1::2].max() - row[1::2].min()
                
                x = min(x, img_size[0] - 1e-3)
                y = min(y, img_size[1] - 1e-3)
                w = min(w, img_size[0] - 1e-3)
                h = min(h, img_size[1] - 1e-3)

                x, y = x / img_size[0] * self.grid_size[0], y / img_size[1] * self.grid_size[1]
                
                grid_x, grid_y = min(max(int(x), 0), self.grid_size[0] - 1), min(max(int(y), 0), self.grid_size[1] - 1)
                x, y = x - grid_x, y - grid_y
                
                w, h = w / img_size[0], h / img_size[1]

                obj = 1.0   # Objectness
                cls = 0.0 if self.num_classes == 1 else torch.zeros(self.num_classes, device=self.device) # Class probabilities

                # 4 (x,y,w,h) + 1 (objectness) + num_classes (class probabilities)
                box = torch.tensor([x, y, w, h, obj, cls] if self.num_classes == 1 else [x, y, w, h, obj] + cls.tolist(), device=self.device)
                center_distance = torch.sqrt((box[2] - 0.5)**2 + (box[3] - 0.5)**2).item()
                
                for anchor in range(self.num_anchors):
                    anchor_slice = slice(anchor * (5 + self.num_classes), (anchor + 1) * (5 + self.num_classes))
                    if labels[grid_y, grid_x, anchor_slice].sum() == 0:
                        labels[grid_y, grid_x, anchor_slice] = box
                        break
                    elif labels[grid_y, grid_x, anchor * (5 + self.num_classes) + 4] == 1.0:
                        c_box = labels[grid_y, grid_x, anchor_slice]
                        c_distance = torch.sqrt((c_box[2] - 0.5)**2 + (c_box[3] - 0.5)**2).item()
                        if center_distance < c_distance:
                            labels[grid_y, grid_x, anchor_slice] = box
                            break
        return labels

    def next_batch(self, batch_size, randomized=True):
        if randomized:
            indices = np.random.choice(len(self), batch_size, replace=False)
        else:
            indices = np.arange(self.batch_count, self.batch_count + batch_size) % len(self)
            self.batch_count += batch_size
        
        batch_images = torch.stack([self[i][0] for i in indices]) # Images
        batch_labels = torch.stack([self.labels[i] for i in indices]) # Labels
        return batch_images.to(self.device), batch_labels.to(self.device)

    @staticmethod
    def collate_fn(batch):
        images, labels = zip(*batch)
        images = torch.stack(images)
        labels = torch.stack(labels)
        return images, labels

---

## Loss

In [5]:
class YoLoss(nn.Module):
    def __init__(self, num_classes=1, num_anchors=3, lambda_coord=5, lambda_noobj=0.5):
        super(YoLoss, self).__init__()
        self.num_classes = num_classes
        self.num_anchors = num_anchors
        self.lambda_coord = lambda_coord
        self.lambda_noobj = lambda_noobj

    def forward(self, predictions, targets):
        batch_size = predictions.size(0)
        grid_size = predictions.size(1)

        # TODO. Maybe this works?!?
        predictions = predictions.view(batch_size, grid_size, grid_size, self.num_anchors, 5 + self.num_classes)
        
        # Separate components of the predictions
        pred_x     = predictions[..., 0]
        pred_y     = predictions[..., 1]
        pred_w     = predictions[..., 2]
        pred_h     = predictions[..., 3]
        pred_obj   = predictions[..., 4]
        pred_class = predictions[..., 5:]

        # Separate components of the targets
        target_x     = targets[..., 0]
        target_y     = targets[..., 1]
        target_w     = targets[..., 2]
        target_h     = targets[..., 3]
        target_obj   = targets[..., 4]
        target_class = targets[..., 5:]

        # Create object mask
        obj_mask = target_obj.bool()
        noobj_mask = ~obj_mask

        # Coordinate loss
        coord_loss = self.lambda_coord * obj_mask * (
            F.mse_loss(pred_x, target_x, reduction='sum') +
            F.mse_loss(pred_y, target_y, reduction='sum') +
            F.mse_loss(torch.sqrt(pred_w), torch.sqrt(target_w), reduction='sum') +
            F.mse_loss(torch.sqrt(pred_h), torch.sqrt(target_h), reduction='sum')
        )

        # Objectness loss
        obj_loss = obj_mask * F.binary_cross_entropy(pred_obj, target_obj, reduction='sum')
        noobj_loss = self.lambda_noobj * noobj_mask * F.binary_cross_entropy(pred_obj, target_obj, reduction='sum')

        # Class loss (only if num_classes > 1)
        if self.num_classes > 1:
            class_loss = obj_mask * F.binary_cross_entropy(pred_class, target_class, reduction='sum')
        else:
            class_loss = 0

        # Total loss
        total_loss = (coord_loss + obj_loss + noobj_loss + class_loss) / batch_size

        return total_loss

---

## Training

In [None]:
batch_size = 32
num_workers = 4
num_classes = 1
learning_rate = 1e-3
num_epochs = 25
target_architecture = 'cuda' # else 'cpu'

# https://www.kaggle.com/datasets/bestofbests9/icdar2015
dataset_path = Path('/kaggle/input/icdar2015')
train_path = dataset_path / 'ch4_training_images'
train_labels = dataset_path / 'ch4_training_localization_transcription_gt'
test_path = dataset_path / 'ch4_test_images'
test_labels = dataset_path / 'ch4_test_localization_transcription_gt'

model_path = 'pq_yolov3_mobilenetv3.pth'

train_dataset = ICDAR2015(train_path, train_labels, num_classes)
test_dataset = ICDAR2015(test_path, test_labels, num_classes)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using training device: {device}")

In [None]:
model = MobileYOLOv3(num_classes=num_classes).to(device)
criterion = YoLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=2e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
scaler = torch.amp.GradScaler(str(device))

lossi = []
losst = []

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    
    for batch_idx, (data, targets) in enumerate(train_loader):        
        data = data.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad()

        with torch.amp.autocast(device_type=str(device)):
            outputs = model(data)
            loss = criterion(outputs, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()
        
    epoch_loss = epoch_loss / len(train_loader)
    lossi.append(epoch_loss)
    
    t_loss = test_loss(model, test_loader, criterion, device)
    losst.append(t_loss)

    scheduler.step(epoch_loss)

    # Print training and test loss
    print(f'Epoch [{epoch+1:3}/{num_epochs}] | Train Loss: {epoch_loss:8.3f} | Test Loss: {t_loss:8.3f} | LR: {optimizer.param_groups[0]["lr"]:.6f}')

In [None]:
# Prune, Quantize
pruned_model = prune_model(model)
quantized_model = quantize_model(pruned_model, device)

# Save the quantized model
torch.save(quantized_model.state_dict(), model_path)

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(range(num_epochs), lossi, label="Training Loss", color='blue', marker='o', linestyle='-', markersize=3)
plt.plot(range(num_epochs), losst, label="Test Loss", color="red", marker='o', linestyle='-', markersize=3)

plt.title('Training + Test Loss Over Epochs', fontsize=16)
plt.xlabel('Epochs', fontsize=14)
plt.ylabel('Loss', fontsize=14)
plt.grid(True)
plt.legend(loc='upper right')

plt.show();

In [None]:
def load_model(model_class, num_classes, model_path, target_device='cpu'):
    """
    Load a PyTorch model for inference on the target device, regardless of where it was originally trained.
    """
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"No file at {model_path}")

    if isinstance(target_device, str):
        target_device = torch.device(target_device)

    # Load the state dict to CPU first
    state_dict = torch.load(model_path, map_location=target_device, weights_only=False)
    
    # If it's a full checkpoint, extract just the model state dict
    if isinstance(state_dict, dict) and 'model_state_dict' in state_dict:
        state_dict = state_dict['model_state_dict']

    # Remove pruning-related keys
    new_state_dict = {}
    for key, value in state_dict.items():
        if 'weight_mask' not in key:
            new_key = key.replace('weight_orig', 'weight')
            new_state_dict[new_key] = value
        
    model = model_class(num_classes)
    model.load_state_dict(new_state_dict)
    model = model.to(target_device)
    model.eval()
    return model

In [None]:
def yolo_to_icdar2015(predictions, original_sizes, conf_threshold=0.5, grid_size=(8, 14), input_size=(256, 448)):
    batch_size, _, grid_h, grid_w = predictions.shape
    predictions = predictions.view(batch_size, 5, -1, grid_h, grid_w).permute(0, 1, 3, 4, 2).contiguous()

    input_w, input_h = input_size  # Resized width and height
    bboxes_per_image = []

    for img_idx in range(batch_size):
        boxes = []
        orig_w, orig_h = original_sizes[img_idx]  # Get the original size of the current image

        # Calculate the scaling factors between the resized input and the original image size
        scale_w, scale_h = orig_w / input_w, orig_h / input_h

        for i in range(5):  # Iterate over the 5 anchor boxes
            for y in range(grid_h):
                for x in range(grid_w):
                    pred = predictions[img_idx, i, y, x]
                    obj_conf = pred[4].sigmoid()  # Objectness confidence
                    if obj_conf > conf_threshold:
                        # Extract bounding box (cx, cy, w, h)
                        cx, cy, w, h = pred[:4].sigmoid()  # Apply sigmoid to bounding box dimensions

                        # Convert relative to absolute pixel coordinates (in the resized image's grid space)
                        cx_abs = (x + cx) * (input_w / grid_w)  # Center x relative to the resized image
                        cy_abs = (y + cy) * (input_h / grid_h)  # Center y relative to the resized image
                        w_abs = w * input_w  # Width relative to the resized image
                        h_abs = h * input_h  # Height relative to the resized image

                        # Calculate corner coordinates from (cx, cy, w, h) in the resized image space
                        x_min = cx_abs - w_abs / 2
                        y_min = cy_abs - h_abs / 2
                        x_max = cx_abs + w_abs / 2
                        y_max = cy_abs + h_abs / 2

                        # Scale the coordinates back to the original image size
                        x_min *= scale_w
                        x_max *= scale_w
                        y_min *= scale_h
                        y_max *= scale_h

                        # Convert to four-point bounding box (assuming upright box for simplicity)
                        box = [
                            x_min, y_min,  # Top-left corner
                            x_max, y_min,  # Top-right corner
                            x_max, y_max,  # Bottom-right corner
                            x_min, y_max   # Bottom-left corner
                        ]
                        boxes.append(box)
        
        bboxes_per_image.append(boxes)
    return bboxes_per_image

In [None]:
def process_batch(model, dataset, device, num_images=10, conf_threshold=0.5):
    model.eval()
    sampled_indices = random.sample(range(len(dataset)), num_images)
    
    original_images = []
    original_sizes = []  # List to hold original sizes
    transformed_images = []

    for idx in sampled_indices:
        image, _ = dataset[idx]
        original_size = (image.shape[2], image.shape[1])  # (width, height)

        original_images.append(image)
        original_sizes.append(original_size)  # Append original size
        transformed_images.append(image.unsqueeze(0))  # Add batch dimension

    transformed_images = torch.cat(transformed_images).to(device)

    with torch.no_grad():
        predictions = model(transformed_images)

    return yolo_to_icdar2015(predictions, original_sizes, conf_threshold), original_images

In [None]:
def visualize_predictions(model, dataset, device, num_images=10, conf_threshold=0.5):
    predicted_boxes, original_images = process_batch(model, dataset, device, num_images, conf_threshold)

    for img_idx in range(num_images):
        img = original_images[img_idx]
        img_disp = img.permute(1, 2, 0).cpu().numpy()  # Convert to HWC format for display

        fig, ax = plt.subplots(1)
        ax.imshow(img_disp)
        boxes = predicted_boxes[img_idx]
        
        for box in boxes:
            x1, y1, x2, y2, x3, y3, x4, y4 = box
            rect = patches.Rectangle((x1, y1), x3 - x1, y3 - y1, linewidth=2, edgecolor='red', facecolor='none')
            ax.add_patch(rect)
        
        plt.axis('off')  # Hide axes
        plt.show()

In [None]:
model = load_model(MobileYOLOv3, num_classes, model_path, 'cpu')
dataset = ICDAR2015(train_path, train_labels)
visualize_predictions(model, dataset, 'cpu', num_images=10, conf_threshold=0.973) # Still horrible