In [None]:
#| hide

%load_ext autoreload
%autoreload 2

# DiffPASS – Example usage on prokaryotic datasets

> API details.

In [None]:
# Stdlib
from collections.abc import Sequence
from typing import Optional

# Progress bars
from tqdm import tqdm

# NumPy
import numpy as np

# PyTorch
import torch

# Plotting
from matplotlib import pyplot as plt

# DiffPASS imports
from diffpass.train import InformationAndReciprocalBestHits, compute_num_correct_matchings
from diffpass.msa_parsing import read_msa
from diffpass.data_utils import generate_dataset, dataset_tokenizer


# Set the number of threads for PyTorch
torch.set_num_threads(8)

# Device
DEVICE = torch.device(
    f"cuda{(':' + input('Enter the CUDA device number:')) if torch.cuda.device_count() > 1 else ''}"
    if torch.cuda.is_available() else "cpu"
)
print(f"Using device: {DEVICE}")

Load one of the two prokaryotic datasets: HK-RR and MALG-MALK.

`get_species_name` extracts species names from the FASTA header.  

In [None]:
# PROKARYOTIC DATASETS

msa_data = [
    read_msa("data/HK-RR/HK_in_Concat_nnn.fasta", -1),
    read_msa("data/HK-RR/RR_in_Concat_nnn.fasta", -1)
]
get_species_name = lambda header: header.split("|")[1]

# msa_data = [
#     read_msa("data/MALG-MALK/MALG_cov75_hmmsearch_extr5000_withLast_b.fasta", -1),
#     read_msa("data/MALG-MALK/MALK_cov75_hmmsearch_extr5000_withLast_b.fasta", -1)
# ]
# get_species_name = lambda header: header.split("_")[-1]

In [None]:
parameters_dataset = {
    "N": 500,  # Average number of sequences in the input
    "pos": 0,  # Size of the context pairs to use as positive example 
    "max_size": 100,  # Max size of species MSAs (if same as N there is no limit on size)
    "NUMPY_SEED": 10,
    "NUMPY_SEED_OTHER": 11,
}

In [None]:
dataset, group_sizes = generate_dataset(
    parameters_dataset, msa_data, get_species_name=get_species_name
)
tokenized_dataset = dataset_tokenizer(dataset, device=DEVICE)
x, y = tokenized_dataset["msa"]["left"], tokenized_dataset["msa"]["right"]

n_seqs = len(x)

## Pairings using DiffPASS's `InformationAndReciprocalBestHits` class

### Configuration and initialization

In [None]:
TORCH_SEED = 100
torch.manual_seed(TORCH_SEED);

In [None]:
# Settings for soft permutations (Gumbel-Sinkhorn)
permutation_cfg = {
    "tau": torch.tensor(1e-2),
    "n_iter": 10,
    "noise": False,
}

# Information-theoretic part of the loss: whether we use two-body entropy or mutual information
information_measure = "TwoBodyEntropy"

# Settings affecting the reciprocal best hits part of the loss
similarity_kind = "Hamming"
similarities_cfg = {
    "use_dot": False,
    "p": 1
}
reciprocal_best_hits_cfg = {
    "tau": torch.tensor(1e-1)
}
inter_group_loss_score_fn = torch.nn.CosineSimilarity(dim=-1)
intra_group_loss_score_fn = torch.nn.CosineSimilarity(dim=-1)

# Loss weights
loss_weights_rbh = {
    information_measure: 1.,
    "ReciprocalBestHits": 1.,
}
loss_weights_mt = {
    information_measure:1.,
    "Mirrortree": 0.
}

### Optimization

In [None]:
# Optimization parameters
fit_cfg = {
    "epochs": 1,
    "optimizer_name": "SGD",
    "optimizer_kwargs": {"lr": 1e-1, "weight_decay": 0.},
    "mean_centering": True,
    # "similarity_gradient_bypass": False,
    "show_pbar": False,
    "compute_final_soft": False
}

ipa_cfg = {
    "n_start": 20,
    "n_end": None,
    "max_iters": 1,
    "compute_correct_index_based": True
}

In [None]:
def optimize(
    x: torch.Tensor,
    y: torch.Tensor,
    *,
    group_sizes: Sequence[int],
    n_start: int,
    n_end: Optional[int],
    max_iters: int,
    compute_correct_index_based: bool,
    show_pbar: bool = True,
):
    def create_dpass():
        return InformationAndReciprocalBestHits(
            group_sizes=group_sizes,
            fixed_matchings=fixed_matchings,
            loss_weights=loss_weights_rbh,
            # loss_weights=loss_weights_mt,
            permutation_cfg=permutation_cfg,
            information_measure=information_measure,
            similarity_kind=similarity_kind,
            similarities_cfg=similarities_cfg,
            reciprocal_best_hits_cfg=reciprocal_best_hits_cfg,
            inter_group_loss_score_fn=inter_group_loss_score_fn,
            # intra_group_loss_score_fn=intra_group_loss_score_fn
        ).to(DEVICE)

    n_seqs = len(x)
    n_groups = len(group_sizes)
    offsets = torch.from_numpy(np.repeat(np.cumsum([0] + list(group_sizes))[:-1], repeats=group_sizes))
    group_idxs = torch.from_numpy(np.repeat(np.arange(n_groups), repeats=group_sizes))

    if n_end is None:
        n_end = n_seqs

    # Initial fit with no fixed matchings
    fixed_matchings = None
    dpass = create_dpass()
    results = dpass.fit(x, y, **fit_cfg)
    hard_losses_identity_perm = results.hard_losses_identity_perm
    hard_losses_this_step = {
        k: (results.hard_losses[k][0], results.hard_losses[k][-1])
        for k in results.hard_losses
    }
    correct_this_step = compute_num_correct_matchings(
        results, index_based=compute_correct_index_based
    )
    hard_perms = [results.hard_perms]
    hard_losses = [hard_losses_this_step]
    correct = [correct_this_step]

    # Subsequent fits: at a given iteration we use fixed matchings chosen from the results of the
    # previous iteration
    pbar = list(range(n_start, n_end))
    pbar = tqdm(pbar) if show_pbar else pbar
    for N in pbar:
        for it in range(max_iters):
            mapped_idxs = offsets + torch.cat(results.hard_perms[1])
            rand_fixed_idxs = torch.randperm(n_seqs)[:N]
            rand_fixed_idxs = torch.sort(rand_fixed_idxs).values
            rand_mapped_idxs = mapped_idxs[rand_fixed_idxs]
            rand_group_idxs = group_idxs[rand_fixed_idxs]
            rand_fixed_rel_idxs = rand_fixed_idxs - offsets[rand_fixed_idxs]
            rand_mapped_rel_idxs = rand_mapped_idxs - offsets[rand_mapped_idxs]
    
            fixed_matchings = [[] for _ in range(n_groups)]
            for rand_group_idx, mapped_rel_idx, fixed_rel_idx in zip(
                    rand_group_idxs, rand_mapped_rel_idxs, rand_fixed_rel_idxs
            ):
                fixed_matchings[rand_group_idx].append(
                    (mapped_rel_idx.item(), fixed_rel_idx.item())
                )
    
            dpass = create_dpass()
            results = dpass.fit(x, y, **fit_cfg)
    
        hard_losses_this_step = {
            k: (results.hard_losses[k][0], results.hard_losses[k][-1])
            for k in results.hard_losses
        }
        correct_this_step = compute_num_correct_matchings(
            results, index_based=compute_correct_index_based
        )
        hard_perms.append(results.hard_perms)
        hard_losses.append(hard_losses_this_step)
        correct.append(correct_this_step)

    return {
        "hard_perms": hard_perms,
        "hard_losses": hard_losses,
        "correct": correct,
        "hard_losses_identity_perm": hard_losses_identity_perm
    }

In [None]:
ipa_results = optimize(
    x, y,
    group_sizes=group_sizes,
    **ipa_cfg
)

hard_losses = ipa_results["hard_losses"]
correct = ipa_results["correct"]
hard_losses_identity_perm = ipa_results["hard_losses_identity_perm"]

In [None]:
for key in hard_losses[0]:
    plt.plot([l[key][0] for l in hard_losses], ".-", label="start")
    plt.plot([l[key][1] for l in hard_losses], ".-", label="end")
    plt.axhline(hard_losses_identity_perm[key])
    plt.legend()
    plt.title(f"{key} loss")
    plt.show()

plt.plot([c[0] / n_seqs for c in correct], ".-", label="start")
plt.plot([c[1] / n_seqs for c in correct], ".-", label="end")
plt.title("Fraction correct")
plt.legend()
plt.show()