# MobileNetV3 YOLOv3 on ICDAR 2015

A text detection model with a MobileNetV3 backbone, trained according to the YOLOv3 paradigm on the ICDAR 2015 dataset.<br>
One-shot pruned and quantized for deployment on edge-devices.

In [1]:
import gc
import csv
import math
import torch
import random
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import torch_optimizer as optim
import matplotlib.patches as patches
import torch.nn.utils.prune as prune
import torchvision.transforms.functional as TF

from PIL import Image
from pathlib import Path
from torchvision import transforms
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import OneCycleLR
from mobileyolov3 import MobileYOLOv3, DSConv, Resizer

%matplotlib inline

---

## Hyperparameters

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using training device: {device}")

batch_size = 128
num_workers = 4
num_classes = 1
learning_rate = 5e-4
num_epochs = 500
lr_warmup = num_epochs * 0.1
weight_decay = 1e-5
optim_k = 5
optim_alpha = 0.3
warmup_start = 0.35
warmup_epochs = lr_warmup
scheduler_t0 = warmup_epochs
scheduler_tmult = 2
prune_amount = 0.2
dropout_rate = 0.3
target_architecture = 'cuda'
anchors = [(0.28, 0.35), (0.43, 0.58), (0.62, 0.78)]
num_anchors = len(anchors)

# https://www.kaggle.com/datasets/bestofbests9/icdar2015
dataset_path = Path("./icdar2015/")
train_path = dataset_path / 'ch4_training_images'
train_labels = dataset_path / 'ch4_training_localization_transcription_gt'
test_path = dataset_path / 'ch4_test_images'
test_labels = dataset_path / 'ch4_test_localization_transcription_gt'

model_path = 'mobileyolov3_icdar2015.pth'

---

## Activation Functions

We will need to incorporate the correct activation functions and place them throughout our model.<br>
If you expect a value to only ever be $[0;1]$, then using LeakyReLU might not be as computationally effective (projects into $[0;\infty]$) as Sigmoid (projects into $[0;1]$).

Plotting some activation function candidates helps in making sure that you really select the most fitting activation for your setting:

In [None]:
# This is interesting too: https://pat.chormai.org/blog/2020-relu-softplus (considered, but didn't end up using it)
# For Swish: https://arxiv.org/pdf/1710.05941.pdf (esp pp. 5-6) (considered, but didn't end up using it)

x = np.linspace(-10, 10, 400)
y_relu = np.maximum(0, x)
y_leaky_relu_0_1 = np.where(x > 0, x, 0.1 * x)
y_leaky_relu_0_02 = np.where(x > 0, x, 0.02 * x)
y_sigmoid = 1 / (1 + np.exp(-x))
y_softplus_beta_0_5 = (1 / 0.5) * np.log(1 + np.exp(0.5 * x))
y_softplus_beta_1 = (1 / 1) * np.log(1 + np.exp(1 * x))
y_swish_derivative = (x * y_sigmoid) + y_sigmoid * (1 - (x * y_sigmoid))
y_elu = np.where(x > 0, x, np.exp(x) - 1)

plt.figure(figsize=(8, 4))
plt.plot(x, y_relu, label="ReLU", linewidth=2, color='black')
plt.plot(x, y_sigmoid, label="Sigmoid", linewidth=2, color='red')
plt.plot(x, y_elu, label="ELU", linewidth=2, linestyle='dashed', color='brown')
plt.plot(x, y_leaky_relu_0_1, label="Leaky ReLU (α=0.1)", linewidth=2, linestyle='dotted', color='green')
plt.plot(x, y_leaky_relu_0_02, label="Leaky ReLU (α=0.02)", linewidth=2, linestyle='dotted', color='gray')
plt.plot(x, y_softplus_beta_0_5, label="Softplus (β=0.5)", linewidth=2, color='blue')
plt.plot(x, y_softplus_beta_1, label="Softplus (β=1)", linewidth=2, linestyle='dashed', color='orange')
plt.plot(x, y_swish_derivative, label="Swish Derivative", linewidth=2, color='purple')

plt.title('Compared Activation Functions', fontsize=14)
plt.xlabel('x', fontsize=12)
plt.ylabel('activation(x)', fontsize=12)
plt.legend(loc='upper left')
plt.axhline(0, color='black', linewidth=0.5)
plt.axvline(0, color='black', linewidth=0.5)
plt.grid(True)
plt.xlim(-7, 7)
plt.ylim(-1, 3)

plt.show()

---

## Pruning & Quantization Definition

Pruning and quantization will get applied after model training concluded.

Pruning picks out tiny/non-contributing weights and removes them from the model structure all together.<br>
This frees up memory and processing resources at little to no cost in accuracy. Impact on accuracy varies per use-case though.<br>
Quantization is an additional step that, considering weights, doesn't look at the value itself, but the difference between required and actually granted numeric precision.<br>
If a weight can be represented little informational loss through a coarser numeric precision, this could decrease memory and computation demands.

In [4]:
def prune_model(model, amount):
    def prune_conv(conv, amount):
        prune.ln_structured(conv, name='weight', amount=amount, n=2, dim=0)

    for _, module in model.named_modules():
        if isinstance(module, DSConv):
            prune_conv(module.depthwise, amount)
            prune_conv(module.pointwise, amount)
        elif isinstance(module, Resizer):
            if isinstance(module.conv, DSConv):
                prune_conv(module.conv.depthwise, amount)
                prune_conv(module.conv.pointwise, amount)
            else:
                prune_conv(module.conv, amount)
        elif isinstance(module, nn.Conv2d):
            prune_conv(module, amount)

    parameters_to_prune = []
    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            parameters_to_prune.append((module, 'weight'))
        elif isinstance(module, DSConv):
            parameters_to_prune.extend([(module.depthwise, 'weight'), (module.pointwise, 'weight')])
        elif isinstance(module, nn.Linear):
            parameters_to_prune.append((module, 'weight'))
    prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=amount)
    return model

def quantize_model(model, device):
    """
    Crunch numeric precision of weights and activations.
    """
    model = model.cpu() # Works only on CPU
    quantized_model = torch.quantization.quantize_dynamic(model, {nn.Conv2d, nn.Linear}, dtype=torch.qint8, inplace=True) # inplace=True avoids deepcopy issues
    return quantized_model.to(device)

def lift_pruning(model):
    for module in model.modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            if hasattr(module, 'weight_mask'):
                module.weight.data *= module.weight_mask
                prune.remove(module, 'weight')
            elif hasattr(module, 'weight_orig'):
                # If there's a weight_orig but no mask, it means pruning was applied but the mask was removed
                module.weight.data = module.weight_orig.data
                delattr(module, 'weight_orig')
        elif isinstance(module, DSConv):
            for conv in [module.depthwise, module.pointwise]:
                if hasattr(conv, 'weight_mask'):
                    conv.weight.data *= conv.weight_mask
                    prune.remove(conv, 'weight')
                elif hasattr(conv, 'weight_orig'):
                    conv.weight.data = conv.weight_orig.data
                    delattr(conv, 'weight_orig')
        elif isinstance(module, Resizer):
            if isinstance(module.conv, DSConv):
                for conv in [module.conv.depthwise, module.conv.pointwise]:
                    if hasattr(conv, 'weight_mask'):
                        conv.weight.data *= conv.weight_mask
                        prune.remove(conv, 'weight')
                    elif hasattr(conv, 'weight_orig'):
                        conv.weight.data = conv.weight_orig.data
                        delattr(conv, 'weight_orig')
            else:
                if hasattr(module.conv, 'weight_mask'):
                    module.conv.weight.data *= module.conv.weight_mask
                    prune.remove(module.conv, 'weight')
                elif hasattr(module.conv, 'weight_orig'):
                    module.conv.weight.data = module.conv.weight_orig.data
                    delattr(module.conv, 'weight_orig')
    return model

---

## Datasets

In [5]:
class ICDAR2015(Dataset):
    """
    ICDAR2015 Dataset for YOLOv3 training.
    """
    def __init__(self, input_path, label_path, num_classes=1, num_anchors=3, img_size=(224, 224), img_format='.jpg', anchors=anchors):
        self.input_path = Path(input_path)  # Path to images
        self.label_path = Path(label_path)  # Path to labels
        self.num_classes = num_classes      # Number of associable classes
        self.num_anchors = num_anchors      # Number of predictable distinct objects per grid tile
        self.img_size = img_size            # Image size
        self.batch_count = 0                # Batch counter
        self.anchors = anchors
        # Encounter same image multiple times, different augmentations each time
        self.transform = transforms.Compose([
            transforms.Resize(img_size),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.RandomAdjustSharpness(sharpness_factor=3.0, p=0.5),
            transforms.ToTensor()
        ])
        self.files = self._assemble_files(img_format=img_format)
        self.labels = [self._parse_label(label, Image.open(img).size) for img, label in self.files]

    def _flip_image_and_labels(self, img, labels):
        img = TF.hflip(img)
        flipped_labels = labels.clone()
        grid_sizes = [7, 14, 28]
        num_anchors = self.num_anchors
        num_classes = self.num_classes
        start_idx = 0
        for grid_size in grid_sizes:
            grid_area = grid_size * grid_size
            level_size = grid_area * num_anchors * (5 + num_classes)
            level_labels = flipped_labels[start_idx:start_idx+level_size].view(grid_size, grid_size, num_anchors, 5 + num_classes)
            level_labels[:, :, :, 0] = 1 - level_labels[:, :, :, 0] - level_labels[:, :, :, 2]
            level_labels = torch.flip(level_labels, [1])
            flipped_labels[start_idx:start_idx+level_size] = level_labels.view(-1)
            start_idx += level_size
        return img, flipped_labels

    def _assemble_files(self, img_format):
        image_files, data = list(self.input_path.glob(f'*{img_format}')), []
        for img_file in image_files:
            img_id = img_file.stem.split('_')[-1]
            label_file = self.label_path / f"gt_img_{img_id}.txt"
            if label_file.exists():
                data.append((img_file, label_file))
            else:
                print(f"Warning: No matching label file found for {img_file.name}")
        return data

    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img_path, _ = self.files[idx]
        img = Image.open(img_path)
        img = self.transform(img)
        label = self.labels[idx]
        aug_coin = np.random.rand()
        if aug_coin < 0.25:
            return self._flip_image_and_labels(img, label)
        return img, label
    
    def __iter__(self):
        self.index = 0
        return self

    def __next__(self):
        if self.index >= len(self):
            raise StopIteration
        item = self[self.index]
        self.index += 1
        return item

    def _calculate_iou(self, box1, box2):
        # Calculate intersection
        x1 = max(box1[0], box2[0])
        y1 = max(box1[1], box2[1])
        x2 = min(box1[0] + box1[2], box2[0] + box2[2])
        y2 = min(box1[1] + box1[3], box2[1] + box2[3])
        inter_area = max(0, x2 - x1) * max(0, y2 - y1)
        box1_area, box2_area = box1[2] * box1[3], box2[2] * box2[3]
        union_area = box1_area + box2_area - inter_area
        return inter_area / union_area if union_area > 0 else 0

    def _to_grid(self, grid, box):
        grid_h, grid_w = grid.size(0), grid.size(1)  # Switched to height-first indexing
        x, y, w, h, obj, _ = box
        # Grid cell coordinates
        grid_x, grid_y = int(x * grid_w), int(y * grid_h)
        # Convert to relative coordinates in grid
        x, y = x * grid_w - grid_x, y * grid_h - grid_y
        best_iou, best_anchor_idx = 0, -1
        for anchor_idx, (anchor_w, anchor_h) in enumerate(self.anchors):
            anchor_box = torch.tensor([x, y, w / anchor_w, h / anchor_h])
            iou = self._calculate_iou(anchor_box.numpy(), [x, y, w, h])
            if iou > best_iou:
                best_iou = iou
                best_anchor_idx = anchor_idx
        if best_anchor_idx >= 0:
            anchor_slice = slice(best_anchor_idx * (5 + self.num_classes), (best_anchor_idx + 1) * (5 + self.num_classes))
            grid[grid_y, grid_x, anchor_slice][:4] = torch.tensor([x, y, w / self.anchors[best_anchor_idx][0], h / self.anchors[best_anchor_idx][1]])
            grid[grid_y, grid_x, anchor_slice][4] = obj

    def _parse_label(self, label_path, img_size):
        coarse_labels = torch.zeros((7, 7, self.num_anchors * (5 + self.num_classes)))
        medium_labels = torch.zeros((14, 14, self.num_anchors * (5 + self.num_classes)))
        fine_labels   = torch.zeros((28, 28, self.num_anchors * (5 + self.num_classes)))

        with open(label_path, 'r', encoding='utf-8-sig') as file:
            reader = csv.reader(file, delimiter=',')
            for row in reader:
                # Convert bounding box points to tensors
                row = torch.tensor([float(i) for i in row[:8]])  # BBox points as (x1, y1, x2, y2, ...)
                x, y = (row[0::2].sum() / 4, row[1::2].sum() / 4)  # Compute center (x, y)
                w = row[0::2].max() - row[0::2].min()  # BBox width
                h = row[1::2].max() - row[1::2].min()  # BBox height
                
                # Ensure that x, y are within image bounds
                x, y = min(x, img_size[0] - 1e-3), min(y, img_size[1] - 1e-3)
                w, h = min(w, img_size[0] - 1e-3), min(h, img_size[1] - 1e-3)
                
                # Normalize the bounding box coordinates
                x, y = x / img_size[0], y / img_size[1]
                w, h = w / img_size[0], h / img_size[1]
                
                obj, cls = 1.0, 0.0  # Change latter for accommodating multi-class settings
                box = torch.tensor([x, y, w, h, obj, cls])
                
                # Assign box to (coarse, medium, fine) grids
                self._to_grid(coarse_labels, box)
                self._to_grid(medium_labels, box)
                self._to_grid(fine_labels, box)
        
        # Flatten and concatenate the labels from all grid levels
        return torch.cat([coarse_labels.flatten(), medium_labels.flatten(), fine_labels.flatten()], dim=0)

    def get_batch(self, batch_size, randomized=True):
        if randomized:
            indices = np.random.choice(len(self), batch_size, replace=False)
        else:
            indices = np.arange(self.batch_count, self.batch_count + batch_size) % len(self)
            self.batch_count += batch_size
        batch_images = torch.stack([self[i][0] for i in indices], dim=0) # Images
        batch_labels = torch.stack([self[i][1] for i in indices], dim=0) # Labels
        return batch_images, batch_labels

    @staticmethod
    def collate_fn(batch):
        images, labels = zip(*batch)
        images = torch.stack(images, dim=0)
        labels = torch.stack(labels, dim=0)
        return images, labels

---

## Loss

In [6]:
class YoLoss(nn.Module):
    def __init__(self, num_classes=1, num_anchors=3, lambda_coord=7.0, lambda_noobj=1.0,
                 lambda_class=1.0, iou_threshold=0.5, focal_alpha=0.25, focal_gamma=2.0, label_smoothing=0.1, anchors=anchors):
        super(YoLoss, self).__init__()
        self.num_classes = num_classes
        self.num_anchors = num_anchors
        self.lambda_coord = lambda_coord
        self.lambda_noobj = lambda_noobj
        self.lambda_class = lambda_class
        self.iou_threshold = iou_threshold
        self.focal_alpha = focal_alpha
        self.focal_gamma = focal_gamma
        self.label_smoothing = label_smoothing
        self.anchors = torch.tensor(anchors)
        self.eps = 1e-7
    
    def gaussian_objectness(self, x, y, sigma=0.3):
        return torch.exp(-((x ** 2 + y ** 2) / (2 * sigma ** 2)))

    def focal_loss(self, pred, target):
        if torch.allclose(pred.float(), target.float(), atol=self.eps):
            return torch.zeros_like(pred)
        pred_prob = torch.clamp(torch.sigmoid(pred), self.eps, 1 - self.eps)
        p_t = target * pred_prob + (1 - target) * (1 - pred_prob)
        p_t = torch.clamp(p_t, self.eps, 1 - self.eps)
        alpha_factor = target * self.focal_alpha + (1 - target) * (1 - self.focal_alpha)
        modulating_factor = (1.0 - p_t).pow(self.focal_gamma)
        loss = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
        return alpha_factor * modulating_factor * loss

    def bbox_iou(self, box1, box2, DIoU=False, CIoU=False):
        b1_x1, b1_x2 = box1[..., 0] - box1[..., 2] / 2.0, box1[..., 0] + box1[..., 2] / 2.0
        b1_y1, b1_y2 = box1[..., 1] - box1[..., 3] / 2.0, box1[..., 1] + box1[..., 3] / 2.0
        b2_x1, b2_x2 = box2[..., 0] - box2[..., 2] / 2.0, box2[..., 0] + box2[..., 2] / 2.0
        b2_y1, b2_y2 = box2[..., 1] - box2[..., 3] / 2.0, box2[..., 1] + box2[..., 3] / 2.0
        if torch.allclose(box1.float(), box2.float(), atol=self.eps):
            return torch.ones_like(box1[..., 0])
        inter = torch.clamp((torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)), min=0) * \
                torch.clamp((torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)), min=0)
        inter = torch.clamp(inter, min=0)
        w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1
        w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1
        union = (w1 * h1 + w2 * h2 - inter) + self.eps
        iou = inter / torch.clamp(union, min=self.eps)
        if CIoU or DIoU:
            cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)
            ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)
            c2 = (cw ** 2 + ch ** 2) + self.eps
            rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4
            if DIoU:
                return iou - rho2 / c2
            elif CIoU:
                v = (4 / (math.pi ** 2)) * torch.pow(torch.atan(w2 / (h2 + self.eps)) - torch.atan(w1 / (h1 + self.eps)), 2)
                alpha = v / (1 - iou + v + self.eps)
                return iou - (rho2 / c2 + v * alpha)
        return iou

    def get_box_loss(self, predictions, targets):
        box_loss = 0
        for pi, ti in zip(predictions, targets):
            mask = ti[..., 4] > 0
            p_boxes, t_boxes = pi[mask][..., :4], ti[mask][..., :4]
            if p_boxes.numel() > 0:
                iou = self.bbox_iou(p_boxes, t_boxes, CIoU=True)
                box_loss += torch.mean(1.0 - iou)
                box_loss += F.l1_loss(p_boxes, t_boxes, reduction='mean')
        return box_loss

    def get_obj_loss(self, predictions, targets):
        obj_loss = 0.0
        pos_weight = 8.0
        neg_weight = 0.5
        smooth_factor = 0.1
        for pi, ti in zip(predictions, targets):
            pred_obj = pi[..., 4]
            target_obj = (1 - smooth_factor) * ti[..., 4] + smooth_factor * 0.5
            # Soft objectness targets based on predicted IoU
            pred_box = pi[..., :4]
            target_box = ti[..., :4]
            ious = self.bbox_iou(pred_box, target_box)
            soft_target_obj = torch.where(ious > self.iou_threshold, torch.ones_like(ious), target_obj)
            # Calculate binary cross-entropy with Focal Loss
            obj_loss_per_anchor = self.focal_loss(pred_obj, soft_target_obj)
            # Weight positive and negative samples differently
            pos_mask = (soft_target_obj > 0.5).float()
            neg_mask = (soft_target_obj <= 0.5).float()
            weighted_loss = pos_weight * pos_mask * obj_loss_per_anchor + neg_weight * neg_mask * obj_loss_per_anchor
            obj_loss += torch.mean(weighted_loss)
        return obj_loss

    def get_cls_loss(self, predictions, targets):
        cls_loss = 0
        if self.num_classes > 1:
            for pi, ti in zip(predictions, targets):
                pred_cls = pi[..., 5:]
                target_cls = ti[..., 5:]
                if not torch.allclose(pred_cls.float(), target_cls.float(), atol=self.eps):
                    target_cls = (1 - self.label_smoothing) * target_cls + self.label_smoothing / self.num_classes
                cls_loss += torch.mean(self.focal_loss(pred_cls, target_cls))
        return cls_loss

    def forward(self, predictions, targets):
        b_size = targets.size(0)
        coarse_size = 7 * 7 * self.num_anchors * (5 + self.num_classes)
        medium_size = 14 * 14 * self.num_anchors * (5 + self.num_classes)
        fine_size = 28 * 28 * self.num_anchors * (5 + self.num_classes)

        flat_coarse, flat_medium, flat_fine = torch.split(targets, [coarse_size, medium_size, fine_size], dim=1)
        t_coarse = flat_coarse.view(b_size, 7, 7, self.num_anchors, (5 + self.num_classes))
        t_medium = flat_medium.view(b_size, 14, 14, self.num_anchors, (5 + self.num_classes))
        t_fine = flat_fine.view(b_size, 28, 28, self.num_anchors, (5 + self.num_classes))

        targets_split = [t_coarse, t_medium, t_fine]
        
        box_loss = self.get_box_loss(predictions, targets_split)
        obj_loss = self.get_obj_loss(predictions, targets_split)
        cls_loss = self.get_cls_loss(predictions, targets_split)
        
        total_loss = self.lambda_coord * box_loss + self.lambda_noobj * obj_loss + self.lambda_class * cls_loss

        if torch.isnan(total_loss):
            print(f'total_loss={total_loss} box_loss={box_loss} obj_loss={obj_loss} cls_loss={cls_loss}')

        return total_loss

---

### Data + Loss Sanity Check

In [7]:
sanity_dataset = ICDAR2015(train_path, train_labels, num_classes)
sanity_criterion = YoLoss()

In [None]:
# Assuming train_path, train_labels, and num_classes are already defined
img, label = sanity_dataset[random.randint(0, len(sanity_dataset) - 1)]
_, labem = sanity_dataset[1]
imgb, labelb = sanity_dataset.get_batch(2)

print('Image:', img.shape, '\t\tLabel:', label.shape)
print('Image Batch:', imgb.shape, '\tLabel:', labelb.shape)

def get_loss(label_a, label_b, title='', grid_sizes=[(7, 7), (14, 14), (28, 28)]):
    coarse_size = grid_sizes[0][0] * grid_sizes[0][1] * num_anchors * (5 + num_classes)
    medium_size = grid_sizes[1][0] * grid_sizes[1][1] * num_anchors * (5 + num_classes)
    fine_size = grid_sizes[2][0] * grid_sizes[2][1] * num_anchors * (5 + num_classes)
    coarse_flat, medium_flat, fine_flat = torch.split(label_a, [coarse_size, medium_size, fine_size], dim=0)
    coarse = coarse_flat.view(grid_sizes[0][0], grid_sizes[0][1], num_anchors, (5 + num_classes))
    medium = medium_flat.view(grid_sizes[1][0], grid_sizes[1][1], num_anchors, (5 + num_classes))
    fine = fine_flat.view(grid_sizes[2][0], grid_sizes[2][1], num_anchors, (5 + num_classes))
    predictions = [coarse.unsqueeze(0), medium.unsqueeze(0), fine.unsqueeze(0)]
    print(title, sanity_criterion(predictions, label_b.unsqueeze(0)), '\n', '-' * 50)

def calculate_iou(box1, box2):
    # Calculate intersection
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[0] + box1[2], box2[0] + box2[2])
    y2 = min(box1[1] + box1[3], box2[1] + box2[3])
    inter_area = max(0, x2 - x1) * max(0, y2 - y1)
    box1_area, box2_area = box1[2] * box1[3], box2[2] * box2[3]
    union_area = box1_area + box2_area - inter_area
    return inter_area / union_area if union_area > 0 else 0

def show_image_with_bboxes(img, label, num_anchors=3, num_classes=1, grid_sizes=[(7, 7), (14, 14), (28, 28)], anchors=anchors):
    img_np = TF.to_pil_image(img)
    _, ax = plt.subplots(1)
    ax.imshow(img_np)
    ax.set_xticks([])
    ax.set_yticks([])
    coarse_size = grid_sizes[0][0] * grid_sizes[0][1] * num_anchors * (5 + num_classes)
    medium_size = grid_sizes[1][0] * grid_sizes[1][1] * num_anchors * (5 + num_classes)
    fine_size = grid_sizes[2][0] * grid_sizes[2][1] * num_anchors * (5 + num_classes)
    coarse_flat, medium_flat, fine_flat = torch.split(label, [coarse_size, medium_size, fine_size], dim=0)
    coarse = coarse_flat.view(grid_sizes[0][0], grid_sizes[0][1], num_anchors, (5 + num_classes))
    medium = medium_flat.view(grid_sizes[1][0], grid_sizes[1][1], num_anchors, (5 + num_classes))
    fine = fine_flat.view(grid_sizes[2][0], grid_sizes[2][1], num_anchors, (5 + num_classes))
    def draw_boxes(grid, grid_w, grid_h, anchors):
        for y in range(grid_h):
            for x in range(grid_w):
                best_iou, best_anchor_idx = 0, -1
                for a in range(num_anchors):
                    box = grid[y, x, a]
                    if box[4] > 0:
                        anchor_w, anchor_h = anchors[a]
                        w = box[2].item() * anchor_w * img_np.width  # Width
                        h = box[3].item() * anchor_h * img_np.height # Height
                        cx = (box[0].item() + x) / grid_w * img_np.width   # Center x
                        cy = (box[1].item() + y) / grid_h * img_np.height  # Center y
                        # Create a normalized bounding box for IoU calculation
                        normalized_box = torch.tensor([cx - w / 2, cy - h / 2, w, h])
                        # Calculate IoU with the current anchor
                        iou = calculate_iou(normalized_box.numpy(), [cx - w / 2, cy - h / 2, w, h])
                        # Update best IoU and corresponding anchor index
                        if iou > best_iou:
                            best_iou = iou
                            best_anchor_idx = a
                if best_anchor_idx != -1:
                    # Draw the bounding box with the best anchor
                    best_box = grid[y, x, best_anchor_idx]
                    anchor_w, anchor_h = anchors[best_anchor_idx]
                    w_best = best_box[2].item() * anchor_w * img_np.width  # Width using best anchor
                    h_best = best_box[3].item() * anchor_h * img_np.height # Height using best anchor
                    cx_best = (best_box[0].item() + x) / grid_w * img_np.width   # Center x using best anchor
                    cy_best = (best_box[1].item() + y) / grid_h * img_np.height  # Center y using best anchor
                    rect = patches.Rectangle((cx_best - w_best / 2, cy_best - h_best / 2), w_best, h_best,
                                             linewidth=1, edgecolor='g', facecolor='none')
                    ax.add_patch(rect)

    draw_boxes(coarse, grid_sizes[0][0], grid_sizes[0][1], anchors)
    draw_boxes(medium, grid_sizes[1][0], grid_sizes[1][1], anchors)
    draw_boxes(fine, grid_sizes[2][0], grid_sizes[2][1], anchors)
    plt.show()

show_image_with_bboxes(img, label)
get_loss(label, label, 'Self Loss:')
get_loss(label, labem, 'Random Label Comparison Loss:')

---

## Training

In [9]:
sanity_dataset = ICDAR2015(train_path, train_labels, num_classes)
train_loader  = DataLoader(sanity_dataset, batch_size=batch_size, shuffle=True, 
                           num_workers=num_workers, collate_fn=ICDAR2015.collate_fn,
                           pin_memory=True)

val_dataset = ICDAR2015(test_path, test_labels, num_classes)
val_loader  = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, 
                         num_workers=num_workers, collate_fn=ICDAR2015.collate_fn,
                         pin_memory=True)

<img src="https://pytorch.org/tutorials/_images/pinmem.png" alt="why pin_memory" width="350" height="auto">

In [10]:
def evaluate(model, criterion, data_loader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for images, targets in data_loader:
            images = images.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            with torch.amp.autocast(device_type=str(device)):
                outputs = model(images)
                loss = criterion(outputs, targets)
            total_loss += loss.item()
    return total_loss / len(data_loader)

def adaptive_gradient_clipping(model, clip_factor=0.01, eps=1e-3):
    for param in model.parameters():
        param_norm = torch.norm(param.grad)
        clip_value = clip_factor * (torch.norm(param) + eps)
        param.grad = param.grad * (clip_value / (param_norm + eps))

In [11]:
model = MobileYOLOv3(num_classes=num_classes, dropout_rate=dropout_rate, anchors=torch.tensor(anchors, dtype=torch.float32, device=device)).to(device)
criterion = YoLoss()

# Adam and decoupling weight decay from gradient update
base_optimizer = torch.optim.AdamW([
    {'params': model.backbone.parameters(), 'lr': learning_rate * 0.8, 'weight_decay': weight_decay * 0.5},
    {'params': model.conv_7.parameters()},
    {'params': model.eca_7.parameters()},
    {'params': model.det1.parameters()},
    {'params': model.r_1024_128.parameters()},
    {'params': model.r_48_128.parameters()},
    {'params': model.conv_14.parameters()},
    {'params': model.eca_14.parameters()},
    {'params': model.det2.parameters()},
    {'params': model.r_512_64.parameters()},
    {'params': model.r_24_64.parameters()},
    {'params': model.conv_28.parameters()},
    {'params': model.eca_28.parameters()},
    {'params': model.det3.parameters()},
], lr=learning_rate, weight_decay=weight_decay)

# Periodically look ahead, update weights by averaging weight updates at every k steps
optimizer = optim.Lookahead(base_optimizer, k=optim_k, alpha=optim_alpha)

# Gradually cool down LR over time
scheduler = OneCycleLR(optimizer, max_lr=learning_rate*2, epochs=num_epochs, steps_per_epoch=len(train_loader),
                       pct_start=0.3, anneal_strategy='cos', div_factor=10.0, final_div_factor=10000.0)

# Avoids numerical underflow/overflow through scaling, helps maintain information in mixed-precision
scaler = torch.amp.GradScaler(enabled=(str(device) != 'cpu'))

In [None]:
# Expect this to take ~10 minutes on a standard 3060
lossi, losst = [], []
lowsi = float('inf')

# The fine selection of hyperparameters
print(f"{batch_size} | {learning_rate} | {weight_decay} | {num_epochs} | {lr_warmup} | {optim_k} | {optim_alpha} | {warmup_start} | {scheduler_t0} | {scheduler_tmult} | {prune_amount} | {target_architecture} | {num_anchors} | {dropout_rate}")

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    
    for files, targets in train_loader:
        files = files.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        optimizer.zero_grad()
        
        with torch.amp.autocast(device_type=str(device)):
            logits = model(files)
            loss = criterion(logits, targets)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()
        scheduler.step()
        epoch_loss += loss.item()
        
    epoch_loss /= len(train_loader)
    lossi.append(epoch_loss)

    t_loss = evaluate(model, criterion, val_loader, device)
    losst.append(t_loss)
    
    gc.collect()
    torch.cuda.empty_cache()

    # Print training and test loss
    print(f'Epoch [{epoch+1:3}/{num_epochs}] | Train: {epoch_loss:8.6f} | Test: {t_loss:8.6f} | LR: {optimizer.param_groups[-1]["lr"]:.6f}')

In [13]:
# Save unaltered model
torch.save(model.state_dict(), f'solid_{model_path}')

# Prune, Quantize
pruned_model = prune_model(model, amount=prune_amount)
lifted_model = lift_pruning(pruned_model)
quantized_model = quantize_model(lifted_model, device)

# Save the quantized model
torch.save(quantized_model.state_dict(), model_path)

---

## Loss Graph

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(range(num_epochs), lossi, label="Training Loss", color='blue', marker='o', linestyle='-', markersize=3)
plt.plot(range(num_epochs), losst, label="Test Loss", color="red", marker='o', linestyle='-', markersize=3)

plt.title('Loss Curves', fontsize=16)
plt.xlabel('Epochs', fontsize=14)
plt.ylabel('Loss', fontsize=14)
plt.grid(True)
plt.xticks(range(0, num_epochs, 4))
plt.legend(loc='upper right')
plt.show();

---

## Evaluate

In [15]:
def load_model(model_class, num_classes, model_path, device='cpu'):
    """
    Load a PyTorch model for inference on the target device, regardless of where it was originally trained.
    """
    if isinstance(device, str):
        device = torch.device(device)

    # Load to CPU first
    state_dict = torch.load(model_path, map_location=device, weights_only=False)
    
    if isinstance(state_dict, dict) and 'model_state_dict' in state_dict:
        state_dict = state_dict['model_state_dict']

    # Remove 'module.' prefix caused by SWA
    state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}

    # Remove pruning-related keys
    new_state_dict = {}
    for key, value in state_dict.items():
        if 'weight_mask' not in key:
            new_key = key.replace('weight_orig', 'weight')
            new_state_dict[new_key] = value
        
    model = model_class(num_classes)
    model.load_state_dict(new_state_dict, strict=False)
    model = model.to(device)
    model.eval()
    return model

In [None]:
model = load_model(MobileYOLOv3, num_classes, model_path, device='cuda')
test_dataset = ICDAR2015(test_path, test_labels)

In [22]:
def display_with_boxes(img, outputs):
    # img (3, 224, 224), 
    # labels (43218) = flat((7, 7, num_anchors, (5 + num_classes)), (14, 14, num_anchors, (5 + num_classes)), (28, 28, num_anchors, (5 + num_classes)))
    _, ax = plt.subplots(1)

    img = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
    ax.imshow(img)
    outputs = [outputs[0].squeeze(0).cpu(), outputs[1].squeeze(0).cpu(), outputs[2].squeeze(0).cpu()]
    img_x, img_y = img.shape[0], img.shape[1]

    t_coarse = outputs[0]
    t_medium = outputs[1]
    t_fine = outputs[2]

    for target in [t_coarse, t_medium, t_fine]:
        for i in range(target.shape[0]):  # Iterate over rows
            for j in range(target.shape[1]):  # Iterate over columns
                for k in range(num_anchors):  # Iterate over anchors
                    box = target[i, j, k, :4] # (x, y, w, h)
                    obj = target[i, j, k, 4]
                    if box.sum() > 0 and obj >= 0.2:
                        x, y, w, h = box
                        x, y, w, h = x * img_x, y * img_y, w * img_x, h * img_y
                        rect = patches.Rectangle((x - w / 2, y - h / 2), w, h, linewidth=1, edgecolor='g', facecolor='none')
                        ax.add_patch(rect)
    
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    plt.show()

def visualize_inference(model, dataset, num_images=5):
    for _ in range(num_images):
        img, _ = dataset[random.randint(0, len(dataset))]
        with torch.no_grad():
            img = img.unsqueeze(0).to(device)
            outputs = model(img)
            display_with_boxes(img, outputs)

In [None]:
visualize_inference(model, test_dataset)