In [1]:
from stateval.src.configs import MSAConfig
from stateval.src.msa import Msa 
from stateval.src.sh_entropies import ShannonEntropies 
from stat_eval_utils import fasta2dict


sequences = fasta2dict("./mdh_train_sample.fasta")
sequences = {id: seq for id, seq in zip(sequences["ID"], sequences["Seq"])}

msa = Msa(MSAConfig(max_gap_ratio=0.8))
aligned_seqs = msa.align(sequences)
se = ShannonEntropies()
entropies = se.calculate_entropies(aligned_seqs)




Using 3 threads
Read 500 sequences (type: Protein) from msa_input.fasta
Calculating pairwise ktuple-distances...
Ktuple-distance calculation progress done. CPU time: 14.00u 0.00s 00:00:14.00 Elapsed: 00:00:05
Guide-tree computation done.
Progressive alignment progress done. CPU time: 13.26u 0.08s 00:00:13.34 Elapsed: 00:00:07
Alignment written to _temp_msa_output.fasta


In [3]:
max([len(v) for v in sequences.values()])

430

In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import yaml
from torch.nn import functional as F
import torch
with open('generation_config.yml', 'r') as file:
    config  = yaml.safe_load(file)

tokenizer = AutoTokenizer.from_pretrained(config["model_checkpoint_path"])
tokenizer.add_special_tokens({'pad_token': '[PAD]'})  # Add a padding token
model = AutoModelForCausalLM.from_pretrained(config["model_checkpoint_path"], device_map="auto", load_in_8bit=True)
model.to_bettertransformer()
model.config.max_length = 512  #


  from .autonotebook import tqdm as notebook_tqdm
The BetterTransformer implementation does not support padding during training, as the fused kernels do not support attention masks. Beware that passing padded batched data during training may result in unexpected outputs. Please refer to https://huggingface.co/docs/optimum/bettertransformer/overview for more details.


In [6]:
next(model.parameters()).device

device(type='cuda', index=0)

In [8]:
def generate(model, max_new_tokens, segments_indicies, temps, prompt, top_k=None, eos_token_id=0):
    temp_index = 0
    idx = prompt
    past_key_values = None
    for _ in range(max_new_tokens):
        # if the sequence context is growing too long we must crop it at block_size
        # forward the model to get the logits for the index in the sequence
        outputs = model(idx, past_key_values=past_key_values)
        logits = outputs.logits
        past_key_values = outputs.past_key_values
        # pluck the logits at the final step and scale by desired temperature
        logits = logits[:, -1, :] / temps[temp_index]
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float('Inf')
        # apply softmax to convert logits to (normalized) probabilities
        probs = F.softmax(logits, dim=-1)
        # either sample from the distribution or take the most likely element
        idx_next = torch.multinomial(probs, num_samples=1)
        # append sampled index to the running sequence and continue
        idx = torch.cat((idx, idx_next), dim=1)
        if idx.shape[1] > segments_indicies[temp_index] and temp_index +1 < len(temps):
            temp_index += 1
        
        if idx_next.item() == eos_token_id:
            break
    return idx.to(torch.long)

In [9]:
import numpy as np
n_segments = 7
indicies = np.linspace(0, len(entropies), n_segments+1).astype(int)
entropies_median = [np.median(entropies[indicies[i]: indicies[i+1]]) for i in range(n_segments)]
temps = [v* 2 for v in entropies_median]
temps


[0.8504, 0.8922, 0.8676, 0.9131, 0.9812, 1.0258, 1.1472]

In [10]:
from tqdm import tqdm
device = "cuda"
generated_sequences = []
batch_size = 1
seqs = fasta2dict("./mdh_train_sample.fasta")["Seq"]
for batch_idx in tqdm(range(0, len(sequences), batch_size)):
        inputs = tokenizer.batch_encode_plus(
            seqs[batch_idx:batch_idx+batch_size],
            padding="longest",
            truncation=True,
            max_length=512,  # Set the maximum length to 512
            return_tensors="pt"
        ).to(device)

        input_ids = inputs.input_ids
        attention_mask = inputs.attention_mask
        prompt = input_ids[:, :config["sequence_prompt_index"]]
        # Generate multiple prompt
        with torch.no_grad():

                outputs = generate(model, input_ids.shape[1], indicies[1:], temps, prompt.clone().long())
                # outputs = model.generate(
                    
                #     attention_mask=attention_mask[:, :config["sequence_prompt_index"]],
                #     **config["generate_kwargs"]
                #     # num_return_sequences=5
                # )
            # Decode and store the generated sequences
        for generated_output in outputs:
            generated_text = tokenizer.decode(generated_output, skip_special_tokens=True)
            generated_text = generated_text.replace("|endoftext|>", "")  # Remove the header
            generated_sequences.append(generated_text)


  0%|          | 0/500 [00:00<?, ?it/s]