2323
2424
2525class Gemma :
26- def __init__ (self , image_path : str , output_file : str ) -> None :
26+ def __init__ (self , image_path : str , output_file : str , debug : bool ) -> None :
2727 self .client = genai .Client (api_key = GEMINI_API_KEY )
2828 self .model = "gemma-3-27b-it"
2929 self .image_path = image_path
3030 self .output_file = output_file
31+ self .debug = debug
3132
32- def get_response (self , image_path ) -> dict :
33+ def get_response (self , image_path : str , prev_resp : str = "" ) -> dict :
3334 """
3435 Makes a single response from OpenRouter.
3536
@@ -39,7 +40,7 @@ def get_response(self, image_path) -> dict:
3940 Returns:
4041 dict: The JSON dump of the model's response or error if one is generated.
4142 """
42- prompt = """
43+ base_prompt = """
4344Analyze this image and identify every visible object.
4445
4546For each object, return:
@@ -74,7 +75,19 @@ def get_response(self, image_path) -> dict:
7475 "bounding_box": [[0.2, 0.25], [0.45, 0.25], [0.45, 0.45], [0.2, 0.45]]
7576 }
7677}
78+
79+ """
80+ prev_resp = (
81+ """
82+ The following is the output from the last frame that you predicted. You must attempt to be as
83+ consistent as possible, but keep in mind that new items may have appeared and bounding boxes will
84+ have changed:
85+
7786"""
87+ + prev_resp
88+ )
89+ if prev_resp :
90+ prompt = base_prompt + prev_resp
7891 img = PIL .Image .open (image_path )
7992 response = self .client .models .generate_content (
8093 model = self .model , contents = [prompt , img ]
@@ -98,19 +111,30 @@ def run_nth_frame(self, n: int):
98111 str: The final json output
99112 """
100113
101- folder_path = Path (self .image_path [: self . image_path . rfind ( "/" )])
114+ folder_path = Path (self .image_path ). parent
102115 file_count = sum (1 for item in folder_path .iterdir () if item .is_file ())
103116
104117 results = {}
105- image_path = list (self .image_path )
118+ prev_resp = None
119+
120+ for i in range (1 , file_count , n ):
121+ image_path = folder_path / f"{ i :04} .jpg"
122+ image_path_str = str (image_path )
123+
124+ print ("Running Gemma on" , image_path_str )
125+
126+ if prev_resp is None :
127+ resp = self .get_response (image_path = image_path_str )
128+ else :
129+ resp = self .get_response (
130+ image_path = image_path_str , prev_resp = str (prev_resp )
131+ )
132+
133+ results [i ] = resp
134+ prev_resp = resp # Setting it simply to the response from this one
106135
107- i = 0
108- while i < file_count :
109- image_path [- 8 :] = f"{ i + 1 :04} .jpg"
110- image_path_str = "" .join (image_path )
111- temp = self .get_response (image_path = image_path_str )
112- results [i + 1 ] = temp # Keyed by frame number
113- i += n
136+ if self .debug :
137+ print ("DEBUG:" , resp )
114138
115139 return results
116140
@@ -126,7 +150,9 @@ def save_to_json(self, json_input: dict) -> None:
126150
127151
128152gemma = Gemma (
129- "data/env_imgs/albert_room/frame_0001.jpg" , "data/vision_json/albert_room.json"
153+ "data/env_imgs/albert_room/frame_0001.jpg" ,
154+ "data/vision_json/albert_room.json" ,
155+ False ,
130156)
131157result = gemma .run_nth_frame (10 )
132158gemma .save_to_json (result )
0 commit comments