Skip to content

Commit 3eaec1f

Browse files
author
SrGonao
committed
Debug prints
1 parent a0b6452 commit 3eaec1f

File tree

2 files changed

+60
-32
lines changed

2 files changed

+60
-32
lines changed

delphi/explainers/iterative/iterative.py

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import TypeVar
44

55
import torch
6+
import time
67

78
from delphi.latents import (
89
ActivatingExample,
@@ -35,7 +36,7 @@ def _to_string_examples(
3536
for i, example in enumerate(examples):
3637
str_toks = example.str_tokens
3738
activations = example.activations.tolist()
38-
highlighted_examples.append(self._highlight(str_toks, activations))
39+
highlighted_examples.append("Example " + str(i) + ": \n" + self._highlight(str_toks, activations))
3940

4041
if show_activations:
4142
assert (
@@ -106,7 +107,7 @@ class HillClimbing:
106107
explainer: IterativeExplainer
107108
"""Explainer to use for explanation generation."""
108109

109-
n_loops: int = 5
110+
n_loops: int = 1
110111
"""Number of loops to run the explanation generation."""
111112

112113
def _compute_score(self, results: list[list[ClassifierOutput]]) -> float:
@@ -151,7 +152,9 @@ def _get_wrong_examples(
151152
activations=torch.tensor(sample.activations),
152153
str_tokens=sample.str_tokens,
153154
)
154-
wrong_examples.append(new_example)
155+
# Check if there's no existing example with same str_tokens
156+
if not any(ex.str_tokens == new_example.str_tokens for ex in wrong_examples):
157+
wrong_examples.append(new_example)
155158
return wrong_examples
156159

157160
async def __call__(self, record: LatentRecord) -> list[ScorerResult] | None:
@@ -165,14 +168,17 @@ async def __call__(self, record: LatentRecord) -> list[ScorerResult] | None:
165168
random.shuffle(non_activating_examples)
166169

167170
first_generation_examples = train_examples
168-
held_out_set_size = len(record.test) // 3
171+
held_out_set_size = max(50, len(record.test) // 3)
169172
held_out_activating_examples = activating_examples[:held_out_set_size]
170173
held_out_non_activating_examples = non_activating_examples[:held_out_set_size]
171174
train_test_activating_examples = activating_examples[held_out_set_size:]
172175
train_test_non_activating_examples = non_activating_examples[held_out_set_size:]
173176

177+
start_time = time.time()
174178
first_explanation = await self.explainer(record)
175179
record.explanation = first_explanation.explanation
180+
end_time = time.time()
181+
print(f"Time taken for first explanation: {end_time - start_time} seconds")
176182

177183
print("----- First explanation ------")
178184
print(first_explanation.explanation)
@@ -188,56 +194,74 @@ async def __call__(self, record: LatentRecord) -> list[ScorerResult] | None:
188194
explanation=first_explanation.explanation,
189195
)
190196
results = []
197+
scores = []
198+
start_time = time.time()
191199
for scorer in self.scorers:
192200
result = await scorer(test_record)
193-
results.append(result.score)
194-
print("----- Holdout score ------")
195-
holdout_score = self._compute_score(results)
196-
197-
scores = [results]
201+
results.append(result)
202+
scores.append(result.score)
203+
end_time = time.time()
204+
print(f"Time taken for holdout score: {end_time - start_time} seconds")
205+
#print("----- Holdout score ------")
206+
holdout_score = self._compute_score(scores)
207+
208+
198209
for i in range(self.n_loops):
199-
print(f"----- Loop {i} ------")
210+
#print(f"----- Loop {i} ------")
200211

201212
random.shuffle(train_test_non_activating_examples)
202213
random.shuffle(train_test_activating_examples)
203214

204215
new_record = LatentRecord(
205216
latent=record.latent,
206217
train=first_generation_examples,
207-
not_active=train_test_non_activating_examples[:held_out_set_size],
208-
test=train_test_activating_examples[:held_out_set_size],
218+
not_active=train_test_non_activating_examples[:15],
219+
test=train_test_activating_examples[:15],
209220
explanation=first_explanation.explanation,
210221
)
211-
results = []
222+
scores = []
223+
start_time = time.time()
212224
for scorer in self.scorers:
213225
result = await scorer(new_record)
214-
results.append(result.score)
215-
print("----- Train score ------")
216-
_ = self._compute_score(results)
226+
scores.append(result.score)
227+
end_time = time.time()
228+
print(f"Time taken for train score: {end_time - start_time} seconds")
229+
#print("----- Train score ------")
230+
_ = self._compute_score(scores)
217231
# get the wrong examples
218-
wrong_examples = self._get_wrong_examples(results)
232+
wrong_examples = self._get_wrong_examples(scores)
219233
# update the record
220234
record.extra_examples.extend(wrong_examples)
221235
# update the explanation
236+
start_time = time.time()
222237
new_explanation = await self.explainer(record)
238+
end_time = time.time()
239+
print(f"Time taken for new explanation: {end_time - start_time} seconds")
223240
if new_explanation.explanation == "Explanation could not be parsed.":
224241
print("Error generating explanation")
225242
pass # we do not update the explanation
226243

227-
print("----- New explanation ------")
228-
print(new_explanation.explanation)
244+
#print("----- New explanation ------")
245+
#print(new_explanation.explanation)
229246
# compute the score in the held out set
230247
test_record.explanation = new_explanation.explanation
231-
results = []
248+
scores = []
249+
start_time = time.time()
232250
for scorer in self.scorers:
233251
result = await scorer(test_record)
234-
results.append(result.score)
235-
print("----- Holdout score ------")
236-
new_holdout_score = self._compute_score(results)
237-
scores.append(results)
238-
if new_holdout_score > holdout_score:
239-
holdout_score = new_holdout_score
240-
record.explanation = new_explanation.explanation
241-
first_explanation = new_explanation
242-
243-
return scores
252+
scores.append(result.score)
253+
results.append(result)
254+
end_time = time.time()
255+
print(f"Time taken for holdout score: {end_time - start_time} seconds")
256+
#print("----- Holdout score ------")
257+
final_score = self._compute_score(scores)
258+
record.explanation = new_explanation.explanation
259+
first_explanation = new_explanation
260+
#if new_holdout_score > holdout_score:
261+
# holdout_score = new_holdout_score
262+
# record.explanation = new_explanation.explanation
263+
# first_explanation = new_explanation
264+
print("Initial score: ", holdout_score)
265+
print("Last explanation: ", record.explanation)
266+
print("Final score: ", final_score)
267+
return results

delphi/explainers/iterative/prompt_builder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,14 @@ def build_prompt(
1010
messages = [{"role": "system", "content": SYSTEM}]
1111

1212
user_start = f"Current explanation: {explanation}\n\n"
13+
print("Current explanation: ", explanation)
1314
user_start += f"Normal examples:\n{normal_examples}\n\n"
14-
user_start += f"False positives:\n{false_positives}\n\n"
15+
#print("Normal examples: ", normal_examples)
1516
user_start += f"False negatives:\n{false_negatives}\n"
16-
17+
#print("False negatives: ", false_negatives)
18+
user_start += f"False positives:\n{false_positives}\n\n"
19+
#print("False positives: ", false_positives)
20+
1721
messages.append(
1822
{
1923
"role": "user",

0 commit comments

Comments
 (0)