In [None]:
#| hide

%load_ext autoreload
%autoreload 2

# DiffPASS – Example usage on prokaryotic datasets

> API details.

In [None]:
# NumPy
import numpy as np

# PyTorch
import torch
torch.set_num_threads(8)

# Plotting
from matplotlib import pyplot as plt

# DiffPASS imports
from diffpass.train import InformationAndReciprocalBestHits
from diffpass.msa_parsing import read_msa
from diffpass.data_utils import generate_dataset, dataset_tokenizer

Load one of the two prokaryotic datasets: HK-RR and MALG-MALK.

`get_species_name` extracts species names from the FASTA header.  

In [None]:
# PROKARYOTIC DATASETS

msa_data = [
    read_msa("data/HK-RR/HK_in_Concat_nnn.fasta", -1),
    read_msa("data/HK-RR/RR_in_Concat_nnn.fasta", -1)
]
get_species_name = lambda header: header.split("|")[1]

# msa_data = [
#     read_msa("data/MALG-MALK/MALG_cov75_hmmsearch_extr5000_withLast_b.fasta", -1),
#     read_msa("data/MALG-MALK/MALK_cov75_hmmsearch_extr5000_withLast_b.fasta", -1)
# ]
# get_species_name = lambda header: header.split("_")[-1]

In [None]:
parameters_dataset = {
    "N": 500,  # Average number of sequences in the input
    "pos": 0,  # Size of the context pairs to use as positive example 
    "max_size": 100,  # Max size of species MSAs (if same as N there is no limit on size)
    "NUMPY_SEED": 10,
    "NUMPY_SEED_OTHER": 11,
}

In [None]:
dataset, group_sizes = generate_dataset(
    parameters_dataset, msa_data, get_species_name=get_species_name
)
tokenized_dataset = dataset_tokenizer(dataset)

x, y = tokenized_dataset["msa"]["left"], tokenized_dataset["msa"]["right"]
positive_examples = tokenized_dataset["positive_examples"]

The following function can be used to shuffle the sequences within each species. This can be a useful control.

In [None]:
def in_species_random_shuffle(
    x: torch.Tensor,
    *,
    group_sizes: np.ndarray
):
    idx = 0
    for s in group_sizes:
        x[idx:s + idx, ...] = x[idx:s + idx][torch.randperm(s)]
        idx += s

## Pairings using DiffPASS's `InformationAndReciprocalBestHits` class

### Configuration and initialization

In [None]:
TORCH_SEED = 100
torch.manual_seed(TORCH_SEED);

In [None]:
# Settings for soft permutations (Gumbel-Sinkhorn)
permutation_cfg = {
    "tau": torch.tensor(1e-2),
    "n_iter": 10,
    "noise": False,
}

# Information-theoretic part of the loss: whether we use two-body entropy or mutual information
information_measure = "TwoBodyEntropy"

# Settings affecting the reciprocal best hits part of the loss
hamming_similarities_cfg = {
    "use_dot": False,
    "p": 1
}
reciprocal_best_hits_cfg = {
    "tau": torch.tensor(1e-1)
}
inter_group_loss_score_fn = torch.dot  #torch.nn.CosineSimilarity(dim=-1)

# Loss weights
loss_weights = {
    information_measure: 1.,
    "ReciprocalBestHits": 1.,
}

# Device
device = torch.device(
    f"cuda{(':' + input('Enter the CUDA device number:')) if torch.cuda.device_count() > 1 else ''}"
    if torch.cuda.is_available() else "cpu"
)

In [None]:
dpass = InformationAndReciprocalBestHits(
    group_sizes=group_sizes,
    fixed_matchings=None,
    loss_weights=loss_weights,
    permutation_cfg=permutation_cfg,
    information_measure=information_measure,
    hamming_similarities_cfg=hamming_similarities_cfg,
    reciprocal_best_hits_cfg=reciprocal_best_hits_cfg,
    inter_group_loss_score_fn=inter_group_loss_score_fn,
)

dpass.to(device)
x = x.to(device)
y = y.to(device)

### Optimization

In [None]:
# Optimization parameters
fit_cfg = {
    "epochs": 400,
    "optimizer_name": "SGD",
    "optimizer_kwargs": {"lr": 1e-1, "weight_decay": 0.},
    "mean_centering": True,
    "hamming_gradient_bypass": False
}

In [None]:
results = dpass.fit(
    x,
    y,
    **fit_cfg
)

### Visualizing the results

In [None]:
hard_losses_total = 0
soft_losses_total = 0
hard_loss_id_total = 0
soft_loss_id_total = 0
for loss_kind in results.hard_losses:
    hard_losses = np.array([x.item() for x in results.hard_losses[loss_kind]])
    soft_losses = np.array([x.item() for x in results.soft_losses[loss_kind]])
    hard_losses_total += hard_losses
    soft_losses_total += soft_losses
    correct = [
        sum([
            (perm == torch.arange(
                perm.shape[-1]
            )).sum().item() for perm in perms
        ])
        for perms in results.hard_perms
    ]
    hard_loss_id = results.hard_losses_identity_perm[loss_kind]
    soft_loss_id = results.soft_losses_identity_perm[loss_kind]
    hard_loss_id_total += hard_loss_id
    soft_loss_id_total += soft_loss_id

    plt.plot(hard_losses)
    plt.axhline(hard_loss_id)
    ax_correct = plt.twinx()
    ax_correct.plot(correct)
    plt.title(f"{loss_kind}, hard")
    plt.show()

    plt.plot(soft_losses)
    plt.axhline(soft_loss_id)
    plt.title(f"{loss_kind}, soft")
    plt.show()

plt.plot(hard_losses_total)
plt.axhline(hard_loss_id_total)
plt.title("Total, hard")
plt.show()

plt.plot(soft_losses_total)
plt.axhline(soft_loss_id_total)
plt.title("Total, soft")
plt.show()