In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import shapely.geometry as geometry
import matplotlib.pyplot as plt

import matplotlib.animation as animation
from IPython.display import HTML

In [2]:
class BoxPredictor(nn.Module):
    """Predicts x, y, w, h for box
    """
    
    def __init__(self):
        super().__init__()
        
        self.fc = nn.Sequential(
            nn.Linear(4, 8),
            nn.PReLU(),
            nn.Linear(8, 4),
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc(x)
    
    def create_box(self, x: float, y: float, w: float, h: float) -> geometry.Polygon:
        """Create box given x, y, w, h

        Args:
            x (float): centroid x of box
            y (float): centroid y of box
            w (float): box width
            h (float): box height

        Returns:
            geometry.Polygon: box
        """
        
        x_min = x - w / 2
        y_min = y - h / 2
        x_max = x + w / 2
        y_max = y + h / 2
        
        return geometry.box(x_min, y_min, x_max, y_max)
    
    
class DIoU(nn.Module):
    """Computes Distance-Intersection over Union
    """
    
    def __init__(self):
        super().__init__()
    
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """Compute DIoU loss

        Args:
            pred (torch.Tensor): box predicted my model
            target (torch.Tensor): target box

        Returns:
            torch.Tensor: DIoU
        """
        
        px, py, pw, ph = pred
        tx, ty, tw, th = target
        
        x1 = torch.max(px - pw / 2, tx - tw / 2)
        y1 = torch.max(py - ph / 2, ty - th / 2)
        x2 = torch.min(px + pw / 2, tx + tw / 2)
        y2 = torch.min(py + ph / 2, ty + th / 2)
        
        intersection = max(0, x2 - x1) * max(0, y2 - y1)
        union = pw * ph + tw * th - intersection
        
        iou = intersection / union
        
        distance = torch.linalg.norm(pred[:2] - target[:2])
        diagonal = torch.linalg.norm(
            torch.tensor([min(px - pw / 2, tx - tw / 2), min(py - ph / 2, ty - th / 2)]) 
            - torch.tensor([max(px + pw / 2, tx + tw / 2), max(py + ph / 2, ty + th / 2)])
        )
        
        diou = 1 - iou + (distance ** 2) / (diagonal ** 2)
        
        return diou, iou, distance, diagonal

In [3]:
torch.manual_seed(777)

# Train

predictor = BoxPredictor()
loss_function = DIoU()
optimizer = optim.Adam(predictor.parameters())
epochs = 150

label_box = torch.tensor([1, 0, 3, 5]).float()
label_box_geometry = predictor.create_box(*label_box.tolist())

input_box = torch.tensor([8, -7.3, 1, 1]).float()
input_box_geometry = predictor.create_box(*input_box.tolist())

predictions = []
losses = []
ious = []
distances = []
diagonals = []

for epoch in range(epochs):
    pred = predictor(input_box)
    
    diou, iou, distance, diagonal = loss_function(pred, label_box)
    
    optimizer.zero_grad()
    diou.backward()
    optimizer.step()
    
    predictions.append(pred.detach().clone())
    losses.append(diou.item())
    ious.append(iou.item())
    distances.append(distance.item())
    diagonals.append(diagonal.item())

In [4]:
# Visualize

fig, ax = plt.subplots()

def animate(frame):
    pred = predictions[frame]
    pred_geometry = predictor.create_box(*pred.tolist())
    
    distance = geometry.LineString([pred_geometry.centroid, label_box_geometry.centroid])
    
    ax.clear()
    ax.set_xlim(-10, 10)
    ax.set_ylim(-10, 10)
    ax.grid(True, alpha=0.3)
    ax.plot(*input_box_geometry.exterior.xy, color="blue", label="input", linewidth=1)
    ax.plot(*label_box_geometry.exterior.xy, color="green", label="ground-truth", linewidth=1)
    ax.plot(*pred_geometry.exterior.xy, color="black", label="predicted", linewidth=1)
    ax.plot(*distance.xy, color="red", label="distance", linewidth=1, linestyle="dotted")
    ax.set_aspect("equal")
    ax.legend(loc="upper left")
    ax.set_title(f"Epoch: {frame + 1}, DIoU: {losses[frame]:.5f}, IoU: {ious[frame]:.5f}, d: {distances[frame]:.5f} \n", fontsize=9)
    
    return ()

anim = animation.FuncAnimation(fig, animate, frames=epochs, interval=50, blit=True, repeat=False)
plt.close(fig)
HTML(anim.to_jshtml())