# base

> DiffPASS base classes and mixins

In [None]:
#| default_exp base

In [None]:
#| export

# Stdlib imports
from collections.abc import Iterable, Sequence
from typing import Optional, Union, Any

# NumPy
# import numpy as np

# PyTorch
import torch

# PLOTTING
# from matplotlib import colormaps as cm
# import matplotlib.pyplot as plt
# from matplotlib.colors import CenteredNorm

In [None]:
#| export

class DiffPASSMixin:
    allowed_permutation_cfg_keys = {
        "tau",
        "n_iter",
        "noise",
        "noise_factor",
        "noise_std",
    }
    allowed_information_measures = {"MI", "TwoBodyEntropy"}
    allowed_similarity_kinds = {"Hamming", "Blosum62"}
    allowed_similarities_cfg_keys = {
        "Hamming": {"use_dot", "p"},
        "Blosum62": {"use_scoredist", "aa_to_int", "gaps_as_stars"},
    }
    allowed_reciprocal_best_hits_cfg_keys = {"tau"}

    group_sizes: Iterable[int]
    information_measure: str
    similarity_kind: str

    @staticmethod
    def reduce_num_tokens(x: torch.Tensor) -> torch.Tensor:
        """Reduce the number of tokens in a one-hot encoded tensor."""
        used_tokens = x.clone()
        for _ in range(x.ndim - 1):
            used_tokens = used_tokens.any(-2)

        return x[..., used_tokens]

    def validate_permutation_cfg(self, permutation_cfg: dict) -> None:
        if not set(permutation_cfg).issubset(self.allowed_permutation_cfg_keys):
            raise ValueError(
                f"Invalid keys in `permutation_cfg`: {set(permutation_cfg) - self.allowed_permutation_cfg_keys}"
            )

    def validate_information_measure(self, information_measure: str) -> None:
        if information_measure not in self.allowed_information_measures:
            raise ValueError(
                f"Invalid information measure: {self.information_measure}. "
                f"Allowed values are: {self.allowed_information_measures}"
            )

    def validate_similarity_kind(self, similarity_kind: str) -> None:
        if similarity_kind not in self.allowed_similarity_kinds:
            raise ValueError(
                f"Invalid similarity kind: {self.similarity_kind}. "
                f"Allowed values are: {self.allowed_similarity_kinds}"
            )

    def validate_similarities_cfg(self, similarities_cfg: dict) -> None:
        if not set(similarities_cfg).issubset(
            self.allowed_similarities_cfg_keys[self.similarity_kind]
        ):
            raise ValueError(
                f"Invalid keys in `similarities_cfg`: {set(similarities_cfg) - self.allowed_similarities_cfg_keys[self.similarity_kind]}"
            )

    def validate_reciprocal_best_hits_cfg(self, reciprocal_best_hits_cfg: dict) -> None:
        if not set(reciprocal_best_hits_cfg).issubset(
            self.allowed_reciprocal_best_hits_cfg_keys
        ):
            raise ValueError(
                f"Invalid keys in `reciprocal_best_hits_cfg`: {set(reciprocal_best_hits_cfg) - self.allowed_reciprocal_best_hits_cfg_keys}"
            )

    def validate_inputs(
        self, x: torch.Tensor, y: torch.Tensor, check_same_alphabet_size: bool = False
    ) -> None:
        """Validate input tensors representing aligned objects."""
        size_x, length_x, alphabet_size_x = x.shape
        size_y, length_y, alphabet_size_y = y.shape
        if size_x != size_x:
            raise ValueError(f"Size mismatch between x ({size_x}) and y ({size_y}).")
        if check_same_alphabet_size and (alphabet_size_x != alphabet_size_y):
            raise ValueError("Inputs must have the same alphabet size.")

        # Validate size attribute
        total_size = sum(self.group_sizes)
        if size_x != total_size:
            raise ValueError(
                f"Inputs have size {total_size} but `group_sizes` implies a total "
                f"size of {total_size}."
            )

In [None]:
#| export

def scalar_or_1d_tensor(
    *, param: Any, param_name: str, dtype: torch.dtype = torch.float32
) -> torch.Tensor:
    if not isinstance(param, (float, torch.Tensor)):
        raise TypeError(f"`{param_name}` must be a float or a torch.Tensor.")
    if isinstance(param, float):
        param = torch.tensor(param, dtype=dtype)
    elif param.ndim > 1:
        raise ValueError(
            f"`{param_name}` must be a scalar or a tensor of dimension <= 1."
        )

    return param


class EnsembleMixin:
    def _validate_ensemble_param(
        self,
        *,
        param: Union[float, torch.Tensor],
        param_name: str,
        ensemble_shape: Sequence[int],
        dim_in_ensemble: Optional[int] = None,
        n_dims_per_instance: Optional[int] = None,
    ) -> torch.Tensor:
        param = scalar_or_1d_tensor(param=param, param_name=param_name)

        param = self._reshape_ensemble_param(
            param=param,
            ensemble_shape=ensemble_shape,
            dim_in_ensemble=dim_in_ensemble,
            n_dims_per_instance=n_dims_per_instance,
            param_name=param_name,
        )

        return param

    @staticmethod
    def _reshape_ensemble_param(
        *,
        param: torch.Tensor,
        ensemble_shape: Sequence[int],
        dim_in_ensemble: Optional[int],
        n_dims_per_instance: int,
        param_name: str,
    ) -> torch.Tensor:
        n_ensemble_dims = len(ensemble_shape)
        if param.ndim == 1:
            if dim_in_ensemble is None:
                raise ValueError(
                    f"`dim_in_ensemble` cannot be None if {param_name} is 1D."
                )
            param = param.to(torch.float32)
            # If param is not a scalar, broadcast it along the `ensemble_dim`-th ensemble dimension
            if dim_in_ensemble >= n_ensemble_dims or dim_in_ensemble < -n_ensemble_dims:
                raise ValueError(
                    f"Ensemble dimension for {param_name} must be an available index "
                    f"in `ensemble_shape`."
                )
            elif len(param) != ensemble_shape[dim_in_ensemble]:
                raise ValueError(
                    f"Parameter `{param_name}` must have the same length as "
                    f"``ensemble_shape[dim_in_ensemble]`` = "
                    f"{ensemble_shape[dim_in_ensemble]}."
                )
            new_shape = (
                (1,) * dim_in_ensemble
                + param.shape
                + (1,) * (n_ensemble_dims - dim_in_ensemble - 1)
                + (1,) * n_dims_per_instance
            )
            param = param.view(*new_shape)

        return param

In [None]:
# class PermutationsMixin:
#     """Mixin class for validating input and plotting the results of the optimization."""
# 
#     std_init: float
#     device: torch.device
#     group_sizes: list[int]
# 
#     def _init_log_alpha(self, skip=False):
#         """Initialize log_alpha as a list of matrices of shape (s, s) where d is the
#         size of the species MSA. The matrices are initialized with standard normal entries.
#         """
#         if not skip:
#             # Permutations restricted to species
#             self.log_alpha = [
#                 (self.std_init * torch.randn(s, s, device=self.device)).requires_grad_()
#                 for s in self._effective_sizes_not_fixed
#             ]
# 
#     def _validator(self, input_1, input_2, fixed_pairings=None):
#         """Validate input MSAs and check fixed pairings."""
#         # Validate input MSAs
#         size_1, length_1, alphabet_size_1 = input_1.shape[1:]
#         size_2, length_2, alphabet_size_2 = input_2.shape[1:]
#         length_1 -= 1
#         length_2 -= 1
#         if size_1 != size_2:
#             raise ValueError(
#                 f"Size mismatch between MSA 1 ({size_1}) and MSA 2 " f"({size_2})"
#             )
#         if alphabet_size_1 != alphabet_size_2:
#             raise ValueError("Input MSAs must have the same alphabet size/")
#         self._alphabet_size = alphabet_size_1
# 
#         # Validate size attribute
#         self._total_size = sum(self.group_sizes)
#         if size_1 != self._total_size:
#             raise ValueError(
#                 f"Input MSAs have size {size_1} but model expects a total "
#                 f"size of {self._total_size}"
#             )
#         self._length_1, self._length_2 = length_1, length_2
# 
#         self._effective_non_fixed_pairs = torch.ones(
#             self._total_size, self._total_size, dtype=torch.bool, device=self.device
#         )
# 
#         if fixed_pairings is not None:
#             if len(fixed_pairings) != len(self.group_sizes):
#                 raise ValueError(
#                     f"`fixed_pairings` has length {len(fixed_pairings)} but "
#                     f"there are {self.group_sizes} species."
#                 )
#             _fixed_pairings = fixed_pairings
# 
#             start = 0
#             self._effective_sizes_not_fixed = []
#             self._effective_fixed_pairings_zip = []
#             for species_idx, (species_size, species_fixed_pairings) in enumerate(
#                 zip(self.group_sizes, _fixed_pairings)
#             ):
#                 # Check uniqueness of pairs (i, j)
#                 n_fixed = len(set(species_fixed_pairings))
#                 if len(species_fixed_pairings) > n_fixed:
#                     raise ValueError(
#                         "Repeated indices for fixed pairings at species "
#                         f"{species_idx}: {species_fixed_pairings}"
#                     )
#                 fixed_pairings_arr = np.zeros((species_size, species_size), dtype=int)
#                 if species_fixed_pairings:
#                     species_fixed_pairings_zip = tuple(zip(*species_fixed_pairings))
#                 else:
#                     # species_fixed_pairings is an empty list
#                     species_fixed_pairings_zip = (tuple(), tuple())
#                 try:
#                     fixed_pairings_arr[species_fixed_pairings_zip] = 1
#                 except IndexError:
#                     raise ValueError(
#                         f"Fixed pairings indices out of bounds: passed {species_fixed_pairings} "
#                         f"for species {species_idx} with size {species_size}."
#                     )
#                 partial_sum_0 = fixed_pairings_arr.sum(axis=0)
#                 partial_sum_1 = fixed_pairings_arr.sum(axis=1)
#                 if (partial_sum_0 > 1).any() or (partial_sum_1 > 1).any():
#                     raise ValueError(
#                         f"Passed fixed pairings for species {species_idx} are either not one-one "
#                         "or a multiply-defined mapping from row to column indices: "
#                         f"{species_fixed_pairings}"
#                     )
#                 for i, j in species_fixed_pairings:
#                     self._effective_non_fixed_pairs[start + i, :] = False
#                     self._effective_non_fixed_pairs[:, start + j] = False
#                 total_minus_fixed = species_size - n_fixed
#                 # If species_size - n_fixed <= 1 then actually everything is fixed
#                 self._effective_sizes_not_fixed.append(
#                     int(total_minus_fixed > 1) * total_minus_fixed
#                 )
#                 if total_minus_fixed == 1:
#                     # Deduce implicitly fixed pair
#                     i_implicit, j_implicit = np.argmin(partial_sum_1), np.argmin(
#                         partial_sum_0
#                     )
#                     self._effective_non_fixed_pairs[start + i_implicit, :] = False
#                     self._effective_non_fixed_pairs[:, start + j_implicit] = False
#                     species_fixed_pairings_zip = (
#                         species_fixed_pairings_zip[0] + (i_implicit,),
#                         species_fixed_pairings_zip[1] + (j_implicit,),
#                     )
#                 self._effective_fixed_pairings_zip.append(species_fixed_pairings_zip)
#                 start += species_size
#         else:
#             self._effective_sizes_not_fixed = self.group_sizes
#             self._effective_fixed_pairings_zip = None
# 
#         self._default_target_idx = torch.arange(
#             self._total_size, dtype=torch.int64, device=self.device
#         )
# 
#     def plot_real_time(
#         self,
#         it,
#         gs_matching_mat_np,
#         gs_mat_np,
#         list_idx,
#         target_idx,
#         list_log_alpha,
#         losses,
#         batch_size,
#         epochs,
#         lr,
#         tar_loss,
#         new_noise_factor,
#         output_dir,
#         only_loss_plot,
#     ):
#         """Plot the results of the optimization in real time."""
#         n_correct = [sum(idx == target_idx) for idx in list_idx[::batch_size]]
# 
#         cmap = cm.get_cmap("Blues")
#         normalizer = None
#         fig, axes = plt.subplots(figsize=(30, 5), ncols=5, constrained_layout=True)
# 
#         null_model = len(self.group_sizes)
#         _size = [0] + list(np.cumsum(self.group_sizes))
#         for k in range(1, len(_size)):
#             for ii in range(2):
#                 elem, elem1 = _size[k - 1], _size[k]
#                 axes[ii].plot(
#                     [elem - 0.5, elem1 - 0.5, elem1 - 0.5, elem - 0.5],
#                     [elem - 0.5, elem - 0.5, elem1 - 0.5, elem1 - 0.5],
#                     color="r",
#                 )
#                 axes[ii].plot(
#                     [elem - 0.5, elem - 0.5, elem1 - 0.5, elem1 - 0.5],
#                     [elem - 0.5, elem1 - 0.5, elem1 - 0.5, elem - 0.5],
#                     color="r",
#                 )
# 
#         ims_soft = axes[0].imshow(gs_mat_np, cmap=cmap, norm=normalizer)
#         axes[0].set_title(f"Soft {it // batch_size}")
#         axes[1].imshow(gs_matching_mat_np, cmap=cmap, norm=normalizer)
#         axes[1].set_title("Hard")
# 
#         curr_log_alpha = list_log_alpha[-1]
#         ims_log_alpha = axes[2].imshow(curr_log_alpha, norm=CenteredNorm(), cmap=cm.bwr)
#         axes[2].set_title("Log-alpha")
# 
#         prev_log_alpha = (
#             list_log_alpha[-2] if len(list_log_alpha) > 1 else curr_log_alpha
#         )
#         diff_log_alpha = curr_log_alpha - prev_log_alpha
#         if np.nansum(np.abs(diff_log_alpha)):
#             ims_log_alpha_diff = axes[3].imshow(
#                 diff_log_alpha, norm=CenteredNorm(), cmap=cm.bwr
#             )
#             fig.colorbar(ims_log_alpha_diff, ax=axes[3], shrink=0.8)
#         else:
#             axes[3].imshow(np.zeros_like(diff_log_alpha), cmap=cm.bwr)
#         axes[3].set_title("Log-alpha diff")
# 
#         avg_loss = np.mean(np.array(losses).reshape(-1, batch_size), axis=1)
#         axes[4].plot(avg_loss, color="b", linewidth=1)
#         if not only_loss_plot:
#             if tar_loss is not None:
#                 axes[4].axhline(tar_loss, color="b", linewidth=2)
#             diff = avg_loss[0] - tar_loss
#             axes[4].set_ylim([tar_loss - 0.6 * diff, avg_loss[0] + 0.5 * diff])
#             ax3_2 = axes[4].twinx()
#             ax3_2.set_ylabel("No. of correct pairs", color="red")
#             ax3_2.plot(n_correct, color="red", linewidth=1)
#             ax3_2.axhline(null_model, color="red", linewidth=2)
#             ax3_2.tick_params(axis="y", labelcolor="red")
#         axes[4].set_ylabel("Loss")
#         axes[4].set_xlim([0, epochs])
#         axes[4].set_title(f"lr: {lr:.3g}, noise:{new_noise_factor:.3g}")
#         fig.colorbar(ims_soft, ax=axes[0], shrink=0.8)
#         fig.colorbar(ims_log_alpha, ax=axes[2], shrink=0.8)
#         if output_dir is not None:
#             fig.savefig(output_dir / "Iterations" / f"Epoch={it // batch_size}.svg")
#         plt.show()