# YOLOv1 implementation in PyTorch

## Oshri Fatkiev

In [None]:
import torch
import os
import glob
import cv2

from tqdm import tqdm
from datetime import datetime
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter

from torchvision.ops import nms, box_iou, distance_box_iou
import torch.nn as nn
from torch.nn import functional as F

import matplotlib.pyplot as plt
import numpy as np

### Hyperparameters

In [None]:
BATCH_SIZE = 64
EPOCHS = 100
WARMUP_EPOCHS = 0

LEARNING_RATE = 1e-4
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4

EPSILON = 1e-6
IMAGE_SIZE = (448,448) 

FLIPUD = 0.5
FLIPLR = 0.5

S = 7                        
B = 2                        
C = 1                        

In [None]:
class Albumentations:
    def __init__(self):
        self.transform = None
        T = [
            A.RandomRotate90(p=1),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
        ]
        self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))

    def __call__(self, img, labels, p=1.0):
        if self.transform and np.random.random() < p:
            new = self.transform(image=img, bboxes=labels[:, 1:], class_labels=labels[:, 0]) 
            img, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
        return img, labels
    

### Dataset

In [None]:
class YoloDataset(Dataset):
    def __init__(self, images_path, labels_path, augment=False):
        labels = sorted([f for f in glob.glob(f"{labels_path}/*.txt")])
        
        self.images, self.labels = [], []
        for label in labels:
            im = label.replace('labels', 'images').replace('txt', 'npy')
            if os.path.exists(im):
                self.labels.append(label)
                self.images.append(im)

        self.augment = augment
        # self.albumentations = Albumentations() if self.augment else None
        self.classes = {0 : 0}

    def __getitem__(self, i):
        # print(self.images[i])
        img = np.load(self.images[i])
        if img.shape != IMAGE_SIZE:
            img = cv2.resize(img, dsize=IMAGE_SIZE, interpolation=cv2.INTER_CUBIC) # IMAGE_SIZE
         
        labels = np.loadtxt(self.images[i].replace('images', 'labels').replace('.npy', '.txt'))
        labels = torch.tensor(labels)
        if labels.dim() == 1:
            labels = labels.unsqueeze(0)
        
        original_img = img
        if self.augment:
            # data, labels = self.albumentations(data, labels)
            nl = len(labels) 
            
            n_rots = np.random.randint(low=1, high=5)
            for i in range(n_rots):
                img = np.rot90(img)         
                if nl:
                    labels_rot = torch.zeros_like(labels)
                    labels_rot[..., 1] = labels[..., 2]          # Rotate x
                    labels_rot[..., 2] = 1 - labels[..., 1]      # Rotate y
                    labels_rot[..., 3] = labels[..., 4]          # Rotate width
                    labels_rot[..., 4] = labels[..., 3]          # Rotate height
                    
                    labels = labels_rot
                
            # Flip up-down
            if np.random.random() < FLIPUD:
                img = np.flipud(img)
                if nl:
                    labels[:, 2] = 1 - labels[:, 2]  # y -> -y

            # Flip left-right
            if np.random.random() < FLIPLR:
                img = np.fliplr(img)
                if nl:
                    labels[:, 1] = 1 - labels[:, 1]  # x -> -x
        
        img = torch.from_numpy(np.ascontiguousarray(img))
        img = img.permute(2, 0, 1)
        original_img = torch.from_numpy(np.ascontiguousarray(original_img))
        original_img = original_img.permute(2, 0, 1)
        
        grid_size_x = img.size(dim=2) / S  # Images in PyTorch have size (channels, height, width)
        grid_size_y = img.size(dim=1) / S

        # Process bounding boxes into the SxSx(5*B+C) ground truth tensor
        boxes = {}
        class_names = {}                    # Track what class each grid cell has been assigned to
        depth = 5 * B + C                   # 5 numbers per bbox, then one-hot encoding of label
        ground_truth = torch.zeros((S, S, depth))
        
        for name, x, y, width, height in labels: 
            assert name.item() in self.classes, f"Unrecognized class '{name.item()}'"
            class_index = self.classes[name.item()]

            # Calculate the position of the center of the bounding box
            mid_x = x * IMAGE_SIZE[0]
            mid_y = y * IMAGE_SIZE[1]
            col = int(mid_x // grid_size_x)
            row = int(mid_y // grid_size_y)

            if 0 <= col < S and 0 <= row < S:
                cell = (row, col)
                if cell not in class_names or name == class_names[cell]:

                    # Insert class one-hot encoding into ground truth
                    one_hot = torch.zeros(C)      
                    one_hot[class_index] = 1.0
                    ground_truth[row, col, :C] = one_hot
                    class_names[cell] = name

                    # Insert bounding box into ground truth tensor
                    bbox_index = boxes.get(cell, 0)
                    if bbox_index < B:
                        bbox_truth = (
                            (mid_x - col * grid_size_x) / IMAGE_SIZE[0],            # x coord relative to grid square
                            (mid_y - row * grid_size_y) / IMAGE_SIZE[1],            # y coord relative to grid square
                            width,                                                  # Width
                            height,                                                 # Height
                            1.0                                                     # Confidence
                        )

                        # Fill all bbox slots with the current bbox (starting from the current bbox slot, avoiding overriding prev)
                        # This prevents having "dead" boxes (zeros) at the end, which messes up IOU loss calculations
                        bbox_start = 5 * bbox_index + C
                        ground_truth[row, col, bbox_start:] = torch.tensor(bbox_truth).repeat(B - bbox_index)
                        boxes[cell] = bbox_index + 1

        return img, ground_truth, original_img

    def __len__(self):
        return len(self.images)
    

### Loss

In [None]:
def bbox_attr(data, i):
    """Returns the Ith attribute of each bounding box in data."""

    attr_start = C + i
    return data[..., attr_start::5]

def xywh2xyxy(t):
    """Changes format of bounding boxes from [x, y, width, height] to ([x1, y1], [x2, y2])."""

    width = bbox_attr(t, 2)
    x = bbox_attr(t, 0)
    x1 = x - width / 2.0
    x2 = x + width / 2.0

    height = bbox_attr(t, 3)
    y = bbox_attr(t, 1)
    y1 = y - height / 2.0
    y2 = y + height / 2.0

    return torch.stack((x1, y1), dim=4), torch.stack((x2, y2), dim=4)

def get_iou(p, a):
    p_tl, p_br = xywh2xyxy(p)          # (batch, S, S, B, 2)
    a_tl, a_br = xywh2xyxy(a)

    # Largest top-left corner and smallest bottom-right corner give the intersection
    coords_join_size = (-1, -1, -1, B, B, 2)
    tl = torch.max(
        p_tl.unsqueeze(4).expand(coords_join_size),         # (batch, S, S, B, 1, 2) -> (batch, S, S, B, B, 2)
        a_tl.unsqueeze(3).expand(coords_join_size)          # (batch, S, S, 1, B, 2) -> (batch, S, S, B, B, 2)
    )
    br = torch.min(
        p_br.unsqueeze(4).expand(coords_join_size),
        a_br.unsqueeze(3).expand(coords_join_size)
    )

    intersection_sides = torch.clamp(br - tl, min=0.0)
    intersection = intersection_sides[..., 0] \
                   * intersection_sides[..., 1]       # (batch, S, S, B, B)

    p_area = bbox_attr(p, 2) * bbox_attr(p, 3)                  # (batch, S, S, B)
    p_area = p_area.unsqueeze(4).expand_as(intersection)        # (batch, S, S, B, 1) -> (batch, S, S, B, B)

    a_area = bbox_attr(a, 2) * bbox_attr(a, 3)                  # (batch, S, S, B)
    a_area = a_area.unsqueeze(3).expand_as(intersection)        # (batch, S, S, 1, B) -> (batch, S, S, B, B)

    union = p_area + a_area - intersection

    # Catch division-by-zero
    zero_unions = (union == 0.0)
    union[zero_unions] = EPSILON
    intersection[zero_unions] = 0.0

    return intersection / union

In [None]:
class YoloLoss(nn.Module): 
    def __init__(self):
        super().__init__()
        self.l_coord = 5
        self.l_noobj = 0.5

    def forward(self, p, a):
        # Calculate IOU of each predicted bbox against the ground truth bbox
        iou = get_iou(p, a)                     # (batch, S, S, B, B)
        max_iou = torch.max(iou, dim=-1)[0]     # (batch, S, S, B)

        # Get masks
        bbox_mask = bbox_attr(a, 4) > 0.0
        p_template = bbox_attr(p, 4) > 0.0
        obj_i = bbox_mask[..., 0:1]         # 1 if grid I has any object at all
        responsible = torch.zeros_like(p_template).scatter_(       # (batch, S, S, B)
            -1,
            torch.argmax(max_iou, dim=-1, keepdim=True),                # (batch, S, S, B)
            value=1                         # 1 if bounding box is "responsible" for predicting the object
        )
        obj_ij = obj_i * responsible        # 1 if object exists AND bbox is responsible
        noobj_ij = ~obj_ij                  # Otherwise, confidence should be 0

        # XY position losses
        x_losses = mse_loss(
            obj_ij * bbox_attr(p, 0),
            obj_ij * bbox_attr(a, 0)
        )
        y_losses = mse_loss(
            obj_ij * bbox_attr(p, 1),
            obj_ij * bbox_attr(a, 1)
        )
        pos_losses = x_losses + y_losses
        # print('pos_losses', pos_losses.item())

        # Bbox dimension losses
        p_width = bbox_attr(p, 2)
        a_width = bbox_attr(a, 2)
        width_losses = mse_loss(
            obj_ij * torch.sign(p_width) * torch.sqrt(torch.abs(p_width) + EPSILON),
            obj_ij * torch.sqrt(a_width)
        )
        p_height = bbox_attr(p, 3)
        a_height = bbox_attr(a, 3)
        height_losses = mse_loss(
            obj_ij * torch.sign(p_height) * torch.sqrt(torch.abs(p_height) + EPSILON),
            obj_ij * torch.sqrt(a_height)
        )
        dim_losses = width_losses + height_losses
        # print('dim_losses', dim_losses.item())

        # Confidence losses (target confidence is IOU)
        obj_confidence_losses = mse_loss(
            obj_ij * bbox_attr(p, 4),
            obj_ij * torch.ones_like(max_iou)
        )
        # print('obj_confidence_losses', obj_confidence_losses.item())
        noobj_confidence_losses = mse_loss(
            noobj_ij * bbox_attr(p, 4),
            torch.zeros_like(max_iou)
        )
        # print('noobj_confidence_losses', noobj_confidence_losses.item())

        # Classification losses
        class_losses = mse_loss(
            obj_i * p[..., :C],
            obj_i * a[..., :C]
        )
        # print('class_losses', class_losses.item())

        total = self.l_coord * (pos_losses + dim_losses) \
                + obj_confidence_losses \
                + self.l_noobj * noobj_confidence_losses \
                + class_losses
        
        return total / BATCH_SIZE


def mse_loss(a, b):
    flattened_a = torch.flatten(a, end_dim=-2)
    flattened_b = torch.flatten(b, end_dim=-2).expand_as(flattened_a)
    return F.mse_loss(
        flattened_a,
        flattened_b,
        reduction='sum'
    )

### Model

In [None]:
class YOLOv1(nn.Module):
    def __init__(self, n_channels=6):
        super().__init__()
        self.depth = B * 5 + C

        layers = [
            nn.Conv2d(n_channels, 64, kernel_size=7, stride=2, padding=3),          
            nn.LeakyReLU(negative_slope=0.1),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 192, kernel_size=3, padding=1),                           
            nn.LeakyReLU(negative_slope=0.1),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(192, 128, kernel_size=1),                                     
            nn.LeakyReLU(negative_slope=0.1),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Conv2d(256, 256, kernel_size=1),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.LeakyReLU(negative_slope=0.1),
            nn.MaxPool2d(kernel_size=2, stride=2)
        ]

        for _ in range(4):                                                          
            layers += [
                nn.Conv2d(512, 256, kernel_size=1),
                nn.Conv2d(256, 512, kernel_size=3, padding=1),
                nn.LeakyReLU(negative_slope=0.1)
            ]
            
        layers += [
            nn.Conv2d(512, 512, kernel_size=1),
            nn.Conv2d(512, 1024, kernel_size=3, padding=1),
            nn.LeakyReLU(negative_slope=0.1),
            nn.MaxPool2d(kernel_size=2, stride=2)
        ]

        for _ in range(2):                                                          
            layers += [
                nn.Conv2d(1024, 512, kernel_size=1),
                nn.Conv2d(512, 1024, kernel_size=3, padding=1),
                nn.LeakyReLU(negative_slope=0.1)
            ]
            
        layers += [
            nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(negative_slope=0.1),
        ]

        for _ in range(2):                                                          
            layers += [
                nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
                nn.LeakyReLU(negative_slope=0.1)
            ]

        layers += [
            nn.Flatten(),
            nn.Linear(S * S * 1024, 1024),                                           
            nn.Dropout(p=0.5),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Linear(1024, S * S * self.depth),                                    
        ]

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = torch.Tensor(x).to(torch.float)    # x.shape is [batch_size, n_channels, height, width]
        res = self.model.forward(x)
        return torch.reshape(res, (x.size(dim=0), S, S, self.depth))

In [None]:
model = YOLOv1(n_channels=6)
x = torch.randn((32, 6, 448, 448))
res = model(x)
res.size()

### Train

In [None]:
torch.cuda.empty_cache()

In [None]:
# %time

np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.autograd.set_detect_anomaly(True)         # Check for nan loss
writer = SummaryWriter()
now = datetime.now()

model = YOLOv1Lite8(n_channels=6).to(device)
loss_function = YoloLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Learning rate scheduler 
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=utils.scheduler_lambda)

folder = 'mocks_6d_crop_norm_64_0to255_VELA6_toy_clumps' #_v2 # 'vela3_6d_crop_64_0to255' # /128

# Load the dataset 
images_train = f'/sci/labs/dekel/oshri.fatkiev/{folder}/images/train'
labels_train = f'/sci/labs/dekel/oshri.fatkiev/{folder}/labels/train'
train_set = YoloDataset(images_path=images_train, labels_path=labels_train, augment=True)

images_test = f'/sci/labs/dekel/oshri.fatkiev/{folder}/images/val'
labels_test = f'/sci/labs/dekel/oshri.fatkiev/{folder}/labels/val'
val_set = YoloDataset(images_path=images_test, labels_path=labels_test, augment=True)

train_loader = DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    num_workers=4,
    persistent_workers=True,
    drop_last=True,
    shuffle=True
)

val_loader = DataLoader(
    val_set,
    batch_size=BATCH_SIZE,
    num_workers=4,
    persistent_workers=True,
    drop_last=True
)

# Create folders
root = os.path.join('models', 'yolo_v1', now.strftime('%m_%d_%Y'), now.strftime('%H_%M_%S'))
weight_dir = os.path.join(root, 'weights')
if not os.path.isdir(weight_dir): os.makedirs(weight_dir)

# Metrics
train_losses, train_errors = np.empty((2, 0)), np.empty((2, 0))
val_losses, val_errors = np.empty((2, 0)), np.empty((2, 0))

def save_metrics():
    np.save(os.path.join(root, 'train_losses'), train_losses)
    np.save(os.path.join(root, 'val_losses'), val_losses)
    np.save(os.path.join(root, 'train_errors'), train_errors)
    np.save(os.path.join(root, 'val_errors'), val_errors)

    
l_trn, l_val = [], []
for epoch in tqdm(range(WARMUP_EPOCHS + EPOCHS), desc='Epoch'):
    model.train()
    train_loss = 0
    for data, labels, _ in tqdm(train_loader, desc='Train', position=0, leave=True, colour='green'):
        # print(f'data:{data.shape}, labels: {labels.shape}')
        data = data.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        predictions = model.forward(data)
        # print(f'prediction shape {predictions.size()} true shape {labels.size()}')
        loss = loss_function(predictions, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() / len(train_loader)
        del data, labels, predictions

    # Step and graph scheduler once an epoch
    # writer.add_scalar('Learning Rate', scheduler.get_last_lr()[0], epoch)
    # scheduler.step()

    train_losses = np.append(train_losses, [[epoch], [train_loss]], axis=1)
    writer.add_scalar('Loss/train', train_loss, epoch)

    l_trn.append(train_loss)

    if epoch % 5 == 0:
        model.eval()
        with torch.no_grad():
            val_loss = 0
            for data, labels, _ in tqdm(val_loader, desc='Val', position=0, leave=True):
                data = data.to(device)
                labels = labels.to(device)

                predictions = model.forward(data)
                loss = loss_function(predictions, labels)

                val_loss += loss.item() / len(val_loader)
                del data, labels, predictions
        val_losses = np.append(val_losses, [[epoch], [val_loss]], axis=1)
        writer.add_scalar('Loss/val', val_loss, epoch)
        save_metrics()

        print(f'\nepoch: {epoch}/{EPOCHS} Loss/train: {train_loss:.3f}, Loss/val: {val_loss:.3f}')
        l_val.append(val_loss)

save_metrics()
torch.save(model.state_dict(), os.path.join(weight_dir, 'final'))

### Evaluate

In [None]:
def calculate_precision_recall_with_nms(true_boxes, pred_boxes, scores, iou_thres=0.5, nms_thres=0.3):
    
    true_boxes = true_boxes.float()
    pred_boxes = pred_boxes.float()

    # Apply NMS to predicted boxes
    keep_indices = nms(pred_boxes, scores=scores.float(), iou_threshold=nms_thres)
    pred_boxes = pred_boxes[keep_indices]

    # Calculate IoU  with non-maximum suppressed predictions
    # I'th row with j'th column represents iou between pred[i,:] and true[j,:]
    iou_matrix = distance_box_iou(pred_boxes, true_boxes)
    # iou_matrix = box_iou(pred_boxes, true_boxes)
    
    # Identify true positives, false positives, and false negatives
    tp = (iou_matrix >= iou_thres).sum(dim=1)
    fp = (tp == 0).sum()
    fn = (iou_matrix.max(dim=0).values < iou_thres).sum()
    
    # Convert tp from list of zeros and ones to the sum of all the ones in the list
    tp = tp.sum()
    
    # Calculate precision and recall
    precision = tp.float() / (tp + fp)
    recall = tp.float() / (tp + fn)

    return precision.item(), recall.item()


def get_bboxes_xyxy(pred, grid_size_x, grid_size_y, conf_thres=0.2):
    
    m = pred.size(dim=0) 
    n = pred.size(dim=1) 
    
    bboxes = []
    for i in range(m):
        for j in range(n):
            for k in range((pred.size(dim=2) - C) // 5):
                bbox_start = 5 * k + C
                bbox_end = 5 * (k + 1) + C
                bbox = pred[i, j, bbox_start:bbox_end]
                class_index = torch.argmax(pred[i, j, :C]).item()
                confidence = pred[i, j, class_index].item() * bbox[4].item()          # pr(c) * IOU
                if confidence < conf_thres:
                    continue
                x = (bbox[0] * IMAGE_SIZE[0]) + (j * grid_size_x)
                y = (bbox[1] * IMAGE_SIZE[1]) + (i * grid_size_y)
                width = bbox[2] * IMAGE_SIZE[0]
                height = bbox[3] * IMAGE_SIZE[1]
                b = [x-width/2, y-height/2, x+width/2, y+height/2, confidence, class_index]
                if b not in bboxes:
                    bboxes.append(b)       

    return torch.tensor(bboxes)


In [None]:
%matplotlib inline

np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)

classes = [0]

images_test = '/sci/labs/dekel/oshri.fatkiev/vela3_6d_crop_64_0to255/images/test'
labels_test = '/sci/labs/dekel/oshri.fatkiev/vela3_6d_crop_64_0to255/labels/test'

dataset = YoloDataset(images_path=images_test, labels_path=labels_test)
dataloader = DataLoader(dataset, batch_size=16)

model = YOLOv1Lite8(n_channels=6)
model.eval()
model.load_state_dict(torch.load(os.path.join(weight_dir, 'final')))#, map_location=torch.device('cpu')))
                      
precision, recall, count = 0, 0, 0

with torch.no_grad():
    for image, labels, original in tqdm(dataloader, colour='green'):
        
        res = model.forward(image)
        batch_size = image.size(dim=0)
        
        for idx in range(batch_size):

            grid_size_x = image.size(dim=3) / S
            grid_size_y = image.size(dim=2) / S
            
            pred = res[idx,...]
            pred_bboxes = get_bboxes_xyxy(pred, grid_size_x, grid_size_y, conf_thres=0.2)
    
            labels_copy = labels[idx,...]
            true_bboxes = get_bboxes_xyxy(labels_copy, grid_size_x, grid_size_y, conf_thres=1)
            
            if len(pred_bboxes) > 0:
                p_temp, r_temp = calculate_precision_recall_with_nms(
                    true_bboxes[...,:4],
                    pred_bboxes[...,:4],
                    pred_bboxes[...,5], 
                    iou_thres=0.3, 
                    nms_thres=0.5
                )
            elif len(pred_bboxes) == 0 and len(true_bboxes) == 0:
                p_temp, r_temp = 1, 1
            elif len(pred_bboxes) == 0 and len(true_bboxes) > 0: 
                p_temp, r_temp = 0, 0
                
            precision += p_temp
            recall += r_temp
            count += 1
        
print(f'Avg. Precision: {precision/count:.3f}, Avg. Recall: {recall/count:.3f}, based on {count} images')