Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 86 additions & 4 deletions src/tree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,25 @@ def get_state(self)->TreeState:
interactive_elements=self.get_interactive_elements()
return TreeState(interactive_elements=interactive_elements)


def get_interactive_elements(self)->list:
interactive_elements=[]
element_tree = self.get_element_tree()
nodes=element_tree.findall('.//node[@visible-to-user="true"][@enabled="true"]')

try:
element_tree = self.get_element_tree()
nodes=element_tree.findall('.//node[@visible-to-user="true"][@enabled="true"]')

except Exception as e:
print(f"Error getting element tree: {e}")
return interactive_elements

for node in nodes:
attributes=node.attrib
if attributes.get('text') or attributes.get('content-desc') or attributes.get('class') in INTERACTIVE_CLASSES:
try:
attributes=node.attrib

if not self.is_interactive_element(attributes) or not attributes.get('bounds'):
continue

x1,y1,x2,y2 = extract_cordinates(attributes.get('bounds'))
name=attributes.get('text') or attributes.get('content-desc')
x_center,y_center = get_center_cordinates((x1,y1,x2,y2))
Expand All @@ -37,8 +49,78 @@ def get_interactive_elements(self)->list:
'coordinates':CenterCord(x=x_center,y=y_center),
'bounding_box':BoundingBox(x1=x1,y1=y1,x2=x2,y2=y2)
}))

except (ValueError, TypeError, AttributeError) as e:
print(f"Error processing element: {e}, skipping element")
continue

except Exception as e:
print(f"Unexpected error processing element: {e}, skipping element")
continue
return interactive_elements


def is_interactive_element(self, attributes: dict) -> bool:

element_class = attributes.get('class', '')

if element_class in INTERACTIVE_CLASSES:
return True

elif attributes.get('clickable') == 'true':
return True

elif attributes.get('focusable') == 'true':
element_class = attributes.get('class', '')
if any(input_class in element_class for input_class in [
'EditText', 'AutoCompleteTextView', 'MultiAutoCompleteTextView'
]):
return True

elif attributes.get('scrollable') == 'true':
bounds = attributes.get('bounds')

if bounds:
coords = extract_cordinates(bounds)
if coords:
x1, y1, x2, y2 = coords
width, height = x2 - x1, y2 - y1
if width > 100 and height > 100:
return True

elif element_class == "android.widget.TextView":
text = attributes.get('text', '').strip()

button_keywords = [
'login', 'submit', 'send', 'save', 'delete', 'edit', 'cancel', 'ok', 'yes', 'no',
'continue', 'next', 'previous', 'back', 'home', 'menu', 'settings', 'help',
'sign in', 'sign up', 'log in', 'log out', 'register', 'create', 'update',
'confirm', 'proceed', 'finish', 'done', 'apply', 'reset', 'clear'
]

if text and any(keyword in text.lower() for keyword in button_keywords):
return True

if element_class == "android.widget.ImageView":
content_desc = attributes.get('content-desc', '').strip()
if content_desc:
interactive_descriptions = [
'button', 'menu', 'icon', 'avatar', 'profile', 'settings', 'back', 'close',
'search', 'filter', 'sort', 'refresh', 'reload', 'share', 'favorite',
'bookmark', 'like', 'download', 'upload', 'play', 'pause', 'stop'
]
if any(desc in content_desc.lower() for desc in interactive_descriptions):
return True

if element_class in [
"android.widget.ListView",
"androidx.recyclerview.widget.RecyclerView",
]:
return True

return False


def annotated_screenshot(self, nodes: list[ElementNode],scale:float=0.7) -> Image.Image:
screenshot = self.mobile.get_screenshot(scale=scale)
# Add padding
Expand Down
3 changes: 2 additions & 1 deletion src/tree/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
"android.widget.RadioButton",
"android.widget.Spinner",
"android.widget.SeekBar"
]
]