In [4]:
"""
CS 4391 Homework 5 Programming
Run this script for YOLO training
"""
import torch
import torch.utils.data as data
import os, math
import sys
import time
import random
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from data import CrackerBox
from model import YOLO
from loss import compute_loss


# plot losses
def plot_losses(losses, filename='train_loss.pdf'):

    num_epoches = losses.shape[0]
    l = np.mean(losses, axis=1)

    plt.subplot(1, 1, 1)
    plt.plot(range(num_epoches), l, marker='o', alpha=0.5, ms=4)
    plt.title('Loss')
    plt.xlabel('Epoch')
    loss_xlim = plt.xlim()

    plt.gcf().set_size_inches(6, 4)
    plt.savefig(filename, bbox_inches='tight')
    print('save training loss plot to %s' % (filename))
    plt.clf()
    plt.close() 


if __name__ == '__main__':

    # hyper-parameters
    # you can tune these for your training
    num_epochs = 2
    batch_size = 8
    learning_rate = 1e-5
    num_workers = 1

    # dataset
    dataset_train = CrackerBox('train')  
    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    epoch_size = len(train_loader)

    # network
    num_classes = 1
    num_boxes = 2
    network = YOLO(num_boxes, num_classes)
    image_size = network.image_size
    grid_size = network.grid_size
    network.train()

    # Optimizer: Adam
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    
    # create output directory
    output_dir = 'checkpoints'
    print('Output will be saved to `{:s}`'.format(output_dir))
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # save the losses
    losses = np.zeros((num_epochs, epoch_size), dtype=np.float32)
    # for each epoch
    for epoch in range(num_epochs):

        # for each sample
        for i, sample in enumerate(train_loader):
        
            image = sample['image']
            gt_box = sample['gt_box']
            gt_mask = sample['gt_mask']

            # forward pass
            # Assuming you have this line in your code
            outputs = network(image)
            fc_output, pred_box = outputs['fc_output'], outputs['pred_box']

            # Use fc_output in the compute_loss function
            loss = compute_loss(fc_output, pred_box, gt_box, gt_mask, num_boxes, num_classes, grid_size, image_size)

            # optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print('epoch %d/%d, iter %d/%d, lr %.6f, loss %.4f' % (epoch, num_epochs, i, epoch_size, learning_rate, loss))
            losses[epoch, i] = loss

        
        # save checkpoint for every epoch
        state = network.state_dict()
        filename = 'yolo_epoch_{:d}'.format(epoch+1) + '.checkpoint.pth'
        torch.save(state, os.path.join(output_dir, filename))
        print(filename)
        
        
    # save the final checkpoint
    state = network.state_dict()
    filename = 'yolo_final.checkpoint.pth'
    torch.save(state, os.path.join(output_dir, filename))
    print(filename)

    # plot loss
    plot_losses(losses)

"""
CS 4391 Homework 5 Programming
Run this script for YOLO testing
"""
import torch
import torch.utils.data as data
import os, math
import sys
import time
import random
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from data import CrackerBox
from model import YOLO
from voc_eval import voc_eval


# from the network prediction, extract the bounding boxes with confidences larger than threshold
# pred_box: (batch_size, num_boxes * 5 + num_classes, 7, 7), predicted bounding boxes from the network (see the forward() function)
def extract_detections(pred_box, threshold, num_boxes):
    # Assuming pred_box is a tensor with shape (batch_size, num_boxes * 5 + num_classes, 7, 7)
    
    # extract boxes
    boxes_all = torch.zeros((0, 5), dtype=torch.float32, device=pred_box.device)
    for i in range(num_boxes):
        confidence = pred_box[0, 5 * i + 4]
        y, x = torch.where(confidence > threshold)
        boxes = pred_box[0, 5 * i:5 * i + 5, y, x].t()
        boxes_all = torch.cat((boxes_all, boxes), dim=0)

    # convert to (x1, y1, x2, y2)
    boxes = boxes_all.clone()
    boxes[:, 0] = boxes_all[:, 0] - boxes_all[:, 2] * 0.5
    boxes[:, 2] = boxes_all[:, 0] + boxes_all[:, 2] * 0.5
    boxes[:, 1] = boxes_all[:, 1] - boxes_all[:, 3] * 0.5
    boxes[:, 3] = boxes_all[:, 1] + boxes_all[:, 3] * 0.5
    return boxes



# visualize the detections
def visualize(image, gt, detections):
    '''
    im = image[0].permute(1, 2, 0).cpu().detach().numpy()
    pixel_mean = np.array([[[102.9801, 115.9465, 122.7717]]], dtype=np.float32)

    # show ground truth
    fig = plt.figure()
    ax = fig.add_subplot(1, 2, 1)
    im = im * 255.0 + pixel_mean
    im = im.astype(np.uint8)
    plt.imshow(im[:, :, (2, 1, 0)])
    rect = patches.Rectangle((gt[0, 0], gt[0, 1]), gt[0, 2]-gt[0, 0], gt[0, 3]-gt[0, 1], linewidth=2, edgecolor='g', facecolor="none")
    ax.add_patch(rect) 
    plt.title('ground truth')   
    
    # show detection
    ax = fig.add_subplot(1, 2, 2)
    plt.imshow(im[:, :, (2, 1, 0)])
    plt.title('prediction')
    for i in range(detections.shape[0]):   
        x1 = detections[i, 0].detach().numpy()
        x2 = detections[i, 2].detach().numpy()
        y1 = detections[i, 1].detach().numpy()
        y2 = detections[i, 3].detach().numpy()
        score = detections[i, 4].detach().numpy()
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='g', facecolor="none")
        ax.add_patch(rect)
        plt.plot((x1+x2)/2, (y1+y2)/2, 'ro')
        ax.text(x1, y1, '%.2f' % score, color='y')
    plt.show()'''
    



# main function for testing
if __name__ == '__main__':

    # dataset
    dataset = CrackerBox('val')  
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)
    epoch_size = len(data_loader)

    # network
    num_classes = 1
    num_boxes = 2
    network = YOLO(num_boxes, num_classes)
    image_size = network.image_size
    grid_size = network.grid_size

    # load checkpoint
    output_dir = 'checkpoints'
    filename = 'yolo_final.checkpoint.pth'
    filename = os.path.join(output_dir, filename)
    network.load_state_dict(torch.load(filename))
    network.eval()
    
    # detection threshold
    threshold = 0.1

    # main test loop
    results_gt = []
    results_pred = []
    for i, sample in enumerate(data_loader):

        image = sample['image']
        gt_box = sample['gt_box']
        gt_mask = sample['gt_mask']

        # forward pass
        outputs = network(image)
        output, pred_box = outputs['fc_output'], outputs['pred_box']

        


        # convert gt box
        gt_box = sample['gt_box'][0].numpy()
        gt_mask = sample['gt_mask'][0].numpy()
        y, x = np.where(gt_mask == 1)
        cx = gt_box[0, y, x] * dataset.yolo_grid_size + x * dataset.yolo_grid_size
        cy = gt_box[1, y, x] * dataset.yolo_grid_size + y * dataset.yolo_grid_size
        w = gt_box[2, y, x] * dataset.yolo_image_size
        h = gt_box[3, y, x] * dataset.yolo_image_size
        x1 = cx - w * 0.5
        x2 = cx + w * 0.5
        y1 = cy - h * 0.5
        y2 = cy + h * 0.5        
        gt = np.array([x1, y1, x2, y2]).reshape((1, 4))
        results_gt.append(gt)

        # extract predictions
        detections = extract_detections(pred_box, threshold, num_boxes)
        results_pred.append(detections)
        print('image %d/%d, %d objects detected' % (i+1, epoch_size, detections.shape[0]))

        # visualization, uncomment the following line to see the detection results
        visualize(image, gt, detections)


        
    # evaluation
    rec, prec, ap = voc_eval(results_gt, results_pred)
    print('Detection AP', ap)
    
    # save the PR curve
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    plt.plot(rec, prec)
    plt.xlabel('recall')
    plt.ylabel('precision')
    plt.title('AP: %.2f' % ap)
    plt.gcf().set_size_inches(6, 4)
    plt.savefig('test_ap.pdf', bbox_inches='tight')
    plt.clf()


100 images for training
Output will be saved to `checkpoints`
epoch 0/2, iter 0/13, lr 0.000010, loss 7961807.0000
epoch 0/2, iter 1/13, lr 0.000010, loss 9534487.0000
epoch 0/2, iter 2/13, lr 0.000010, loss 8072014.0000
epoch 0/2, iter 3/13, lr 0.000010, loss 7865175.5000
epoch 0/2, iter 4/13, lr 0.000010, loss 7922552.0000
epoch 0/2, iter 5/13, lr 0.000010, loss 8149928.0000
epoch 0/2, iter 6/13, lr 0.000010, loss 7964957.5000
epoch 0/2, iter 7/13, lr 0.000010, loss 7907393.0000
epoch 0/2, iter 8/13, lr 0.000010, loss 9373734.0000
epoch 0/2, iter 9/13, lr 0.000010, loss 7478009.0000
epoch 0/2, iter 10/13, lr 0.000010, loss 7017914.0000
epoch 0/2, iter 11/13, lr 0.000010, loss 8313464.5000
epoch 0/2, iter 12/13, lr 0.000010, loss 3164800.0000
yolo_epoch_1.checkpoint.pth
epoch 1/2, iter 0/13, lr 0.000010, loss 7874880.5000
epoch 1/2, iter 1/13, lr 0.000010, loss 7306433.5000
epoch 1/2, iter 2/13, lr 0.000010, loss 8459700.0000
epoch 1/2, iter 3/13, lr 0.000010, loss 8100984.0000
epoch 

<Figure size 600x400 with 0 Axes>