## Setup development environment

This notebook has been tested on Amazon SageMaker Notebook Instances with single GPU on ml.g5.2xlarge

In [1]:
!pip install transformers==4.38.1 datasets==2.17.1 peft==0.8.2 bitsandbytes==0.42.0 trl==0.7.11 --upgrade --quiet

## Load and prepare the dataset


### Choose a dataset

For the purpose of this tutorial, we will use dolly, an open-source dataset containing 15k instruction pairs.

Example record from dolly:
```
{
  "instruction": "Who was the first woman to have four country albums reach No. 1 on the Billboard 200?",
  "context": "",
  "response": "Carrie Underwood."
}
```


In [2]:
from datasets import load_dataset
from random import randrange

# Load dataset from the hub
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")

print(f"dataset size: {len(dataset)}")
print(dataset[randrange(len(dataset))])

dataset size: 15011
{'instruction': 'Where is the Lighthouse Point, Bahamas', 'context': 'Lighthouse Point, Bahamas, or simply Lighthouse Point, is a private peninsula in The Bahamas which serves as an exclusive port for the Disney Cruise Line ships. It is located in the south-eastern region of Bannerman Town, Eleuthera. In March 2019, The Walt Disney Company purchased the peninsula from the Bahamian government, giving the company control over the area.', 'response': 'The Lighthouse Point, Bahamas, or simply Lighthouse Point, is a private peninsula in the Bahamas which serves as an exclusive port for the Disney Cruise Line ships. It is located in the south-eastern region of Bannerman Town, Eleuthera.', 'category': 'summarization'}


### Understand the Mistral format

The mistralai/Mixtral-8x7B-Instruct-v0.1 is a conversational chat model meaning we can chat with it using the following prompt:


```
<s> [INST] User Instruction 1 [/INST] Model answer 1</s> [INST] User instruction 2 [/INST]
```


For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response.

In [3]:
from random import randint

# Define the create_prompt function
def create_prompt(sample):
    bos_token = "<s>"
    eos_token = "</s>"
    
    instruction = sample['instruction']
    context = sample['context']
    response = sample['response']

    text_row = f"""[INST] Below is the question based on the context. Question: {instruction}. Below is the given the context {context}. Write a response that appropriately completes the request.[/INST]"""
    answer_row = response

    sample["prompt"] = bos_token + text_row
    sample["completion"] = answer_row + eos_token

    return sample

### Mistral finetuned model inference 

In [3]:
new_model_path = "./Mistral-Finetuned-Merged" #set the name of the new model

In [4]:
import json
import pandas as pd
import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

new_model = AutoModelForCausalLM.from_pretrained(
    new_model_path,
    torch_dtype=torch.float16,
    device_map="auto"
)


new_model.config.use_cache = False
new_model.config.pretraining_tp = 1

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [5]:
# Load MitsralAi tokenizer
tokenizer = AutoTokenizer.from_pretrained(new_model_path)

In [6]:
#benchmark_test = create_prompt(dataset[randrange(len(dataset))])
benchmark_test = create_prompt(dataset[6])
eval_prompt = benchmark_test["prompt"]
eval_completion = benchmark_test["completion"]

print(eval_prompt)
print("Dataset ground truth:")
print(eval_completion)

<s>[INST] Below is the question based on the context. Question: Given a reference text about Lollapalooza, where does it take place, who started it and what is it?. Below is the given the context Lollapalooza /ˌlɒləpəˈluːzə/ (Lolla) is an annual American four-day music festival held in Grant Park in Chicago. It originally started as a touring event in 1991, but several years later, Chicago became its permanent location. Music genres include but are not limited to alternative rock, heavy metal, punk rock, hip hop, and electronic dance music. Lollapalooza has also featured visual arts, nonprofit organizations, and political organizations. The festival, held in Grant Park, hosts an estimated 400,000 people each July and sells out annually. Lollapalooza is one of the largest and most iconic music festivals in the world and one of the longest-running in the United States.

Lollapalooza was conceived and created in 1991 as a farewell tour by Perry Farrell, singer of the group Jane's Addictio

In [7]:
model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")

new_model.eval()
with torch.no_grad():
    print(tokenizer.decode(new_model.generate(**model_input, max_new_tokens=256, pad_token_id=2)[0], skip_special_tokens=False))

<s><s> [INST] Below is the question based on the context. Question: Given a reference text about Lollapalooza, where does it take place, who started it and what is it?. Below is the given the context Lollapalooza /ˌlɒləpəˈluːzə/ (Lolla) is an annual American four-day music festival held in Grant Park in Chicago. It originally started as a touring event in 1991, but several years later, Chicago became its permanent location. Music genres include but are not limited to alternative rock, heavy metal, punk rock, hip hop, and electronic dance music. Lollapalooza has also featured visual arts, nonprofit organizations, and political organizations. The festival, held in Grant Park, hosts an estimated 400,000 people each July and sells out annually. Lollapalooza is one of the largest and most iconic music festivals in the world and one of the longest-running in the United States.

Lollapalooza was conceived and created in 1991 as a farewell tour by Perry Farrell, singer of the group Jane's Addi

#### You might notice that the output is identical to dataset ground truth

This is expected behavious as we fine tuned the model on small samples.

### Mistral original model inference 

In [4]:
model_id = "mistralai/Mistral-7B-Instruct-v0.1"

In [5]:
import json
import pandas as pd
import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Load MitsralAi tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
#benchmark_test = create_prompt(dataset[randrange(len(dataset))])
benchmark_test = create_prompt(dataset[6])
eval_prompt = benchmark_test["prompt"]
eval_completion = benchmark_test["completion"]

print(eval_prompt)
print("Dataset ground truth:")
print(eval_completion)

<s>[INST] Below is the question based on the context. Question: Given a reference text about Lollapalooza, where does it take place, who started it and what is it?. Below is the given the context Lollapalooza /ˌlɒləpəˈluːzə/ (Lolla) is an annual American four-day music festival held in Grant Park in Chicago. It originally started as a touring event in 1991, but several years later, Chicago became its permanent location. Music genres include but are not limited to alternative rock, heavy metal, punk rock, hip hop, and electronic dance music. Lollapalooza has also featured visual arts, nonprofit organizations, and political organizations. The festival, held in Grant Park, hosts an estimated 400,000 people each July and sells out annually. Lollapalooza is one of the largest and most iconic music festivals in the world and one of the longest-running in the United States.

Lollapalooza was conceived and created in 1991 as a farewell tour by Perry Farrell, singer of the group Jane's Addictio

#### You might notice that the output is semantically correct

This is expected behavious thanks the the zero shot capabilities of original Mistral model 

In [7]:
model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")


base_model.eval()
with torch.no_grad():
    print(tokenizer.decode(base_model.generate(**model_input, max_new_tokens=256, pad_token_id=2)[0], skip_special_tokens=False))

<s><s> [INST] Below is the question based on the context. Question: Given a reference text about Lollapalooza, where does it take place, who started it and what is it?. Below is the given the context Lollapalooza /ˌlɒləpəˈluːzə/ (Lolla) is an annual American four-day music festival held in Grant Park in Chicago. It originally started as a touring event in 1991, but several years later, Chicago became its permanent location. Music genres include but are not limited to alternative rock, heavy metal, punk rock, hip hop, and electronic dance music. Lollapalooza has also featured visual arts, nonprofit organizations, and political organizations. The festival, held in Grant Park, hosts an estimated 400,000 people each July and sells out annually. Lollapalooza is one of the largest and most iconic music festivals in the world and one of the longest-running in the United States.

Lollapalooza was conceived and created in 1991 as a farewell tour by Perry Farrell, singer of the group Jane's Addi