In [24]:
#Part of code taken from https://learnopencv.com/weighted-boxes-fusion/

import torch 
import cv2 
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
from PIL import Image
import os
from ensemble_boxes import *

In [19]:
model=torch.hub.load("ultralytics/yolov5",'custom',path='best_yolov5_weights.pt',device='cpu')


Using cache found in /Users/smudge/.cache/torch/hub/ultralytics_yolov5_master
YOLOv5 🚀 2023-9-28 Python-3.11.4 torch-2.1.0 CPU

Fusing layers... 
YOLOv5s summary: 157 layers, 7015519 parameters, 0 gradients, 15.8 GFLOPs
Adding AutoShape... 


In [20]:
#set the env variables
PATH_TO_TEST='/Volumes/T7 Shield/Smudge/Datasets/Wheat_Head_detection/test/'
PATH_TO_INPAINT='/Volumes/T7 Shield/Smudge/Datasets/Wheat_Head_detection/modified_test_set/test_enhanced_inpainting/'
PATH_TO_AUTO='/Volumes/T7 Shield/Smudge/Datasets/Wheat_Head_detection/modified_test_set/autoencoder_enhanced/'


In [21]:
#Get the image ids
test_ids=os.listdir(PATH_TO_TEST)
inpaint_ids=os.listdir(PATH_TO_INPAINT)
auto_ids=os.listdir(PATH_TO_AUTO)

In [22]:
#get the predictions
model.conf=0.3

def get_predictions(img_names,path):
    pred_scores=dict()
    pred_boxes=dict()
    pred_classes=dict()

    for img in img_names:
        if '_' not in img:
            img_path=path+img

            prediction=model(img_path)

            pred_scores[img]=prediction.xyxy[0][:,4].detach().numpy()
            pred_boxes[img]=prediction.xyxy[0][:,0:4].detach().numpy()
            pred_classes[img]=prediction.xyxy[0][:,5].detach().numpy()
    
    return pred_scores,pred_boxes,pred_classes

In [23]:
#get predictions for different test sets
#Take combinations of 2
pred_confs=[]
pred_boxes=[]
pred_classes=[]

test_paths=[PATH_TO_TEST,PATH_TO_INPAINT]
test_imgids=[test_ids,inpaint_ids]

for i in range(2):
    confs_scores,box_preds,cls_pred=get_predictions(test_imgids[i],test_paths[i])

    pred_confs.append(confs_scores)
    pred_boxes.append(box_preds)
    pred_classes.append(cls_pred)

In [29]:
#function for weighted box fusion
def perform_wbf(pred_confs_models,pred_boxes_models,pred_classes_models):
    wbf_boxes_dict=dict()
    wbf_scores_dict=dict()

    for image_id in test_ids:
        if '_' not in image_id:
            res_array=np.array([1024,1024,1024,1024]) #1024X1024 images

            all_model_boxes=[]
            all_model_scores=[]
            all_model_classes=[]

            for boxes,scores,classes in zip(pred_boxes_models,pred_confs_models,pred_classes_models):
                pred_boxes_norm=(boxes[image_id]/res_array).clip(min=0.,max=1.)
                scores_model=scores[image_id]
                classes_model=classes[image_id]

                all_model_boxes.append(pred_boxes_norm)
                all_model_scores.append(scores_model)
                all_model_classes.append(classes_model)
            
            boxes,scores,labels=weighted_boxes_fusion(all_model_boxes,all_model_scores,all_model_classes,weights=None,iou_thr=0.5,skip_box_thr=0.30)

            final_score_ids=np.where(scores>0.28)[0]
            final_boxes=boxes[final_score_ids]
            final_scores=scores[final_score_ids]

            final_boxes=(final_boxes*res_array).clip(min=[0,0,0,0],max=[1023,1023,1023,1023])

            final_boxes=final_boxes.astype("int")

            final_boxes[:,2:]=final_boxes[:,2:]-final_boxes[:,:2]
            wbf_boxes_dict[image_id]=final_boxes.tolist()
            wbf_scores_dict[image_id]=np.expand_dims(np.round(final_scores,5),axis=-1).tolist()
    
    return wbf_boxes_dict,wbf_scores_dict


        

In [30]:
boxes_dict_wbf,scores_dict_wbf=perform_wbf(pred_confs,pred_boxes,pred_classes)

In [31]:
#draw the outputs using opencv tools
def draw_bbox_conf(image,boxes,scores,color=(255,0,0),thickness=-1):
    overlay=image.copy()

    font_size=0.25+0.07*min(overlay.shape[:2])/100
    font_size=max(font_size,0.5)
    font_size=min(font_size,0.8)
    text_offset=7

    for box,score in zip(boxes,scores):
        xmin=box[0]
        ymin=box[1]
        xmax=box[0]+box[2]
        ymax=box[1]+box[3]

        overlay=cv2.reactangle(overlay,(xmin,ymin),(xmax,ymax),color,thickness)
        display_text=f"wheat_head: {score[0]:.2f}"
        (text_width,text_height),_=cv2.getTextSize(display_text,cv2.FONT_HERSHEY_SIMPLEX,font_size,2)

        cv2.rectangle(overlay,(xmin,ymin),(xmin+text_width+text_offset,ymin-text_height-int(15*font_size)),color,thickness=-1)

        overlay=cv2.putText(overlay,display_text,(xmin+text_offset,ymin-int(10*font_size)),cv2.FONT_HERSHEY_SIMPLEX,font_size,(255,255,255),2,lineType=cv2.LINE_AA)

    return cv2.addWeighted(overlay,0.75,image,0.25,0)


In [None]:
#test the outputs
