diff --git a/configs/data/butterfly_coords_text.yaml b/configs/data/butterfly_coords_text.yaml index 9c98968..90f3d9e 100644 --- a/configs/data/butterfly_coords_text.yaml +++ b/configs/data/butterfly_coords_text.yaml @@ -14,7 +14,7 @@ dataset: caption_builder: _target_: src.data.butterfly_caption_builder.ButterflyCaptionBuilder templates_fname: v3.json - concepts_fname: v1.json + concepts_fname: v2.json data_dir: ${paths.data_dir}/s2bms seed: ${seed} diff --git a/src/models/text_alignment_model.py b/src/models/text_alignment_model.py index be3f072..4d537ae 100644 --- a/src/models/text_alignment_model.py +++ b/src/models/text_alignment_model.py @@ -89,10 +89,17 @@ def _setup(self, stage: str = "fit") -> None: self.trainable_modules.append("geo_encoder.extra_projector") # Configure contrastive retrieval evaluation - self.setup_retrieval_evaluation() + self.setup_retrieval_evaluation(verbose=0) print("------------------------") - def setup_retrieval_evaluation(self): + def setup_retrieval_evaluation(self, use_train_threshold=True, verbose=0): + """Set up the contrastive retrieval evaluation by computing dynamic k baselines and + initializing the validation object. + + Parameters: + - use_train_threshold: whether to use the train set to compute the dynamic k baselines (if False, will calculate one for each set) + - verbose: whether to print the dynamic ks and their baselines + """ self.concept_configs = self.trainer.datamodule.concept_configs self.concepts = [c["concept_caption"] for c in self.concept_configs] self.concept_names = [ @@ -100,7 +107,11 @@ def setup_retrieval_evaluation(self): for c in self.concept_configs ] - dataset_names = ["train", "val", "test"] + dataset_names = [ + "train", + "val", + "test", + ] # ensure 'train' is first for use_train_threshold logic! self.dynamic_k_baselines = {} for dataset_name in dataset_names: if not hasattr(self.trainer.datamodule, f"data_{dataset_name}"): @@ -124,14 +135,34 @@ def setup_retrieval_evaluation(self): c_name = self.concept_names[i_c] aux_vals_current_ds = aux_vals_per_concept[i_c] - theta_k = self.find_elbow_point(aux_vals_current_ds) - self.concept_configs[i_c][ - "theta_k" - ] = theta_k # assign new theta_k to concept_configs for later use in validation + if use_train_threshold and dataset_name != "train": + theta_k = self.concept_configs[i_c]["theta_k"] + else: + theta_k = self.find_elbow_point(aux_vals_current_ds) + self.concept_configs[i_c][ + "theta_k" + ] = theta_k # assign new theta_k to concept_configs for later use in validation + if c["is_max"]: n_baseline = sum(aux_val >= theta_k for aux_val in aux_vals_current_ds) else: n_baseline = sum(aux_val <= theta_k for aux_val in aux_vals_current_ds) + + if n_baseline == n_ds: + n_baseline = ( + n_ds - 1 + ) # to avoid having a baseline of 100% (will still yield index score of 1) + if dataset_name == "train": + theta_k = ( + min(aux_vals_current_ds) + 1e-6 + if c["is_max"] + else max(aux_vals_current_ds) - 1e-6 + ) + + if verbose: + print( + f"Concept '{self.concept_names[i_c]}' in {dataset_name} set: is_max={c['is_max']}, original theta_k={self.concept_configs[i_c]['theta_k']:.6f}, new theta_k={theta_k:.6f}, baseline={n_baseline}/{n_ds} ({n_baseline / n_ds * 100:.1f}%)" + ) self.dynamic_k_baselines[dataset_name][c_name] = n_baseline / n_ds * 100 self.contrastive_val = RetrievalContrastiveValidation(self.ks, self.concept_configs) @@ -213,7 +244,7 @@ def _step(self, batch: Dict[str, torch.Tensor], mode: str = "train"): return loss - def _on_epoch_end(self, mode: str): + def _on_epoch_end(self, mode: str, verbose=0): # Combine batches geo_feats = torch.cat([x["geo_feats"] for x in self.outputs_epoch_memory], dim=0) @@ -228,7 +259,8 @@ def _on_epoch_end(self, mode: str): avr_scores = {f"{mode}_avr_top-{k}": [] for k in self.ks} for i, result in concept_scores.items(): - print(f'\nConcept "{self.concepts[i]}" average top-k accuracies in {mode} split:') + if verbose: + print(f'\nConcept "{self.concepts[i]}" average top-k accuracies in {mode} split:') for k, v in result.items(): if k == "dynamic_k": self.log(f"dyn_k_{self.concept_names[i]}", v, **self.log_kwargs) @@ -236,7 +268,8 @@ def _on_epoch_end(self, mode: str): 100 - self.dynamic_k_baselines[mode][self.concept_names[i]] ) self.log(f"dyn_k_index_{self.concept_names[i]}", indexed_v, **self.log_kwargs) - print(f"Top-{k}: {v:.1f}%") + if verbose: + print(f"Top-{k}: {v:.1f}%") avr_scores[f"{mode}_avr_top-{k}"].append(v) for k, v in avr_scores.items(): @@ -279,20 +312,31 @@ def concept_similarities(self, geo_embeds, concept=None) -> torch.Tensor: @staticmethod def find_elbow_point(vals): - vals = np.sort(vals) - x = np.arange(len(vals)) / len(vals) - y = vals - slope = (y[-1] - y[0]) / (x[-1] - x[0]) # diagonal from first to last point - intercept = y[0] - slope * x[0] - orthogonal_slope = -1 / slope - - intercepts_orthogonal = y - orthogonal_slope * x - intersection_diagonal_orthogonal = (intercepts_orthogonal - intercept) / ( - slope - orthogonal_slope - ) - distances = np.sqrt( - (x - intersection_diagonal_orthogonal) ** 2 + (y - (slope * x + intercept)) ** 2 - ) # distance to diagonal - elbow_index = np.argmax(distances) - elbow_point = y[elbow_index] - return elbow_point + """Vals is a list of tensor values.""" + with torch.no_grad(): + vals = torch.tensor(vals).cpu().numpy() + vals = vals[~np.isnan(vals)] # remove NaN values + + vals = np.sort(vals) + vals = vals[vals > vals[0]] + x = np.arange(len(vals)) / len(vals) + y = vals + if x[0] == x[-1]: # all values are the same + print( + "All values are the same, returning the value itself as elbow point.", vals[0] + ) + return vals[0] + slope = (y[-1] - y[0]) / (x[-1] - x[0]) # diagonal from first to last point + intercept = y[0] - slope * x[0] + orthogonal_slope = -1 / slope + + intercepts_orthogonal = y - orthogonal_slope * x + intersection_diagonal_orthogonal = (intercepts_orthogonal - intercept) / ( + slope - orthogonal_slope + ) + distances = np.sqrt( + (x - intersection_diagonal_orthogonal) ** 2 + (y - (slope * x + intercept)) ** 2 + ) # distance to diagonal + elbow_index = np.argmax(distances) + elbow_point = y[elbow_index] + return elbow_point