In [1]:
import cv2
import numpy as np
import os,glob,uuid
import torch
from torchvision import transforms
from PIL import Image

import sys
sys.path.append('../')

from learn import CustomResNet,SimpleCNN,SimpleCNN_1

In [2]:
os.makedirs('../dataset/label', exist_ok=True)

In [3]:
def apply_hsv_threshold(image):
    low_h, high_h = 90, 100  # Example range for yellow hue
    low_s, high_s = 140, 255 # Example range for saturation
    low_v, high_v = 0, 255 # Example range for value
    

    # Convert the image from RGB to HSV
    hsv_image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
    
    # Define the lower and upper bounds of the HSV threshold
    lower_bound = np.array([low_h, low_s, low_v])
    upper_bound = np.array([high_h, high_s, high_v])
    
    # Create a mask where pixels within the threshold are white, and others are black
    mask = cv2.inRange(hsv_image, lower_bound, upper_bound)
    
    # Create an all black image
    black_image = np.zeros_like(image)
    
    # Copy the pixels from the original image where the mask is white
    result_image = np.where(mask[:, :, None] == 255, image, black_image)
    
    return result_image

In [4]:
input_directory = "../dataset/collect"
output_directory = "../dataset/label/"

entries = os.listdir(input_directory)

# Filter to include only .jpg files
img_files = [file for file in entries if file.endswith('.jpg')]

In [5]:
img_files

['xy_151_054_032_928fde56-a09a-11ee-9d87-40ed000c4092.jpg',
 'xy_149_148_007_3a77b912-a09c-11ee-9d5b-40ed000c4092.jpg',
 'xy_078_140_006_d83afc98-a099-11ee-85d3-40ed000c4092.jpg',
 'xy_117_074_039_037078b2-a099-11ee-a36e-40ed000c4092.jpg',
 'xy_027_056_026_233d3e18-a09a-11ee-85d3-40ed000c4092.jpg',
 'xy_109_128_012_f5c51f9c-a098-11ee-a36e-40ed000c4092.jpg',
 'xy_020_114_012_e2d0ee2c-a09b-11ee-9d5b-40ed000c4092.jpg',
 'xy_168_140_006_851db3e8-a099-11ee-85d3-40ed000c4092.jpg',
 'xy_119_110_023_fc0965b4-a09a-11ee-b33d-40ed000c4092.jpg',
 'xy_183_131_007_ed4b03c6-a099-11ee-85d3-40ed000c4092.jpg',
 'xy_115_074_025_045f77dc-a09e-11ee-bacf-40ed000c4092.jpg',
 'xy_109_144_006_7969c442-a099-11ee-85d3-40ed000c4092.jpg',
 'xy_096_050_032_ee7880c2-a09c-11ee-9d5b-40ed000c4092.jpg',
 'xy_122_041_030_286247d0-a09a-11ee-85d3-40ed000c4092.jpg',
 'xy_205_114_015_34a7ce4c-a09b-11ee-b33d-40ed000c4092.jpg',
 'xy_120_138_011_1dd6c52e-a09b-11ee-b33d-40ed000c4092.jpg',
 'xy_114_139_009_d9c2317e-a09b-11ee-9d5b

In [6]:
# デバイスの設定（CUDAが利用可能な場合はGPUを使用）
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weight = '../weight/light/simpleCNN_v3.pth'
# モデルのインスタンスを作成
model = SimpleCNN()
# モデルの重みをロードし、推論モードに設定
model.load_state_dict(torch.load(weight, map_location=device))
model.to(device)
model.eval()

SimpleCNN(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (fc1): Linear(in_features=100352, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=3, bias=True)
)

In [7]:
input_directory = '../dataset/collect/'
output_path = '../dataset/label/'

cnt = 0
for img in img_files:
    img_path = input_directory + img
    image = cv2.imread(img_path)

    mask_image = apply_hsv_threshold(image)
    pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    image_tensor = transform(pil_image).unsqueeze(0).to(device)

    #モデルにmask=imageを入力し、出力から座標、半径を得る
    #ファイルの名前を、filename = '_x_y_r_uuid.jpg'として保存
    # モデルを使用して予測
    with torch.no_grad():
        outputs = model(image_tensor)
        predicted_coords = outputs.cpu().numpy()[0]
        pos_x, pos_y, pos_r = predicted_coords

    # 座標と半径をスケーリング
    pos_x = int(pos_x * 224)
    pos_y = int(pos_y * 224)
    pos_r = int(pos_r * (224 * 1.414))

    if pos_r < 0:
        pos_r = 0

    filename = f'{cnt}_{pos_x}_{pos_y}_{pos_r}_{uuid.uuid4().hex}.jpg'
    cnt += 1
    cv2.imwrite(output_path+filename,image)

    
    