In [None]:
from ProtMamba_ssm.core import *
from ProtMamba_ssm.dataloaders import *
from ProtMamba_ssm.utils import *
from ProtMamba_ssm.modules import *
import torch
import numpy as np
from tqdm import tqdm
import pickle

In [None]:
def sample_sequences(dataset,
                     model,
                     n_samples_per_family,
                     max_length=1000,
                     family_idxs=[],
                     parameters_list=[],
                     fim_generation=False,
                     save_path=None):
    """
    Function to sample sequences from the model. Given a dataset, a list of families (their indexes in the dataset)
    and a set of generating parameters, it generates `n_samples_per_family` sequences for each family and each parameter set.
    The function returns a dictionary with the following structure:
    gen_seqs = {family_idx: {parameters: {sequence: perplexity}}}
    The parameters are in a list of tuples with the following structure:    
    parameters_list = [(nr_seqs_ctx, temperature, top_k, top_p)]
    """        
    gen_seqs = {}
    for j in family_idxs:
        gen_seqs[j] = {}
        print("Sampling sequences for family {}".format(j))
        for params in tqdm(parameters_list):
            gen_seqs[j][params] = {}
            n_seqs_ctx , temperature, top_k, top_p = params
            for _ in range(n_samples_per_family):
                # Sample the dataset to get the input
                data = dataset[j]
                tokens = data["input_ids"][None,:].to("cuda")
                pos_ids = data["position_ids"][None,:].to("cuda")
                start_seqs = torch.argwhere(tokens[0]==0)[:,0].cpu().numpy()
                if fim_generation:
                    n_seqs_ctx = len(start_seqs)-1 if len(start_seqs) < n_seqs_ctx+1 else n_seqs_ctx
                    L = start_seqs[n_seqs_ctx+1] if n_seqs_ctx>0 else start_seqs[n_seqs_ctx]
                    context_tokens, context_pos_ids, tokens_fim, pos_ids_fim, is_fim_dict = prepare_dataset_for_fim_generation(tokens[:,:L], pos_ids[:,:L])
                    is_fim = is_fim_dict
                else:
                    n_seqs_ctx = len(start_seqs) if len(start_seqs) < n_seqs_ctx else n_seqs_ctx
                    L = start_seqs[n_seqs_ctx]+1
                    context_tokens = tokens[:,:L]
                    context_pos_ids = pos_ids[:,:L]
                    is_fim=False
                # Generate the new sequence               
                output = generate_sequence(model,
                                        context_tokens,
                                        position_ids=context_pos_ids,
                                        is_fim=is_fim,
                                        max_length=(L+max_length),
                                        temperature=temperature,
                                        top_k=top_k,
                                        top_p=top_p,
                                        return_dict_in_generate=True,
                                        output_scores=True,
                                        eos_token_id=torch.tensor([AA_TO_ID["<cls>"]]).to("cuda"),
                                        device="cuda")
                # Get the perplexity of the generated sequence
                output_seq = output["generated"] 
                loss = torch.nn.functional.cross_entropy(torch.from_numpy(output["scores"]).permute(0, 2, 1),
                                                        torch.from_numpy(output["generated_tokens"][0][None,:]))
                # save only sequences with length < max_length
                if len(output_seq[0]) < max_length:
                    if fim_generation:
                        original_input = output["input"][0].split("<cls>")[-1]
                        original_input_continuation = decode_sequence(tokens_fim[0].cpu().numpy())+"<cls>"
                        generated_input_continuation = output_seq[0]
                        if len(original_input_continuation) == len(generated_input_continuation):
                            outp_str = reorder_masked_sequence(original_input + generated_input_continuation)
                            gen_seqs[j][params][outp_str] = {"original_input": original_input,
                                                            "original_input_fim": original_input_continuation,
                                                            "generated_input_fim": generated_input_continuation,
                                                            "perplexity": torch.exp(loss).item()}
                        else:
                            print("Lengths of original and generated FIM do not match. {} vs {}".format(original_input_continuation, generated_input_continuation))
                    else:
                        gen_seqs[j][params][output_seq[0]] = {"perplexity": torch.exp(loss).item()}
        if save_path is not None:
            with open(save_path, "wb") as f:
                pickle.dump(gen_seqs, f)
    return gen_seqs

In [None]:
# Load the dataset used for training
dataset_name = "encoded_MSAs_test.pkl"
fim_strategy = "multiple_span"
mask_fraction = 0.2
dataset = Uniclust30_Dataset(dataset_name,
                            filepath="/data1/common/OpenProteinSet/",
                            sample=False,
                            max_msa_len=-1,
                            max_patches=5,
                            mask_fraction=mask_fraction,
                            fim_strategy=fim_strategy,
                            max_position_embeddings=2048,
                            add_position_ids="1d")
    
# Load pretrained model
checkpoint = "../../nbs/results/train_100M_FIM_restart-spikes_merged/checkpoint_131k-3750"
model = load_model(checkpoint,
                   model_class=MambaLMHeadModelwithPosids,
                   device="cuda",
                   dtype=torch.bfloat16,
                   checkpoint_mixer=False
                   )
model = model.eval()

## Sample from different families using different generation methods

In [None]:
# family_idxs = [0, 1, 2, 3, 4, 5, 6, 8, 9, 10]
family_idxs = [11, 13, 14, 16, 18] # [20, 21, 23, 24, 25] # 

# # parameters: (nr_seqs_ctx, temperature, top_k, top_p)
# parameters_list =  [(100,1.,10,0.), (-1,0.9,10,0.95)]
parameters_list = [(10,1.,10,0.), (10,1.,15,0.), (10,1.,10,0.95), (10,0.9,10,0.95), (10,0.8,10,0.9),
                   (100,1.,10,0.), (100,1.,15,0.), (100,1.,10,0.95), (100,0.9,10,0.95), (100,0.8,10,0.9),
                   (500,1.,10,0.), (500,1.,15,0.), (500,1.,10,0.95), (500,0.9,10,0.95), (500,0.8,10,0.9),
                   (1000,1.,10,0.), (1000,1.,15,0.), (1000,1.,10,0.95), (1000,0.9,10,0.95), (1000,0.8,10,0.9),
                   (-1,1.,10,0.), (-1,1.,15,0.), (-1,1.,10,0.95), (-1,0.9,10,0.95), (-1,0.8,10,0.9)]
n_samples_per_family = 100
generate_fim = False

In [None]:
for i in family_idxs:
    data = dataset[i]
    tokens = data["input_ids"][None,:].to("cuda")
    start_seqs = torch.argwhere(tokens[0]==0)[:,0].cpu().numpy()
    inds = start_seqs < 131072
    print(f"Family name: {dataset.cluster_names[i]}\t", "\tNumber of sequences: ", len(start_seqs), "\tNum sequences < 131072: ", len(start_seqs[inds]))

In [None]:
end_str = "_fim" if generate_fim else "_full"
save_path = f"figures/generated_sequences/check-131k(11-18)_gen_seqs{end_str}.pkl"
gen_seqs = sample_sequences(dataset,
                            model,
                            n_samples_per_family=n_samples_per_family,
                            max_length=1000,
                            family_idxs=family_idxs,
                            parameters_list=parameters_list,
                            fim_generation=generate_fim,
                            save_path=save_path
                            )
with open(save_path, "wb") as f:
    pickle.dump(gen_seqs, f)