In [None]:
!pip install -q pandas numpy biopython tqdm transformers accelerate sentencepiece peft

import pandas as pd
import numpy as np
import torch
import random
from tqdm.auto import tqdm
from Bio.SeqUtils.ProtParam import ProteinAnalysis

seed = 42
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Device:", device)


In [None]:
traits = ["charge_pH7","gravy","aromaticity","instability_index","mol_weight","iso_point"]

steer_down_traits = {"aromaticity", "instability_index"}

def is_steer_down(prop: str) -> bool:
    return prop in steer_down_traits

def get_score(seq: str, trait_name: str):
    """Compute Biopython-based property for a protein sequence."""
    try:
        pa = ProteinAnalysis(seq)
        fns = {
            "charge_pH7":        lambda x: x.charge_at_pH(7.0),
            "gravy":             lambda x: x.gravy(),
            "aromaticity":       lambda x: x.aromaticity(),
            "instability_index": lambda x: x.instability_index(),
            "mol_weight":        lambda x: x.molecular_weight(),
            "iso_point":         lambda x: x.isoelectric_point(),
        }
        return fns[trait_name](pa)
    except Exception:
        return np.nan


In [None]:
from transformers import LlamaForCausalLM, AutoTokenizer
from peft import PeftModel
import torch

MODEL_NAME = "GreatCaptainNemo/ProLLaMA"
REV = "main"  
ADAPTER_PATH = None 

dtype = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load base model
model = LlamaForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=dtype,
    low_cpu_mem_usage=True,
    revision=REV,
)
model.to(device).eval()

# Load tokenizer 
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False, revision=REV)
# Ensure PAD exists
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

# Optionally load a LoRA adapter
if ADAPTER_PATH:
    model = PeftModel.from_pretrained(model, ADAPTER_PATH)
    model.to(device).eval()
    print(f"Loaded LoRA adapter from: {ADAPTER_PATH}")

AA = set("ACDEFGHIKLMNPQRSTVWY")

print(f"Loaded ProLLaMA on {device} (dtype={dtype})")


In [None]:
import re
import torch

SUPERFAMILY_PROMPT = "Ankyrin repeat-containing domain superfamily"

def extract_aa(text: str, max_len: int) -> str:
    filtered = "".join(ch for ch in text if ch in AA)
    return filtered[:max_len]

def generate_sequence_prollama(
    model,
    tokenizer,
    length: int = 500,
    temperature: float = 0.7,
    top_k: int = 40,
    top_p: float = 0.9,
    max_new_tokens: int = 512,
    repetition_penalty: float = 1.2,
    superfamily: str = SUPERFAMILY_PROMPT,
) -> str:

    model_device = next(model.parameters()).device
    prompt = f"[Generate by superfamily] Superfamily=<{superfamily}>"

    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(model_device)
    attn_mask = inputs["attention_mask"].to(model_device)

    with torch.no_grad():
        output_ids = model.generate(
            input_ids=input_ids,
            attention_mask=attn_mask,
            do_sample=True,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
    decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    # Extract AA-only and cap to `length`
    return extract_aa(decoded, max_len=length)


In [None]:
import itertools
import sys

n_trials        = 1
n_samples       = 500
sequence_length = 500
temperature     = 0.7
top_k           = 40     
top_p           = 0.9
seed            = 42

for property_label in tqdm(
    traits, desc="Properties", dynamic_ncols=True, mininterval=0.2, leave=True, file=sys.stdout
):
    trial_means, trial_stds, all_scores = [], [], []

    for _ in range(n_trials):
        random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
        if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

        scores = []
        for _ in tqdm(
            range(n_samples),
            desc=f"Generating ({property_label})",
            dynamic_ncols=True,
            mininterval=0.1,
            leave=False,
            file=sys.stdout
        ):
            seq = generate_sequence_prollama(
                model=model,
                tokenizer=tokenizer,
                length=sequence_length,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
            )
            try:
                scores.append(get_score(seq, property_label))
            except Exception:
                scores.append(np.nan)

        arr = np.asarray(scores, dtype=float)
        arr = arr[~np.isnan(arr)]
        trial_mean = float(np.mean(arr)) if len(arr) else float("nan")
        trial_std  = float(np.std(arr, ddof=1)) if len(arr) > 1 else 0.0

        trial_means.append(trial_mean)
        trial_stds.append(trial_std)
        all_scores.append(arr.tolist())

    # CI (n_trials=1 => from per-sample std)
    if n_trials == 1 and all_scores and len(all_scores[0]) > 1:
        per_sample = np.array(all_scores[0], dtype=float)
        n    = len(per_sample)
        mean_ps = float(np.mean(per_sample))
        std_ps  = float(np.std(per_sample, ddof=1))
        ci95    = 1.96 * (std_ps / np.sqrt(n))
    else:
        overall_mean = float(np.mean(trial_means)) if trial_means else float("nan")
        overall_std  = float(np.std(trial_means, ddof=1)) if len(trial_means) > 1 else 0.0
        n = sum(len(x) for x in all_scores) if all_scores else 0
        mean_ps, std_ps = overall_mean, overall_std
        ci95 = 1.96 * (overall_std / np.sqrt(max(n_trials, 1))) if n_trials > 0 else 0.0

    print(f"[SUMMARY] {property_label}: mean={mean_ps:.6f} | std={std_ps:.6f} | 95% CI=Â±{ci95:.6f} | n={n}", flush=True)

    # Save
    pd.DataFrame({'trial': np.arange(1, n_trials+1), 'mean_score': trial_means, 'std_score': trial_stds}) \
      .to_csv(f"prollama_trial_stats_{property_label}.csv", index=False)

    flat_scores = list(itertools.chain.from_iterable(all_scores))
    pd.DataFrame({'score': flat_scores}).to_csv(f"prollama_generated_scores_{property_label}.csv", index=False)
