In [None]:
!pip install -q pandas numpy biopython tqdm transformers accelerate sentencepiece peft
import torch, math, time, os, csv, contextlib, gc
import numpy as np
import pandas as pd
from tqdm import tqdm
from Bio.SeqUtils.ProtParam import ProteinAnalysis
from transformers import LlamaForCausalLM, AutoTokenizer

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "GreatCaptainNemo/ProLLaMA"

_PROL = {"model": None, "tok": None}

def load_prollama():
    if _PROL["model"] is not None:
        return _PROL["model"], _PROL["tok"]

    dtype = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else (
        torch.float16 if torch.cuda.is_available() else torch.float32
    )
    tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    model = LlamaForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=dtype,
        low_cpu_mem_usage=True,
    ).to(DEVICE).eval()
    model.config.pad_token_id = tok.pad_token_id
    _PROL.update(model=model, tok=tok)
    print(f"[ProLLaMA] Loaded dtype={dtype}, device={DEVICE}")
    return model, tok


In [None]:

_AA = set("ACDEFGHIKLMNPQRSTVWY")

def extract_aa(s: str, max_len: int) -> str:
    s = "".join(ch for ch in s if ch in _AA)
    return s[:max_len] if len(s) >= max_len else s.ljust(max_len, "A")

SUPERFAMILY_PROMPT = "Ankyrin repeat-containing domain superfamily"

@torch.no_grad()
def generate_prollama_one(model, tok, length=100, temperature=0.7, top_k=40, top_p=0.9,
                          max_new_tokens=512, repetition_penalty=1.2, superfamily=SUPERFAMILY_PROMPT):
    prompt = f"[Generate by superfamily] Superfamily=<{superfamily}>"
    inp = tok(prompt, return_tensors="pt")
    out_ids = model.generate(
        input_ids=inp["input_ids"].to(DEVICE),
        attention_mask=inp["attention_mask"].to(DEVICE),
        do_sample=True, temperature=temperature, top_k=top_k, top_p=top_p,
        max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty,
        eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id,
    )
    decoded = tok.decode(out_ids[0], skip_special_tokens=True)
    return extract_aa(decoded, max_len=length)

@torch.no_grad()
def generate_prollama_batch(model, tok, n_total=64, length=100, batch_size=8,
                            temperature=0.7, top_k=40, top_p=0.9,
                            max_new_tokens=512, repetition_penalty=1.2, superfamily=SUPERFAMILY_PROMPT):

    out = []
    prompt = f"[Generate by superfamily] Superfamily=<{superfamily}>"
    enc = tok([prompt] * batch_size, return_tensors="pt", padding=True)
    while len(out) < n_total:
        B = min(batch_size, n_total - len(out))
        input_ids = enc["input_ids"][:B].to(DEVICE)
        attn_mask = enc["attention_mask"][:B].to(DEVICE)
        out_ids = model.generate(
            input_ids=input_ids, attention_mask=attn_mask,
            do_sample=True, temperature=temperature, top_k=top_k, top_p=top_p,
            max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty,
            eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id,
        )
        decoded = tok.batch_decode(out_ids, skip_special_tokens=True)
        out.extend(extract_aa(s, max_len=length) for s in decoded)
    return out[:n_total]


In [None]:
traits = ["charge_pH7","gravy","aromaticity","instability_index","mol_weight","iso_point"]
steer_down_traits = {"aromaticity", "instability_index"}

def get_score(seq: str, trait_name: str):
    try:
        pa = ProteinAnalysis(seq)
        fns = {
            "charge_pH7":        lambda x: x.charge_at_pH(7.0),
            "gravy":             lambda x: x.gravy(),
            "aromaticity":       lambda x: x.aromaticity(),
            "instability_index": lambda x: x.instability_index(),
            "mol_weight":        lambda x: x.molecular_weight(),
            "iso_point":         lambda x: x.isoelectric_point(),
        }
        return fns[trait_name](pa)
    except Exception:
        return np.nan


In [None]:
N_SAMPLES   = 500
SEQ_LEN     = 100
BATCH_GEN   = 16    
TEMP        = 0.7
TOP_K       = 40
TOP_P       = 0.9
MAX_NEW_TOK = 512
REP_PEN     = 1.2
SUPERFAM    = SUPERFAMILY_PROMPT


model, tok = load_prollama()

class OnlineStats:
    def __init__(self): self.n=0; self.mean=0.0; self.M2=0.0
    def add(self, x):
        self.n += 1
        d = x - self.mean
        self.mean += d / self.n
        self.M2 += d * (x - self.mean)
    def finalize(self):
        if self.n < 2: return self.mean, float("nan"), float("nan")
        var = self.M2 / (self.n - 1)
        std = math.sqrt(var)
        ci95 = 1.96 * std / math.sqrt(self.n)
        return self.mean, std, ci95

for prop in traits:
    out_path = f"prollama_generated_{prop}.csv"
    write_header = not os.path.exists(out_path)
    stats = OnlineStats()
    produced = 0
    start = time.time()

    with open(out_path, "a", newline="") as f:
        writer = csv.writer(f)
        if write_header:
            writer.writerow(["sequence", "score"])

        with tqdm(total=N_SAMPLES, desc=f"Generating ({prop})", leave=True, dynamic_ncols=True, mininterval=0.2) as pbar:
            while produced < N_SAMPLES:
                need = N_SAMPLES - produced
                this_batch = min(BATCH_GEN, need)
                seqs = generate_prollama_batch(
                    model, tok,
                    n_total=this_batch, length=SEQ_LEN, batch_size=this_batch,
                    temperature=TEMP, top_k=TOP_K, top_p=TOP_P,
                    max_new_tokens=MAX_NEW_TOK, repetition_penalty=REP_PEN,
                    superfamily=SUPERFAM,
                )
                for s in seqs:
                    sc = get_score(s, prop)
                    if not np.isnan(sc):
                        writer.writerow([s, sc])
                        stats.add(float(sc))
                        produced += 1
                        pbar.update(1)
                        # live ETA every batch
                        if produced % max(1, BATCH_GEN) == 0:
                            elapsed = time.time() - start
                            rate = produced / max(elapsed, 1e-9)
                            remain = (N_SAMPLES - produced) / max(rate, 1e-9)
                            pbar.set_postfix_str(f"{rate:.1f} seq/s | ETA {remain/60:.1f}m")
                    if produced >= N_SAMPLES:
                        break

    mean, std, ci95 = stats.finalize()
    print(f"[SUMMARY] {prop}: mean={mean:.6f} | std={std:.6f} | 95% CI=Â±{ci95:.6f} | n={stats.n}")

    # aggressive cleanup between properties
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()
