Skip to content

Commit

Permalink
Fixed CV capturing system; filtered specific pets and animals
Browse files Browse the repository at this point in the history
  • Loading branch information
andreped committed Jan 12, 2024
1 parent 829a783 commit 1185361
Showing 1 changed file with 56 additions and 99 deletions.
155 changes: 56 additions & 99 deletions smp/image_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@
# get screen resolution scale, store as global variable in this scope
dimensions_scale = get_screen_scale()

supported_pets = ["ant", "beaver", "cricket", "duck", "fish", "horse", "otter", "pig", "sloth"]
supported_food = [
"apple", "banana", "canned_food", "chili", "chocolate", "cupcake", "garlic",
"melon", "pear", "pizza", "salad_bowl", "sleeping_pill", "steak", "sushi",
]

print("DIMENSIONS_SCALE:", dimensions_scale)


def get_img_from_coords(coords, to_numpy=True):
"""
Expand Down Expand Up @@ -57,7 +65,12 @@ def get_animal_from_screen():
"""
captures images of the current animals on screen (to be classified at a later stage)
"""
img = get_img_from_coords(coords=(450, 620, 1500, 750), to_numpy=False) # bbox: left, top, right, bottom
# img = get_img_from_coords(coords=(450, 620, 1500, 750), to_numpy=False) # bbox: left, top, right, bottom
img = get_img_from_coords(coords=(310, 620, 1630, 750), to_numpy=False)
# img = get_img_from_coords(coords=(0, 1920, ))

# plt.imshow(img)
# plt.show()

# template dimensions -> to be scaled if necessary
img_n_width = 130
Expand All @@ -67,8 +80,10 @@ def get_animal_from_screen():
[300, 0, 430, img_n_width],
[445, 0, 575, img_n_width],
[590, 0, 720, img_n_width],
[730, 0, 860, img_n_width],
[875, 0, 1005, img_n_width],
[1015, 0, 1145, img_n_width],
[1160, 0, 1290, img_n_width],
#[730, 0, 860, img_n_width],
#[875, 0, 1005, img_n_width],
]

images = []
Expand All @@ -93,61 +108,14 @@ def matching(image, needle_img):
"""

needle_img = cv2.resize(needle_img, image.shape[:2][::-1])

needle_img = needle_img[..., ::-1]
image = image[..., ::-1]

#image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
#needle_img = cv2.cvtColor(needle_img, cv2.COLOR_RGB2GRAY)

#needle_img = (needle_img > 220).astype("uint8")
#image = (image > 220).astype("uint8")

#hd95 = hausdorff_distance(image, needle_img, method="modified")
#hd95 = dice(image, needle_img)

#max_val = hd95

#mask = np.zeros_like(needle_img.copy())#[..., 0]
#mask[needle_img > 0] = 1
#mask = needle_img.copy()
# mask = 1 - mask

# pad image to fit needle image into
#tmp = np.zeros((int(image.shape[0] * 3), int(image.shape[1] * 3), 3), dtype="uint8")
#tmp[image.shape[0]:(2*image.shape[0]), image.shape[1]:(2*image.shape[1]), :] = image
#image = tmp.copy()

"""
# Initialize the ORB detector and detect the keypoints in the query image and scene
orb = cv2.ORB_create()
query_keypoints, query_descriptors = orb.detectAndCompute(image, None)
scene_keypoints, scene_descriptors = orb.detectAndCompute(needle_img, None)
# Match the keypoints using Brute Force Matcher
bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
matches = bf.match(query_descriptors, scene_descriptors)
# Sort the matches by distance
matches = sorted(matches, key=lambda x: x.distance)
# if these match, there should be a lot of matches,
# use that as criteria to state whether it is a match
# Calculate the distance between the matched keypoints
distance = np.mean([match.distance for match in matches])
"""

# Preprocess the images
image_pr = cv2.resize(image, (112, 112))
#image_pr = Image.resize(image_pr, (112, 112))
#image_pr = imutils.resize(image, width=112)
image_pr = tf.keras.applications.vgg16.preprocess_input(image_pr)

needle_img_pr = cv2.resize(needle_img, (112, 112))
#needle_img = Image.resize(needle_img, (112, 112))
#needle_img_pr = imutils.resize(needle_img, width=112)
needle_img_pr = tf.keras.applications.vgg16.preprocess_input(needle_img_pr)

# Extract the features from the images
Expand All @@ -161,50 +129,26 @@ def matching(image, needle_img):
similarity = cosine_similarity(query_features.reshape(1, -1), scene_features.reshape(1, -1))
distance = similarity[0][0]

# Print the similarity score
print(f"The similarity score between the two images is {similarity[0][0]}")

#plt.imshow(image_pr)
#plt.show()

print(distance)
#return 1

# needle_img[needle_img == 0] = 255
#result = cv2.matchTemplate(image, needle_img, cv2.TM_SQDIFF_NORMED) # cv2.TM_SQDIFF_NORMED) # cv2.TM_CCOEFF_NORMED) # TM_CCORR_NORMED
#min_val, max_val, _, _ = cv2.minMaxLoc(result)
# print(max_val)

#print(min_val, max_val)
# Print the similarity score
# print(f"The similarity score between the two images is {similarity[0][0]}")

if distance > 0.52: # distance < 1300:
#"""
if distance > 0.40: # distance < 1300:
"""
fig, ax = plt.subplots(1, 3)
ax[0].imshow(image)
ax[1].imshow(needle_img)
#ax[1].set_title(str(min_val) + " | " + str(max_val))
ax[1].set_title(str(distance))
#ax[2].imshow(mask)
plt.show()
#"""
"""

return 1

#if max_val > 0.7:
# return 1

if False: # (distance < 57.5) and (len(matches) > 8):
#for m in matches:
# print(m.distance)

fig, ax = plt.subplots(1, 3)
ax[0].imshow(image)
ax[1].imshow(needle_img)
#ax[1].set_title(str(min_val) + " | " + str(max_val))
ax[1].set_title(str(distance))
#ax[2].imshow(mask)
plt.show()
return 1

#if min_val < 0.75:
# return 1

return 0

Expand All @@ -222,38 +166,51 @@ def get_image_directory(directory):
yield os.path.join(directory, file).replace("\\", "/")


def find_the_animals(directory: str):
def find_the_animals(pets_directory: str, food_directory: str):
"""
overall method for detecting which animals are on screen
"""
list_of_animals = []
images, references = get_animal_from_screen()

# go through all the animals images in the directory
pet_paths = [directory + filename for filename in os.listdir(directory)]
pet_paths = [[pets_directory, filename] for filename in os.listdir(pets_directory)]
food_paths = [[food_directory, filename] for filename in os.listdir(food_directory)]

print("N IMAGES:", len(images))
for image in images:
#plt.imshow(image)
#plt.show()
for filename in os.listdir(directory):
pet_path = directory + filename
im = cv2.imread(pet_path, cv2.COLOR_BGR2RGB)[..., :3][:, ::-1, :]
# im = cv2.resize(im, (150, 150))

# matching returns which animals
for image in images[:5]:
# plt.imshow(image)
# plt.show()
# continue
for directory, filename in pet_paths:
# supported pets
if filename.split(".")[0] not in supported_pets:
continue
if filename.startswith(".DS_Store"):
continue
im = cv2.imread(directory + filename, cv2.COLOR_BGR2RGB)[..., :3][:, ::-1, :]
if matching(image, im):
list_of_animals.append("pet-" + filename.split(".")[0])
list_of_animals.append(filename.split(".")[0])
break

for image in images[5:]:
# plt.imshow(image)
# plt.show()
# continue
for directory, filename in food_paths:
if filename.split(".")[0] not in supported_food:
continue
if filename.startswith(".DS_Store"):
continue
im = cv2.imread(directory + filename, cv2.COLOR_BGR2RGB)[..., :3][:, ::-1, :]
if matching(image, im):
list_of_animals.append(filename.split(".")[0])
break

if len(list_of_animals) > 7:
return 0

#list_of_animals1 = []
#for i in list_of_animals:
# temp = i.split('/')
# list_of_animals1.append(temp[-2])
list_of_animals1 = list_of_animals.copy()

print(list_of_animals)

list_of_animals1 = tuple(list_of_animals1)
Expand Down

0 comments on commit 1185361

Please sign in to comment.