Skip to content

Commit bb8d990

Browse files
committed
FEAT: Give Gemma last-frame context so that it can be more consistent with its predictions across frames
1 parent 8fe533d commit bb8d990

File tree

1 file changed

+39
-13
lines changed

1 file changed

+39
-13
lines changed

src/video_processing/material_tagging/gemma_loader.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@
2323

2424

2525
class Gemma:
26-
def __init__(self, image_path: str, output_file: str) -> None:
26+
def __init__(self, image_path: str, output_file: str, debug: bool) -> None:
2727
self.client = genai.Client(api_key=GEMINI_API_KEY)
2828
self.model = "gemma-3-27b-it"
2929
self.image_path = image_path
3030
self.output_file = output_file
31+
self.debug = debug
3132

32-
def get_response(self, image_path) -> dict:
33+
def get_response(self, image_path: str, prev_resp: str = "") -> dict:
3334
"""
3435
Makes a single response from OpenRouter.
3536
@@ -39,7 +40,7 @@ def get_response(self, image_path) -> dict:
3940
Returns:
4041
dict: The JSON dump of the model's response or error if one is generated.
4142
"""
42-
prompt = """
43+
base_prompt = """
4344
Analyze this image and identify every visible object.
4445
4546
For each object, return:
@@ -74,7 +75,19 @@ def get_response(self, image_path) -> dict:
7475
"bounding_box": [[0.2, 0.25], [0.45, 0.25], [0.45, 0.45], [0.2, 0.45]]
7576
}
7677
}
78+
79+
"""
80+
prev_resp = (
81+
"""
82+
The following is the output from the last frame that you predicted. You must attempt to be as
83+
consistent as possible, but keep in mind that new items may have appeared and bounding boxes will
84+
have changed:
85+
7786
"""
87+
+ prev_resp
88+
)
89+
if prev_resp:
90+
prompt = base_prompt + prev_resp
7891
img = PIL.Image.open(image_path)
7992
response = self.client.models.generate_content(
8093
model=self.model, contents=[prompt, img]
@@ -98,19 +111,30 @@ def run_nth_frame(self, n: int):
98111
str: The final json output
99112
"""
100113

101-
folder_path = Path(self.image_path[: self.image_path.rfind("/")])
114+
folder_path = Path(self.image_path).parent
102115
file_count = sum(1 for item in folder_path.iterdir() if item.is_file())
103116

104117
results = {}
105-
image_path = list(self.image_path)
118+
prev_resp = None
119+
120+
for i in range(1, file_count, n):
121+
image_path = folder_path / f"{i:04}.jpg"
122+
image_path_str = str(image_path)
123+
124+
print("Running Gemma on", image_path_str)
125+
126+
if prev_resp is None:
127+
resp = self.get_response(image_path=image_path_str)
128+
else:
129+
resp = self.get_response(
130+
image_path=image_path_str, prev_resp=str(prev_resp)
131+
)
132+
133+
results[i] = resp
134+
prev_resp = resp # Setting it simply to the response from this one
106135

107-
i = 0
108-
while i < file_count:
109-
image_path[-8:] = f"{i + 1:04}.jpg"
110-
image_path_str = "".join(image_path)
111-
temp = self.get_response(image_path=image_path_str)
112-
results[i + 1] = temp # Keyed by frame number
113-
i += n
136+
if self.debug:
137+
print("DEBUG:", resp)
114138

115139
return results
116140

@@ -126,7 +150,9 @@ def save_to_json(self, json_input: dict) -> None:
126150

127151

128152
gemma = Gemma(
129-
"data/env_imgs/albert_room/frame_0001.jpg", "data/vision_json/albert_room.json"
153+
"data/env_imgs/albert_room/frame_0001.jpg",
154+
"data/vision_json/albert_room.json",
155+
False,
130156
)
131157
result = gemma.run_nth_frame(10)
132158
gemma.save_to_json(result)

0 commit comments

Comments
 (0)