# train

> Perform optimization using DiffPaSS models

In [None]:
#| default_exp train

In [None]:
#| hide

%load_ext autoreload
%autoreload 2

In [None]:
#| hide

from nbdev.showdoc import *

In [None]:
#| export

# Stdlib imports
from collections.abc import Sequence
from typing import Optional, Any, Literal

# NumPy
import numpy as np

# PyTorch
import torch

# DiffPaSS imports
from diffpass.base import DiffPaSSModel
from diffpass.model import (
    MatrixApply,
    PermutationConjugate,
    apply_hard_permutation_batch_to_similarity,
    TwoBodyEntropyLoss,
    MILoss,
    InterGroupSimilarityLoss,
    IntraGroupSimilarityLoss,
)

In [None]:
#| export

class InformationPairing(DiffPaSSModel):
    """DiffPaSS model for information-theoretic pairing of multiple sequence alignments (MSAs)."""

    are_inputs_msas = True

    def __init__(
        self,
        group_sizes: Sequence[int],
        fixed_pairings: Optional[Sequence[Sequence[Sequence[int]]]] = None,
        permutation_cfg: Optional[dict[str, Any]] = None,
        information_measure: Literal["MI", "TwoBodyEntropy"] = "TwoBodyEntropy",
    ):
        super().__init__()

        # Initialize permutation and matrix apply modules
        # (self.permutation and self.matrix_apply)
        self.init_permutation(
            group_sizes=group_sizes,
            fixed_pairings=fixed_pairings,
            permutation_cfg=permutation_cfg,
        )
        self.matrix_apply = MatrixApply(group_sizes=self.group_sizes)

        # Initialize information-theoretic loss module
        self.validate_information_measure(information_measure)
        self.information_measure = information_measure
        if self.information_measure == "TwoBodyEntropy":
            self.information_loss = TwoBodyEntropyLoss()
        elif self.information_measure == "MI":
            self.information_loss = MILoss()

    def forward(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        # Soft or hard permutations (list)
        perms = self.permutation()
        x_perm = self.matrix_apply(x, mats=perms)

        # Two-body entropy portion of the loss
        loss = self.information_loss(x_perm, y)

        return {"perms": perms, "x_perm": x_perm, "loss": loss}

    def prepare_fit(self, x: torch.Tensor, y: torch.Tensor) -> None:
        # Validate inputs
        self.validate_inputs(x, y, check_same_alphabet_size=True)

    def compute_losses_identity_perm(
        self, x: torch.Tensor, y: torch.Tensor
    ) -> dict[str, float]:
        # Compute hard/soft losses when using identity permutation
        self.hard_()
        with torch.no_grad():
            hard_loss_identity_perm = self.information_loss(x, y).item()
            soft_loss_identity_perm = hard_loss_identity_perm

        return {"hard": hard_loss_identity_perm, "soft": soft_loss_identity_perm}

In [None]:
show_doc(InformationPairing)

---

[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/train.py#L30){target="_blank" style="float:right; font-size:smaller"}

### InformationPairing

>      InformationPairing (group_sizes:collections.abc.Sequence[int], fixed_pair
>                          ings:Optional[collections.abc.Sequence[collections.ab
>                          c.Sequence[collections.abc.Sequence[int]]]]=None,
>                          permutation_cfg:Optional[dict[str,Any]]=None, informa
>                          tion_measure:Literal['MI','TwoBodyEntropy']='TwoBodyE
>                          ntropy')

DiffPaSS model for information-theoretic pairing of multiple sequence alignments (MSAs).

In [None]:
def test_information_bootstrap():
    # Data: two highly correlated MSAs
    n_classes = 3
    length = 5
    size_each_group = 10
    n_groups = 10
    # Define first MSA group by group
    x_tok_by_group = [torch.randint(0, n_classes, (size_each_group, length)) for _ in range(n_groups)]
    # Within-group shuffling to control for algorithmic biases towards identity permutation
    x_tok_by_group_shuffle = [x[torch.randperm(size_each_group)] for x in x_tok_by_group]
    x_tok = torch.cat(x_tok_by_group, dim=0)
    x_tok_shuffle = torch.cat(x_tok_by_group_shuffle, dim=0)
    y_tok = (x_tok + 1) % n_classes
    x = torch.nn.functional.one_hot(x_tok).to(torch.get_default_dtype())
    x_shuffle = torch.nn.functional.one_hot(x_tok_shuffle).to(torch.get_default_dtype())
    y = torch.nn.functional.one_hot(y_tok).to(torch.get_default_dtype())

    group_sizes = [size_each_group] * n_groups

    # Model
    model = InformationPairing(group_sizes=group_sizes)
    results = model.fit_bootstrap(x_shuffle, y)
    hard_loss_identity_perm = model.compute_losses_identity_perm(x, y)["hard"]

    # Check that the hard loss of the optimized permutation is close to the ground truth
    assert np.abs(results.hard_losses[-2][-1] - hard_loss_identity_perm) < 1e-4

test_information_bootstrap()    

In [None]:
#| export

class BestHitsPairing(DiffPaSSModel):
    """DiffPaSS model for pairing of multiple sequence alignments (MSAs) by aligning their orthology networks, constructed using (reciprocal) best hits ."""

    are_inputs_msas = True

    def __init__(
        self,
        group_sizes: Sequence[int],
        fixed_pairings: Optional[Sequence[Sequence[Sequence[int]]]] = None,
        permutation_cfg: Optional[dict[str, Any]] = None,
        similarity_kind: Literal["Hamming", "Blosum62"] = "Hamming",
        similarities_cfg: Optional[dict[str, Any]] = None,
        compute_in_group_best_hits: bool = True,
        best_hits_cfg: Optional[dict[str, Any]] = None,
        similarities_comparison_loss: Optional[callable] = None,
        compare_soft_best_hits_to_hard: bool = True,
    ):
        super().__init__()

        # Initialize permutation and matrix apply modules
        # (self.permutation and self.matrix_apply)
        self.init_permutation(
            group_sizes=group_sizes,
            fixed_pairings=fixed_pairings,
            permutation_cfg=permutation_cfg,
        )
        self.matrix_apply = MatrixApply(group_sizes=self.group_sizes)

        # Validate similarity kind/config and initialize similarities module
        self.init_similarities(
            similarity_kind=similarity_kind, similarities_cfg=similarities_cfg
        )

        # Validate best hits config and initialize best hits module
        self.compute_in_group_best_hits = compute_in_group_best_hits
        self.init_best_hits(best_hits_cfg)

        self.compare_soft_best_hits_to_hard = compare_soft_best_hits_to_hard

        # Similarities comparison loss
        self.similarities_comparison_loss = similarities_comparison_loss
        if self.similarities_comparison_loss is None:
            self.effective_similarities_comparison_loss_ = InterGroupSimilarityLoss(
                group_sizes=self.group_sizes
            )
        else:
            self.effective_similarities_comparison_loss_ = (
                self.similarities_comparison_loss
            )

    def _precompute_bh(self, x: torch.Tensor, y: torch.Tensor) -> None:
        mode = self.best_hits.mode

        # Temporarily switch to hard BH
        self.best_hits.hard_()
        similarities_x = self.similarities(x)
        self.register_buffer("_bh_hard_x", self.best_hits(similarities_x))
        similarities_y = self.similarities(y)
        self.register_buffer("_bh_hard_y", self.best_hits(similarities_y))

        # Switch to soft BH
        self.best_hits.soft_()
        self.register_buffer("_bh_soft_x", self.best_hits(similarities_x))
        self.register_buffer("_bh_soft_y", self.best_hits(similarities_y))

        # Restore initial mode
        self.best_hits.mode = mode

    @property
    def _bh_y_for_soft_x(self):
        if self.compare_soft_best_hits_to_hard:
            return self._bh_hard_y
        return self._bh_soft_y

    def forward(
        self, x: torch.Tensor, y: Optional[torch.Tensor] = None
    ) -> dict[str, torch.Tensor]:
        mode = self.permutation.mode
        assert (
            mode == self.best_hits.mode
        ), "Permutation and best hits must be either both in soft mode or both in hard mode."

        # Soft or hard permutations
        perms = self.permutation()
        x_perm = self.matrix_apply(x, mats=perms)

        # Best hits loss, with shortcut for hard permutations
        if mode == "soft":
            similarities_x = self.similarities(x_perm)
            bh_x = self.best_hits(similarities_x)
            # Ensure comparisons are soft_x-{soft,hard}_y, depending on
            # self.compare_soft_best_hits_to_hard
            loss = self.effective_similarities_comparison_loss_(
                bh_x, self._bh_y_for_soft_x
            )
        else:
            bh_x = apply_hard_permutation_batch_to_similarity(
                x=self._bh_hard_x, perms=perms
            )
            loss = self.effective_similarities_comparison_loss_(bh_x, self._bh_hard_y)

        return {
            "perms": perms,
            "x_perm": x_perm,
            "loss": loss,
        }

    def prepare_fit(self, x: torch.Tensor, y: torch.Tensor) -> None:
        # Validate inputs
        self.validate_inputs(x, y, check_same_alphabet_size=True)

        # Precompute matrices of best hits
        self._precompute_bh(x, y)

    def compute_losses_identity_perm(
        self, x: torch.Tensor, y: torch.Tensor
    ) -> dict[str, float]:
        # Precompute matrices of best hits
        self._precompute_bh(x, y)

        # Compute hard/soft losses when using identity permutation
        with torch.no_grad():
            hard_loss_identity_perm = self.effective_similarities_comparison_loss_(
                self._bh_hard_x, self._bh_hard_y
            ).item()
            soft_loss_identity_perm = self.effective_similarities_comparison_loss_(
                self._bh_soft_x, self._bh_y_for_soft_x
            ).item()

        return {"hard": hard_loss_identity_perm, "soft": soft_loss_identity_perm}

In [None]:
show_doc(BestHitsPairing)

---

[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/train.py#L89){target="_blank" style="float:right; font-size:smaller"}

### BestHitsPairing

>      BestHitsPairing (group_sizes:collections.abc.Sequence[int], fixed_pairing
>                       s:Optional[collections.abc.Sequence[collections.abc.Sequ
>                       ence[collections.abc.Sequence[int]]]]=None,
>                       permutation_cfg:Optional[dict[str,Any]]=None,
>                       similarity_kind:Literal['Hamming','Blosum62']='Hamming',
>                       similarities_cfg:Optional[dict[str,Any]]=None,
>                       compute_in_group_best_hits:bool=True,
>                       best_hits_cfg:Optional[dict[str,Any]]=None,
>                       similarities_comparison_loss:Optional[<built-
>                       infunctioncallable>]=None,
>                       compare_soft_best_hits_to_hard:bool=True)

DiffPaSS model for pairing of multiple sequence alignments (MSAs) by aligning their orthology networks, constructed using (reciprocal) best hits .

In [None]:
def test_besthits_bootstrap():
    # Data: two highly correlated MSAs
    n_classes = 3
    length = 100
    size_each_group = 10
    n_groups = 10
    # Define first MSA group by group
    x_tok_by_group = [torch.randint(0, n_classes, (size_each_group, length)) for _ in range(n_groups)]
    # Within-group shuffling to control for algorithmic biases towards identity permutation
    x_tok_by_group_shuffle = [x[torch.randperm(size_each_group)] for x in x_tok_by_group]
    x_tok = torch.cat(x_tok_by_group, dim=0)
    x_tok_shuffle = torch.cat(x_tok_by_group_shuffle, dim=0)
    y_tok = (x_tok + 1) % n_classes
    x = torch.nn.functional.one_hot(x_tok).to(torch.get_default_dtype())
    x_shuffle = torch.nn.functional.one_hot(x_tok_shuffle).to(torch.get_default_dtype())
    y = torch.nn.functional.one_hot(y_tok).to(torch.get_default_dtype())

    group_sizes = [size_each_group] * n_groups

    # Model
    model = BestHitsPairing(group_sizes=group_sizes)
    results = model.fit_bootstrap(x_shuffle, y)
    target_hard_loss = model.compute_losses_identity_perm(x, y)["hard"]

    # Check that the hard loss of the optimized permutation is close to the ground truth
    assert results.hard_losses[-2][-1] / target_hard_loss > 0.7

test_besthits_bootstrap()

In [None]:
#| export

class MirrortreePairing(DiffPaSSModel):
    """DiffPaSS model for pairing of multiple sequence alignments (MSAs) by aligning their sequence distance networks as in the Mirrortree method."""

    are_inputs_msas = True

    def __init__(
        self,
        group_sizes: Sequence[int],
        fixed_pairings: Optional[Sequence[Sequence[Sequence[int]]]] = None,
        permutation_cfg: Optional[dict[str, Any]] = None,
        similarity_kind: Literal["Hamming", "Blosum62"] = "Hamming",
        similarities_cfg: Optional[dict[str, Any]] = None,
        similarities_comparison_loss: Optional[callable] = None,
    ):
        super().__init__()

        # Initialize permutation and matrix apply modules
        # (self.permutation and self.matrix_apply)
        self.init_permutation(
            group_sizes=group_sizes,
            fixed_pairings=fixed_pairings,
            permutation_cfg=permutation_cfg,
        )
        self.matrix_apply = MatrixApply(group_sizes=self.group_sizes)

        # Validate similarity kind/config and initialize similarities module
        self.init_similarities(
            similarity_kind=similarity_kind, similarities_cfg=similarities_cfg
        )

        #  Similarities comparison loss
        self.similarities_comparison_loss = similarities_comparison_loss
        if self.similarities_comparison_loss is None:
            self.effective_similarities_comparison_loss_ = IntraGroupSimilarityLoss(
                group_sizes=self.group_sizes
            )
        else:
            self.effective_similarities_comparison_loss_ = (
                self.similarities_comparison_loss
            )

    def _precompute_similarities(self, x: torch.Tensor, y: torch.Tensor) -> None:
        self.register_buffer("_similarities_hard_x", self.similarities(x))
        self.register_buffer("_similarities_hard_y", self.similarities(y))

    def forward(
        self, x: torch.Tensor, y: Optional[torch.Tensor] = None
    ) -> dict[str, torch.Tensor]:
        mode = self.permutation.mode

        # Soft or hard permutations (list)
        perms = self.permutation()
        x_perm = self.matrix_apply(x, mats=perms)

        # Compute similarity matrix of soft- or hard-permuted x
        if mode == "soft":
            similarities_x = self.similarities(x_perm)
        else:
            similarities_x = apply_hard_permutation_batch_to_similarity(
                x=self._similarities_hard_x, perms=perms
            )

        loss = self.effective_similarities_comparison_loss_(
            similarities_x, self._similarities_hard_y
        )

        return {
            "perms": perms,
            "x_perm": x_perm,
            "loss": loss,
        }

    def prepare_fit(self, x: torch.Tensor, y: torch.Tensor) -> None:
        # Validate inputs
        self.validate_inputs(x, y, check_same_alphabet_size=True)

        # Precompute similarity matrices
        self._precompute_similarities(x, y)

    def compute_losses_identity_perm(
        self, x: torch.Tensor, y: torch.Tensor
    ) -> dict[str, float]:
        # Precompute matrices of best hits
        self._precompute_similarities(x, y)

        # Compute hard/soft losses when using identity permutation
        with torch.no_grad():
            hard_loss_identity_perm = self.effective_similarities_comparison_loss_(
                self._similarities_hard_x, self._similarities_hard_y
            ).item()
            soft_loss_identity_perm = hard_loss_identity_perm

        return {"hard": hard_loss_identity_perm, "soft": soft_loss_identity_perm}

In [None]:
show_doc(MirrortreePairing)

---

[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/train.py#L219){target="_blank" style="float:right; font-size:smaller"}

### MirrortreePairing

>      MirrortreePairing (group_sizes:collections.abc.Sequence[int], fixed_pairi
>                         ngs:Optional[collections.abc.Sequence[collections.abc.
>                         Sequence[collections.abc.Sequence[int]]]]=None,
>                         permutation_cfg:Optional[dict[str,Any]]=None, similari
>                         ty_kind:Literal['Hamming','Blosum62']='Hamming',
>                         similarities_cfg:Optional[dict[str,Any]]=None,
>                         similarities_comparison_loss:Optional[<built-
>                         infunctioncallable>]=None)

DiffPaSS model for pairing of multiple sequence alignments (MSAs) by aligning their sequence distance networks as in the Mirrortree method.

In [None]:
def test_mirrortree_bootstrap():
    # Data: two highly correlated MSAs
    n_classes = 3
    length = 100
    size_each_group = 10
    n_groups = 10
    # Define first MSA group by group
    x_tok_by_group = [torch.randint(0, n_classes, (size_each_group, length)) for _ in range(n_groups)]
    # Within-group shuffling to control for algorithmic biases towards identity permutation
    x_tok_by_group_shuffle = [x[torch.randperm(size_each_group)] for x in x_tok_by_group]
    x_tok = torch.cat(x_tok_by_group, dim=0)
    x_tok_shuffle = torch.cat(x_tok_by_group_shuffle, dim=0)
    y_tok = (x_tok + 1) % n_classes
    x = torch.nn.functional.one_hot(x_tok).to(torch.get_default_dtype())
    x_shuffle = torch.nn.functional.one_hot(x_tok_shuffle).to(torch.get_default_dtype())
    y = torch.nn.functional.one_hot(y_tok).to(torch.get_default_dtype())

    group_sizes = [size_each_group] * n_groups

    # Model
    model = MirrortreePairing(group_sizes=group_sizes)
    results = model.fit_bootstrap(x_shuffle, y)
    target_hard_loss = model.compute_losses_identity_perm(x, y)["hard"]

    # Check that the hard loss of the optimized permutation is close to the ground truth
    assert results.hard_losses[-2][-1] / target_hard_loss > 0.95

test_mirrortree_bootstrap()

In [None]:
#| export

class GraphAlignment(DiffPaSSModel):
    """DiffPaSS model for general graph alignment starting from the weighted adjacency matrices of two graphs."""

    are_inputs_msas = False

    def __init__(
        self,
        group_sizes: Optional[Sequence[int]] = None,
        fixed_pairings: Optional[Sequence[Sequence[Sequence[int]]]] = None,
        permutation_cfg: Optional[dict[str, Any]] = None,
        comparison_loss: Optional[callable] = None,
    ):
        super().__init__()

        # Initialize permutation and matrix apply modules
        # (self.permutation and self.matrix_apply)
        self.init_permutation(
            group_sizes=group_sizes,
            fixed_pairings=fixed_pairings,
            permutation_cfg=permutation_cfg,
        )
        self.permutation_conjugate = PermutationConjugate(group_sizes=self.group_sizes)

        #  Comparison loss
        self.comparison_loss = comparison_loss
        if self.comparison_loss is None:
            # Default: dot product between all upper triangular elements
            self.effective_comparison_loss_ = IntraGroupSimilarityLoss(group_sizes=None)
        else:
            self.effective_comparison_loss_ = self.comparison_loss

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> dict[str, torch.Tensor]:
        mode = self.permutation.mode

        # Soft or hard permutations (list)
        perms = self.permutation()

        # Conjugate adjacency matrix x by soft/hard permutation P: P @ x @ P.T
        if mode == "soft":
            x_perm = self.permutation_conjugate(x, mats=perms)
        else:
            x_perm = apply_hard_permutation_batch_to_similarity(x=x, perms=perms)
        loss = self.effective_comparison_loss_(x_perm, y, mats=perms)

        return {
            "perms": perms,
            "x_perm": x_perm,
            "loss": loss,
        }

    def prepare_fit(self, x: torch.Tensor, y: torch.Tensor) -> None:
        # Validate inputs
        self.validate_inputs(x, y)

    def compute_losses_identity_perm(
        self, x: torch.Tensor, y: torch.Tensor
    ) -> dict[str, float]:
        # Compute hard/soft losses when using identity permutation
        with torch.no_grad():
            hard_loss_identity_perm = self.effective_comparison_loss_(
                x, y, mats=[torch.eye(s).to(x.device) for s in self.group_sizes]
            ).item()
            soft_loss_identity_perm = hard_loss_identity_perm

        return {"hard": hard_loss_identity_perm, "soft": soft_loss_identity_perm}

In [None]:
show_doc(GraphAlignment)

---

[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/train.py#L312){target="_blank" style="float:right; font-size:smaller"}

### GraphAlignment

>      GraphAlignment (group_sizes:Optional[collections.abc.Sequence[int]]=None,
>                      fixed_pairings:Optional[collections.abc.Sequence[collecti
>                      ons.abc.Sequence[collections.abc.Sequence[int]]]]=None,
>                      permutation_cfg:Optional[dict[str,Any]]=None,
>                      comparison_loss:Optional[<built-
>                      infunctioncallable>]=None)

DiffPaSS model for general graph alignment starting from the weighted adjacency matrices of two graphs.

In [None]:
def test_graph_alignment_bootstrap():
    # Data: two identical weighted adjacency matrices
    size_each_group = 10
    n_groups = 10
    n_samples = size_each_group * n_groups
    x = torch.exp(torch.randn((n_samples, n_samples))).to(torch.get_default_dtype())
    y = x.clone()
    # Within-group shuffling to control for algorithmic biases towards identity permutation
    rand_perm_mats = []
    for _ in range(n_groups):
        rp_mat = torch.zeros(
            (size_each_group, size_each_group), dtype=x.dtype, device=x.device
        )
        rp_mat[torch.arange(size_each_group), torch.randperm(size_each_group)] = 1
        rand_perm_mats.append(rp_mat)
    x_shuffle = apply_hard_permutation_batch_to_similarity(x=x, perms=rand_perm_mats)

    group_sizes = [size_each_group] * n_groups

    # Model
    model = GraphAlignment(group_sizes=group_sizes)
    results = model.fit_bootstrap(x_shuffle, y)
    target_hard_loss = model.compute_losses_identity_perm(x, y)["hard"]

    # Check that the hard loss of the optimized permutation is close to the ground truth
    assert results.hard_losses[-2][-1] / target_hard_loss > 0.95

test_graph_alignment_bootstrap()