In [1]:
import os
import cv2
import numpy as np
import pandas as pd
import torch
import json
import torchvision
from torchvision import transforms

import albumentations as A
from albumentations.pytorch import ToTensorV2

In [2]:
# YOLO model
yolo_model =  torch.hub.load('yolov5', 'custom', 'models/model_yolo.pt', source='local')

# Configurations

img_dir = "images"
save_dir = img_dir + "_pred"
results_file = 'results.json'
crop_shape = (64,64)

# Transformations

transforms = A.Compose(
    [
        A.SmallestMaxSize(max_size=64),
        A.RandomCrop(height=64, width=64),
        A.Normalize(mean=0.5, std=0.5),
        ToTensorV2(),
    ]
)

YOLOv5  v7.0-157-g5178d41 Python-3.9.16 torch-2.0.0+cpu CPU

Fusing layers... 


[31m[1mrequirements:[0m C:\Users\henri\Desktop\nato_project\repo\nn23_project\generator\requirements.txt not found, check failed.


Model summary: 276 layers, 35248920 parameters, 0 gradients, 48.9 GFLOPs
Adding AutoShape... 


In [3]:
# Classification models

models = [
    {
        "label": "infantry",
        "path": "models/model_infantry.pt",
        "model": None
    },
    {
        "label": "anti-tank",
        "path": "models/model_at.pt",
        "model": None
    },
    {
        "label": "supply",
        "path": "models/model_supply.pt",
        "model": None
    },
    {
        "label": "hq",
        "path": "models/model_hq.pt",
        "model": None
    },
    {
        "label": "medic",
        "path": "models/model_medic.pt",
        "model": None
    },
    {
        "label": "reconnaissance",
        "path": "models/model_recce.pt",
        "model": None
    }
    # Extend the list with additional models
]

for model in models:
    loaded_model = torch.jit.load(model["path"], map_location=torch.device("cpu"))
    loaded_model.eval()
    model["model"] = loaded_model

In [4]:
def detect_symbols(save_results=True):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        
    detections = []
    
    for img in os.listdir(img_dir):
        if "img" not in img:
            continue
        
        print('Processing image', img)
        
        detection = {
            "img": img,
            "symbols": [],
        }
        
        canvas = cv2.imread(img_dir+'/'+img)
        canvas_gray = cv2.cvtColor(canvas, cv2.COLOR_BGR2GRAY)
        
        symbol_locs = yolo_model(canvas)
        
        for i, d in symbol_locs.pandas().xyxy[0].iterrows():
            symbol = {
                "xmin": int(d.xmin),
                "ymin": int(d.ymin),
                "xmax": int(np.ceil(d.xmax)),
                "ymax": int(np.ceil(d.ymax)),
                "labels": [],
            }
            
            unit_symbol = canvas_gray[int(d.ymin):int(np.ceil(d.ymax)), int(d.xmin):int(np.ceil(d.xmax))]
            unit_symbol = transforms(image=unit_symbol)["image"]
            
            translate_label = 0 # For writing the labels on the image by amount of pixels
            for model in models:
                logits = model["model"](unit_symbol.unsqueeze(0))
                
                if logits.shape[1] == 2:
                    classify_result = torch.argmax(logits) == 1
                else:
                    classify_result = torch.round(torch.sigmoid(logits)).squeeze()
                    
                if classify_result:
                    symbol["labels"].append(model["label"])
                    cv2.rectangle(
                        canvas, 
                        (int(d.xmin), int(d.ymin)), 
                        (int(np.ceil(d.xmax)), int(np.ceil(d.ymax))), 
                        (255,0,0), 2)
                    canvas = cv2.putText(
                        canvas, 
                        model["label"], 
                        (int(d.xmin)-10, int(d.ymin)-translate_label),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, 
                        (255, 0, 255), 1, cv2.LINE_AA)
                    translate_label += 18
                    
            detection["symbols"].append(symbol)
        
        cv2.imwrite(f"{save_dir}/{img}", canvas)
        detections.append(detection)
                    
    json_result = json.dumps(detections, indent=4)
    
    if save_results:
        print('Saving results to', results_file)
        with open(results_file, 'w') as file:
            file.write(json_result)

In [5]:
detect_symbols()

Processing image img0.jpg
Processing image img1.jpg
Processing image img10.jpg
Processing image img11.jpg
Processing image img12.jpg
Processing image img13.jpg
Processing image img14.jpg
Processing image img15.jpg
Processing image img16.jpg
Processing image img17.jpg
Processing image img18.jpg
Processing image img19.jpg
Processing image img2.jpg
Processing image img20.jpg
Processing image img21.jpg
Processing image img22.jpg
Processing image img23.jpg
Processing image img24.jpg
Processing image img25.jpg
Processing image img26.jpg
Processing image img27.jpg
Processing image img28.jpg
Processing image img29.jpg
Processing image img3.jpg
Processing image img30.jpg
Processing image img31.jpg
Processing image img32.jpg
Processing image img33.jpg
Processing image img34.jpg
Processing image img35.jpg
Processing image img36.jpg
Processing image img37.jpg
Processing image img38.jpg
Processing image img39.jpg
Processing image img4.jpg
Processing image img5.jpg
Processing image img6.jpg
Processi