# Proxy-Tuning with Llama 7

## Runs in Single-GPU Setting

In [1]:
%load_ext watermark
%watermark -p torch,transformers,bitsandbytes

torch       : 2.1.2+cu121
transformers: 4.37.1
bitsandbytes: 0.42.0



Make sure that you use a machine with a GPU.

In [2]:
import torch

if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs available: {num_gpus}")
    if num_gpus < 1:
        raise Exception("At least 1 GPU is required for this setup.")
else:
    raise Exception("No GPUs found, make sure a GPU is available")


Number of GPUs available: 1


### Loading the Models

In this section, we are applying proxy-tuning to boost the performance of a Llama 7B base model to that of a Llama 13B Chat model. (We focus on the 13B model because it conveniently fits into the RAM of an affordable GPU like A10G, but the same methodology could also be applied to 70B models for even more impressive results.)

First, we download and load each model to ensure the models are available:

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import os


model_names = [
    "meta-llama/Llama-2-7b-hf",
    "meta-llama/Llama-2-7b-chat-hf",
    "meta-llama/Llama-2-13b-hf",
    "meta-llama/Llama-2-13b-chat-hf",
]

device = "cuda"

for i, model_name in enumerate(model_names):
    model_dir = os.path.join("models", model_name)
    print("Loading:", model_dir)

    if not os.path.exists(model_dir): # use HF token to download models if they don't exist on disk yet
        hf_token = "your access token"

        if "7b" in model_name:
            model = AutoModelForCausalLM.from_pretrained(
                model_name, token=hf_token, cache_dir="models", device_map=device,
                attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
            )
        else:
            model = AutoModelForCausalLM.from_pretrained(
                model_name, token=hf_token, cache_dir="models", device_map=device,
                attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, load_in_4bit=True
            )
        del model
        tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token, cache_dir="models")
        del tokenizer
    else:
        if "7b" in model_name:
            model = AutoModelForCausalLM.from_pretrained(
                model_dir, device_map=device, attn_implementation="flash_attention_2", 
                torch_dtype=torch.bfloat16
            )
        else:
            model = AutoModelForCausalLM.from_pretrained(
                model_dir, device_map=device, attn_implementation="flash_attention_2",
                torch_dtype=torch.bfloat16, load_in_4bit=True
            )
        del model
        tokenizer = AutoTokenizer.from_pretrained(model_dir)
        del tokenizer


print("Models have been downloaded and are available.")

Loading: models/meta-llama/Llama-2-7b-hf


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading: models/meta-llama/Llama-2-7b-chat-hf


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading: models/meta-llama/Llama-2-13b-hf


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading: models/meta-llama/Llama-2-13b-chat-hf


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Models have been downloaded and are available.


### Generating Text with Unmodified Models

As a reminder, note that the "Llama-2-13b-chat-hf" is a control experiment. We aim to use a smaller available "Llama-2-7b-chat-hf" model to achieve the same performance when we apply it on top of the "Llama-2-13b-hf" base model.

Next, let's check the outputs that these models produce without applying proxy-tuning, yet:

In [4]:
def load_model_and_tokenizer(model_dir):

    hf_token = "your access token"

    if "7b" in model_name:
        model = AutoModelForCausalLM.from_pretrained(
            model_name, token=hf_token, cache_dir="models", device_map=device,
            attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_name, token=hf_token, cache_dir="models", device_map=device,
            attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, load_in_4bit=True
        )
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token, cache_dir="models")
    return model, tokenizer


def generate_text(model, tokenizer, input_text, max_length):
    device = next(model.parameters()).device
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
    generated_tokens = []

    with torch.no_grad():
        for _ in range(max_length):
            
            outputs = model(input_ids)
            logits = outputs.logits
            predictions = torch.softmax(logits[:, -1, :], dim=-1)
            next_token_id = torch.argmax(predictions).unsqueeze(0)
            generated_tokens.append(next_token_id.item())
            
            # Append the new token to the input sequence for the next iteration
            input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)

            if next_token_id.item() == tokenizer.eos_token_id:
                break

    generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)

    return generated_text


model_names = [
    "meta-llama/Llama-2-7b-hf",
    "meta-llama/Llama-2-7b-chat-hf",
    "meta-llama/Llama-2-13b-hf",
    "meta-llama/Llama-2-13b-chat-hf",
]

for model_name in model_names:
    input_text = ("If I have 5 apples and eat 2, but then find 3 more on my way home, how many do I have?")

    print("Loading", model_name)
    model, tokenizer = load_model_and_tokenizer(model_name)
    
    print(f"{50*'='}\n\nAsking {model_name}:")
    generated = generate_text(model, tokenizer, input_text, max_length=60)
    print(generated, f"\nLength: {len((input_text+generated).split(' '))}")

    del model
    del tokenizer

Loading meta-llama/Llama-2-7b-hf


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


Asking meta-llama/Llama-2-7b-hf:

I have 5 apples.
I eat 2.
I find 3 more.
I have 10 apples.
I have 5 apples.
I eat 2.
I find 3 more.
I have 10 apples. 
Length: 44
Loading meta-llama/Llama-2-7b-chat-hf


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


Asking meta-llama/Llama-2-7b-chat-hf:


Answer: You have 5 apples in total.

Explanation:

When you started your journey, you had 5 apples.

When you ate 2 apples, you had 5 - 2 = 3 apples left.
 
Length: 49
Loading meta-llama/Llama-2-13b-hf


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]


Asking meta-llama/Llama-2-13b-hf:

If I have 5 apples and eat 2, but then find 3 more on my way home, how many do I have?
I have 5 apples and eat 2, but then find 3 more on my way home, how many do I have?
 
Length: 63
Loading meta-llama/Llama-2-13b-chat-hf


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]


Asking meta-llama/Llama-2-13b-chat-hf:


You start with 5 apples and eat 2, so you have 5 - 2 = 3 apples left.
Then, you find 3 more apples on your way home, so you have 3 + 3 = 6 apples in total. 
Length: 59


As we can see, only the "Llama-2-13b-chat-hf" model (our control experiment) can achieve the correct response. 

### Proxy-tune the 13B Base model using Llama 2 7B Chat

In [5]:
import sys
import os
import contextlib
from tqdm import tqdm


@contextlib.contextmanager
def suppress_output():
    """Context manager to suppress stdout and stderr."""
    new_stdout, new_stderr = os.devnull, os.devnull
    old_stdout, old_stderr = sys.stdout, sys.stderr
    sys.stdout, sys.stderr = open(new_stdout, 'w'), open(new_stderr, 'w')
    try:
        yield
    finally:
        sys.stdout.close()
        sys.stderr.close()
        sys.stdout, sys.stderr = old_stdout, old_stderr


def generate_proxy_tuning(model_base_name, model_tuned_name, model_target_name, input_text, max_length):
    # Load the tokenizer once as we assume it's the same for all models
    _, tokenizer = load_model_and_tokenizer(model_base_name)

    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
    generated_tokens = []

    with torch.no_grad():
        for _ in tqdm(range(max_length), desc="Generating tokens"):
            # Suppress output during model loading
            with suppress_output():
                model_base, _ = load_model_and_tokenizer(model_base_name)
            logits_base = model_base(input_ids).logits
            del model_base

            with suppress_output():
                model_tuned, _ = load_model_and_tokenizer(model_tuned_name)
            logits_tuned = model_tuned(input_ids).logits
            del model_tuned

            with suppress_output():
                model_target, _ = load_model_and_tokenizer(model_target_name)
            logits_target = model_target(input_ids).logits
            del model_target

            # Calculate final logits
            logits = logits_target + (logits_tuned - logits_base)

            # Generate next token
            predictions = torch.softmax(logits[:, -1, :], dim=-1)
            next_token_id = torch.argmax(predictions).unsqueeze(0)
            generated_tokens.append(next_token_id.item())

            # Append the new token to the input sequence for the next iteration
            input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)

            # Break if end of sequence token is generated
            if next_token_id.item() == tokenizer.eos_token_id:
                break

    generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
    return generated_text


generated = generate_proxy_tuning(
    model_base_name="meta-llama/Llama-2-7b-hf",
    model_tuned_name="meta-llama/Llama-2-7b-chat-hf",
    model_target_name="meta-llama/Llama-2-13b-hf",
    input_text=input_text,
    max_length=100
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:   0%|          | 0/100 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:   1%|          | 1/100 [00:26<43:55, 26.63s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:   2%|▏         | 2/100 [00:53<43:38, 26.72s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:   3%|▎         | 3/100 [01:20<43:16, 26.76s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:   4%|▍         | 4/100 [01:47<42:50, 26.78s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:   5%|▌         | 5/100 [02:14<42:31, 26.85s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:   6%|▌         | 6/100 [02:40<42:04, 26.86s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:   7%|▋         | 7/100 [03:08<41:48, 26.97s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:   8%|▊         | 8/100 [03:34<41:13, 26.89s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:   9%|▉         | 9/100 [04:01<40:50, 26.93s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  10%|█         | 10/100 [04:28<40:19, 26.88s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  11%|█         | 11/100 [04:55<39:53, 26.89s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  12%|█▏        | 12/100 [05:22<39:26, 26.90s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  13%|█▎        | 13/100 [05:49<39:02, 26.93s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  14%|█▍        | 14/100 [06:16<38:33, 26.90s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  15%|█▌        | 15/100 [06:43<38:12, 26.98s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  16%|█▌        | 16/100 [07:10<37:50, 27.03s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  17%|█▋        | 17/100 [07:37<37:24, 27.04s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  18%|█▊        | 18/100 [08:04<36:54, 27.01s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  19%|█▉        | 19/100 [08:31<36:27, 27.00s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  20%|██        | 20/100 [08:58<36:01, 27.01s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  21%|██        | 21/100 [09:25<35:34, 27.02s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  22%|██▏       | 22/100 [09:52<35:03, 26.96s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  23%|██▎       | 23/100 [10:19<34:34, 26.94s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  24%|██▍       | 24/100 [10:46<34:08, 26.95s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  25%|██▌       | 25/100 [11:13<33:46, 27.02s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  26%|██▌       | 26/100 [11:40<33:20, 27.03s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  27%|██▋       | 27/100 [12:07<32:53, 27.03s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  28%|██▊       | 28/100 [12:34<32:27, 27.05s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  29%|██▉       | 29/100 [13:01<31:58, 27.03s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  30%|███       | 30/100 [13:28<31:35, 27.08s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  31%|███       | 31/100 [13:55<31:09, 27.09s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  32%|███▏      | 32/100 [14:22<30:40, 27.06s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  33%|███▎      | 33/100 [14:50<30:13, 27.07s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  34%|███▍      | 34/100 [15:17<29:47, 27.09s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  35%|███▌      | 35/100 [15:44<29:19, 27.06s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  36%|███▌      | 36/100 [16:11<28:50, 27.03s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  37%|███▋      | 37/100 [16:38<28:20, 27.00s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  38%|███▊      | 38/100 [17:05<27:53, 26.99s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  39%|███▉      | 39/100 [17:31<27:23, 26.94s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  40%|████      | 40/100 [17:58<26:56, 26.94s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  41%|████      | 41/100 [18:25<26:31, 26.97s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  42%|████▏     | 42/100 [18:53<26:08, 27.04s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  43%|████▎     | 43/100 [19:20<25:42, 27.06s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  44%|████▍     | 44/100 [19:47<25:15, 27.06s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  45%|████▌     | 45/100 [20:14<24:45, 27.01s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  46%|████▌     | 46/100 [20:40<24:14, 26.94s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  47%|████▋     | 47/100 [21:08<23:51, 27.01s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  48%|████▊     | 48/100 [21:35<23:26, 27.05s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  49%|████▉     | 49/100 [22:02<23:01, 27.09s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  50%|█████     | 50/100 [22:29<22:33, 27.07s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  51%|█████     | 51/100 [22:56<22:04, 27.03s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  52%|█████▏    | 52/100 [23:23<21:38, 27.06s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  53%|█████▎    | 53/100 [23:50<21:14, 27.11s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  54%|█████▍    | 54/100 [24:17<20:46, 27.10s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  55%|█████▌    | 55/100 [24:44<20:16, 27.03s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  56%|█████▌    | 56/100 [25:11<19:43, 26.91s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  57%|█████▋    | 57/100 [25:38<19:20, 26.98s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  58%|█████▊    | 58/100 [26:05<18:54, 27.02s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  59%|█████▉    | 59/100 [26:32<18:28, 27.05s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  60%|██████    | 60/100 [26:59<18:04, 27.11s/it]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Generating tokens:  60%|██████    | 60/100 [27:26<18:17, 27.45s/it]


As we can see, the `"Llama-2-13b-hf"` model, which previously couldn't answer the question correctly, can now produce an answer similar to the answer created via `"Llama-2-13b-chat-hf"`. And we were able to achieve this without changing the weights of that `"Llama-2-13b-hf"` model!

In [6]:
print(generated)



You start with 5 apples and eat 2, so you have 5 - 2 = 3 apples left.
Then, you find 3 more apples on your way home, so you have 3 + 3 = 6 apples in total.
