In [1]:
import sys
sys.path.append('./yolo')
from detection_fusion import EnsembleModel
import numpy as np
import matplotlib.pyplot as plt
from rcnn.dataset import ChestCocoDetection
import matplotlib.patches as patches
from torch.utils.data import DataLoader
import tqdm
import torch

In [2]:
from yolo.utils.general import xywh2xyxy, scale_coords, box_iou
from yolo.utils.metrics import ap_per_class
import torch
from rcnn.model import ChestRCNN
from yolo.yolo import Model
from yolo.utils.datasets import LoadImages
from yolo.utils.general import check_img_size

In [49]:
YOLO_FINAL_MODEL_PATH = "./yolo/models_final/yolov5_epoch_26.pt"
FASTER_RCNN_FINAL_MODEL_PATH = "./rcnn/models/fasterrcnn_epoch_23.pt"
RESNET_BACKBONE_PATH = "./resnet/models/resnext101_32x8d_epoch_35.pt"

yolov5_weights = torch.load(YOLO_FINAL_MODEL_PATH)
fasterrcnn_r101_weights = torch.load(FASTER_RCNN_FINAL_MODEL_PATH)

yolo = Model(cfg="./yolo/yolo5l.yaml",ch=3,nc=1)
yolo.load_state_dict(yolov5_weights, strict=False) 

fasterRcnn = ChestRCNN(RESNET_BACKBONE_PATH)
fasterRcnn.load_state_dict(fasterrcnn_r101_weights)
ensemble = EnsembleModel(fasterRcnn=fasterRcnn, yolo=yolo)

### Evaluation of Ensemble method

In [50]:
def collate_fn(batch):
    imgs, targets, path = zip(*batch)
    imgs = torch.stack(imgs)
    return tuple((imgs, targets, path))

In [51]:
# need bs = 1 because model fusion cannot handle batches 
test = ChestCocoDetection(root="F:\\aml-project\data\\siim-covid19-detection", ann_file="F:\\aml-project\\data\\siim-covid19-detection\\test.json", training=False, image_size=512, detection_fusion=True)
test_loader = DataLoader(test, batch_size=1, shuffle=False, pin_memory=False, num_workers=0, collate_fn=collate_fn)

loading annotations into memory...
Done (t=0.01s)
creating index...
index created!


In [52]:
VISUALIZATION = True
jdict, stats, ap, ap_class = [], [], [], []
stats_yolo, stats_fcnn = [], []
iouv = torch.linspace(0.5, 0.55, 1) # iou vector for mAP@0.5:0.95
niou = iouv.numel()
general_test_results = []
visualization_sample_data = {}
for i, (images, targets, path) in enumerate(tqdm.tqdm(test_loader)):
    nb, _, height, width = images.shape
    assert nb == 1

    # Load image with Yolo transformations since we use the FCNN data loader per default
    path = '.\..\\data\\siim-covid19-detection\\' + path[0]
    stride = int(yolo.stride.max())  # model stride
    img_size = check_img_size(512, s=stride)  # check img_size
    yolo_img = LoadImages(path,img_size=img_size,stride=stride).__iter__().__next__()[1]
    yolo_img = torch.from_numpy(yolo_img).float()
    yolo_img /= 255.0
    if yolo_img.ndimension() == 3:
         yolo_img = yolo_img.unsqueeze(0)

    _, boxes, scores, pred_labels, yolo_box, yolo_scores, yolo_labels, fcnn_box, fcnn_scores, fcnn_labels = ensemble.detection_fusion(images,yolo_img, extended_output=True)
    boxes = torch.tensor(boxes)
    yolo_box = torch.tensor(yolo_box)
    fcnn_box = torch.tensor(fcnn_box)
    gt_boxes = targets[0]['boxes']
    labels = targets[0]['labels']
    nl = len(labels)
    tcls = labels[:].tolist() if nl else []  # target class

    # save one sample prediction for visualization
    if len(boxes) and i == 5:
        visualization_sample_data["img"] = images
        visualization_sample_data["gt_box"] = gt_boxes
        visualization_sample_data["yolo_box"] = yolo_box
        visualization_sample_data["yolo_scores"] = yolo_scores
        visualization_sample_data["fcnn_box"] = fcnn_box
        visualization_sample_data["fcnn_scores"] = fcnn_scores
        visualization_sample_data["ensemble_box"] = boxes
        visualization_sample_data["ensemble_scores"] = scores
        if VISUALIZATION:
            break


    if len(boxes) == 0:
        if nl:
            stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
            stats_yolo.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
            stats_fcnn.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))

        continue

    correct = torch.zeros(len(boxes), niou, dtype=torch.bool)
    correct_yolo = torch.zeros(len(yolo_box), niou, dtype=torch.bool)
    correct_fcnn = torch.zeros(len(fcnn_box), niou, dtype=torch.bool)
    
    if nl:
        detected = []
        detected_yolo = []
        detected_fcnn = []
        tcls_tensor = labels[:]

        # Per target class
        for cls in torch.tensor([1]):
            ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1)
            pi = (cls == torch.tensor(pred_labels)).nonzero(as_tuple=False).view(-1)
            pi_yolo = (cls == torch.tensor(yolo_labels)).nonzero(as_tuple=False).view(-1)
            pi_fcnn = (cls == torch.tensor(fcnn_labels)).nonzero(as_tuple=False).view(-1)

            # Search for detections
            if pi.shape[0]:        
                ious, i = box_iou(boxes[pi], gt_boxes[ti]).max(1)  # best ious, indices
                # Append detections - Ensemble
                detected_set = set()
                for j in (ious > iouv[0]).nonzero(as_tuple=False):
                    d = ti[i[j]]  
                    if d.item() not in detected_set:
                        detected_set.add(d.item())
                        detected.append(d)
                        correct[pi[j]] = ious[j] > iouv  
                        if len(detected) == nl: 
                            break
                            
            if pi_yolo.shape[0]:
                ious_yolo, i_yolo = box_iou(yolo_box[pi_yolo], gt_boxes[ti]).max(1)
                # Yolo
                detected_set_yolo = set()
                for j in (ious_yolo > iouv[0]).nonzero(as_tuple=False):
                    d = ti[i_yolo[j]]  
                    if d.item() not in detected_set_yolo:
                        detected_set_yolo.add(d.item())
                        detected_yolo.append(d)
                        correct_yolo[pi_yolo[j]] = ious_yolo[j] > iouv  
                        if len(detected_yolo) == nl: 
                            break
                            
            if pi_fcnn.shape[0]:
                ious_fcnn, i_fcnn = box_iou(fcnn_box[pi_fcnn], gt_boxes[ti]).max(1)
                # Faster R-CNN
                detected_set_fcnn = set()
                for j in (ious_fcnn > iouv[0]).nonzero(as_tuple=False):
                    d = ti[i_fcnn[j]]  
                    if d.item() not in detected_set_fcnn:
                        detected_set_fcnn.add(d.item())
                        detected_fcnn.append(d)
                        correct_fcnn[pi_fcnn[j]] = ious_fcnn[j] > iouv  
                        if len(detected_fcnn) == nl: 
                            break



        # Append statistics (correct, conf, pcls, tcls)
        stats.append((correct.cpu(), torch.tensor(scores),torch.tensor(pred_labels), tcls))
        stats_yolo.append((correct_yolo.cpu(), torch.tensor(yolo_scores),torch.tensor(yolo_labels), tcls))
        stats_fcnn.append((correct_fcnn.cpu(), torch.tensor(fcnn_scores),torch.tensor(fcnn_labels), tcls))

for statistic in [stats, stats_yolo, stats_fcnn]:
    statistic = [np.concatenate(x, 0) for x in zip(*statistic)] 
    if len(statistic) and statistic[0].any():
        p, r, ap, f1, ap_class = ap_per_class(*statistic)
        ap50, ap = ap[:, 0], ap.mean(1) 
        mp, mr, map50, m = p.mean(), r.mean(), ap50.mean(), ap.mean()
        nt = np.bincount(statistic[3].astype(np.int64), minlength=1)

        general_test_results.append({
        "precision": p,
        "recall": r,
        "ap": ap,
        "f1": f1,
        "ap_class": ap_class,
        "ap": ap,
        "ap50": ap50,
        "mp": mp,
        "mr": mr,
        "map50": map50,
        "map": m
    })

    else:
        nt = torch.zeros(1)


  1%|          | 5/859 [00:23<1:05:44,  4.62s/it]


In [18]:
# Ensemble
general_test_results[0]

{'precision': array([    0.59357]),
 'recall': array([     0.5302]),
 'ap': array([    0.45424]),
 'f1': array([     0.5601]),
 'ap_class': array([1]),
 'ap50': array([    0.45424]),
 'mp': 0.5935737555211111,
 'mr': 0.5301970756516211,
 'map50': 0.454241758431956,
 'map': 0.454241758431956}

In [19]:
# YOLO
general_test_results[1]

{'precision': array([    0.50593]),
 'recall': array([    0.45048]),
 'ap': array([    0.34455]),
 'f1': array([    0.47659]),
 'ap_class': array([1]),
 'ap50': array([    0.34455]),
 'mp': 0.505925645110543,
 'mr': 0.4504762082687064,
 'map50': 0.34455343562825014,
 'map': 0.34455343562825014}

In [20]:
# Faster R-CNN
general_test_results[2]

{'precision': array([    0.60309]),
 'recall': array([    0.50328]),
 'ap': array([    0.40929]),
 'f1': array([    0.54868]),
 'ap_class': array([1]),
 'ap50': array([    0.40929]),
 'mp': 0.6030946457346565,
 'mr': 0.5032776105936924,
 'map50': 0.40928837471026075,
 'map': 0.40928837471026075}

In [46]:
visualization_sample_data

{'img': tensor([[[[-1.63841, -1.65554, -1.65554,  ..., -1.91241, -1.77541, -1.67266],
           [-1.56991, -1.70691, -1.75828,  ..., -1.80966, -1.87816, -1.80966],
           [-1.68979, -1.70691, -1.77541,  ..., -1.84391, -1.91241, -1.91241],
           ...,
           [-1.80966, -1.92953, -1.98091,  ..., -1.15892, -1.09042, -1.00479],
           [-1.72403, -1.98091, -1.99803,  ..., -1.07329, -1.07329, -0.98767],
           [-1.89528, -1.80966, -1.89528,  ..., -1.21029, -1.12467, -1.05617]],
 
          [[-1.54552, -1.56303, -1.56303,  ..., -1.82563, -1.68557, -1.58053],
           [-1.47549, -1.61555, -1.66807,  ..., -1.72059, -1.79062, -1.72059],
           [-1.59804, -1.61555, -1.68557,  ..., -1.75560, -1.82563, -1.82563],
           ...,
           [-1.72059, -1.84314, -1.89566,  ..., -1.05532, -0.98529, -0.89776],
           [-1.63305, -1.89566, -1.91317,  ..., -0.96779, -0.96779, -0.88025],
           [-1.80812, -1.72059, -1.80812,  ..., -1.10784, -1.02031, -0.95028]],
 
       

In [54]:
def add_bounding_boxes(target, ax):
    boxes = target['boxes']
    print(boxes.shape)
    for box in boxes:
        mp_box = patches.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], edgecolor="r", facecolor='none')
        ax.add_patch(mp_box)

def show_samples_for(gt, yolo, fcnn, ensemble):
    fig, axs = plt.subplots(1, 4, figsize=(12,8))
    for im_i, (img, box, name) in enumerate([gt, yolo, fcnn, ensemble]):
        ax = axs[im_i] if isinstance(axs, np.ndarray) else axs
        ax.set_title(name)
        plt.figure()
        ax.imshow(img.squeeze_(0).permute(1,2,0))#,cmap=plt.cm.bone)
        add_bounding_boxes(box, ax)
    
    fig.savefig("combo.png",bbox_inches='tight', dpi=200)


show_samples_for(
    (visualization_sample_data['img'], {"boxes":visualization_sample_data['gt_box']}, "Ground Truth"), 
    (visualization_sample_data['img'], {"boxes":visualization_sample_data['yolo_box']}, "YOLO"),
    (visualization_sample_data['img'], {"boxes":visualization_sample_data['fcnn_box']}, "Faster R-CNN"),
    (visualization_sample_data['img'], {"boxes":visualization_sample_data['ensemble_box']}, "Ensemble")
)

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])
