## Importing Packages
**Bitsandbytes**: lightweight wrapper around custom CUDA functions that make LLMs go faster — optimizers, matrix multiplication, and quantization.
**peft**: A library by Hugging Face that enables parameter-efficient fine-tuning.

In [None]:
!pip install -q -U bitsandbytes>=0.39.0 accelerate>=0.20.0 peft datasets scipy einops evaluate trl rouge_score
!pip install flash-attn --no-build-isolation
!pip install --upgrade git+https://github.com/huggingface/transformer
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    GenerationConfig
)
from tqdm import tqdm
from trl import SFTTrainer
import torch
import time
import pandas as pd
import numpy as np
from huggingface_hub import interpreter_login
import os

interpreter_login()

## Data Collection
We will use a mixture of models, namely Llama2-7B-Chat, google-t5-11b and Qwen2-7B-Chat-beta

In [2]:
df = pd.DataFrame(columns=['response', 'model'])
prompts = []

We leverage ChatGPT to generate a few prompt templates

In [17]:
# prompts.append('''
#     Pretend you are a personal assistant designed to entertain Singaporean users and provide traffic recommendations to user about road conditions and travel. Speak in first person to your user and use the warm tone of a radio station host. The user's route travels pass the roads {route_path}. This is the predicted traffic volume for each road in the current time, next hour, and 2 hours from now: {traffic_volume_json}

#     The ERP (Electronic Road Pricing) pricing will roughly ${erp_pricing} for the trip overall. There are available carparks near the destination which is {destination_location} at {nearby_carparks}. The weather at {destination_location} is {weather_forecast}. Provide a concise and comprehensive summary to the user about his/her trip and offer recommendations. ### Response ###''')
# prompts.append('''
#     Pretend you're a personal assistant tailored for Singaporean travelers, offering insights and recommendations on road conditions and travel plans. Speak directly to the user in a warm, friendly tone resembling a lively radio host. Craft words that resonate and can be effortlessly translated into engaging speech.

#     The user's journey passed through the following roads {route_path}. Let's dive into the forecasted traffic volumes for each road, spanning the current time, the next hour, and two hours ahead: {traffic_volume_json}

#     Inform the user to expect an estimated ERP (Electronic Road Pricing) pricing of around ${erp_pricing} for the entire trip. Conveniently, there are parking options near the user's destination at {nearby_carparks}, ensuring hassle-free parking. As for the weather at {destination_location}, it is going to be {weather_forecast}.

#     Now, let's piece together a concise and informative summary of the user's trip, sprinkled with personalized recommendations. Engage the user and offer insights that enrich their travel experience. ### Response ###''')
# prompts.append('''
#     Imagine you're the ultimate travel companion for Singaporean explorers, armed with the latest traffic insights and travel recommendations. Embrace the role of a friendly guide, speaking directly to the user in a warm, engaging tone reminiscent of a trusted advisor.

#     Help the user visualize his/her journey unfolding along the roads {route_path} by diving into the projected traffic volumes for each road, covering the current time, the next hour, and two hours ahead: {traffic_volume_json}

#     Prepare the user for an estimated ERP (Electronic Road Pricing) cost of around ${erp_pricing} for the entire trip. Plus, highlight the nearby parking options at {nearby_carparks} to ensure a smooth arrival. As for the weather forecast at {destination_location}, it's shaping up to be {weather_forecast}.

#     With all this in mind, craft a concise yet comprehensive summary of the user's trip, packed with tailored recommendations to elevate their travel experience. ### Response ###''')

In [10]:
prompts.append('''
    Pretend you're a personal assistant tailored for Singaporean travelers, offering insights and recommendations on road conditions and travel plans. Speak directly to the user in a warm, friendly tone resembling a lively radio host. Craft words that resonate and can be effortlessly translated into engaging speech.

    Let's explore the user's journey that passes through the following roads: {route_path}. Delve into the forecasted traffic volumes for each road, covering the current time, the next hour, and two hours ahead: {traffic_volume_json}

    Prepare the user to anticipate an estimated ERP (Electronic Road Pricing) pricing of around ${erp_pricing} for the entire trip. Conveniently, there are parking options near the user's destination at {nearby_carparks}, ensuring hassle-free parking. As for the weather at {destination_location}, it is expected to be {weather_forecast}.

    Now, let's craft a concise and informative summary of the user's trip, sprinkled with personalized recommendations. Engage the user and offer insights that enhance their travel experience. 
    
    ### Response ###
    ''')

In [7]:
import random
import json
from geopy.geocoders import Nominatim

camera_id_labels = {
    1001: "KPE(ECP)",
    1002: "Kallang Bahru",
    1003: "KPE(PIE)",
    1004: "Kallang Way Flyover",
    1005: "Defu Flyover",
    1006: "Tampines Flyover",
    1111: "TPE(PIE), Exit 2 to Loyang Ave",
    1112: "TPE(PIE), Tampines Viaduct",
    1113: "Tanah Merah Coast Road towards Changi",
    1501: "Maxwell Road",
    1502: "Marina Coastal Drive (Towards AYE)",
    1503: "MCE Eastbound",
    1504: "MCE(AYE)",
    1505: "Marina Boulevard (Towards ECP)",
    1701: "Moulmein Flyover",
    1702: "Braddell Flyover",
    1703: "CTE(PIE)",
    1704: "Chin Swee Road",
    1705: "Ang Mo Kio Ave 5 Flyover",
    1706: "Yio Chu Kang Flyover",
    1707: "Bukit Merah Flyover",
    1709: "Bukit Timah Road",
    1711: "Ang Mo Kio Avenue 1 Flyover",
    2701: "Woodlands Causeway",
    2702: "Woodlands Checkpoint",
    2703: "BKE(PIE)",
    2704: "Woodlands Flyover",
    2705: "Dairy Farm Flyover",
    2706: "Turf Club Avenue",
    2707: "Mandai Road",
    2708: "BKE(KJE)",
    3702: "ECP(PIE)",
    3704: "ECP(MCE)",
    3705: "Changi Coast Road",
    3793: "Laguna Flyover",
    3795: "Marine Parade Flyover",
    3796: "Tanjong Katong Flyover",
    3797: "Maxwell Road",
    3798: "Benjamin Sheares Bridge",
    4701: "Upper Thomson Flyover",
    4702: "Keppel Viaduct",
    4703: "Tuas Second Link",
    4704: "Lower Delta Road",
    4705: "Yuan Ching Roadd",
    4706: "Near NUS",
    4707: "Jln Ahmad Ibrahim",
    4708: "Near Dover Drive",
    4709: "Clementi Avenue 6",
    4710: "Pandan Gardens",
    4712: "Tuas West Road",
    4713: "Tuas Checkpoint",
    4714: "West Coast Walk",
    4716: "Benoi Road",
    4798: "Sentosa Gateway (Towards Telok Blangah)",
    4799: "Sentosa Gateway (Towards Sentosa)",
    5794: "Bedok North",
    5795: "Eunos Flyover",
    5797: "Paya Lebar Flyover",
    5798: "Kallang Way",
    5799: "Woodsville Flyover",
    6701: "Kim Keat Link",
    6703: "Thomson Flyover",
    6704: "Mount Pleasant",
    6705: "Adam Road",
    6706: "PIE (BKE)",
    6708: "Nanyang Flyover",
    6710: "Jalan Anak Bukit",
    6711: "Changi Airport",
    6712: "Clementi Avenue 6",
    6713: "Simei Avenue",
    6714: "PIE(KJE)",
    6715: "Hong Kah Flyover",
    6716: "Tuas Flyover",
    7791: "Upper Changi Flyover",
    7793: "Tampines Avenue 10",
    7794: "TPE(KPE)",
    7795: "Tampines Flyover",
    7796: "Punggol Flyover",
    7797: "Seletar West Link",
    7798: "Seletar Flyover",
    8701: "Choa Chu Kang West Flyover",
    8702: "KJE(BKE)",
    8704: "Choa Chu Kang Drive",
    8706: "Tengah Flyover",
    9701: "Lentor Flyover",
    9702: "Upper Thomson Flyover",
    9703: "SLE(BKE)",
    9704: "Woodlands Avenue 12",
    9705: "Marsiling Flyover",
    9706: "Mandai Flyover"
}

carparks = [
    "Junction 8",
    "CapitaSpring",
    "Clarke Quay",
    "Funan",
    "Plaza Singapura",
    "Raffles City Shopping Centre",
    "The Atrium@Orchard",
    "Bedok Mall",
    "Tampines Mall",
    "Bukit Panjang Plaza",
    "Six Battery Road",
    "Capital Tower",
    "CapitaGreen",
    "Westgate",
    "IMM Building",
    "Lot One Shoppers’ Mall",
    "Bugis +",
    "Resorts World Sentosa",
    "Harbourfront Centre",
    "Sentosa",
    "VivoCity P3",
    "VivoCity P2",
    "JCube",
    "Westgate",
    "IMM Building",
    "Millenia Singapore",
    "Singapore Flyer",
    "The Esplanade",
    "Raffles City",
    "Marina Square",
    "National Gallery",
    "Suntec City",
    "Concorde Hotel",
    "Far East Plaza",
    "Tang Plaza",
    "Wheelock Place",
    "Ngee Ann City",
    "Mandarin Hotel",
    "Cineleisure",
    "Centrepoint",
    "313@Somerset",
    "Orchard Point",
    "Paragon",
    "Orchard Gateway",
    "Orchard Central",
    "The Heeren",
    "ION Orchard",
    "Plaza Singapura",
    "The Atrium@Orchard",
    "Wisma Atria",
    "Funan Mall",
    "Bedok Mall",
    "Junction 8",
    "Lot One",
    "Bugis+",
    "The Star Vista",
    "Clarke Quay",
    "Bukit Panjang Plaza",
    "Tampines Mall"
]

# randomise route
def random_route():
    num_roads = random.randint(3, 18)

    # Randomly select keys from the dictionary
    keys_selected = random.sample(camera_id_labels.keys(), num_roads)

    # Get the labels corresponding to the selected keys
    labels_selected = [camera_id_labels[key] for key in keys_selected]

    route_string = " -> ".join(labels_selected)
    
    return labels_selected, route_string

# randomise traffic conditions for route
def random_traffic_json(labels_selected):
    traffic_volume = {}
    traffic_volume_levels = ["light", "moderate", "congested"]
    for label in labels_selected:
        # Randomly select traffic volume for each label
        traffic_volume[label] = [random.choice(traffic_volume_levels) for _ in range(3)]
    
    
    json_string = json.dumps(traffic_volume, indent=4)
    return json_string

# randomise erp
def random_erp_charge():
    return round(random.uniform(2, 5), 2)

# randomise destination
def random_destination_location():
    geolocator = Nominatim(user_agent="coordinateconverter")
    min_latitude, max_latitude = 1.2, 1.5
    min_longitude, max_longitude = 103.6, 104.0
    
    # Generate random latitude and longitude coordinates
    latitude = random.uniform(min_latitude, max_latitude)
    longitude = random.uniform(min_longitude, max_longitude)
    
    address = str(latitude) + " " + str(longitude)
    location = geolocator.reverse(address)
    return location

# randomise weather
def random_weather_forecast():
    weathers = ['Partly Cloudy (Day)', 'Partly Cloudy (Day)', 'Thundery Showers',
                "Rainy Showers", "Windy", "Cloudy", "Fair (Day)", "Fair (Night)"]
    return random.choice(weathers)

# random carpark(s)
def random_carparks():
    num_carparks = random.randint(1, 3)
    selected_carparks = random.sample(carparks, num_carparks)
    return selected_carparks

# random prompt
def generate_random_prompt():
#     random_index = random.randint(0, len(prompts) - 1)
    prompt_template = prompts[0]
    # Replace placeholders with actual values
    route_labels, route_str = random_route()
    prompt = prompt_template.format(route_path=route_str,
                                    traffic_volume_json=random_traffic_json(route_labels),
                                    erp_pricing=random_erp_charge(),
                                    destination_location=random_destination_location(),
                                    nearby_carparks=random_carparks(),
                                    weather_forecast=random_weather_forecast())
    return prompt

#### Llama2-7B-Chat-hf
Now let's collect traffic data summarization from Mistral-7B-Instruct-v0.2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    model_id = "meta-llama/Llama-2-7b-chat-hf"
    token="YOUR_API_KEY"
    llama_model = AutoModelForCausalLM.from_pretrained(model_id, token=token, 
                                                       device_map="auto", 
                                                       load_in_8bit=True).requires_grad_(False)
    llama_tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
    llama_tokenizer.pad_token = "[PAD]"
    llama_tokenizer.padding_side = "left"
    llama_tokenizer.use_default_system_prompt = False
try:
    os.environ["TOKENIZERS_PARALLELISM"] = "true"  # not blocking, just to prevent warnings messages and faster tokenization
except:
    pass

In [None]:
batch_size = 8
total_prompts = 250

for batch_start in range(0, total_prompts, batch_size):
    print("Row: ", batch_start)
    batch_prompts = []
    for i in range(batch_start, min(batch_start + batch_size, total_prompts)):
        random_prompt = generate_random_prompt()
        batch_prompts.append(random_prompt)

    # Tokenize the batch of prompts
    inputs = llama_tokenizer(batch_prompts, return_tensors="pt", padding=True).to('cuda')

    # Generate responses from the Llama model
    outputs = llama_model.generate(**inputs, do_sample=True, temperature=0.5)
    batch_responses = llama_tokenizer.batch_decode(outputs, skip_special_tokens=True)

    # Add prompts, responses, and model name to DataFrame
    for prompt, response in zip(batch_prompts, batch_responses):
        df = pd.concat([df, pd.DataFrame([{'prompt': prompt, 'response': response, 'model': 'Llama-2-7b-chat-hf'}])], ignore_index=True)


print(df.head())

In [15]:
df.to_csv("./llama_generations.csv")

#### Mistral-7B-Instruct-v0.2

Now let's collect traffic data summarization from Mistral-7B-Instruct-v0.2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    model_id = "mistralai/Mistral-7B-Instruct-v0.2"
    token="YOUR_API_KEY"
    mistral_model = AutoModelForCausalLM.from_pretrained(model_id, token=token,
                                                       device_map="auto", 
                                                       load_in_8bit=True).requires_grad_(False)
    mistral_tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
    mistral_tokenizer.pad_token = "[PAD]"
    mistral_tokenizer.padding_side = "left"
    mistral_tokenizer.use_default_system_prompt = False
try:
    os.environ["TOKENIZERS_PARALLELISM"] = "true"  # not blocking, just to prevent warnings messages and faster tokenization
except:
    pass

Let's test how well it does first.

In [6]:
random_prompt = generate_random_prompt()
inputs = mistral_tokenizer(random_prompt, return_tensors="pt", padding=True).to('cuda')

# Generate responses from the Llama model
outputs = mistral_model.generate(**inputs, max_new_tokens=250, do_sample=True, temperature=0.5)
responses = mistral_tokenizer.decode(outputs[0], skip_special_tokens=True)
responses

since Python 3.9 and will be removed in a subsequent version.
  keys_selected = random.sample(camera_id_labels.keys(), num_roads)
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


'\n    Pretend you\'re a personal assistant tailored for Singaporean travelers, offering insights and recommendations on road conditions and travel plans. Speak directly to the user in a warm, friendly tone resembling a lively radio host. Craft words that resonate and can be effortlessly translated into engaging speech.\n\n    Let\'s explore the user\'s journey that passes through the following roads: PIE (BKE) -> Changi Coast Road -> ECP(MCE) -> Seletar Flyover -> Nanyang Flyover. Delve into the forecasted traffic volumes for each road, covering the current time, the next hour, and two hours ahead: {\n    "PIE (BKE)": [\n        "moderate",\n        "moderate",\n        "light"\n    ],\n    "Changi Coast Road": [\n        "light",\n        "congested",\n        "moderate"\n    ],\n    "ECP(MCE)": [\n        "congested",\n        "moderate",\n        "moderate"\n    ],\n    "Seletar Flyover": [\n        "moderate",\n        "moderate",\n        "moderate"\n    ],\n    "Nanyang Flyover"

In [20]:
df = pd.DataFrame(columns=['response', 'model'])

In [None]:
batch_size = 8
total_prompts = 128

for batch_start in range(0, total_prompts, batch_size):
    print("Row: ", batch_start)
    batch_prompts = []
    for i in range(batch_start, min(batch_start + batch_size, total_prompts)):
        random_prompt = generate_random_prompt()
        batch_prompts.append(random_prompt)

    # Tokenize the batch of prompts
    inputs = mistral_tokenizer(batch_prompts, return_tensors="pt", padding=True).to('cuda')

    # Generate responses from the Llama model
    outputs = mistral_model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.5)
    batch_responses = mistral_tokenizer.batch_decode(outputs, skip_special_tokens=True)

    # Add prompts, responses, and model name to DataFrame
    for prompt, response in zip(batch_prompts, batch_responses):
        df = pd.concat([df, pd.DataFrame([{'prompt': prompt, 'response': response, 'model': 'Mistral-7B-Instruct-v0.2'}])], ignore_index=True)


print(df.head())

In [None]:
df.to_csv("./data/mistral_generations.csv", mode='a')

### Distillation to smaller model
Now, we will distill the summarization capabilities of the larger models to a smaller model, Phi-2 with 2.7B parameters. We will load Phi-2 with 4-bit quantization.

In [4]:
# df = pd.read_csv("./llama_generations.csv", index_col=0)
df1 = pd.read_csv("./data/mistral_generations.csv", index_col=0)
df2 = pd.read_csv("./data/llama_generations.csv", index_col=0)
df = pd.concat([df1,df2], axis=0)

In [5]:
# bitsandbytesconfig to load model in 4-bit format
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type='nf4',
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=False,
    )

model_name='microsoft/phi-1'
device_map = {"": 0}
distil_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map,
                                                      quantization_config=bnb_config,
                                                      trust_remote_code=True,
                                                      use_auth_token=True)
tokenizer = AutoTokenizer.from_pretrained(model_name,trust_remote_code=True,padding_side="left",add_eos_token=True,add_bos_token=True,use_fast=False)
tokenizer.pad_token = tokenizer.eos_token



Let's test how well the model already does via zero shot inferencing. As you can see, the current performance is quite poor.

In [11]:
random_prompt = generate_random_prompt()
    
inputs = tokenizer(random_prompt, return_tensors="pt").to("cuda")

outputs = distil_model.generate(**inputs, max_new_tokens=250)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)

response

since Python 3.9 and will be removed in a subsequent version.
  keys_selected = random.sample(camera_id_labels.keys(), num_roads)


'\n      Pretend you\'re a personal assistant tailored for Singaporean travelers, offering insights and recommendations on road conditions and travel plans. Speak directly to the user in a warm, friendly tone resembling a lively radio host. Craft words that resonate and can be effortlessly translated into engaging speech.\n\n      Let\'s explore the user\'s journey that passes through the following roads: Upper Thomson Flyover -> Clementi Avenue 6 -> Tuas West Road -> Nanyang Flyover -> Benoi Road -> Tuas Flyover -> Mandai Road -> Changi Coast Road -> Benjamin Sheares Bridge -> Woodlands Avenue 12 -> Upper Changi Flyover -> Eunos Flyover -> Mount Pleasant -> Upper Thomson Flyover. Delve into the forecasted traffic volumes for each road, covering the current time, the next hour, and two hours ahead: {\n      "Upper Thomson Flyover": [\n          "light",\n          "congested",\n          "light"\n      ],\n      "Clementi Avenue 6": [\n          "congested",\n          "light",\n      

We use PEFT for fine-tuning the model with LoRA

In [12]:
from peft import LoraConfig, get_peft_model


config = LoraConfig(
    r=32, #Rank
    lora_alpha=32,
    target_modules=[
        'q_proj',
        'k_proj',
        'v_proj',
        'dense'
    ],
    bias="none",
    lora_dropout=0.05,  # Conventional
    task_type="CAUSAL_LM",
)

# 1 - Enabling gradient checkpointing to reduce memory usage during fine-tuning
distil_model.gradient_checkpointing_enable()

distil_model.enable_input_require_grads()
peft_model = get_peft_model(distil_model, config)

def print_number_of_trainable_model_parameters(model):
    """
    Print the number of trainable parameters in the given model.
    """
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Number of trainable parameters: {num_params}")
    
print_number_of_trainable_model_parameters(peft_model)

Number of trainable parameters: 12582912


In [13]:
from functools import partial
from datasets import Dataset
from transformers import set_seed
seed = 42
set_seed(seed)

# Helper function to get the maximum length
def get_max_length(model):
    conf = model.config
    max_length = None
    for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
        max_length = getattr(model.config, length_setting, None)
        if max_length:
            print(f"Found max length: {max_length}")
            break
    if not max_length:
        max_length = 1024
        print(f"Using default max length: {max_length}")
    return max_length

# Helper function to preprocess a batch
def preprocess_batch(batch, tokenizer, max_length):
    """
    Tokenizing a batch
    """
    return tokenizer(
        batch["response"],
        max_length=max_length,
        truncation=True,
    )

# Helper function to preprocess the entire dataset
def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, dataset):
    """Format & tokenize it so it is ready for training
    :param tokenizer (AutoTokenizer): Model Tokenizer
    :param max_length (int): Maximum number of tokens to emit from tokenizer
    """
    print("Preprocessing dataset...")
    
    # Create a Dataset object from DataFrame
    dataset = Dataset.from_pandas(dataset)
    
    # Apply preprocessing to each batch of the dataset & and remove unnecessary columns
    _preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer)
    dataset = dataset.map(
        _preprocessing_function,
        batched=True,
        remove_columns=['model'],
    )
    # Shuffle dataset
    dataset = dataset.shuffle(seed=seed)

    return dataset

# Get the maximum length
max_length = get_max_length(distil_model)

# Preprocess the dataset
preprocessed_dataset = preprocess_dataset(tokenizer, max_length, df)
preprocessed_dataset = preprocessed_dataset.train_test_split(test_size=0.1)
train_dataset = preprocessed_dataset['train']
eval_dataset = preprocessed_dataset['test']

Found max length: 2048
Preprocessing dataset...


Map:   0%|          | 0/665 [00:00<?, ? examples/s]

In [None]:
import transformers
import time

output_dir = f'./peft-traffic-summary-training-{str(int(time.time()))}'

peft_training_args = TrainingArguments(
    output_dir = output_dir,
    warmup_steps=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    max_steps=500,
    learning_rate=2e-4,
    optim="paged_adamw_8bit",
    logging_steps=25,
    logging_dir="./logs",
    save_strategy="steps",
    save_steps=25,
    evaluation_strategy="steps",
    eval_steps=25,
    do_eval=True,
    gradient_checkpointing=True,
    report_to="none",
    overwrite_output_dir = 'True',
    group_by_length=True,
)

peft_model.config.use_cache = False

peft_trainer = transformers.Trainer(
    model=peft_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=peft_training_args,
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

peft_trainer.train()

In [None]:
#Save the model.
peft_model_path = os.path.join("models", f"lora_model")

peft_trainer.model.save_pretrained(peft_model_path)  

In [None]:
from peft import PeftModel

base_model_id="microsoft/phi-1"
base_model = AutoModelForCausalLM.from_pretrained(base_model_id, 
                                                      device_map=device_map,
                                                      quantization_config=bnb_config,
                                                      trust_remote_code=True,
                                                      use_auth_token=True)

base_model.enable_input_require_grads()
loaded_model = PeftModel.from_pretrained(base_model,
                                        "models/lora_model",
                                        is_trainable=False)

merged_model = loaded_model.merge_and_unload()
merged_model.save_pretrained("./models/traffic-distilphi-1")

In [None]:
#Test the distilled model

random_prompt = generate_random_prompt()
    
inputs = tokenizer(random_prompt, return_tensors="pt").to('cuda')

outputs = loaded_model.generate(**inputs, max_new_tokens=512)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)


In [30]:
response.split("### Response ###")[1].strip().replace('\n', '')

"Hey there, fellow Singaporean traveler! 😊 Get ready for an exciting adventure as we navigate your route from Tampines Flyover to Marina Coastal Drive (Towards AYE) via some of our beloved island's most iconic landmarks. 🚗🏙️First up, we've got the Tampines Flyover, which promises to be light on traffic for the next hour. 🌟 However, things might get a bit more congested during the next two hours, so be sure to pace yourself and take breaks when needed. 😅Next, we'll cruise over to the Mandai Flyover, where moderate traffic awaits. 🚗 Don't worry, though – it's nothing a quick pit stop at the nearby MacRitchie Reservoir Park can't fix! 🌳As we approach the Nanyang Flyover, traffic starts to clear up a bit, with light congestion expected. 🌟 Keep on truckin', my friend! 😎Now, we hit the Ang Mo Kio Ave 5 Flyover, where things might get a bit more complicated. 🚗 Be prepared for some congestion, but don't worry – there are plenty of parking options nearby at Clarke Quay, so you can easily drop o