<a href="https://colab.research.google.com/github/aarshitaacharya/peft-techniques/blob/main/R13_QLoRA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install necessary packages
!pip install transformers peft datasets bitsandbytes dash pyngrok

# Import required libraries
import bitsandbytes as bnb
import torch
import random
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from dash import Dash, dcc, html
from dash.dependencies import Input, Output
from pyngrok import ngrok



In [None]:
from transformers import BitsAndBytesConfig

# Load dataset
dataset = load_dataset("THUDM/humaneval-x", "js")

# Model and tokenizer configuration
model_name = "bigscience/bloomz-560m"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side='right'

bnb_config=BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

# Load model with 4-bit quantization for QLoRA
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",  # Automatically assigns the model to available devices
)

In [None]:
# Tokenization function
def tokenize_function(examples):
    tokenized_inputs = tokenizer(examples['prompt'], padding="max_length", truncation=True, max_length=512)
    tokenized_inputs['labels'] = tokenized_inputs['input_ids'].copy()
    return tokenized_inputs

# Tokenize the dataset
tokenized_dataset = dataset['test'].map(tokenize_function, batched=True)

# Split the dataset into training (80%) and evaluation (20%) sets
train_size = int(0.8 * len(tokenized_dataset))
train_indices = random.sample(range(len(tokenized_dataset)), train_size)
train_dataset = tokenized_dataset.select(train_indices)
eval_indices = list(set(range(len(tokenized_dataset))) - set(train_indices))
eval_dataset = tokenized_dataset.select(eval_indices)

In [None]:
# LoRA configuration
lora_config = LoraConfig(
    r=4,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["query_key_value"],
)

# Apply LoRA to the model
lora_model = get_peft_model(model, lora_config)

training_args = TrainingArguments(
    output_dir="./lora_model",
    evaluation_strategy="epoch",  # Evaluate once per epoch for simplicity
    learning_rate=5e-5,  # Reduced learning rate for stability
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,  # Adjust to balance batch size
    num_train_epochs=5,
    weight_decay=0.01,  # Slightly lower weight decay
    fp16=True,  # Use mixed precision for speed and memory
    save_steps=500,
    save_total_limit=2,
    logging_dir="./logs",
    logging_steps=100,  # Log loss every 100 steps
)

# Trainer setup
trainer = Trainer(
    model=lora_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

# Start fine-tuning
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Epoch,Training Loss,Validation Loss
0,No log,No log
1,No log,No log
2,No log,No log
4,No log,No log


TrainOutput(global_step=40, training_loss=2.575897979736328, metrics={'train_runtime': 159.6092, 'train_samples_per_second': 4.104, 'train_steps_per_second': 0.251, 'total_flos': 591425905360896.0, 'train_loss': 2.575897979736328, 'epoch': 4.848484848484849})

In [None]:
# After training, define the app for model inference
app = Dash(__name__)

app.layout = html.Div(style={'backgroundColor': '#f8f9fa', 'padding': '20px'}, children=[
    html.H1("JavaScript Code Generation Model Inference", style={'textAlign': 'center', 'color': '#343a40'}),
    dcc.Input(id='input-text', type='text', placeholder='Enter prompt for JavaScript code generation...',
               style={'width': '100%', 'padding': '10px', 'fontSize': '18px', 'marginBottom': '10px'}),
    html.Button('Submit', id='submit-button', n_clicks=0,
                 style={'backgroundColor': '#007bff', 'color': 'white', 'padding': '10px 20px',
                        'border': 'none', 'borderRadius': '5px', 'cursor': 'pointer'}),
    html.Div(id='output-prediction', style={'marginTop': '20px', 'fontSize': '18px',
                                             'color': '#495057', 'border': '1px solid #ced4da',
                                             'padding': '10px', 'borderRadius': '5px'})
])

# Define callback for model inference
@app.callback(
    Output('output-prediction', 'children'),
    [Input('submit-button', 'n_clicks')],
    [Input('input-text', 'value')]
)
def update_output(n_clicks, input_text):
    if n_clicks > 0 and input_text:
        prompt = f"// Generate a JavaScript function for the following task:\n{input_text}\n"
        inputs = tokenizer(prompt, return_tensors='pt').to(model.device)

        with torch.no_grad():
            outputs = model.generate(**inputs, max_length=150)

        decoded_prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)

        return html.Pre(f"Generated Code:\n{decoded_prediction}", style={'whiteSpace': 'pre-wrap'})

    return "Enter a prompt and click Submit."

# Start the app server
app.run_server(port=8050)

<IPython.core.display.Javascript object>

In [None]:
# Start ngrok tunnel for public access
ngrok.set_auth_token('')
public_url = ngrok.connect(8050)
print(f"Public URL: {public_url}")

Public URL: NgrokTunnel: "https://4e1e-35-199-163-134.ngrok-free.app" -> "http://localhost:8050"
