# Example notebook for tuning a prompt for RITA on protein family PF03272.

In [None]:
#@title Setup for Colab only
import os
colab = 'google.colab' in str(get_ipython())
if colab:
  colab_prefix = "/content/drive/MyDrive/"
  !pip install transformers==4.20.1
  !pip install biopython
  !pip install git+https://github.com/AndreaNathansen/protein-prompt-tuning.git#egg=protein-prompt-tuning --log PIP_LOG

  from google.colab import drive
  drive.mount('/content/drive')
  # Download datasets
  if not os.path.exists(colab_prefix + "datasets"):
    os.makedirs(colab_prefix + "datasets") 
  !wget -N https://raw.githubusercontent.com/AndreaNathansen/protein-prompt-tuning/main/datasets/InterProUniprotPF03272prepared_train.fasta -P /content/drive/MyDrive/datasets/
  !wget -N https://raw.githubusercontent.com/AndreaNathansen/protein-prompt-tuning/main/datasets/InterProUniprotPF03272prepared_validation.fasta -P /content/drive/MyDrive/datasets/
  !wget -N https://raw.githubusercontent.com/AndreaNathansen/protein-prompt-tuning/main/datasets/InterProUniprotPF03272prepared_test.fasta -P /content/drive/MyDrive/datasets/
else:
  colab_prefix=""

In [None]:
import json
from pathlib import Path

import torch
from torch.utils.data import DataLoader
from transformers import AdamW, AutoTokenizer, AutoModelForCausalLM

from mkultra.evaluator import Evaluator
import mkultra.sequence_loader as sequence_loader
from mkultra.trainers import SoftPromptTrainer
from mkultra.tuning import RITAPromptTuningLM

## Set up training

In [None]:
seed = 1234567890

In [None]:
sp_name = "RITA-prompt-tuning-example"
# Specify the project directory base.
project_dir = f"{colab_prefix}soft_prompts/{sp_name}/"

if not os.path.exists(project_dir):
    os.makedirs(project_dir)

model_name = "lightonai/RITA_s"

In [None]:
n_tokens = 10
block_size = 1014 # 1024 - 10 (prompt size)
batch_size = 2
optimizer_params = {"lr": 0.001}
num_epochs = 2
checkpoint_interval = 1
patience = 2
init_from_vocab = True

In [None]:
dataset_file_train = colab_prefix + "datasets/InterProUniprotPF03272prepared_train.fasta"
dataset_file_validation = colab_prefix + "datasets/InterProUniprotPF03272prepared_validation.fasta"
dataset_file_test = colab_prefix + "datasets/InterProUniprotPF03272prepared_test.fasta"

In [None]:
model = RITAPromptTuningLM.from_pretrained(model_name).half().to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
dataset = sequence_loader.FastaDataset(dataset_file_train, tokenizer, block_size, tokenizer.vocab['<PAD>'], tokenizer.vocab['<EOS>'], tokenizer.vocab['<EOS>'])
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataset_val = sequence_loader.FastaDataset(dataset_file_validation, tokenizer, block_size, tokenizer.vocab['<PAD>'], tokenizer.vocab['<EOS>'], tokenizer.vocab['<EOS>'])
dataloader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False)

In [None]:
trainer = SoftPromptTrainer(
    model=model,
    optimizer_class=AdamW,
    optimizer_params=optimizer_params,
    project_dir=project_dir,
    data_loader_train=dataloader,
    data_loader_eval=dataloader_val,
    checkpoint_interval=checkpoint_interval,
    patience=patience,
    n_tokens=n_tokens,
    shuffle_seed=seed,
    init_from_vocab=init_from_vocab,
    prompt_init_seed=seed)

trainer.train(num_epochs=num_epochs)    

## Evaluate trained prompt

In [None]:
dataset_test = sequence_loader.FastaDataset(dataset_file_test, tokenizer, block_size, tokenizer.vocab['<PAD>'], tokenizer.vocab['<EOS>'], tokenizer.vocab['<EOS>'])
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

In [None]:
evaluator = Evaluator(
        model=model,
        is_prompt_tuned=True,
        data_loader_test=dataloader_test,
        project_dir=project_dir)
perplexity = evaluator.evaluate_perplexity()
print(perplexity)

Compare to base model

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).half().to("cuda")

In [None]:
base_evaluator = Evaluator(
        model=base_model,
        is_prompt_tuned=False,
        data_loader_test=dataloader_test)
base_perplexity = base_evaluator.evaluate_perplexity()
print(base_perplexity)

## Generate sequences with the prompt-tuned model

In [None]:
# leave out the EOS token that the RITA tokenizer always appends
input_ids = tokenizer("<EOS>", return_tensors="pt").input_ids[:, :-1].to("cuda")
output = model.generate(input_ids=input_ids, max_length=block_size, do_sample=True, top_k=950, repetition_penalty=1.2, 
                    num_return_sequences=2, eos_token_id=2)
sequences = [tokenizer.decode(output_ids) for output_ids in output]                   
print([sequence.replace('<EOS>','').replace(' ', '') for j, sequence in enumerate(sequences)])