<a href="https://colab.research.google.com/github/SashreekMallem/stitches-and-stories-model/blob/main/Untitled0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "offline"

# Now import the rest
import json
import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer

# Step 1: Load your dataset (text only)
with open("gemma_vision_training_data.json", "r") as f:
    data = json.load(f)
print(f"Loaded {len(data)} examples")

# Step 2: Prepare training data in text-only format
processed_data = []

for example in data:
    # Extract the prompt text and answer
    user_message = example["conversations"][0]
    assistant_message = example["conversations"][1]

    # Get the prompt text (ignore images)
    prompt_text = ""
    for content in user_message["content"]:
        if content["type"] == "text":
            prompt_text += content["text"]

    # Get the assistant's response
    response_text = ""
    for content in assistant_message["content"]:
        if content["type"] == "text":
            response_text += content["text"]

    # Create a simple example with just text
    processed_data.append({
        "prompt": f"Book Description Task: {prompt_text}",
        "response": response_text
    })

print(f"Processed {len(processed_data)} examples")

# Split into train and validation sets
train_size = int(0.8 * len(processed_data))
train_data = processed_data[:train_size]
val_data = processed_data[train_size:]

# Create HF datasets
train_dataset = Dataset.from_list(train_data)
val_dataset = Dataset.from_list(val_data)

print(f"Ready! Train: {len(train_dataset)}, Val: {len(val_dataset)}")

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Loaded 3356 examples
Processed 3356 examples
Ready! Train: 2684, Val: 672


In [2]:
import os
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "offline"

# Load the Gemma 3 4B model and tokenizer
model_id = "google/gemma-3-4b-it"  # Instruction-tuned model
tokenizer = AutoTokenizer.from_pretrained(model_id, token=True)
tokenizer.pad_token = tokenizer.eos_token

# Import BitsAndBytesConfig for quantization
from transformers import BitsAndBytesConfig

# Configure 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Load the model with quantization efficiency using BitsAndBytesConfig
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    token=True
)

# Define the LoRA configuration for Gemma 3 architecture
lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

# We'll apply LoRA through the SFTTrainer instead
# This way we ensure compatibility with the version of TRL being used
print(f"Will apply LoRA during training")

# Use TrainingArguments for configuration
from transformers import TrainingArguments

# Disable wandb completely
import os
os.environ["WANDB_DISABLED"] = "true"

# Set up basic training arguments - using older version compatible arguments
training_args = TrainingArguments(
    output_dir="book-damage-model",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    fp16=False,  # Use FP32 for compatibility
    logging_dir="./logs",
    logging_steps=10,
    save_steps=500,
    # Remove evaluation_strategy and related parameters as they're not available in older versions
    report_to=[],  # Empty list to disable all reporting
    run_name="book-damage-training"  # Set a different run_name to avoid the warning
)

# Prepare the dataset with simple text formatting for older versions of libraries
def format_for_training(example):
    # Older versions may not support chat templates, use a simple formatting approach
    return {
        "text": f"User: {example['prompt']}\nAssistant: {example['response']}"
    }

# Format datasets
train_dataset = train_dataset.map(
    format_for_training,
    remove_columns=train_dataset.column_names
)
val_dataset = val_dataset.map(
    format_for_training,
    remove_columns=val_dataset.column_names
)

# First apply LoRA directly (older TRL versions might not support passing peft_config)
model = get_peft_model(model, lora_config)
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

# Set up the trainer - simplified for compatibility
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    # If using older version, we can try without eval_dataset
    # eval_dataset=val_dataset,
)

print("Starting training...")
trainer.train()
print("Training complete!")

# Save the model
trainer.model.save_pretrained("./book-damage-model")
tokenizer.save_pretrained("./book-damage-model")
print("Model saved!")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

Will apply LoRA during training


Map:   0%|          | 0/2684 [00:00<?, ? examples/s]

Map:   0%|          | 0/672 [00:00<?, ? examples/s]

Trainable parameters: 32788480


Adding EOS to train dataset:   0%|          | 0/2684 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/2684 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/2684 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Starting training...


Step,Training Loss
10,0.0
20,0.0
30,0.0
40,0.0
50,0.0
60,0.0
70,0.0
80,0.0
90,0.0
100,0.0


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Step,Training Loss
10,0.0
20,0.0
30,0.0
40,0.0
50,0.0
60,0.0
70,0.0
80,0.0
90,0.0
100,0.0


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Training complete!
Model saved!


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
# Cell 5: Package and Download the Model
!zip -r book-damage-model.zip ./book-damage-model
from google.colab import files
files.download('book-damage-model.zip')

  adding: book-damage-model/ (stored 0%)
  adding: book-damage-model/chat_template.jinja (deflated 70%)
  adding: book-damage-model/added_tokens.json (stored 0%)
  adding: book-damage-model/adapter_model.safetensors (deflated 96%)
  adding: book-damage-model/tokenizer.model (deflated 52%)
  adding: book-damage-model/adapter_config.json (deflated 58%)
  adding: book-damage-model/tokenizer_config.json (deflated 97%)
  adding: book-damage-model/checkpoint-2013/ (stored 0%)
  adding: book-damage-model/checkpoint-2013/training_args.bin (deflated 52%)
  adding: book-damage-model/checkpoint-2013/chat_template.jinja (deflated 70%)
  adding: book-damage-model/checkpoint-2013/rng_state.pth (deflated 25%)
  adding: book-damage-model/checkpoint-2013/added_tokens.json (stored 0%)
  adding: book-damage-model/checkpoint-2013/adapter_model.safetensors (deflated 96%)
  adding: book-damage-model/checkpoint-2013/tokenizer.model (deflated 52%)
  adding: book-damage-model/checkpoint-2013/adapter_config.jso

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [2]:
!pip install -q transformers datasets accelerate
!pip install -qU bitsandbytes  # Explicitly upgrade bitsandbytes
!pip install -q peft trl
!huggingface-cli login  # Add this to log in to Hugging Face

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/376.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m368.6/376.2 kB[0m [31m14.0 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m376.2/376.2 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/494.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m494.8/494.8 kB[0m [31m26.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/193.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the fol

In [1]:
!pip uninstall -y bitsandbytes
!pip install bitsandbytes

[0mCollecting bitsandbytes
  Using cached bitsandbytes-0.46.1-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3,>=2.2->bitsandbytes)
  Using cached nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3,>=2.2->bitsandbytes)
  Using cached nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3,>=2.2->bitsandbytes)
  Using cached nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3,>=2.2->bitsandbytes)
  Using cached nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<3,>=2.2->bitsandbytes)
  Using cached nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collectin

In [None]:
import os
os.kill(os.getpid(), 9)

In [4]:
from google.colab import files

files.download("./book-damage-model/adapter_model.safetensors")
files.download("./book-damage-model/tokenizer_config.json")
files.download("./book-damage-model/tokenizer.json")
files.download("./book-damage-model/tokenizer.model")
files.download("./book-damage-model/special_tokens_map.json")
files.download("./book-damage-model/generation_config.json")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

FileNotFoundError: Cannot find file: ./book-damage-model/generation_config.json

In [6]:
import os
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "offline"

# Define a function to test the model - compatible with older versions
def test_model(question="Describe a book with a damaged spine and water damage"):
    # Add our context to make it clear we're talking about book damage assessment
    full_prompt = f"User: Book Description Task: {question}\nAssistant:"

    # Simple tokenization without chat templates for compatibility
    inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)

    # Generate
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_new_tokens=256,
            do_sample=True,
            temperature=0.7,
            top_p=0.9
        )

    # Decode only the generated part (length of input)
    input_length = inputs.input_ids.shape[1]
    response = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
    return response

# Example usage
result = test_model("Describe a book with water damage on the cover and torn pages")
print(result)

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [7]:
!zip -r book-damage-model.zip ./book-damage-model
from google.colab import files
files.download('book-damage-model.zip')

updating: book-damage-model/ (stored 0%)
updating: book-damage-model/chat_template.jinja (deflated 70%)
updating: book-damage-model/added_tokens.json (stored 0%)
updating: book-damage-model/adapter_model.safetensors (deflated 96%)
updating: book-damage-model/tokenizer.model (deflated 52%)
updating: book-damage-model/adapter_config.json (deflated 58%)
updating: book-damage-model/tokenizer_config.json (deflated 97%)
updating: book-damage-model/checkpoint-2013/ (stored 0%)
updating: book-damage-model/checkpoint-2013/training_args.bin (deflated 52%)
updating: book-damage-model/checkpoint-2013/chat_template.jinja (deflated 70%)
updating: book-damage-model/checkpoint-2013/rng_state.pth (deflated 25%)
updating: book-damage-model/checkpoint-2013/added_tokens.json (stored 0%)
updating: book-damage-model/checkpoint-2013/adapter_model.safetensors (deflated 96%)
updating: book-damage-model/checkpoint-2013/tokenizer.model (deflated 52%)
updating: book-damage-model/checkpoint-2013/adapter_config.jso

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [8]:
import os
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "offline"

# Define a function to test the model - compatible with older versions and avoids CUDA errors
def test_model(question="Describe a book with a damaged spine and water damage"):
    # Add our context to make it clear we're talking about book damage assessment
    full_prompt = f"User: Book Description Task: {question}\nAssistant:"

    # Simple tokenization without chat templates for compatibility
    inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)

    # Generate with greedy decoding to avoid CUDA errors with multinomial sampling
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_new_tokens=256,
            do_sample=False,  # Disable sampling to avoid CUDA errors
            num_beams=1,      # Use simple greedy search
            temperature=1.0   # No temperature when sampling is disabled
        )

    # Decode only the generated part (length of input)
    input_length = inputs.input_ids.shape[1]
    response = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
    return response

# Alternative test function in case the first one fails
def test_model_simple(question="Describe a book with a damaged spine and water damage"):
    """A simpler test function that uses pure logits to avoid CUDA errors"""
    try:
        # Add our context to make it clear we're talking about book damage assessment
        full_prompt = f"User: Book Description Task: {question}\nAssistant:"

        # Tokenize input
        input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(model.device)
        input_length = input_ids.shape[1]

        # Initialize with input
        generated = input_ids.clone()

        # Generate one token at a time (slower but more robust)
        max_new_tokens = 256
        for _ in range(max_new_tokens):
            # Get model output for the current sequence
            with torch.no_grad():
                outputs = model(generated)

            # Get the next token (greedy)
            next_token_id = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)

            # Append the new token
            generated = torch.cat([generated, next_token_id], dim=-1)

            # Stop if we generated EOS
            if next_token_id.item() == tokenizer.eos_token_id:
                break

        # Decode only the generated part
        response = tokenizer.decode(generated[0][input_length:], skip_special_tokens=True)
        return response
    except Exception as e:
        return f"Error generating with simple method: {str(e)}"

# Try both methods
try:
    print("Testing with standard generation:")
    result = test_model("Describe a book with water damage on the cover and torn pages")
    print(result)
except Exception as e:
    print(f"Error with standard test: {str(e)}")
    print("\nFalling back to simple generation:")
    result = test_model_simple("Describe a book with water damage on the cover and torn pages")
    print(result)

Testing with standard generation:
Error with standard test: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


Falling back to simple generation:
Error generating with simple method: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.



In [9]:
!ls -la
!ls -la ./book-damage-model

total 82776
drwxr-xr-x 1 root root     4096 Jul 15 21:00 .
drwxr-xr-x 1 root root     4096 Jul 15 18:39 ..
drwxr-xr-x 7 root root     4096 Jul 15 20:47 book-damage-model
-rw-r--r-- 1 root root 82720745 Jul 15 21:00 book-damage-model.zip
drwxr-xr-x 4 root root     4096 Jul 14 13:37 .config
drwx------ 5 root root     4096 Jul 15 20:53 drive
-rw-r--r-- 1 root root  2011309 Jul 15 18:42 gemma_vision_training_data.json
drwxr-xr-x 1 root root     4096 Jul 14 13:37 sample_data
total 166544
drwxr-xr-x 7 root root      4096 Jul 15 20:47 .
drwxr-xr-x 1 root root      4096 Jul 15 21:00 ..
-rw-r--r-- 1 root root       903 Jul 15 20:47 adapter_config.json
-rw-r--r-- 1 root root 131252288 Jul 15 20:47 adapter_model.safetensors
-rw-r--r-- 1 root root        35 Jul 15 20:47 added_tokens.json
-rw-r--r-- 1 root root      1532 Jul 15 20:47 chat_template.jinja
drwxr-xr-x 2 root root      4096 Jul 15 19:53 checkpoint-1000
drwxr-xr-x 2 root root      4096 Jul 15 20:19 checkpoint-1500
drwxr-xr-x 2 root root 

In [10]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, LoraConfig

# Define the model ID and the path to your saved fine-tuned model
model_id = "google/gemma-3-4b-it"
fine_tuned_model_path = "./book-damage-model"

# Load the base model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, token=True)
tokenizer.pad_token = tokenizer.eos_token

# Import BitsAndBytesConfig for quantization
from transformers import BitsAndBytesConfig

# Configure 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Load the base model with quantization
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    token=True
)

# Load the LoRA adapters
model = PeftModel.from_pretrained(base_model, fine_tuned_model_path)

print("Fine-tuned model loaded successfully.")

# Define a function to test the model
def test_model(question="Describe a book with a damaged spine and water damage"):
    full_prompt = f"User: Book Description Task: {question}\nAssistant:"

    inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)

    # Generate with greedy decoding (since sampling caused CUDA errors before)
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_new_tokens=256,
            do_sample=False,
            num_beams=1
        )

    input_length = inputs.input_ids.shape[1]
    response = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
    return response

# Example usage
print("\nTesting the model:")
result = test_model("Describe a book with water damage on the cover and torn pages")
print(result)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
