In [1]:
import torch
print(torch.__version__)  # should be 2.6.0.dev...


2.6.0.dev20241112+cu121


In [2]:
import pandas as pd

df = pd.read_csv("MITLL_AAlphaBio_Ab_Binding_dataset.csv")
df["input"] = df["CDRH3"] + df["CDRL3"]
df["output"] = df["Sequence"]

with open("formatted_dataset.txt", "w") as f:
    for i in range(len(df)):
        f.write(f"{df.loc[i, 'input']} -> {df.loc[i, 'output']}\n")


In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "nferruz/ProtGPT2"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # required for GPT-2

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    use_safetensors=True,  # ✅ Enforce safetensors
)


In [5]:
from peft import get_peft_model, LoraConfig, TaskType

lora_config = LoraConfig(
    r=16,
    lora_alpha=64,
    target_modules=["c_attn"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()


trainable params: 2,949,120 || all params: 776,979,200 || trainable%: 0.3796




In [11]:
from datasets import load_dataset, Dataset

with open("formatted_dataset.txt") as f:
    lines = f.readlines()

dataset = Dataset.from_list([{"text": l.strip()} for l in lines])
subset = dataset.select(range(700))  

def tokenize(example):
    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=256)

tokenized_dataset = subset.map(tokenize, batched=True)


Map: 100%|██████████| 700/700 [00:00<00:00, 3621.10 examples/s]


In [12]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./protgpt2-lora-finetuned",
    num_train_epochs=2,  # ✅ Adjusted
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    logging_dir="./logs",
    logging_steps=50,
    save_strategy="epoch",
    fp16=True,
    report_to="none",
)


In [13]:
from transformers import Trainer, DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

trainer.train()


No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss
50,6.8834
100,5.1525
150,4.1986


TrainOutput(global_step=176, training_loss=5.184010765769265, metrics={'train_runtime': 1420.1902, 'train_samples_per_second': 0.986, 'train_steps_per_second': 0.124, 'total_flos': 1529664503808000.0, 'train_loss': 5.184010765769265, 'epoch': 2.0})

In [15]:
from transformers import pipeline

generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1,
)

def generate_antigen(cdr_sequence):
    prompt = f"CDR: {cdr_sequence} Antigen:"
    outputs = generator(prompt, max_new_tokens=500, do_sample=True)
    return [out["generated_text"].split("Antigen:")[-1].strip() for out in outputs]

cdr_input = "GRAAGTFDSQQYHRLPLS"
antigens = generate_antigen(cdr_input)
print(antigens)


Device set to use cuda:0


['NMGDGCELMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMSNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGN

In [22]:
from transformers import pipeline

generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1,
)

def generate_antigen(cdr_sequence):
    prompt = f"CDR: {cdr_sequence} Antigen:"
    outputs = generator(prompt, max_new_tokens=500, do_sample=True)
    return [out["generated_text"].split("Antigen:")[-1].strip() for out in outputs]

cdr_input = "GRAAGTFDSQQYHRLPLS"
antigens = generate_antigen(cdr_input)
print(antigens)


Device set to use cuda:0


['VHGGQQFHPGQGGQGFVQGG\nGWGQGGQQVCQSQTVSGGFGGGFGGGFGGGFGGGFGGGFGGFGGGFGGGFGGGFGGFGGNQ\nEVKTSYSAQTVQSNRVSGGQSSGGQSGGLGGGGFGGAQGGGFGGSSGGGFGGSSGGFGGN\nQGGSSGGGFGGGNQGGSSGGGFGSGGFGGGQGGSSGGGFGGSSGGFGGNQGGSSGGGFGG\nNSGGSSGGGFGGGQGGSSGGGFGASSGGSSGGFGGGSGGGFGSSSGGGFGGGQGGSSGGG\nFGGNQGGSSGGGFGGSSGGGFGGSSGGGFGGSSGGGFGGSSGGGFGGSSGGGFGGSSGGG\nFGGSSGGFGGSSGGGFGGGQGGSSGGGFGGSSGGGFGGSSGGGFGGSSGGFGGSSGGGFG']


In [24]:
from Bio import pairwise2
from Bio.pairwise2 import format_alignment

In [30]:
seq1 = "NMGDGCELMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMSNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEI"
print(len(seq1))
seq2 = "EVQLVETGGGLVQPGGSLRLSCAASGFTLNSYGISWVRQAPGKGPEWVSVIYSDGRRTFYGDSVKGRFTISRDTSTNTVYLQMNSLRVEDTAVYYCAKGRAAGTFDSWGQGTLVTVSSGGGGSGGGGSGGGGSDVVMTQSPESLAVSLGERATISCHSSQSVGYESRMKNSVAWYQQKAGQPPKLLIYWASTRESGVPDRFSGSGSGTDFTLTISSLQAEDAAVYYCQQYHRLPLSFGGGTKVEIK"
print(len(seq2))

1002
246


In [26]:
alignments = pairwise2.align.globalxx(seq1, seq2)
for alignment in alignments:
    print(format_alignment(*alignment))

NMGDGCE--LMGNACEIMGNACEIMGNACEIMGNAC-EIM-GNACEIMGNACEIMGNACEIM----GNACEIMGNACEIMSNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNACEIMGNA----CEIMGNACEIMGNACEIM-G---NACEIM--GNACEIMGN-----ACEIM-GNACEIM-GNAC-E----IM---GNACEIM-----GNACEIM----GNACE---IMG------NACEI-----MGNAC----EIMGN--A---CEIMGNACEIM-GN-ACEIMGNACEIMGNACEIM-----GNACEIM-GNACEIM-------GNACEIMGNACEIMGNACEIMGNACEIM-GNACEIMGNACEIMGNACEIMGNACEIM-GNACEIMGNACEIMGNACEIMGNACEI----MGNAC----EIMGN--ACEIM---GNACEIMGN-ACE-IMGNA-CEIM----