Skip to content

Commit

Permalink
GH-1 loading ARW images included into script
Browse files Browse the repository at this point in the history
  • Loading branch information
NicoMandel committed Dec 27, 2023
1 parent 0d0469c commit f70008a
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
24 changes: 20 additions & 4 deletions inference_single_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,27 @@
import os
from tqdm import tqdm
import time
from argparse import ArgumentParser

# import required functions, classes
import torch
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from sahi.utils.file import list_files
from sahi.utils.cv import IMAGE_EXTENSIONS, read_image_as_pil
from argparse import ArgumentParser

from utils import find_model_files, read_arw_as_pil

IMAGE_EXTENSIONS += [".arw"]

def parse_args():
fdir = os.path.abspath(os.path.dirname(__file__))
datadir = os.path.join(fdir, '..', 'data', 'inference')
datadir = os.path.join(fdir, 'data', 'inference', 'arw')
outdir = os.path.join(datadir, "labels")
parser = ArgumentParser(description="File for creating labels on a folder of inference images using SAHI")
parser.add_argument("-i", "--input", required=False, type=str, help="Location of the input folder", default=datadir)
parser.add_argument("-o", "--output", required=False, help="which output folder to put the labels to", default=outdir)
parser.add_argument("-m", "--model", default=None, help="Path to model file. If None given, will take first file from <config> directory")
args = parser.parse_args()
return vars(args)

Expand Down Expand Up @@ -68,7 +72,16 @@ def convert_pred_to_txt(pred, target_dir, img_name : str = "labels"):

if __name__=="__main__":
args = parse_args()
yolov5_model_path = "yolov5/ohw/combined_m_2/weights/best.pt"

# getting the model
if args["model"] == None:
fdir = os.path.abspath(os.path.dirname(__file__))
confdir = os.path.join(fdir, "config")
mfs = find_model_files(confdir)
yolov5_model_path = mfs[0]
else:
yolov5_model_path = args["model"]

model_type = "yolov5"
model_path = yolov5_model_path
model_device = "cuda:0" # or 'cuda:0'
Expand Down Expand Up @@ -109,7 +122,10 @@ def convert_pred_to_txt(pred, target_dir, img_name : str = "labels"):
for ind, image_path in enumerate(
tqdm(image_iterator, f"Performing inference on {source_image_dir}")
):
image_as_pil = read_image_as_pil(image_path)
if image_path.endswith("ARW"):
image_as_pil = read_arw_as_pil(image_path)
else:
image_as_pil = read_image_as_pil(image_path)
# test_img = "data/OHW/Inference/test_ds/DSC00009.png"
result = get_sliced_prediction(
image_as_pil,
Expand Down
23 changes: 23 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os.path
from pathlib import Path
import rawpy
import numpy as np
from PIL import Image

def find_model_files(dir : str) -> list:
p = Path(dir)
return list(p.glob("*.pt"))

def load_arw(fpath : str) -> np.ndarray:
"""
Loading an ARW image. Thanks to Rob
"""
raw = rawpy.imread(fpath)
return raw.postprocess(use_camera_wb=True, output_bps=8)

def read_arw_as_pil(fpath : str) -> Image.Image:
"""
function to return an ARW image in PIL format for SAHI
"""
np_arr = load_arw(fpath)
return Image.fromarray(np_arr)

0 comments on commit f70008a

Please sign in to comment.