# We built off the excellent PEFT notebook by Younes Belkada

## Fine-tune large models using 🤗 `peft` adapters, `transformers` & `bitsandbytes`

In this tutorial we will cover how we can fine-tune large language models using the very recent `peft` library and `bitsandbytes` for loading large models in 8-bit.
The fine-tuning method will rely on a recent method called "Low Rank Adapters" (LoRA), instead of fine-tuning the entire model you just have to fine-tune these adapters and load them properly inside the model.
After fine-tuning the model you can also share your adapters on the 🤗 Hub and load them very easily. Let's get started!

### Install requirements

First, run the cells below to install the requirements:

In [None]:
!pip install -q bitsandbytes datasets accelerate loralib
!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git
!pip install biopython

### Model loading

Here let's load the `opt-6.7b` model, its weights in half-precision (float16) are about 13GB on the Hub! If we load them in 8-bit we would require around 7GB of memory instead.

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import torch
import torch.nn as nn
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

# used "facebook/opt-6.7b" or "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T". Results in figure from TinyLlama

model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-6.7b",
    load_in_8bit=True,
    device_map='auto',
)



tokenizer = AutoTokenizer.from_pretrained("facebook/opt-6.7b")

### Post-processing on the model

Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons.

In [None]:
for param in model.parameters():
  param.requires_grad = False  # freeze the model - train adapters later
  if param.ndim == 1:
    # cast the small parameters (e.g. layernorm) to fp32 for stability
    param.data = param.data.to(torch.float32)

model.gradient_checkpointing_enable()  # reduce number of stored activations
model.enable_input_require_grads()

class CastOutputToFloat(nn.Sequential):
  def forward(self, x): return super().forward(x).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head)

### Apply LoRA

Here comes the magic with `peft`! Let's load a `PeftModel` and specify that we are going to use low-rank adapters (LoRA) using `get_peft_model` utility function from `peft`.

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [None]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)
print_trainable_parameters(model)

### Training

In [None]:
from Bio import Entrez
import xml.etree.ElementTree as ET
import transformers
from datasets import Dataset, load_dataset
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling

# Set your email
Entrez.email = "your-email@example.com"
query = "material science catalysis AND (\"2017\"[Date - Publication] : \"2020\"[Date - Publication])"

# Search for IDs
handle = Entrez.esearch(db="pubmed", term=query, retmax=10000)
record = Entrez.read(handle)
handle.close()

id_list = record['IdList']
abstracts = []


# Fetch records and parse for abstracts only
for i in range(0, len(id_list), 100):
    fetch_handle = Entrez.efetch(db="pubmed", id=id_list[i:i+100], rettype="xml")
    data = fetch_handle.read()
    fetch_handle.close()

    # Parse XML
    root = ET.fromstring(data)
    for article in root.findall('.//PubmedArticle'):
        abstract_text = article.find('.//Abstract/AbstractText')
        if abstract_text is not None and abstract_text.text is not None:
            abstract = abstract_text.text
            if 800 < len(abstract) < 2200:  # Check length of the abstract
                abstracts.append(abstract)
print(f"Fetched {len(abstracts)} abstracts.")

In [None]:
MINCITES = 200
MAXABSTRACTS = 2000
def getcitations(pmids):
    numcites = {}
    for pmid in pmids:
        q = Entrez.read(Entrez.elink(dbfrom="pubmed", db="pmc", LinkName="pubmed_pubmed_citedin", from_uid=pmid))
        if len(q)>0 and len(q[0]['LinkSetDb'])>0:
            numcites.update({ pmid:len(q[0]['LinkSetDb'][0]['Link']) } )
    return numcites
citedqueries = "material science catalysis AND (\"2017\"[Date - Publication] : \"2020\"[Date - Publication])"
citedhandle = Entrez.esearch(db="pubmed", term=query, retmax=10000)
citedrecord = Entrez.read(citedhandle)
citedhandle.close()

cited_id_list = record['IdList']
citeNumbers = getcitations(cited_id_list)
cited_abstracts = []
for pmid in citeNumbers.keys():
    if len(cited_abstracts)>MAXABSTRACTS:
        break
    if citeNumbers[pmid]>MINCITES:
        fetch_handle = Entrez.efetch(db="pubmed", id=pmid, rettype="xml")
        data = fetch_handle.read()
        fetch_handle.close()

        # Parse XML
        root = ET.fromstring(data)
        for article in root.findall('.//PubmedArticle'):
            abstract_text = article.find('.//Abstract/AbstractText')
            if abstract_text is not None and abstract_text.text is not None:
                abstracts = abstract_text.text
                if 800 < len(abstract) < 2200:  # Check length of the abstract
                    cited_abstracts.append(abstract)

In [None]:
# Assume abstracts is a list of strings containing your fetched data
abstracts_dataset = Dataset.from_dict({'text': abstracts})
tokenized_dataset = abstracts_dataset.map(lambda x: tokenizer(x['text']), batched=True)

# Setup Trainer using the whole dataset
trainer = Trainer(
    model=model,
    train_dataset=tokenized_dataset,
    args=TrainingArguments(
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        warmup_steps=100,
        max_steps=200,
        learning_rate=1e-4,
        fp16=True,
        logging_steps=1,
        output_dir='outputs',
        report_to="wandb"
    ),
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)

model.config.use_cache = False  # Adjust cache settings for training

# Start training
trainer.train()
model.push_to_hub("alxfgh/opt-6.7b-materials-science-catalysis-lora", use_auth_token=True)
wandb.finish()

In [None]:
def computeLogP(abstract):
    abstract = abstract.split(" ")
    inputs = " ".join(abstract[:2])
    outputs = " ".join(abstract[2:])

    input_tokens = tokenizer.encode(inputs, add_special_tokens=False, return_tensors="pt")
    output_tokens = tokenizer.encode(outputs, add_special_tokens=False, return_tensors="pt")
    input_tokens_updated = input_tokens.clone()
    log_sum = 0
    logp_tokens = []
    for i in range(output_tokens.shape[1]):
        # Predict with the given model
        with torch.no_grad():
            outputs = model(input_tokens_updated)
            logit_predictions = outputs.logits

        # Extract the log probability of the most recently added token
        last_token_logit = logit_predictions[0, -1, :]
        last_token_log_probs = torch.nn.functional.log_softmax(last_token_logit, dim=-1)
        log_token_prob = last_token_log_probs[output_tokens[0, i]].item()
        log_sum += log_token_prob
        logp_tokens.append(log_token_prob)

        # Incrementally add an output token to the current sequence
        last_token = tokenizer.decode(output_tokens[:, i])
        input_tokens_updated = torch.cat([input_tokens_updated, output_tokens[:, i:i+1]], dim=1)
        print([tokenizer.decode(token) for token in input_tokens_updated])
        print(f"Token: {last_token}, Log Prob: {log_token_prob}")
    print(f"Total Log Sum Probability: {log_sum}")
    return log_sum,logp_tokens