33from typing import TypeVar
44
55import torch
6+ import time
67
78from delphi .latents import (
89 ActivatingExample ,
@@ -35,7 +36,7 @@ def _to_string_examples(
3536 for i , example in enumerate (examples ):
3637 str_toks = example .str_tokens
3738 activations = example .activations .tolist ()
38- highlighted_examples .append (self ._highlight (str_toks , activations ))
39+ highlighted_examples .append ("Example " + str ( i ) + ": \n " + self ._highlight (str_toks , activations ))
3940
4041 if show_activations :
4142 assert (
@@ -106,7 +107,7 @@ class HillClimbing:
106107 explainer : IterativeExplainer
107108 """Explainer to use for explanation generation."""
108109
109- n_loops : int = 5
110+ n_loops : int = 1
110111 """Number of loops to run the explanation generation."""
111112
112113 def _compute_score (self , results : list [list [ClassifierOutput ]]) -> float :
@@ -151,7 +152,9 @@ def _get_wrong_examples(
151152 activations = torch .tensor (sample .activations ),
152153 str_tokens = sample .str_tokens ,
153154 )
154- wrong_examples .append (new_example )
155+ # Check if there's no existing example with same str_tokens
156+ if not any (ex .str_tokens == new_example .str_tokens for ex in wrong_examples ):
157+ wrong_examples .append (new_example )
155158 return wrong_examples
156159
157160 async def __call__ (self , record : LatentRecord ) -> list [ScorerResult ] | None :
@@ -165,14 +168,17 @@ async def __call__(self, record: LatentRecord) -> list[ScorerResult] | None:
165168 random .shuffle (non_activating_examples )
166169
167170 first_generation_examples = train_examples
168- held_out_set_size = len (record .test ) // 3
171+ held_out_set_size = max ( 50 , len (record .test ) // 3 )
169172 held_out_activating_examples = activating_examples [:held_out_set_size ]
170173 held_out_non_activating_examples = non_activating_examples [:held_out_set_size ]
171174 train_test_activating_examples = activating_examples [held_out_set_size :]
172175 train_test_non_activating_examples = non_activating_examples [held_out_set_size :]
173176
177+ start_time = time .time ()
174178 first_explanation = await self .explainer (record )
175179 record .explanation = first_explanation .explanation
180+ end_time = time .time ()
181+ print (f"Time taken for first explanation: { end_time - start_time } seconds" )
176182
177183 print ("----- First explanation ------" )
178184 print (first_explanation .explanation )
@@ -188,56 +194,74 @@ async def __call__(self, record: LatentRecord) -> list[ScorerResult] | None:
188194 explanation = first_explanation .explanation ,
189195 )
190196 results = []
197+ scores = []
198+ start_time = time .time ()
191199 for scorer in self .scorers :
192200 result = await scorer (test_record )
193- results .append (result .score )
194- print ("----- Holdout score ------" )
195- holdout_score = self ._compute_score (results )
196-
197- scores = [results ]
201+ results .append (result )
202+ scores .append (result .score )
203+ end_time = time .time ()
204+ print (f"Time taken for holdout score: { end_time - start_time } seconds" )
205+ #print("----- Holdout score ------")
206+ holdout_score = self ._compute_score (scores )
207+
208+
198209 for i in range (self .n_loops ):
199- print (f"----- Loop { i } ------" )
210+ # print(f"----- Loop {i} ------")
200211
201212 random .shuffle (train_test_non_activating_examples )
202213 random .shuffle (train_test_activating_examples )
203214
204215 new_record = LatentRecord (
205216 latent = record .latent ,
206217 train = first_generation_examples ,
207- not_active = train_test_non_activating_examples [:held_out_set_size ],
208- test = train_test_activating_examples [:held_out_set_size ],
218+ not_active = train_test_non_activating_examples [:15 ],
219+ test = train_test_activating_examples [:15 ],
209220 explanation = first_explanation .explanation ,
210221 )
211- results = []
222+ scores = []
223+ start_time = time .time ()
212224 for scorer in self .scorers :
213225 result = await scorer (new_record )
214- results .append (result .score )
215- print ("----- Train score ------" )
216- _ = self ._compute_score (results )
226+ scores .append (result .score )
227+ end_time = time .time ()
228+ print (f"Time taken for train score: { end_time - start_time } seconds" )
229+ #print("----- Train score ------")
230+ _ = self ._compute_score (scores )
217231 # get the wrong examples
218- wrong_examples = self ._get_wrong_examples (results )
232+ wrong_examples = self ._get_wrong_examples (scores )
219233 # update the record
220234 record .extra_examples .extend (wrong_examples )
221235 # update the explanation
236+ start_time = time .time ()
222237 new_explanation = await self .explainer (record )
238+ end_time = time .time ()
239+ print (f"Time taken for new explanation: { end_time - start_time } seconds" )
223240 if new_explanation .explanation == "Explanation could not be parsed." :
224241 print ("Error generating explanation" )
225242 pass # we do not update the explanation
226243
227- print ("----- New explanation ------" )
228- print (new_explanation .explanation )
244+ # print("----- New explanation ------")
245+ # print(new_explanation.explanation)
229246 # compute the score in the held out set
230247 test_record .explanation = new_explanation .explanation
231- results = []
248+ scores = []
249+ start_time = time .time ()
232250 for scorer in self .scorers :
233251 result = await scorer (test_record )
234- results .append (result .score )
235- print ("----- Holdout score ------" )
236- new_holdout_score = self ._compute_score (results )
237- scores .append (results )
238- if new_holdout_score > holdout_score :
239- holdout_score = new_holdout_score
240- record .explanation = new_explanation .explanation
241- first_explanation = new_explanation
242-
243- return scores
252+ scores .append (result .score )
253+ results .append (result )
254+ end_time = time .time ()
255+ print (f"Time taken for holdout score: { end_time - start_time } seconds" )
256+ #print("----- Holdout score ------")
257+ final_score = self ._compute_score (scores )
258+ record .explanation = new_explanation .explanation
259+ first_explanation = new_explanation
260+ #if new_holdout_score > holdout_score:
261+ # holdout_score = new_holdout_score
262+ # record.explanation = new_explanation.explanation
263+ # first_explanation = new_explanation
264+ print ("Initial score: " , holdout_score )
265+ print ("Last explanation: " , record .explanation )
266+ print ("Final score: " , final_score )
267+ return results
0 commit comments