To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
<div class="align-center">
<a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
<a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
<a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
</div>

To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).

You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)


### News


Unsloth's [Docker image](https://hub.docker.com/r/unsloth/unsloth) is here! Start training with no setup & environment issues. [Read our Guide](https://docs.unsloth.ai/new/how-to-train-llms-with-unsloth-and-docker).

[gpt-oss RL](https://docs.unsloth.ai/new/gpt-oss-reinforcement-learning) is now supported with the fastest inference & lowest VRAM. Try our [new notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt-oss-(20B)-GRPO.ipynb) which creates kernels!

Introducing [Vision](https://docs.unsloth.ai/new/vision-reinforcement-learning-vlm-rl) and [Standby](https://docs.unsloth.ai/basics/memory-efficient-rl) for RL! Train Qwen, Gemma etc. VLMs with GSPO - even faster with less VRAM.

Unsloth now supports Text-to-Speech (TTS) models. Read our [guide here](https://docs.unsloth.ai/basics/text-to-speech-tts-fine-tuning).

Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).


### Installation

In [1]:
!pip install unsloth
# !pip install transformers==4.55.4
# !pip install --no-deps trl==0.22.2


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
# %%capture
# # These are mamba kernels and we must have these for faster training
# !pip install --no-build-isolation mamba_ssm==2.2.5
# !pip install --no-build-isolation causal_conv1d==1.5.2

### Unsloth

In [3]:
from unsloth import FastLanguageModel
import torch

fourbit_models = [
    "unsloth/granite-4.0-micro",
    "unsloth/granite-4.0-h-micro",
    "unsloth/granite-4.0-h-tiny",
    "unsloth/granite-4.0-h-small",

    # Base pretrained Granite 4 models
    "unsloth/granite-4.0-micro-base",
    "unsloth/granite-4.0-h-micro-base",
    "unsloth/granite-4.0-h-tiny-base",
    "unsloth/granite-4.0-h-small-base",

    # 4bit dynamic quants for superior accuracy and low memory use
    "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
    "unsloth/Phi-4",
    "unsloth/Llama-3.1-8B",
    "unsloth/Llama-3.2-3B",
    "unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit" # [NEW] We support TTS models!
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/granite-4.0-micro",
    max_seq_length = 2048,   # Choose any for long context!
    load_in_4bit = True,    # 4 bit quantization to reduce memory
    load_in_8bit = False,    # [NEW!] A bit more accurate, uses 2x memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
INFO 10-12 20:34:53 [__init__.py:216] Automatically detected platform cuda.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.10.1: Fast Granitemoehybrid patching. Transformers: 4.56.2. vLLM: 0.11.0.
   \\   /|    NVIDIA H100 80GB HBM3. Num GPUs = 1. Max memory: 79.179 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu129. CUDA: 9.0. CUDA Toolkit: 12.9. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


We now add LoRA adapters so we only need to update a small amount of parameters!

In [4]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",
                      "shared_mlp.input_linear", "shared_mlp.output_linear"],
    lora_alpha = 32,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

Unsloth: Making `model.base_model.model.model` require gradients


<a name="Data"></a>
### Data Prep
#### 📄 Using Google Sheets as Training Data
Our goal is to create a customer support bot that proactively helps and solves issues.

We’re storing examples in a Google Sheet with two columns:

- **Snippet**: A short customer support interaction
- **Recommendation**: A suggestion for how the agent should respond

This keeps things simple and collaborative. Anyone can edit the sheet, no database setup required.  
<br>

---
<br>

#### 🔍 Why This Format?

This setup works well for tasks like:

- `Input snippet → Suggested reply`
- `Prompt → Rewrite`
- `Bug report → Diagnosis`
- `Text → Label or Category`

Just collect examples in a spreadsheet, and you’ve got usable training data.  
<br>

---
<br>

#### ✅ What You'll Learn

We’ll show how to:

1. Load the Google Sheet into your notebook
2. Format it into a dataset
3. Use it to train or prompt an LLM


The chat template for granite-4 look like this:
```
<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.
Today's Date: June 24, 2025.
You are Granite, developed by IBM. You are a helpful AI assistant.<|end_of_text|>

<|start_of_role|>user<|end_of_role|>How do astronomers determine the original wavelength of light emitted by a celestial body at rest, which is necessary for measuring its speed using the Doppler effect?<|end_of_text|>

<|start_of_role|>assistant<|end_of_role|>Astronomers make use of the unique spectral fingerprints of elements found in stars...<|end_of_text|>
```

In [5]:
import pandas as pd
import numpy as np
import re
from datasets import Dataset

# Enhanced text cleaning function - extracts key features AND keeps full text
def clean_text_enhanced(text):
    if pd.isnull(text):
        return ""
    
    # Convert to string and clean basic issues
    text = str(text).strip()
    
    # Extract ALL structured information (not just top 3)
    item_name = re.search(r"Item Name:\s*(.*?)(?=\n|$)", text, re.IGNORECASE)
    brand = re.search(r"Brand:\s*(.*?)(?=\n|$)", text, re.IGNORECASE)
    color = re.search(r"Color:\s*(.*?)(?=\n|$)", text, re.IGNORECASE)
    size = re.search(r"Size:\s*(.*?)(?=\n|$)", text, re.IGNORECASE)
    material = re.search(r"Material:\s*(.*?)(?=\n|$)", text, re.IGNORECASE)
    model = re.search(r"Model:\s*(.*?)(?=\n|$)", text, re.IGNORECASE)
    
    # Extract bullet points (all of them)
    bp1 = re.search(r"Bullet Point\s*1:\s*(.*?)(?=\n|$)", text, re.IGNORECASE)
    bp2 = re.search(r"Bullet Point\s*2:\s*(.*?)(?=\n|$)", text, re.IGNORECASE)
    bp3 = re.search(r"Bullet Point\s*3:\s*(.*?)(?=\n|$)", text, re.IGNORECASE)
    bp4 = re.search(r"Bullet Point\s*4:\s*(.*?)(?=\n|$)", text, re.IGNORECASE)
    bp5 = re.search(r"Bullet Point\s*5:\s*(.*?)(?=\n|$)", text, re.IGNORECASE)
    
    # Extract value and unit
    value = re.search(r"Value:\s*([\d.,]+)", text, re.IGNORECASE)
    unit = re.search(r"Unit:\s*([A-Za-z]+)", text, re.IGNORECASE)
    
    # Extract description if present
    description = re.search(r"Description:\s*(.*?)(?=\n|$)", text, re.IGNORECASE)
    
    # Build structured output with KEY features first, then append everything else
    structured_parts = []
    
    # Top priority features (Item Name, Value, Unit)
    if item_name:
        structured_parts.append(f"Item: {item_name.group(1).strip()}")
    if value and unit:
        structured_parts.append(f"Quantity: {value.group(1).strip()} {unit.group(1).strip()}")
    elif value:
        structured_parts.append(f"Value: {value.group(1).strip()}")
    
    # Additional important features
    if brand:
        structured_parts.append(f"Brand: {brand.group(1).strip()}")
    if color:
        structured_parts.append(f"Color: {color.group(1).strip()}")
    if size:
        structured_parts.append(f"Size: {size.group(1).strip()}")
    if material:
        structured_parts.append(f"Material: {material.group(1).strip()}")
    if model:
        structured_parts.append(f"Model: {model.group(1).strip()}")
    
    # All bullet points
    if bp1:
        structured_parts.append(f"Feature 1: {bp1.group(1).strip()}")
    if bp2:
        structured_parts.append(f"Feature 2: {bp2.group(1).strip()}")
    if bp3:
        structured_parts.append(f"Feature 3: {bp3.group(1).strip()}")
    if bp4:
        structured_parts.append(f"Feature 4: {bp4.group(1).strip()}")
    if bp5:
        structured_parts.append(f"Feature 5: {bp5.group(1).strip()}")
    
    if description:
        structured_parts.append(f"Description: {description.group(1).strip()}")
    
    # Join structured parts
    cleaned_text = ". ".join(structured_parts)
    
    # IMPORTANT: Append the FULL original text (cleaned) so nothing is lost
    # This ensures ALL information is available to the model
    full_text_cleaned = text.lower()
    full_text_cleaned = re.sub(r'[^\w\s.,:\-]', ' ', full_text_cleaned)
    full_text_cleaned = re.sub(r'\s+', ' ', full_text_cleaned)
    full_text_cleaned = full_text_cleaned.strip()
    
    # Combine: structured features first, then full text for additional context
    if cleaned_text and full_text_cleaned:
        final_text = f"{cleaned_text}. Full Details: {full_text_cleaned}"
    elif cleaned_text:
        final_text = cleaned_text
    else:
        final_text = full_text_cleaned
    
    return final_text

print("Loading training data from dataset/train.csv...")
train_df = pd.read_csv('/root/train.csv', encoding='latin1')

print(f"Original data shape: {train_df.shape}")
print(f"Columns: {train_df.columns.tolist()}")

# Apply text cleaning
print("\nApplying enhanced text cleaning...")
train_df['catalog_content'] = train_df['catalog_content'].apply(clean_text_enhanced)

# Filter out empty or very short text
train_df['text_length'] = train_df['catalog_content'].str.len()
train_df = train_df[train_df['text_length'] > 10].copy()

print(f"Data shape after cleaning: {train_df.shape}")
print(f"\nPrice statistics:")
print(train_df['price'].describe())

# Convert to HuggingFace Dataset format
dataset = Dataset.from_pandas(train_df[['catalog_content', 'price']])

print(f"\n✅ Dataset loaded: {len(dataset)} samples")

Loading training data from dataset/train.csv...
Original data shape: (75000, 4)
Columns: ['sample_id', 'catalog_content', 'image_link', 'price']

Applying enhanced text cleaning...
Data shape after cleaning: (75000, 5)

Price statistics:
count    75000.000000
mean        23.647654
std         33.376932
min          0.130000
25%          6.795000
50%         14.000000
75%         28.625000
max       2796.000000
Name: price, dtype: float64

✅ Dataset loaded: 75000 samples


We've just loaded the Google Sheet as a csv style Dataset, but we still need to format it into conversational style like below and then apply the chat template.

```
{"role": "system", "content": "You are an assistant"}
{"role": "user", "content": "What is 2+2?"}
{"role": "assistant", "content": "It's 4."}
```

We'll use a helper function `formatting_prompts_func` to do both!

In [6]:
tokenizer.chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"

def formatting_prompts_func(examples):
    catalog_texts = examples['catalog_content']
    prices = examples['price']
    
    messages = [
        [{"role": "user", "content": f"Predict the price for this product: {catalog_text}"},
         {"role": "assistant", "content": f"The predicted price is ${price:.2f}"}] 
        for catalog_text, price in zip(catalog_texts, prices)
    ]
    
    # This will now work correctly
    texts = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False) 
             for message in messages]
    
    return {"text": texts}

print("Formatting dataset with chat template...")
dataset = dataset.map(formatting_prompts_func, batched=True)
print(f"✅ Dataset formatted: {len(dataset)} samples")

Formatting dataset with chat template...


Map:   0%|          | 0/75000 [00:00<?, ? examples/s]

✅ Dataset formatted: 75000 samples


We now look at the raw input data before formatting.

In [7]:
# Show raw catalog content before formatting
print("Sample catalog content:")
print(dataset[5]["catalog_content"][:500])  # Show first 500 chars

Sample catalog content:
Item: Member's Mark Member's Mark, Basil, 6.25 oz. Quantity: 6.25 ounce. Feature 1: Green Herb, Italian Staple, Great mixed with Oregano. Feature 2: Large Size, Chef Bottle. Feature 3: Packed in the USA. Full Details: item name: member s mark member s mark, basil, 6.25 oz bullet point 1: green herb, italian staple, great mixed with oregano bullet point 2: large size, chef bottle bullet point 3: packed in the usa value: 6.25 unit: ounce


In [8]:
# Show the corresponding price
print("Sample price:")
print(f"${dataset[5]['price']:.2f}")

Sample price:
$18.50


And we see how the chat template transformed these conversations.

In [9]:
dataset[5]["text"]

"<|end_of_text|>[INST] Predict the price for this product: Item: Member's Mark Member's Mark, Basil, 6.25 oz. Quantity: 6.25 ounce. Feature 1: Green Herb, Italian Staple, Great mixed with Oregano. Feature 2: Large Size, Chef Bottle. Feature 3: Packed in the USA. Full Details: item name: member s mark member s mark, basil, 6.25 oz bullet point 1: green herb, italian staple, great mixed with oregano bullet point 2: large size, chef bottle bullet point 3: packed in the usa value: 6.25 unit: ounce [/INST]The predicted price is $18.50<|end_of_text|>"

<a name="Train"></a>
### Train the model
Now let's train our model. We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`.

In [10]:
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    eval_dataset = None, # Can set up evaluation!
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 35,
        gradient_accumulation_steps = 4, # Use GA to mimic batch size!
        warmup_steps = 5,
        num_train_epochs = 2, # Set this for 1 full training run.
        # max_steps = 60,
        learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        report_to = "none", # Use this for WandB etc
    ),
)

Unsloth: Tokenizing ["text"] (num_proc=97):   0%|          | 0/75000 [00:00<?, ? examples/s]

Detected kernel version 4.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


We also use Unsloth's `train_on_completions` method to only train on the assistant outputs and ignore the loss on the user's inputs. This helps increase accuracy of finetunes!

In [11]:
# from unsloth.chat_templates import train_on_responses_only
# trainer = train_on_responses_only(
#     trainer,
#     instruction_part = "<|start_of_role|>user<|end_of_role|>",
#     response_part = "<|start_of_role|>assistant<|end_of_role|>",
# )

Let's verify masking the instruction part is done! Let's print the 100th row again.

In [12]:
# Verify the full formatted text (input_ids)
if len(trainer.train_dataset) > 100:
    print("Full formatted example:")
    print(tokenizer.decode(trainer.train_dataset[100]["input_ids"]))
else:
    print(f"Dataset only has {len(trainer.train_dataset)} samples. Showing first sample:")
    print(tokenizer.decode(trainer.train_dataset[0]["input_ids"]))

Full formatted example:
<|end_of_text|>[INST] Predict the price for this product: Item: Amazon Grocery, Lemonade Drink Mix, 10 packets, 1.4 Oz (Previously Happy Belly, Packaging May Vary). Quantity: 1.4 Ounce. Feature 1: 10 packets of Lemonade Drink Mix. Feature 2: Some of your favorite Happy Belly products are now part of the Amazon Grocery brand! Although packaging may vary during the transition, the ingredients and product remain the same. Thank you for your continued trust in our brands. Feature 3: Sugar Free, Low Sodium. Feature 4: 10 calories per serving. Feature 5: Amazon Grocery has all the favorites you love for less. Youâll find everything you need for great-tasting meals in one shopping trip. Description: 10 packets of Lemonade Drink Mix. Full Details: item name: amazon grocery, lemonade drink mix, 10 packets, 1.4 oz previously happy belly, packaging may vary bullet point 1: 10 packets of lemonade drink mix bullet point 2: some of your favorite happy belly products are now

Now let's print the masked out example - you should see only the answer is present:

In [13]:
# Now let's print the masked out example - you should see only the assistant response
if len(trainer.train_dataset) > 100:
    sample_idx = 100
else:
    sample_idx = 0

if "labels" in trainer.train_dataset[sample_idx]:
    masked_labels = [tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[sample_idx]["labels"]]
    decoded = tokenizer.decode(masked_labels)
    if tokenizer.pad_token:
        decoded = decoded.replace(tokenizer.pad_token, " ")
    print("Masked output (only assistant response should be visible):")
    print(decoded)
else:
    print("Labels field not found. The masking will be applied during training.")

Labels field not found. The masking will be applied during training.


In [14]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA H100 80GB HBM3. Max memory = 79.179 GB.
3.289 GB of memory reserved.


Let's train the model! To resume a training run, set `trainer.train(resume_from_checkpoint = True)`

```
Notice you might have to wait ~10 minutes for the Mamba kernels to compile! Please be patient!
```

In [15]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 75,000 | Num Epochs = 2 | Total steps = 1,072
O^O/ \_/ \    Batch size per device = 35 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (35 x 4 x 1) = 140
 "-____-"     Trainable parameters = 58,982,400 of 3,461,818,880 (1.70% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
1,1.8164
2,1.891
3,1.723
4,1.5817
5,1.4925
6,1.4194
7,1.3003
8,1.1951
9,1.1568
10,1.0876


<a name="Inference"></a>
### Inference
Let's run the model via Unsloth native inference! We'll use some example snippets not contained in our training data to get a sense of what was learned.

In [16]:
# Create a fast vLLM inference script
vllm_script = '''
import pandas as pd
import numpy as np
import re
from vllm import LLM, SamplingParams
from tqdm import tqdm

# Same text cleaning function
def clean_text_enhanced(text):
    if pd.isnull(text):
        return ""
    
    text = str(text).strip()
    
    # Extract structured information
    item_name = re.search(r"Item Name:\\s*(.*?)(?=\\n|$)", text, re.IGNORECASE)
    brand = re.search(r"Brand:\\s*(.*?)(?=\\n|$)", text, re.IGNORECASE)
    color = re.search(r"Color:\\s*(.*?)(?=\\n|$)", text, re.IGNORECASE)
    size = re.search(r"Size:\\s*(.*?)(?=\\n|$)", text, re.IGNORECASE)
    material = re.search(r"Material:\\s*(.*?)(?=\\n|$)", text, re.IGNORECASE)
    model = re.search(r"Model:\\s*(.*?)(?=\\n|$)", text, re.IGNORECASE)
    
    bp1 = re.search(r"Bullet Point\\s*1:\\s*(.*?)(?=\\n|$)", text, re.IGNORECASE)
    bp2 = re.search(r"Bullet Point\\s*2:\\s*(.*?)(?=\\n|$)", text, re.IGNORECASE)
    bp3 = re.search(r"Bullet Point\\s*3:\\s*(.*?)(?=\\n|$)", text, re.IGNORECASE)
    bp4 = re.search(r"Bullet Point\\s*4:\\s*(.*?)(?=\\n|$)", text, re.IGNORECASE)
    bp5 = re.search(r"Bullet Point\\s*5:\\s*(.*?)(?=\\n|$)", text, re.IGNORECASE)
    
    value = re.search(r"Value:\\s*([\\d.,]+)", text, re.IGNORECASE)
    unit = re.search(r"Unit:\\s*([A-Za-z]+)", text, re.IGNORECASE)
    description = re.search(r"Description:\\s*(.*?)(?=\\n|$)", text, re.IGNORECASE)
    
    structured_parts = []
    
    if item_name:
        structured_parts.append(f"Item: {item_name.group(1).strip()}")
    if value and unit:
        structured_parts.append(f"Quantity: {value.group(1).strip()} {unit.group(1).strip()}")
    elif value:
        structured_parts.append(f"Value: {value.group(1).strip()}")
    
    if brand:
        structured_parts.append(f"Brand: {brand.group(1).strip()}")
    if color:
        structured_parts.append(f"Color: {color.group(1).strip()}")
    if size:
        structured_parts.append(f"Size: {size.group(1).strip()}")
    if material:
        structured_parts.append(f"Material: {material.group(1).strip()}")
    if model:
        structured_parts.append(f"Model: {model.group(1).strip()}")
    
    if bp1:
        structured_parts.append(f"Feature 1: {bp1.group(1).strip()}")
    if bp2:
        structured_parts.append(f"Feature 2: {bp2.group(1).strip()}")
    if bp3:
        structured_parts.append(f"Feature 3: {bp3.group(1).strip()}")
    if bp4:
        structured_parts.append(f"Feature 4: {bp4.group(1).strip()}")
    if bp5:
        structured_parts.append(f"Feature 5: {bp5.group(1).strip()}")
    
    if description:
        structured_parts.append(f"Description: {description.group(1).strip()}")
    
    cleaned_text = ". ".join(structured_parts)
    
    full_text_cleaned = text.lower()
    full_text_cleaned = re.sub(r\'[^\\w\\s.,:\\-]\', \' \', full_text_cleaned)
    full_text_cleaned = re.sub(r\'\\s+\', \' \', full_text_cleaned)
    full_text_cleaned = full_text_cleaned.strip()
    
    if cleaned_text and full_text_cleaned:
        final_text = f"{cleaned_text}. Full Details: {full_text_cleaned}"
    elif cleaned_text:
        final_text = cleaned_text
    else:
        final_text = full_text_cleaned
    
    return final_text

print("🚀 Loading model with vLLM...")
llm = LLM(
    model="granite_price_predictor_vllm",
    tensor_parallel_size=1,  # Adjust based on your GPU setup
    max_model_len=2048,
    gpu_memory_utilization=0.9,
    trust_remote_code=True
)

print("📂 Loading test data...")
test_df = pd.read_csv(\'dataset/test.csv\', encoding=\'latin1\')
print(f"Test data shape: {test_df.shape}")

# Clean text
print("🧹 Cleaning text...")
test_df[\'catalog_content\'] = test_df[\'catalog_content\'].apply(clean_text_enhanced)

# Create prompts
print("📝 Creating prompts...")
prompts = [
    f"<|start_of_role|>user<|end_of_role|>Predict the price for this product: {text}<|end_of_text|>\\n<|start_of_role|>assistant<|end_of_role|>"
    for text in test_df[\'catalog_content\']
]

# Sampling parameters for deterministic output
sampling_params = SamplingParams(
    temperature=0.1,
    top_p=0.95,
    max_tokens=64,
    stop=["<|end_of_text|>", "\\n\\n"]
)

print(f"\\n⚡ Generating predictions for {len(prompts)} samples with vLLM...")
print("This should be MUCH faster than one-by-one generation!\\n")

# Batch inference - THIS IS THE KEY!
outputs = llm.generate(prompts, sampling_params)

# Extract prices
print("💰 Extracting prices from predictions...")
all_predictions = []

for i, output in enumerate(tqdm(outputs, desc="Processing outputs")):
    predicted_text = output.outputs[0].text
    
    # Extract price from text
    price_match = re.search(r\'\\$(\\d+\\.?\\d*)|price is (\\d+\\.?\\d*)\', predicted_text, re.IGNORECASE)
    
    if price_match:
        price = float(price_match.group(1) or price_match.group(2))
    else:
        # Fallback
        price = 50.0
    
    all_predictions.append(price)

# Create submission
print("\\n💾 Creating submission file...")
submission = pd.DataFrame({
    \'sample_id\': test_df[\'sample_id\'],
    \'price\': all_predictions
})

submission.to_csv(\'submission_granite_vllm.csv\', index=False)

print(f"\\n✅ Submission saved to submission_granite_vllm.csv")
print(f"Shape: {submission.shape}")
print(f"\\nPrice statistics:")
print(submission[\'price\'].describe())
print(f"\\n🎉 Done! Predictions completed in minutes instead of hours!")
'''

# Save the script
with open('vllm_inference.py', 'w') as f:
    f.write(vllm_script)

print("✅ vLLM inference script saved to 'vllm_inference.py'")
print("\n📋 To run fast inference:")
print("1. First, complete training and model saving (cells above)")
print("2. Install vLLM: pip install vllm")
print("3. Run: python vllm_inference.py")
print("\n⚡ This will generate predictions in MINUTES instead of 30+ hours!")

✅ vLLM inference script saved to 'vllm_inference.py'

📋 To run fast inference:
1. First, complete training and model saving (cells above)
2. Install vLLM: pip install vllm
3. Run: python vllm_inference.py

⚡ This will generate predictions in MINUTES instead of 30+ hours!


In [17]:
# Save LoRA adapters first (lightweight backup)
model.save_pretrained("/mnt/neel-amazon/lora_model")
tokenizer.save_pretrained("/mnt/neel-amazon/lora_model")
print("✅ LoRA adapters saved to 'lora_model/'")

# IMPORTANT: Merge and save to 16-bit for vLLM inference
print("\n🔄 Merging LoRA weights and saving for vLLM...")
print("This may take a few minutes...")
model.save_pretrained_merged("/mnt/neel-amazon/granite_price_predictor_vllm", tokenizer, save_method="merged_16bit")
print("✅ Model saved in vLLM-compatible format to 'granite_price_predictor_vllm/'")
print("\n⚡ Ready for fast batched inference with vLLM!")

✅ LoRA adapters saved to 'lora_model/'

🔄 Merging LoRA weights and saving for vLLM...
This may take a few minutes...
Found HuggingFace hub cache directory: /root/.cache/huggingface/hub


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Checking cache directory for required files...
Cache check failed: model-00001-of-00002.safetensors not found in local cache.
Not all required files found in cache. Will proceed with downloading.


Unsloth: Preparing safetensor model files:   0%|                                         | 0/2 [00:00<?, ?it/s]Unsloth: Preparing safetensor model files: 100%|███████████████████████████████| 2/2 [00:00<00:00, 3929.09it/s]
Unsloth: Merging weights into 16bit:   0%|                                               | 0/2 [00:00<?, ?it/s]Unsloth: Merging weights into 16bit:  50%|███████████████████▌                   | 1/2 [00:13<00:13, 13.07s/it]Unsloth: Merging weights into 16bit: 100%|███████████████████████████████████████| 2/2 [00:17<00:00,  7.91s/it]Unsloth: Merging weights into 16bit: 100%|███████████████████████████████████████| 2/2 [00:17<00:00,  8.68s/it]


Unsloth: Merge process complete. Saved to `/mnt/neel-amazon/granite_price_predictor_vllm`
✅ Model saved in vLLM-compatible format to 'granite_price_predictor_vllm/'

⚡ Ready for fast batched inference with vLLM!


In [18]:
# Install vLLM for fast batched inference
!pip install vllm -q
print("✅ vLLM installed successfully!")


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
✅ vLLM installed successfully!


In [19]:
from vllm import LLM, SamplingParams
import pandas as pd
import numpy as np
import re
from tqdm.auto import tqdm
import time

print("🚀 FAST BATCHED INFERENCE WITH vLLM")
print("="*60)

# Load model with vLLM
print("\n📦 Loading model with vLLM...")
print("This will take a minute to initialize...\n")

llm = LLM(
    model="/mnt/neel-amazon/granite_price_predictor_vllm",
    tensor_parallel_size=1,  # Use 1 GPU, increase if you have multiple
    # max_model_len=3033,
    gpu_memory_utilization=0.8,  # Use 90% of GPU memory
    trust_remote_code=True,
    dtype="float16"
)

print("✅ Model loaded successfully!\n")

🚀 FAST BATCHED INFERENCE WITH vLLM

📦 Loading model with vLLM...
This will take a minute to initialize...

INFO 10-13 00:58:37 [utils.py:233] non-default args: {'trust_remote_code': True, 'dtype': 'float16', 'gpu_memory_utilization': 0.8, 'disable_log_stats': True, 'model': '/mnt/neel-amazon/granite_price_predictor_vllm'}


The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


INFO 10-13 00:58:37 [model.py:547] Resolved architecture: GraniteMoeHybridForCausalLM


`torch_dtype` is deprecated! Use `dtype` instead!


INFO 10-13 00:58:37 [model.py:1510] Using max model len 131072
INFO 10-13 00:58:37 [scheduler.py:205] Chunked prefill is enabled with max_num_batched_tokens=16384.
INFO 10-13 00:58:37 [config.py:297] Hybrid or mamba-based model detected: disabling prefix caching since it is not yet supported.
INFO 10-13 00:58:37 [config.py:308] Hybrid or mamba-based model detected: setting cudagraph mode to FULL_AND_PIECEWISE in order to optimize performance.
INFO 10-13 00:58:37 [config.py:376] Setting attention block size to 1312 tokens to ensure that attention page size is >= mamba page size.
INFO 10-13 00:58:37 [config.py:397] Padding mamba page size by 1.20% to ensure that mamba page size and attention page size are exactly equal.
INFO 10-13 00:58:44 [__init__.py:216] Automatically detected platform cuda.
[1;36m(EngineCore_DP0 pid=4328)[0;0m INFO 10-13 00:58:45 [core.py:644] Waiting for init message from front-end.
[1;36m(EngineCore_DP0 pid=4328)[0;0m INFO 10-13 00:58:45 [core.py:77] Initializi

[1;36m(EngineCore_DP0 pid=4328)[0;0m Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
[1;36m(EngineCore_DP0 pid=4328)[0;0m Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:03<00:03,  3.72s/it]
[1;36m(EngineCore_DP0 pid=4328)[0;0m Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:05<00:00,  2.41s/it]
[1;36m(EngineCore_DP0 pid=4328)[0;0m Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:05<00:00,  2.60s/it]
[1;36m(EngineCore_DP0 pid=4328)[0;0m 


[1;36m(EngineCore_DP0 pid=4328)[0;0m INFO 10-13 00:58:54 [default_loader.py:267] Loading weights took 5.23 seconds
[1;36m(EngineCore_DP0 pid=4328)[0;0m INFO 10-13 00:58:55 [gpu_model_runner.py:2653] Model loading took 6.3599 GiB and 5.543192 seconds
[1;36m(EngineCore_DP0 pid=4328)[0;0m INFO 10-13 00:59:04 [backends.py:548] Using cache directory: /root/.cache/vllm/torch_compile_cache/f7381a4b43/rank_0_0/backbone for vLLM's torch.compile
[1;36m(EngineCore_DP0 pid=4328)[0;0m INFO 10-13 00:59:04 [backends.py:559] Dynamo bytecode transform time: 8.96 s


[1;36m(EngineCore_DP0 pid=4328)[0;0m [rank0]:W1013 00:59:04.853000 4328 site-packages/torch/_inductor/remote_cache.py:356] [0/0] Unable to create a remote cache
[1;36m(EngineCore_DP0 pid=4328)[0;0m [rank0]:W1013 00:59:04.853000 4328 site-packages/torch/_inductor/remote_cache.py:356] [0/0] Traceback (most recent call last):
[1;36m(EngineCore_DP0 pid=4328)[0;0m [rank0]:W1013 00:59:04.853000 4328 site-packages/torch/_inductor/remote_cache.py:356] [0/0]   File "/usr/local/lib/python3.12/site-packages/torch/_inductor/remote_cache.py", line 353, in create_cache
[1;36m(EngineCore_DP0 pid=4328)[0;0m [rank0]:W1013 00:59:04.853000 4328 site-packages/torch/_inductor/remote_cache.py:356] [0/0]     return cache_cls(key)
[1;36m(EngineCore_DP0 pid=4328)[0;0m [rank0]:W1013 00:59:04.853000 4328 site-packages/torch/_inductor/remote_cache.py:356] [0/0]            ^^^^^^^^^^^^^^
[1;36m(EngineCore_DP0 pid=4328)[0;0m [rank0]:W1013 00:59:04.853000 4328 site-packages/torch/_inductor/remote_cache.p

[1;36m(EngineCore_DP0 pid=4328)[0;0m INFO 10-13 00:59:09 [backends.py:164] Directly load the compiled graph(s) for dynamic shape from the cache, took 4.593 s
[1;36m(EngineCore_DP0 pid=4328)[0;0m INFO 10-13 00:59:10 [monitor.py:34] torch.compile takes 8.96 s in total
[1;36m(EngineCore_DP0 pid=4328)[0;0m INFO 10-13 00:59:11 [gpu_worker.py:298] Available KV cache memory: 53.08 GiB
[1;36m(EngineCore_DP0 pid=4328)[0;0m INFO 10-13 00:59:11 [kv_cache_utils.py:1087] GPU KV cache size: 695,360 tokens
[1;36m(EngineCore_DP0 pid=4328)[0;0m INFO 10-13 00:59:11 [kv_cache_utils.py:1091] Maximum concurrency for 131,072 tokens per request: 5.30x


[1;36m(EngineCore_DP0 pid=4328)[0;0m Capturing CUDA graphs (mixed prefill-decode, PIECEWISE):   0%|                          | 0/67 [00:00<?, ?it/s]Capturing CUDA graphs (mixed prefill-decode, PIECEWISE):   1%|▎                 | 1/67 [00:00<00:12,  5.26it/s]Capturing CUDA graphs (mixed prefill-decode, PIECEWISE):   3%|▌                 | 2/67 [00:00<00:11,  5.78it/s]Capturing CUDA graphs (mixed prefill-decode, PIECEWISE):   4%|▊                 | 3/67 [00:00<00:10,  5.98it/s]Capturing CUDA graphs (mixed prefill-decode, PIECEWISE):   6%|█                 | 4/67 [00:00<00:10,  6.12it/s]Capturing CUDA graphs (mixed prefill-decode, PIECEWISE):   7%|█▎                | 5/67 [00:00<00:10,  5.77it/s]Capturing CUDA graphs (mixed prefill-decode, PIECEWISE):   9%|█▌                | 6/67 [00:01<00:10,  5.88it/s]Capturing CUDA graphs (mixed prefill-decode, PIECEWISE):  10%|█▉                | 7/67 [00:01<00:10,  5.98it/s]Capturing CUDA graphs (mixed prefill-decode, PIECEWISE):  12%|██

[1;36m(EngineCore_DP0 pid=4328)[0;0m INFO 10-13 00:59:29 [gpu_model_runner.py:3480] Graph capturing finished in 17 secs, took 1.25 GiB
[1;36m(EngineCore_DP0 pid=4328)[0;0m INFO 10-13 00:59:29 [core.py:210] init engine (profile, create kv cache, warmup model) took 34.06 seconds
INFO 10-13 00:59:30 [llm.py:306] Supported_tasks: ['generate']
✅ Model loaded successfully!



In [20]:
# Load test data
print("📂 Loading test data...")
test_df = pd.read_csv('/root/test.csv', encoding='latin1')
print(f"   Test samples: {len(test_df):,}")

# Apply same text cleaning
print("\n🧹 Cleaning text...")
test_df['catalog_content_cleaned'] = test_df['catalog_content'].apply(clean_text_enhanced)

# Create prompts in Granite format
print("\n📝 Creating prompts...")
prompts = [
    f"<|start_of_role|>user<|end_of_role|>Predict the price for this product: {text}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>"
    for text in test_df['catalog_content_cleaned']
]

print(f"   Created {len(prompts):,} prompts")
print(f"\n✅ Data prepared for inference")

📂 Loading test data...
   Test samples: 75,000

🧹 Cleaning text...

📝 Creating prompts...
   Created 75,000 prompts

✅ Data prepared for inference


In [21]:
# Sampling parameters
sampling_params = SamplingParams(
    temperature=0.1,  # Low temperature for consistent outputs
    top_p=0.95,
    max_tokens=100,    # Enough for "The predicted price is $XX.XX"
    stop=["<|end_of_text|>", "\n\n"]  # Stop tokens
)

print("\n⚡ RUNNING BATCHED INFERENCE WITH vLLM")
print("="*60)
print(f"Processing {len(prompts):,} samples...\n")

start_time = time.time()

# THE KEY: Batched generation - processes ALL prompts efficiently!
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)

end_time = time.time()
total_time = end_time - start_time

print(f"\n✅ Inference complete!")
print(f"   Total time: {total_time/60:.1f} minutes")
print(f"   Speed: {len(prompts)/total_time:.1f} samples/second")
print(f"\n🎉 That's {30*60/total_time:.0f}x faster than one-by-one!")


⚡ RUNNING BATCHED INFERENCE WITH vLLM
Processing 75,000 samples...



Adding requests:   0%|          | 0/75000 [00:00<?, ?it/s]

Processed prompts:   0%|         | 0/75000 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…


✅ Inference complete!
   Total time: 7.4 minutes
   Speed: 169.2 samples/second

🎉 That's 4x faster than one-by-one!


In [22]:
# Extract prices from outputs
print("\n💰 Extracting prices from predictions...")
all_predictions = []

for output in tqdm(outputs, desc="Processing outputs"):
    predicted_text = output.outputs[0].text
    
    # Extract price from text (patterns: $XX.XX or "price is XX.XX")
    price_match = re.search(r'\$(\d+\.?\d*)|price is (\d+\.?\d*)|predicted price is (\d+\.?\d*)', 
                           predicted_text, re.IGNORECASE)
    
    if price_match:
        # Get the first non-None group
        price = float([g for g in price_match.groups() if g is not None][0])
    else:
        # Fallback to median price if parsing fails
        price = 50.0
    
    # Ensure reasonable price range
    price = np.clip(price, 0.01, 10000.0)
    all_predictions.append(price)

print(f"✅ Extracted {len(all_predictions):,} prices")


💰 Extracting prices from predictions...


Processing outputs:   0%|          | 0/75000 [00:00<?, ?it/s]

✅ Extracted 75,000 prices


In [23]:
# Create submission DataFrame
print("\n📊 Creating submission DataFrame...")
submission = pd.DataFrame({
    'sample_id': test_df['sample_id'],
    'price': all_predictions
})

# Save submission
submission_file = '/mnt/neel-amazon/submission_granite_vllm.csv'
submission.to_csv(submission_file, index=False)

print(f"\n✅ Submission saved to: {submission_file}")
print(f"   Shape: {submission.shape}")
print(f"\n📈 Price Statistics:")
print(submission['price'].describe())

print("\n" + "="*60)
print("🎉 FAST INFERENCE COMPLETE!")
print("="*60)
print(f"Generated {len(submission):,} predictions in {total_time/60:.1f} minutes")
print(f"Average: {total_time/len(submission):.3f} seconds per sample")
print("\n🚀 Ready for submission!")


📊 Creating submission DataFrame...

✅ Submission saved to: /mnt/neel-amazon/submission_granite_vllm.csv
   Shape: (75000, 2)

📈 Price Statistics:
count    75000.000000
mean        15.125915
std         18.572606
min          0.240000
25%          4.990000
50%          8.990000
75%         19.990000
max        399.990000
Name: price, dtype: float64

🎉 FAST INFERENCE COMPLETE!
Generated 75,000 predictions in 7.4 minutes
Average: 0.006 seconds per sample

🚀 Ready for submission!


In [24]:
# Show 5 random samples
import random

print("Sample Predictions:\n" + "="*80)

for i in random.sample(range(len(test_df)), min(5, len(test_df))):
    print(f"\nSample ID: {test_df.iloc[i]['sample_id']}")
    print(f"Catalog (first 150 chars): {test_df.iloc[i]['catalog_content'][:150]}...")
    print(f"Cleaned text (first 150 chars): {test_df.iloc[i]['catalog_content_cleaned'][:150]}...")
    print(f"Model output: {outputs[i].outputs[0].text}")
    print(f"Extracted price: ${all_predictions[i]:.2f}")
    print("-"*80)

Sample Predictions:

Sample ID: 144705
Catalog (first 150 chars): Item Name: evian Natural Spring Water 500 ml, 16.9 Ounce, 6 Count, Bottled Naturally Filtered Spring Water in Individual-Sized Bottles
Bullet Point 1:...
Cleaned text (first 150 chars): Item: evian Natural Spring Water 500 ml, 16.9 Ounce, 6 Count, Bottled Naturally Filtered Spring Water in Individual-Sized Bottles. Quantity: 101.4 oun...
Model output: The predicted price is $5.99
Extracted price: $5.99
--------------------------------------------------------------------------------

Sample ID: 45028
Catalog (first 150 chars): Item Name: Betty Crocker Super Moist Cake Mix, Butter Recipe Yellow - 15.25 oz box, 2 pack
Bullet Point 1: 2 Boxes 15.25 Oz. Butter Yellow - There's P...
Cleaned text (first 150 chars): Item: Betty Crocker Super Moist Cake Mix, Butter Recipe Yellow - 15.25 oz box, 2 pack. Quantity: 30.5 Ounce. Feature 1: 2 Boxes 15.25 Oz. Butter Yello...
Model output: The predicted price is $3.98
Extracted price: $3