<a href="https://colab.research.google.com/github/Shi-pra-19/gemma-finetune-ui/blob/main/Gemma_2_Fine_Tuning_Dashboard.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Install Required Libraries
Install all the necessary Python packages for fine-tuning the Gemma-2B model using Streamlit, Hugging Face Transformers, and other related tools:

In [None]:
!pip install streamlit pyngrok transformers torch datasets trl peft bitsandbytes ngrok tqdm pandas --quiet

In [None]:
!pip install --upgrade bitsandbytes transformers accelerate


##Authentication and Integration Setup
This cell handles authentication for key services used in fine-tuning and deployment:

In [None]:
from huggingface_hub import login
login("your-api-key")

from pyngrok import ngrok
!ngrok authtoken your-auth-token

import wandb
wandb.login(key="your-api-key")

# Gemma 2 Fine-Tuning Dashboard
- Select a dataset and preview its content.

- Configure training parameters such as batch size, learning rate, and training steps.

- Choose data columns for input, context, and response.

- Load the quantized Gemma-2B model with LoRA (Low-Rank Adaptation) support.

- Visualize real-time training progress, including loss curves and status updates.

- Train the model directly within the app using parameter-efficient fine-tuning.

In [None]:
%%writefile app.py
import streamlit as st
import torch
import pandas as pd
import transformers
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, TrainerCallback
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
import plotly.graph_objects as go
import re
from random import randint
from transformers import pipeline

# Streamlit UI
st.title("Gemma 2 Fine-Tuning Dashboard")

# Initialize session state
if "model_loaded" not in st.session_state:
    st.session_state.model_loaded = False

if "training_started" not in st.session_state:
    st.session_state.training_started = False

# Sidebar: Dataset selection
with st.sidebar:
    st.header("Dataset Selection")
    datasets = {
        "Code Alpaca": "TokenBender/code_instructions_122k_alpaca_style",
        "Databricks Dolly 15K": "databricks/databricks-dolly-15k",
        "OpenAssistant OASST1": "OpenAssistant/oasst1",
        "Guanaco OpenAssistant": "timdettmers/openassistant-guanaco"
    }
    dataset_choice = st.selectbox("Select a dataset:", list(datasets.keys()), index=None)

# Load dataset and show a preview
if dataset_choice:
    dataset = load_dataset(datasets[dataset_choice], split="train")
    dataset = dataset.select(range(10))  # Limit to 10 for quick preview
    df = dataset.to_pandas()
    st.write("### Dataset Preview:")
    st.write(df.head(10))

    # Let user select input, context (optional), and response fields
    columns = df.columns.tolist()
    input_field = st.selectbox("Select Input Field:", columns)
    context_field = st.selectbox("Select Context Field (Optional):", ["None"] + columns)
    response_field = st.selectbox("Select Response Field:", columns)

# Sidebar: Hyperparameter selection
with st.sidebar:
    st.header("Training Parameters")
    batch_size = st.slider("Batch Size", 1, 16, 1)
    learning_rate = st.number_input("Learning Rate", min_value=1e-6, max_value=1e-2, value=2e-4, format="%.6f")
    gradient_accum_steps = st.slider("Gradient Accumulation Steps", 1, 8, 4)
    max_steps = st.slider("Max Steps", 10, 1000, 100)

# Load Model Button
if dataset_choice and input_field and response_field:
    if not st.session_state.model_loaded:
        if st.button("Load Model"):
            model_id = "google/gemma-2b-it"

            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16
            )

            st.session_state.model = AutoModelForCausalLM.from_pretrained(
                model_id, quantization_config=bnb_config, device_map={"": 0}
            )
            st.session_state.tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)
            st.session_state.model_loaded = True
            st.success("Model Loaded Successfully!")
else:
    st.warning("Please select a dataset and configure input/response fields before loading the model.")

# Start Training Button
if st.session_state.model_loaded and not st.session_state.training_started:
    if dataset_choice and input_field and response_field:
        if st.button("Start Training"):
            st.session_state.training_started = True
    else:
        st.warning("Make sure dataset and fields are selected before starting training.")

# Training Logic
if st.session_state.training_started:
    st.write("Preparing dataset for training...")

    def generate_prompt(data_point):
        input_text = data_point[input_field]
        response_text = data_point[response_field]
        context_text = data_point[context_field] if context_field != "None" else None

        if context_text:
            return f"<start_of_turn>user {input_text}\nContext: {context_text}<end_of_turn>\n<start_of_turn>model {response_text}<end_of_turn>"
        else:
            return f"<start_of_turn>user {input_text}<end_of_turn>\n<start_of_turn>model {response_text}<end_of_turn>"

    text_column = [generate_prompt(data_point) for data_point in dataset]
    dataset = dataset.add_column("prompt", text_column).shuffle(seed=1234)

    def tokenize_function(examples):
        return st.session_state.tokenizer(examples["prompt"], truncation=True, padding=True)

    dataset = dataset.map(tokenize_function, batched=True)
    dataset = dataset.train_test_split(test_size=0.2)

    train_data = dataset["train"]
    test_data = dataset["test"]

    st.session_state.model.gradient_checkpointing_enable()
    st.session_state.model = prepare_model_for_kbit_training(st.session_state.model)

    modules = ["q_proj", "v_proj", "k_proj", "o_proj", "down_proj", "gate_proj", "up_proj"]

    lora_config = LoraConfig(
        r=64,
        lora_alpha=32,
        target_modules=modules,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )

    st.session_state.model = get_peft_model(st.session_state.model, lora_config)

    training_args = TrainingArguments(
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accum_steps,
        warmup_steps=0,
        max_steps=max_steps,
        learning_rate=learning_rate,
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit",
        save_strategy="epoch",
    )

    st.write("Starting training...")
    progress_bar = st.progress(0)
    log_area = st.empty()
    loss_chart = st.empty()

    class StreamlitCallback(TrainerCallback):
        def __init__(self, progress_bar, log_area, max_steps, loss_chart):
            self.progress_bar = progress_bar
            self.log_area = log_area
            self.max_steps = max_steps
            self.loss_chart = loss_chart
            self.losses = []

        def on_step_end(self, args, state, control, **kwargs):
            progress = state.global_step / self.max_steps
            self.progress_bar.progress(progress)
            self.log_area.write(f"Step {state.global_step}/{self.max_steps}")

            if state.log_history and "loss" in state.log_history[-1]:
                self.losses.append(state.log_history[-1]["loss"])
                self.plot_loss()

        def plot_loss(self):
            fig = go.Figure()
            fig.add_trace(go.Scatter(
                y=self.losses,
                mode='lines+markers',
                name='Training Loss',
                line=dict(color='royalblue')
            ))
            fig.update_layout(
                title="Training Loss Curve",
                xaxis_title="Steps",
                yaxis_title="Loss",
                template="plotly_white"
            )
            self.loss_chart.plotly_chart(fig, use_container_width=True)

    trainer = SFTTrainer(
        model=st.session_state.model,
        train_dataset=train_data,
        eval_dataset=test_data,
        peft_config=lora_config,
        args=training_args,
        data_collator=transformers.DataCollatorForLanguageModeling(
            st.session_state.tokenizer, mlm=False
        ),
    )

    trainer.add_callback(StreamlitCallback(progress_bar, log_area, max_steps, loss_chart))
    trainer.train()

    st.success("Training Completed!")


# Launching Streamlit with Ngrok

In [None]:
!pkill -f ngrok

In [None]:
!streamlit run app.py &>/dev/null &
from pyngrok import ngrok
public_url = ngrok.connect(8501)
print(f"Streamlit app is running at: {public_url}")