## 🚀 Day 8/15 — Fine-Tuning with Unsloth AI

## Project: Fine-Tuning Gemma 3n (2B) for Image-to-Prompt Generation

This notebook demonstrates how to fine-tune a powerful multimodal model, Google's Gemma 3n, to perform a specific and creative task: generating descriptive text-to-image prompts from an input image.

### 🎯 Project Goal

The primary objective is to create an "AI art deconstruction" tool. Given an image, the fine-tuned model should analyze its visual content, style, and composition, then generate a high-quality text prompt that could be used by a text-to-image model (like Dall-E-3 or Midjourney) to create a similar image.

### 💾 The Dataset

We are using a subset of the **`jackyhate/text-to-image-2M`** dataset available on Hugging Face.

*   **Content:** This dataset contains 2 million pairs of images and the high-quality text prompts used to generate them. This is the perfect data source for our task, as it provides direct examples of the image-to-prompt relationship we want the model to learn.
*   **Format:** The data is provided in the `webdataset` format (`.tar` files), which is optimized for large-scale streaming.
*   **Preprocessing:** For this demonstration, we perform the following steps:
    1.  Stream the dataset directly from Hugging Face.
    2.  Take a small subset of **2,000 examples** to ensure the process is fast and manageable within the Colab environment.
    3.  Transform each sample into a conversational format suitable for training.

You can find the full dataset here: [https://huggingface.co/datasets/jackyhate/text-to-image-2M](https://huggingface.co/datasets/jackyhate/text-to-image-2M)

---
### 👋🏻 About Me

Hi, I'm **Aasher Kamal** — a Generative & Agentic AI developer passionate about building intelligent systems with LLMs.

I have started a **15-day challenge** to master fine-tuning using the open-source **Unsloth AI** framework. This journey will cover everything from LoRA and QLoRA to reinforcement learning, vision, and TTS fine-tuning — all hands-on, all open-source.

I'll be documenting my learnings, experiments, and challenges daily.

---

### 🌐 Connect with Me

- [LinkedIn](https://www.linkedin.com/in/aasher-kamal/)
- [GitHub](https://github.com/aasherkamal216)
- [X (Twitter)](https://x.com/Aasher_Kamal)
- [Facebook](https://www.facebook.com/aasher.kamal)
- [Website](https://aasherkamal.framer.website/)

Let’s build and learn together! 💡

### Installation

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth

In [None]:
from unsloth import FastVisionModel
import torch

model, processor = FastVisionModel.from_pretrained(
    "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit",
    load_in_4bit = True, 
    max_seq_length = 2048,
    use_gradient_checkpointing = "unsloth",
)

In [None]:
model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers     = True,
    finetune_language_layers   = True,
    finetune_attention_modules = True,
    finetune_mlp_modules       = True,

    r = 32,
    lora_alpha = 32,
    lora_dropout = 0,
    bias = "none",
    random_state = 42,
    use_rslora = False,
    loftq_config = None,
    target_modules = "all-linear",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

### Loading Dataset

In [None]:
import json
from datasets import load_dataset
from PIL import Image
import io

# 1. Load the streaming dataset
base_url = "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_{i:06d}.tar"
num_shards = 46
urls = [base_url.format(i=i) for i in range(num_shards)]
full_dataset_stream = load_dataset("webdataset", data_files={"train": urls}, split="train", streaming=True)

# 2. Take a subset and load it into a Python list in memory.
subset_list = list(full_dataset_stream.take(2000))

# 3. Define your formatting function
instruction = "Generate a detailed, descriptive prompt for this image, suitable for a text-to-image model."

def convert_to_conversation(sample):
    image = sample["jpg"]
    prompt_text = sample["json"]["prompt"]

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": instruction},
                {"type": "image", "image": image}
            ]
        },
        {
            "role": "assistant",
            "content": [
                {"type": "text", "text": prompt_text}
            ]
        },
    ]
    # Return the dictionary in the format the trainer expects
    return {"messages": messages}

# 4. Use a list comprehension to create the final processed list
dataset = [convert_to_conversation(sample) for sample in subset_list]

# Let's verify that we have a live PIL Image object, NOT bytes
first_image_object = dataset[0]['messages'][0]['content'][1]['image']
print("Type of the image object in our final list:", type(first_image_object))


In [None]:
dataset[0]

### Chat Template

In [None]:
from unsloth import get_chat_template

processor = get_chat_template(
    processor,
    "gemma-3n"
)

### Training the Model

In [None]:
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig

FastVisionModel.for_training(model) # Enable for training!

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    processing_class=processor.tokenizer,
    data_collator=UnslothVisionDataCollator(model, processor),
    args = SFTConfig(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        gradient_checkpointing = True,

        # use reentrant checkpointing
        gradient_checkpointing_kwargs = {"use_reentrant": False},
        max_grad_norm = 0.3,              # max gradient norm based on QLoRA paper
        warmup_ratio = 0.1,
        num_train_epochs = 1,
        learning_rate = 2e-4,
        logging_steps = 1,
        save_strategy="steps",
        save_steps = 100,
        save_total_limit = 3,
        optim = "adamw_torch_fused",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 42,
        output_dir = "outputs",
        report_to = "none",             # For Weights and Biases

        # You MUST put the below items for vision finetuning:
        remove_unused_columns = False,
        dataset_text_field = "",
        dataset_kwargs = {"skip_prepare_dataset": True},
        max_length = 2048,
    )
)

In [None]:
trainer_stats = trainer.train()

----

### Inference

In [None]:
from transformers import TextStreamer
import requests
from PIL import Image
from io import BytesIO

FastVisionModel.for_inference(model)

image_url = "https://substackcdn.com/image/fetch/$s_!fCoi!,w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F81899f64-aaae-469c-9463-6f8be3f6b2ab_1920x1080.jpeg"
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))

instruction = "Generate a detailed, descriptive prompt for this image, suitable for a text-to-image model."

messages = [
    {
        "role": "user",
        "content": [{"type": "image"}, {"type": "text", "text": instruction}],
    }
]

In [None]:
image

In [None]:
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)

inputs = processor(
    image,
    input_text,
    add_special_tokens=False,
    return_tensors="pt",
).to("cuda")

text_streamer = TextStreamer(processor, skip_prompt=True)

result = model.generate(
    **inputs,
    streamer=text_streamer,
    max_new_tokens=256,
    temperature=0.8,
    top_p=0.9,
    )

### Saving the Model

In [None]:
from google.colab import userdata
hf_token = userdata.get('HF_TOKEN')

model.push_to_hub_merged("Aasher/Image2Prompt_Generator_Gemma_3n_2B", processor, token = hf_token)