Cell 1: Imports and Settings

In [None]:
from transformers import AutoModelForCausalLM
from tokenizers import Tokenizer
import torch

MODEL_NAME = "hugohrban/progen2-medium"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Tags and N-terminal fragments for biomass-degrading enzyme families
TAGS = {
    "GH10": "<|GH10|>",
    "GH11": "<|GH11|>",
    "GH5": "<|GH5|>",
    "GH48": "<|GH48|>",
    "CE1": "<|CE1|>",
    "PL1": "<|PL1|>",
    "PL7": "<|PL7|>"
}

PROMPT_FRAGMENTS = {
    "GH10": "MSKQSSQASGASRAVYAKYT",
    "GH11": "MKYLLPTAAFCLVSCLALAA",
    "GH5": "MSKSFVILFSFLSAVTVLAK",
    "GH48": "MKFNSRLLISVTLAVAGSSS",
    "CE1": "MALQFLLLVVLLLSHQAQA",
    "PL1": "MKAVAAIAAVASLAGSVLAE",
    "PL7": "MNSTTAIALGAVPAAALTYA"
}

MAX_LENGTH = 1024
SAMPLES_PER_FAMILY = 5




Cell 2: Load Model and Tokenizer


In [None]:
print("Loading model and tokenizer...")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True).to(DEVICE)
tokenizer = Tokenizer.from_pretrained(MODEL_NAME)
tokenizer.no_padding()
print("Model loaded successfully...")


Cell 3: Define Sequence Generation Function


In [None]:
def generate_sequences(tag, seed_fragment, n_samples=5):
    prompt = f"MASK_START {tag} {seed_fragment}"
    input_ids = tokenizer.encode(prompt).ids
    input_ids = torch.tensor(input_ids).unsqueeze(0).to(DEVICE)

    generated_sequences = []
    for _ in range(n_samples):
        generated = input_ids.clone()
        with torch.no_grad():
            for _ in range(MAX_LENGTH - input_ids.size(1)):
                outputs = model(generated)
                next_token_logits = outputs.logits[:, -1, :]
                probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                generated = torch.cat((generated, next_token), dim=1)

                # Optional: break if EOS token is generated
                if next_token.item() == tokenizer.token_to_id(""):
                    break

        decoded = tokenizer.decode(generated[0].tolist())
        generated_sequences.append(decoded)

    return generated_sequences
print("Generation Function defined successfully...")


Cell 4: Generate Sequences for Each Family


In [None]:
print("Starting generation process...")


import time

def format_duration(seconds):
    minutes, seconds = divmod(int(seconds), 60)
    return f"{minutes:02d}:{seconds:02d}"

results = {}
total_start = time.time()

for fam, tag in TAGS.items():
    print(f"Generating sequences for {fam}...")
    fragment = PROMPT_FRAGMENTS[fam]

    start = time.time()
    sequences = generate_sequences(tag, fragment, SAMPLES_PER_FAMILY)
    duration = time.time() - start

    print(f"✅ {fam} done in {format_duration(duration)} (mm:ss)")
    results[fam] = sequences

total_end = time.time()
print(f"🏁 Total generation time: {format_duration(total_end - total_start)} (mm:ss)")


print("Generation completed...")


Cell 5: Save Generated Sequences to FASTA File



In [None]:
with open("generated_sequences.fasta", "w") as f:
    for fam, seqs in results.items():
        for i, seq in enumerate(seqs):
            f.write(f">{fam}_sample_{i+1}\n{seq}\n")

print("✅ Sequences saved to 'generated_sequences.fasta'")
