In [2]:
import pandas as pd

In [4]:
df = pd.read_csv("MITLL_AAlphaBio_Ab_Binding_dataset.csv")

In [6]:
df["input"] = df["CDRH3"] + df["CDRL3"]

In [8]:
df["output"] = df["Sequence"]


In [9]:
with open("MITLL_AAlphaBio_Ab_Binding_dataset.csv", "w") as f:
    for i in range(len(df)):
        f.write(f"{df.loc[i, 'input']} -> {df.loc[i, 'output']}\n")

In [12]:
!pip install transformers datasets peft accelerate bitsandbytes --progress-bar=on




In [14]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType

model_name = "nferruz/ProtGPT2"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # Required for padding

# Load model
model = AutoModelForCausalLM.from_pretrained(model _name)

# LoRA config
lora_config = LoraConfig(
    r=8,                      # LoRA rank
    lora_alpha=32,            # Scaling factor
    target_modules=["c_attn"],# GPT2-specific modules
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

# Add LoRA adapters
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()


The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable.


trainable params: 1,474,560 || all params: 775,504,640 || trainable%: 0.1901




In [33]:
small_dataset = dataset.select(range(100))  # first 10k samples only


In [34]:
from datasets import Dataset

# Load your dataset (you can use select(range(300)) to test faster)
with open("MITLL_AAlphaBio_Ab_Binding_dataset.csv", "r") as f:
    lines = f.readlines()
dataset = Dataset.from_list([{"text": l.strip()} for l in lines])

def tokenize_function(example):
    return tokenizer(
        example["text"],
        truncation=True,
        padding="max_length",
        max_length=256,
    )

tokenized_small_dataset = small_dataset.map(tokenize_function, batched=True)



Map: 100%|██████████| 100/100 [00:00<00:00, 1187.32 examples/s]


In [37]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./protgpt2-lora-finetuned",
    run_name="protgpt2-lora-run",  # Appears in wandb
    per_device_train_batch_size=4,  # Increased batch size if memory allows
    gradient_accumulation_steps=2,  # Effective batch size = 8
    num_train_epochs=1,  # Or reduce to 1 for quick iteration
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=100,  # Log loss etc. every 100 steps
    report_to="wandb",  # or "none" if not using wandb
    fp16=True,  # Mixed precision (speeds up training on GPU)
    dataloader_num_workers=2,  # Useful on Colab or multi-core
    warmup_steps=50,  # Helps prevent instability at start
)


In [31]:
pip install wandb

Note: you may need to restart the kernel to use updated packages.


In [38]:
from transformers import Trainer, DataCollatorForLanguageModeling

# Use this if you're training a causal language model (like ProtGPT2)
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False  # Important: set to False for autoregressive models
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_small_dataset,
    data_collator=data_collator,
)

# Start training
trainer.train()


No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss


TrainOutput(global_step=13, training_loss=8.213111290564903, metrics={'train_runtime': 391.932, 'train_samples_per_second': 0.255, 'train_steps_per_second': 0.033, 'total_flos': 109035257856000.0, 'train_loss': 8.213111290564903, 'epoch': 1.0})

In [39]:
from transformers import pipeline

def generate_antigen(cdr_sequence: str, max_length: int = 100, num_return_sequences: int = 1):
    """
    Generate antigen sequence(s) from a given CDR input using the trained model.

    Args:
        cdr_sequence (str): The input CDR amino acid sequence.
        max_length (int): Maximum length of the generated antigen sequence.
        num_return_sequences (int): Number of sequences to return.

    Returns:
        List[str]: Generated antigen sequences.
    """
    # Prepare the text generation pipeline
    generator = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        device=0 if torch.cuda.is_available() else -1
    )

    # Prompt format (adjust based on training formatting)
    prompt = f"CDR: {cdr_sequence} Antigen:"

    # Generate output
    outputs = generator(
        prompt,
        max_length=max_length,
        num_return_sequences=num_return_sequences,
        do_sample=True,
        temperature=1.0,
        top_k=50,
        top_p=0.95,
        pad_token_id=tokenizer.eos_token_id
    )

    # Extract just the generated antigen sequences
    antigen_sequences = [output["generated_text"].split("Antigen:")[-1].strip() for output in outputs]
    return antigen_sequences


In [40]:
cdr_input = "EVQLVETGGGLVQPGGSLRLSCAASGFTLNSYGISWVRQAPGKGPEWV"
generated_antigens = generate_antigen(cdr_input)

for i, antigen in enumerate(generated_antigens, 1):
    print(f"Generated Antigen {i}: {antigen}")


Device set to use cpu
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Both `max_new_tokens` (=256) and `max_length`(=100) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Generated Antigen 1: DPTRHYAHVDCPGHADYVKNMITGAAQMDGAILVVAATDGPMPQTREHILLGRQVGVPYIIVF
ANKMDMVDDEELLELVEMEVRDLLTQYEFDGDNAPIIRGSALKALEGDKELGEKAVMELPDG
TPFLEAIEKFDAIPADHDRPFTFPTRYTTKDQFTVKGQVQVFDGRLTKGTEMVMPGDNTAXV
TFXFTAPIAMAMGLRFAIREGGRTVGSGVITEVVD
