In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| default_exp simulations

# DiffPALM – Example usage on prokaryotic datasets

> API details.

In [None]:
import numpy as np

from diffpalm.core import DiffPALM
from diffpalm.msa_parsing import read_msa
from diffpalm.datasets import generate_dataset, dataset_tokenizer

from pathlib import Path
from datetime import datetime
import pickle


def save_parameters(parameters_all, filepath):
    """Saves the parameters dictionary"""
    for name, parameters in parameters_all.items():
        with open(filepath / f"{name}.pkl", "wb") as f:
            pickle.dump(parameters, f)
        with open(filepath / f"{name}.csv", "w") as f:
            for key in parameters.keys():
                f.write("%s, %s\n" % (key, parameters[key]))

In [None]:
RESULTS_DIR = Path(input(prompt="Insert path to directory where results will be stored (default: 'Results'): ") or "Results")
RESULTS_DIR.mkdir(exist_ok=True)

DEVICE = input(prompt="Choose PyTorch device (default: 'cuda'): ") or "cuda"

Load one of the two prokaryotic datasets used in our paper: HK-RR and MALG-MALK.

`get_species_name` extracts species names from the FASTA header.  

In [None]:
# PROKARYOTIC DATASETS

msa_data = [read_msa("data/HK-RR/HK_in_Concat_nnn.fasta", -1),
            read_msa("data/HK-RR/RR_in_Concat_nnn.fasta", -1)]
get_species_name = (lambda strn: strn.split("|")[1])

# msa_data = [read_msa("data/MALG-MALK/MALG_cov75_hmmsearch_extr5000_withLast_b.fasta", -1),
#             read_msa("data/MALG-MALK/MALK_cov75_hmmsearch_extr5000_withLast_b.fasta", -1)]
# get_species_name = (lambda strn: strn.split("_")[-1])

In [None]:
parameters_dataset = {
    "N": 50,  # Average number of sequences in the input
    "pos": 0,  # Size of the context pairs to use as positive example 
    "max_size": 100,  # Max size of species MSAs (if same as N there is no limit on size)
    "NUMPY_SEED": 10,
    "NUMPY_SEED_OTHER": 11,
}

In [None]:
dataset, species_sizes = generate_dataset(
    parameters_dataset, msa_data, get_species_name=get_species_name
)
tokenized_dataset = dataset_tokenizer(dataset, device=DEVICE)

left_msa, right_msa = tokenized_dataset["msa"]["left"], tokenized_dataset["msa"]["right"]
positive_examples = tokenized_dataset["positive_examples"]

## Train single block

In [None]:
EPOCHS = 100
TORCH_SEED = 100

parameters_init = {
    "device": DEVICE,
    "p_mask": 0.7,
    "random_seed": TORCH_SEED
}

parameters_train = {
    "std_init": 0.,
    "scheduler_name": "ReduceLROnPlateau",
    "scheduler_kwargs": {"mode": "min", "factor": 0.8, "patience": 20},
    "optimizer_name": "Adadelta",
    "optimizer_kwargs": {"lr": 9, "weight_decay": 1e-1},
    "tau": 1.,
    "n_sink_iter": 10,
    "batch_size": 1,
    "epochs": EPOCHS,
    "noise": True,
    "noise_factor": 0.1,  # If noise_std is False, this is just the std of the noise
    "noise_scheduler": True,
    "noise_std": True,
    "use_rand_perm": True,
}

parameters_target_loss = {
    "batch_size": 200
}

parameters_all = {
    "init": parameters_init,
    "target_loss": parameters_target_loss,
    "train": parameters_train,
    "dataset": parameters_dataset
}

In [None]:
dpalm = DiffPALM(species_sizes, **parameters_init)

In [None]:
date = datetime.now().strftime("%Y_%m_%d-%H:%M:%S")
output_dir = RESULTS_DIR / date
output_dir.mkdir()

save_parameters(parameters_all, output_dir)

When `save_all_figs=True`, a figure is saved and shown after each gradient step, illustrating the current state of the optimization. This slows the overall optimization down and may create memory leakage issues. Set `save_all_figs=False` to only have the figure saved and shown after the last gradient step.

The plotting function is able to show the number of correctly predicted pairs because the ground truth pairs are known. The model assumes that the input pairs are already correctly matched (i.e. the correct matching matrix is a diagonal matrix) because in the HK-RR and MALG-MALK datasets the sequences are are already ordered with the correct matches in the same position of the MSA.

In [None]:
tar_loss = dpalm.target_loss(
    left_msa,
    right_msa,
    positive_examples=positive_examples,
    **parameters_target_loss
)

(losses,
 list_scheduler,
 shuffled_indexes,
 mat_perm,
 mat_gs,
 list_log_alpha) = dpalm.train(
    left_msa,
    right_msa,
    positive_examples=positive_examples,
    tar_loss=np.mean(tar_loss),
    output_dir=output_dir,
    save_all_figs=True,
    **parameters_train,
)

results = {
    "trainng_results": (losses, list_scheduler, shuffled_indexes, [mat_perm, mat_gs], list_log_alpha),
    "target_loss": tar_loss,
    "species_sizes": species_sizes
}