In [1]:
!git clone https://github.com/ultralytics/yolov5.git -q

In [2]:
!pip install -r yolov5/requirements.txt -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.3/184.3 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m586.5/586.5 kB[0m [31m17.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m205.1/205.1 kB[0m [31m14.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [42]:
import os
import cv2
import numpy as np
import pandas as pd
import torch
import pickle
import time

In [None]:
# Model imports

#upload model.pt from yolo/one_class
yolo_model =  torch.hub.load('yolov5', 'custom', 'model.pt', source='local') 
#upload cnn model of your choice
lenet_model = pickle.load(open("infantry.pkl", 'rb'))

# Configurations
#class_names = pd.read_csv('classes.csv', index_col=0)
crop_shape = (64,64)

In [40]:
def locate_unit_symbols(canvas):
    start = time.perf_counter()
    located_symbols = yolo_model(canvas)
    #print(time.perf_counter()-start)
    return located_symbols

def classify(canvas, d):
    # Crop
    unit_symbol = canvas[int(d.ymin):int(np.ceil(d.ymax)), int(d.xmin):int(np.ceil(d.xmax))]
    # Resize the crop into correct size
    resized_unit_symbol = cv2.resize(unit_symbol, crop_shape, interpolation=cv2.INTER_AREA)
    # Predict correct class
    pred_class = lenet_model.predict(np.array([resized_unit_symbol]), verbose=0)
    return pred_class

def draw_rectangle(canvas, d, pred_class):
    new_canvas = canvas#.copy()
    #Draw a rectange
    cv2.rectangle(new_canvas, (int(d.xmin), int(d.ymin)), (int(np.ceil(d.xmax)), int(np.ceil(d.ymax))), (255,0,0), 2)
    #Add label
    new_canvas = cv2.putText(new_canvas, str(pred_class[0].argmax()), (int(d.xmin)-10, int(d.ymin)-10),
                              cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 1, cv2.LINE_AA)
    return new_canvas

def main():
    # Test images should be in the "images" folder
    img_dir = "images"
    # Images with predictions will be in the "images_pred" folder
    save_dir = img_dir + "_pred"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    for img in os.listdir(img_dir):
        if "img" in img:
          canvas = cv2.imread(img_dir+'/'+img)

          symbol_locs = locate_unit_symbols(canvas)

          for i, d in symbol_locs.pandas().xyxy[0].iterrows():
              pred_class = classify(canvas, d) 
              new_canvas = draw_rectangle(canvas, d, pred_class)
              
          cv2.imwrite(f"{save_dir}/{img}", new_canvas)

In [None]:
main()