Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/data/butterfly_coords_text.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dataset:
caption_builder:
_target_: src.data.butterfly_caption_builder.ButterflyCaptionBuilder
templates_fname: v3.json
concepts_fname: v1.json
concepts_fname: v2.json
data_dir: ${paths.data_dir}/s2bms
seed: ${seed}

Expand Down
98 changes: 71 additions & 27 deletions src/models/text_alignment_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,29 @@ def _setup(self, stage: str = "fit") -> None:
self.trainable_modules.append("geo_encoder.extra_projector")

# Configure contrastive retrieval evaluation
self.setup_retrieval_evaluation()
self.setup_retrieval_evaluation(verbose=0)
print("------------------------")

def setup_retrieval_evaluation(self):
def setup_retrieval_evaluation(self, use_train_threshold=True, verbose=0):
"""Set up the contrastive retrieval evaluation by computing dynamic k baselines and
initializing the validation object.

Parameters:
- use_train_threshold: whether to use the train set to compute the dynamic k baselines (if False, will calculate one for each set)
- verbose: whether to print the dynamic ks and their baselines
"""
self.concept_configs = self.trainer.datamodule.concept_configs
self.concepts = [c["concept_caption"] for c in self.concept_configs]
self.concept_names = [
f"{c['col'].replace('aux_', '')}_{'max' if c['is_max'] else 'min'}"
for c in self.concept_configs
]

dataset_names = ["train", "val", "test"]
dataset_names = [
"train",
"val",
"test",
] # ensure 'train' is first for use_train_threshold logic!
self.dynamic_k_baselines = {}
for dataset_name in dataset_names:
if not hasattr(self.trainer.datamodule, f"data_{dataset_name}"):
Expand All @@ -124,14 +135,34 @@ def setup_retrieval_evaluation(self):
c_name = self.concept_names[i_c]
aux_vals_current_ds = aux_vals_per_concept[i_c]

theta_k = self.find_elbow_point(aux_vals_current_ds)
self.concept_configs[i_c][
"theta_k"
] = theta_k # assign new theta_k to concept_configs for later use in validation
if use_train_threshold and dataset_name != "train":
theta_k = self.concept_configs[i_c]["theta_k"]
else:
theta_k = self.find_elbow_point(aux_vals_current_ds)
self.concept_configs[i_c][
"theta_k"
] = theta_k # assign new theta_k to concept_configs for later use in validation
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we save this into the file somewhere?
Otherwise, maybe we should at least save to the checkpoint?


if c["is_max"]:
n_baseline = sum(aux_val >= theta_k for aux_val in aux_vals_current_ds)
else:
n_baseline = sum(aux_val <= theta_k for aux_val in aux_vals_current_ds)

if n_baseline == n_ds:
n_baseline = (
n_ds - 1
) # to avoid having a baseline of 100% (will still yield index score of 1)
if dataset_name == "train":
theta_k = (
min(aux_vals_current_ds) + 1e-6
if c["is_max"]
else max(aux_vals_current_ds) - 1e-6
)

if verbose:
print(
f"Concept '{self.concept_names[i_c]}' in {dataset_name} set: is_max={c['is_max']}, original theta_k={self.concept_configs[i_c]['theta_k']:.6f}, new theta_k={theta_k:.6f}, baseline={n_baseline}/{n_ds} ({n_baseline / n_ds * 100:.1f}%)"
)
self.dynamic_k_baselines[dataset_name][c_name] = n_baseline / n_ds * 100

self.contrastive_val = RetrievalContrastiveValidation(self.ks, self.concept_configs)
Expand Down Expand Up @@ -213,7 +244,7 @@ def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train"):

return loss

def _on_epoch_end(self, mode: str):
def _on_epoch_end(self, mode: str, verbose=0):

# Combine batches
geo_feats = torch.cat([x["geo_feats"] for x in self.outputs_epoch_memory], dim=0)
Expand All @@ -228,15 +259,17 @@ def _on_epoch_end(self, mode: str):

avr_scores = {f"{mode}_avr_top-{k}": [] for k in self.ks}
for i, result in concept_scores.items():
print(f'\nConcept "{self.concepts[i]}" average top-k accuracies in {mode} split:')
if verbose:
print(f'\nConcept "{self.concepts[i]}" average top-k accuracies in {mode} split:')
for k, v in result.items():
if k == "dynamic_k":
self.log(f"dyn_k_{self.concept_names[i]}", v, **self.log_kwargs)
indexed_v = (v - self.dynamic_k_baselines[mode][self.concept_names[i]]) / (
100 - self.dynamic_k_baselines[mode][self.concept_names[i]]
)
self.log(f"dyn_k_index_{self.concept_names[i]}", indexed_v, **self.log_kwargs)
print(f"Top-{k}: {v:.1f}%")
if verbose:
print(f"Top-{k}: {v:.1f}%")
avr_scores[f"{mode}_avr_top-{k}"].append(v)

for k, v in avr_scores.items():
Expand Down Expand Up @@ -279,20 +312,31 @@ def concept_similarities(self, geo_embeds, concept=None) -> torch.Tensor:

@staticmethod
def find_elbow_point(vals):
vals = np.sort(vals)
x = np.arange(len(vals)) / len(vals)
y = vals
slope = (y[-1] - y[0]) / (x[-1] - x[0]) # diagonal from first to last point
intercept = y[0] - slope * x[0]
orthogonal_slope = -1 / slope

intercepts_orthogonal = y - orthogonal_slope * x
intersection_diagonal_orthogonal = (intercepts_orthogonal - intercept) / (
slope - orthogonal_slope
)
distances = np.sqrt(
(x - intersection_diagonal_orthogonal) ** 2 + (y - (slope * x + intercept)) ** 2
) # distance to diagonal
elbow_index = np.argmax(distances)
elbow_point = y[elbow_index]
return elbow_point
"""Vals is a list of tensor values."""
with torch.no_grad():
vals = torch.tensor(vals).cpu().numpy()
vals = vals[~np.isnan(vals)] # remove NaN values

vals = np.sort(vals)
vals = vals[vals > vals[0]]
x = np.arange(len(vals)) / len(vals)
y = vals
if x[0] == x[-1]: # all values are the same
print(
"All values are the same, returning the value itself as elbow point.", vals[0]
)
return vals[0]
slope = (y[-1] - y[0]) / (x[-1] - x[0]) # diagonal from first to last point
intercept = y[0] - slope * x[0]
orthogonal_slope = -1 / slope

intercepts_orthogonal = y - orthogonal_slope * x
intersection_diagonal_orthogonal = (intercepts_orthogonal - intercept) / (
slope - orthogonal_slope
)
distances = np.sqrt(
(x - intersection_diagonal_orthogonal) ** 2 + (y - (slope * x + intercept)) ** 2
) # distance to diagonal
elbow_index = np.argmax(distances)
elbow_point = y[elbow_index]
return elbow_point
Loading