[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/apetulante/Tutorials/blob/master/Concepts/Proxy_Tuning/Proxy_Tuning.ipynb)


Follows example in this tutorial: https://lightning.ai/lightning-ai/studios/improve-llms-with-proxy-tuning

From this OG paper: https://arxiv.org/pdf/2401.08565

# Tuning by **Proxy**

Proxy tuning was a concept introduced by Liu et al. (2024). It's a simple, clever way to produced fine tuned-like results for a larger model, *without ever having to fine-tune the larger model*. Instead, you use a fine-tuned version of a smaller model as a "proxy", to show the larger model where its errors are.

This method can save considerable computational resources and time.

## 0. Background and Introduction

In this notebook, we will implement and demonstrate the concept of proxy-tuning using LLaMA models. Proxy-tuning involves using a smaller, fine-tuned model (expert) and its untuned counterpart (anti-expert) to guide the output of a larger, untuned base model *without modifying its weights*. This approach allows us to leverage the strengths of a fine-tuned model to influence a larger model's behavior.

### Why Proxy-Tuning Works

#### Logits as Raw Predictions

Logits are the raw, unnormalized scores produced by a model for each possible output token. They contain rich information about the model's confidence in its predictions.

#### Behavior of Expert and Base Models

- The expert model is fine-tuned on a specific task and has learned to generate high-quality outputs for that task. Its logits reflect this specialized knowledge.
- The base model, while more powerful, is more general and not fine-tuned for the specific task. Its logits reflect a broader understanding but might lack task-specific refinements.

#### Adjusting Logits

By subtracting the logits of the base model from the logits of the expert model, we get a difference that highlights the task-specific adjustments made by the expert model. This difference captures the fine-tuning adjustments. Adding this difference to the logits of the target model (which is similar to the base model but possibly larger) steers the target model's predictions towards those of the expert model.

#### Mathematical Intuition

Let $ L_{\text{anti-expert}} $ be the logits of the anti-expert model.
Let $ L_{\text{expert}} $ be the logits of the expert model.
Let $ L_{\text{target}} $ be the logits of the target model.

The proxy-tuned logits are computed as:

  $
  L_{\text{proxy-tuned}} = L_{\text{target}} + (L_{\text{expert}} - L_{\text{anti-expert}})
  $

This formula adjusts $ L_{\text{target}} $ by incorporating the task-specific modifications captured in $ L_{\text{expert}} - L_{\text{anti-expert}} $.

And it really is just that simple!


## 1. Setup

We'll be making use of huggingface transformers for this tutorial. Llama 2 7B and 13B are both available on huggingface for us to download and use, as are their respective chat-trained versions.

First, you'll need to go to (just one) of the repos and agree to share your contact information to gain access:
- 7B repo: https://huggingface.co/meta-llama/Llama-2-7b-hf
- 7B (chat version) repo: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
- 13B repo: https://huggingface.co/meta-llama/Llama-2-13b-hf

We'll begin by installing some additional packages. This cell might take a moment to run, and you may need to restart the kernel once its finished to be sure you have use of the packages.

In [None]:
%%capture
# we don't need to see package install messages
!pip install accelerate datasets

Then, you'll need your huggingface token to access the repos.

If you need help finding your token or making one, check out this tutorial: https://huggingface.co/docs/hub/security-tokens

In [None]:
from huggingface_hub import notebook_login

# This will prompt you to enter your Hugging Face token
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## 2. Load in and Test LLaMA Models

We'll start by taking a closer look at the models we'll be working with.

In the interest of saving compute and time, we'll be working with Llama 2 models, both 7-billion parameter and 13-billion parameter versions.

Let's start by loading our 3 models. Llama 7b-chat will be our **expert model**. It's been fine-tuned to answer questions.
Llama-7b is our **anti-expert**, it's like our expert, but without the behavior that we desire. Llama 13b will be the model that we'll proxy-tune to improve, aiming to get its performance closer to Llama 13b-chat.

In [None]:
import torch
from transformers import AutoTokenizer, pipeline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")

# Load the pipelines for the models
pipeline_7b_chat = pipeline("text-generation", model="meta-llama/Llama-2-7b-chat-hf",
                            model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")

pipeline_7b = pipeline("text-generation", model="meta-llama/Llama-2-7b-hf",
                       model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")

pipeline_13b = pipeline("text-generation", model="meta-llama/Llama-2-13b-hf",
                        model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")

print('Models loaded and ready!')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/610 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/33.4k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/9.95G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/9.90G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/6.18G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]



tokenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

Let's see what Llama 7B says when it's not chat fine-tuned.

We'll ask it a simple math question. We don't expect an "answer" here. Remember, that it's going to try to finish what we were saying, so the response can be a huge range of things!

In [None]:
input_text = '''If I have 5 apples and eat 2,
                  but then find 4 more on my way home,
                  how many do I have?'''

In [None]:
answer_7b = pipeline_7b(input_text)[0]['generated_text']
print(answer_7b)

If I have 5 apples and eat 2,
                  but then find 4 more on my way home,
                  how many do I have?

                  ANSWER: 7

  */

  // Write your code here
  var apples = 5;

  if (apples == 2) {
    console.log("I have 7 apples.");
  } else {
    console.log("I have " + (apples - 2) + " apples.");
  }
}

// Write your code here
function countApples(apples) {
  if (apples == 2) {
    console.log("I have 7 apples.");
  } else {
    console.log("I have " + (apples - 2) + " apples.");
  }
}

countApples(5);



Now, let's see what Llama 7B chat has to say. This will attempt to respond to what we asked, but Llama 7B is a small model, so it might not be the best at the reasoning problem that we've asked.

In [None]:
answer_7b_chat = pipeline_7b_chat(input_text)[0]['generated_text']
print(answer_7b_chat)

If I have 5 apples and eat 2,
                  but then find 4 more on my way home,
                  how many do I have?

A) 5
B) 7
C) 8
D) 10

Answer: B) 7

Explanation: You started with 5 apples and ate 2, so you had 5 - 2 = 3 apples left. Then, on your way home, you found 4 more apples, so you now have 3 + 4 = 7 apples.


We can see that LLaMA 7b-chat is able to act as an "expert" in question answering, because it knows how to answer a question we've asked, whereas the base Llama 7b couldn't.

## 3. Proxy Fine-Tuning LLaMA 13B

Now, let's proxy tune Llama 13B, leveraging the difference between Llama 7B and Llama 7B-chat to better guide its output.

In [None]:
# Proxy-tuning function
def generate_proxy_tuning(pipeline_base, pipeline_tuned, pipeline_target, tokenizer, input_text, max_length):
    """
    Generates text using proxy tuning by combining logits from base, tuned, and target models.

    Parameters:
    pipeline_base (transformers.pipeline): The base model pipeline.
    pipeline_tuned (transformers.pipeline): The tuned (expert) model pipeline.
    pipeline_target (transformers.pipeline): The target model pipeline.
    tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for encoding and decoding.
    input_text (str): The input text to start the generation.
    max_length (int): The maximum length of the generated text.

    Returns:
    str: The generated text after proxy tuning.
    """
    model_base = pipeline_base.model
    model_tuned = pipeline_tuned.model
    model_target = pipeline_target.model

    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)

    generated_tokens = []

    with torch.no_grad():
        for _ in range(max_length):
            # Proxy-tuning:
            logits_base = model_base(input_ids).logits
            logits_tuned = model_tuned(input_ids).logits
            logits_target = model_target(input_ids).logits
            logits = (
                logits_target.to(device)
              + (logits_tuned.to(device) - logits_base.to(device))
            )

            predictions = torch.softmax(logits[:, -1, :], dim=-1)
            next_token_id = torch.argmax(predictions).unsqueeze(0)
            generated_tokens.append(next_token_id.item())

            # Append the new token to the input sequence for the next iteration
            input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)

            if next_token_id.item() == tokenizer.eos_token_id:
                break

    generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)

    return generated_text

# Example input text
input_text = "If I have 5 apples and eat 2, but then find 4 more on my way home, how many do I have?"

# Generate text using proxy tuning
generated_text = generate_proxy_tuning(
    pipeline_base=pipeline_7b,
    pipeline_tuned=pipeline_7b_chat,
    pipeline_target=pipeline_13b,
    tokenizer=tokenizer,
    input_text=input_text,
    max_length=60
)

print("Proxy-Tuned Generated Text:")
print(generated_text)


Proxy-Tuned Generated Text:


Answer: 
You have 5 apples initially, then you eat 2, so you have 5 - 2 = 3 apples left. Then, you find 4 more apples on your way home, so you have 3 + 4 = 7


The proxy-tuned generated text should reflect the influence of the fine-tuned expert model (LLaMA 7B Chat) on the larger base model (LLaMA 13B). The output should be coherent and similar to what the expert model would generate, without directly fine-tuning the base model.

Indeed, we can confirm, by checking how Llama 13B and LLama 13B-chat perform on their own.


---


** Note: I ran the following cells by restarting the runtime and running just these cells, to deal with GPU memory constraints.*


In [None]:
import torch
from transformers import AutoTokenizer, pipeline

pipeline_13b = pipeline("text-generation", model="meta-llama/Llama-2-13b-hf",
                        model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")

pipeline_13b_chat = pipeline("text-generation", model="meta-llama/Llama-2-13b-chat-hf",
                        model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
input_text = '''If I have 5 apples and eat 2,
                  but then find 4 more on my way home,
                  how many do I have?'''

In [None]:
answer_13b = pipeline_13b(input_text)[0]['generated_text']
print(answer_13b)

If I have 5 apples and eat 2,
                  but then find 4 more on my way home,
                  how many do I have?
                  Answer: 7
                </p>
                <p>
                  If I have 10 apples and eat 4,
                  but then find 6 more on my way home,
                  how many do I have?
                  Answer: 10
                </p>
                <p>
                  If I have 10 apples and eat 4,
                  but then find 6 more on my way home,
                  how many do I have?
                  Answer: 10
                </p>
              </div>
            </div>
          </div>
        </div>
      </div>
    </div>
  </div>
);

export default App;



In [None]:
answer_13b_chat = pipeline_13b_chat(input_text)[0]['generated_text']
print(answer_13b_chat)

If I have 5 apples and eat 2,
                  but then find 4 more on my way home,
                  how many do I have?

  Solution:
  You have 6 apples.

  Why:
  When you started, you had 5 apples.
  You ate 2, so you had 5 - 2 = 3 apples left.
  Then, you found 4 more apples, so you have 3 + 4 = 7 apples.

  So, the final answer is 7 apples.
