<a href="https://colab.research.google.com/github/2525tanuki/gemma3-4b-it-fine-tuning/blob/main/sft_latest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install datasets
!pip install bitsandbytes
!pip install "torchao>=0.4.0"
!pip install flash-attn

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.w

In [2]:
from __future__ import annotations

import os
from typing import Callable
import requests
from io import BytesIO

import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, PeftModel
from datasets import load_dataset
import datasets
from PIL import Image
import wandb
from wandb import UsageError, CommError

Image.MAX_IMAGE_PIXELS = None # PIL.Image.DecompressionBombError: Image size (181764096 pixels) exceeds limit of 178956970 pixels, could be decompression bomb DOS attack.

def create_collate_fn(processor: AutoProcessor) -> Callable[[list[dict]], dict[str, torch.Tensor]]:
    """
    processor をクロージャでキャプチャする collate_fn を返す関数です。

    前提:
      - データセットの「url」カラムは、すでに cast_column("url", Image()) により PIL.Image に変換済みです。
      - 各サンプルは、"question" と "category" のフィールドを持ちます。

    Returns:
      バッチ辞書（テキスト・画像データをトークナイズ済みの torch.Tensor として返します）。
    """
    def collate_fn(examples: list[dict]) -> dict[str, torch.Tensor]:
        texts: list[str] = []
        images: list[list[Image.Image]] = []

        for ex in examples:
            # 「url」カラムはキャスト済みなので、直接 PIL.Image として扱えます（RGB変換も実施）
            img: Image.Image = ex["url"].convert("RGB")

            # "question" と "category" からそれぞれ質問文と回答文を抽出
            question_text: str = ex["question"]
            answer_text: str = ex["category"]

            # ユーザー発話（画像と質問）とアシスタント発話（回答）の辞書を作成
            user_message = {
                "role": "user",
                "content": [
                    {"type": "image", "image": img},
                    {"type": "text", "text": question_text}
                ]
            }
            assistant_message = {
                "role": "assistant",
                "content": [{"type": "text", "text": answer_text}]
            }
            messages = [user_message, assistant_message]

            # processor を用いてチャット形式にフォーマット
            formatted: str = processor.apply_chat_template(
                messages,
                add_generation_prompt=False,
                tokenize=False
            )
            texts.append(formatted.strip())
            images.append([img])

        # processor によるテキストと画像の同時トークナイズ・前処理
        batch: dict[str, torch.Tensor] = processor(
            text=texts, images=images, return_tensors="pt", padding=True
        )
        input_ids: torch.Tensor = batch["input_ids"]
        labels: torch.Tensor = input_ids.clone()

        # パディングトークンを損失計算対象外に（-100に）変更
        pad_id: int = processor.tokenizer.pad_token_id
        labels[labels == pad_id] = -100
        batch["labels"] = labels

        return batch

    return collate_fn

def valid_image(example):
    """
    指定された URL から画像をダウンロードして、PIL で開けるかどうかチェックする。
    問題がなければ True、エラーが発生すれば False を返す。
    """
    url = example["url"]
    try:
        response = requests.get(url, timeout=0.5)
        response.raise_for_status()
        # 画像の読み込みを試みる
        with BytesIO(response.content) as img_buffer:
            Image.open(img_buffer).verify()  # verify() で内部チェックを実施
        return True
    except Exception as e:
        print(f"Invalid image URL skipped: {url} ; Error: {e}")
        return False

In [7]:
from google.colab import userdata
os.environ["HUGGING_FACE_TOKEN"] = userdata.get("HUGGING_FACE_TOKEN")
os.environ["WANDB_API_KEY"] = userdata.get("WANDB_API_KEY")

from huggingface_hub import login
login(token=os.environ["HUGGING_FACE_TOKEN"])

In [9]:
  if torch.cuda.is_available():
      major, _ = torch.cuda.get_device_capability()
      if major < 8:
          raise EnvironmentError("GPU は bfloat16 をサポートする必要があります")
      torch_dtype = torch.bfloat16
      device_map = "auto"
  else:
      torch_dtype = torch.float32
      device_map = "cpu"

  # QLoRA 用の 4bit 量子化設定（NF4）
  bnb_config = BitsAndBytesConfig(
      load_in_4bit=True,
      bnb_4bit_use_double_quant=True,
      bnb_4bit_quant_type="nf4",
      bnb_4bit_compute_dtype=torch_dtype,
      bnb_4bit_quant_storage=torch_dtype
  )

In [10]:
WANDB_API_KEY = os.getenv("WANDB_API_KEY")
if not WANDB_API_KEY:
    print(
        'It appears that the WANDB_API_KEY environment variable is not set. Please ensure you have exported your API key correctly in your .envrc file (using export WANDB_API_KEY="...") or other environment configuration.'
    )
else:
    try:
        # 必要に応じてログインを実行
        wandb.login(key=WANDB_API_KEY)
        wandb.init(project="Gemma3_JICVQA_finetune", name="Gemma3-4B-JICVQA-QLoRA")
    except UsageError as ue:
        # API キーの不備や使い方に起因するエラーの場合
        print(
            "We encountered a usage error with Weights & Biases. Please double-check that your API key is correct and that you are using the library properly. Detailed error info:",
            ue,
        )
    except CommError as ce:
        # 通信エラーの場合
        print(
            "A communication error occurred when connecting to Weights & Biases. Please check your internet connection or firewall settings. Detailed error info:",
            ce,
        )

    except Exception as e:
        # 他の想定外のエラーが発生した場合
        print(
            "An unexpected error occurred. Please review the error details and try again. Detailed error info:",
            e,
        )

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mnnnnnn-4649-aki[0m ([33mnnnnnn-4649-aki-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
  ####### load JIC-VQA #######

  dataset = load_dataset("line-corporation/JIC-VQA")
  # train 分割の中で、train / validation に分割
  split_dataset = dataset["train"].train_test_split(test_size=0.2, seed=42)
  train_dataset = split_dataset["train"]
  val_dataset = split_dataset["test"]

  # 各サンプルの "url" フィールドに対して、画像が取得可能かチェックしてフィルタリング
  train_dataset = train_dataset.filter(valid_image)
  val_dataset = val_dataset.filter(valid_image)

  # キャストすることで、"url" フィールドが自動的に PIL.Image オブジェクトになるようにする
  train_dataset = train_dataset.cast_column("url", datasets.Image())
  val_dataset = val_dataset.cast_column("url", datasets.Image())

  print("Train dataset:", train_dataset)
  print("Validation dataset:", val_dataset)

  # Gemma 用のプロセッサの初期化（SigLIP ビジョンエンコーダも含む）
  model_id = "google/gemma-3-4b-it"
  processor = AutoProcessor.from_pretrained(model_id)
  collate_fn = create_collate_fn(processor)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

jafacility20.csv:   0%|          | 0.00/291k [00:00<?, ?B/s]

jaflower30.csv:   0%|          | 0.00/664k [00:00<?, ?B/s]

jafood101.csv:   0%|          | 0.00/1.41M [00:00<?, ?B/s]

jalandmark10.csv:   0%|          | 0.00/324k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7654 [00:00<?, ? examples/s]

Filter:   0%|          | 0/6123 [00:00<?, ? examples/s]

Invalid image URL skipped: https://live.staticflickr.com/3825/19838612021_6402477b43_o.jpg ; Error: 404 Client Error: Not Found for url: https://live.staticflickr.com/3825/19838612021_6402477b43_o.jpg
Invalid image URL skipped: https://live.staticflickr.com/3428/3397268105_8725943f11_o.jpg ; Error: HTTPSConnectionPool(host='live.staticflickr.com', port=443): Read timed out. (read timeout=0.5)
Invalid image URL skipped: https://live.staticflickr.com/3536/3699434695_7a89d756a8_o.jpg ; Error: HTTPSConnectionPool(host='live.staticflickr.com', port=443): Read timed out. (read timeout=0.5)
Invalid image URL skipped: https://live.staticflickr.com/2329/2350357064_1f29971fea_o.jpg ; Error: HTTPSConnectionPool(host='live.staticflickr.com', port=443): Read timed out. (read timeout=0.5)
Invalid image URL skipped: https://live.staticflickr.com/4036/4549508374_43c743b190_o.jpg ; Error: HTTPSConnectionPool(host='live.staticflickr.com', port=443): Read timed out. (read timeout=0.5)
Invalid image URL s

Filter:   0%|          | 0/1531 [00:00<?, ? examples/s]

Invalid image URL skipped: https://live.staticflickr.com/676/20915075553_1e002e160c_o.jpg ; Error: HTTPSConnectionPool(host='live.staticflickr.com', port=443): Read timed out. (read timeout=0.5)
Invalid image URL skipped: https://live.staticflickr.com/3357/3328003217_c6d7325639_o.jpg ; Error: HTTPSConnectionPool(host='live.staticflickr.com', port=443): Read timed out. (read timeout=0.5)
Invalid image URL skipped: https://live.staticflickr.com/2939/13997906170_a45446b1f0_o.jpg ; Error: HTTPSConnectionPool(host='live.staticflickr.com', port=443): Read timed out. (read timeout=0.5)
Invalid image URL skipped: https://live.staticflickr.com/4255/34836501963_8879867ca2_o.jpg ; Error: HTTPSConnectionPool(host='live.staticflickr.com', port=443): Read timed out. (read timeout=0.5)
Invalid image URL skipped: https://live.staticflickr.com/4526/37605324924_b1c45f8408_o.jpg ; Error: 404 Client Error: Not Found for url: https://live.staticflickr.com/4526/37605324924_b1c45f8408_o.jpg
Train dataset: Da

processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/1.61k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

In [12]:
  ######## model loading ########

  # 4bit 量子化設定付きで Gemma-3-4B-IT モデルをロード
  model = AutoModelForImageTextToText.from_pretrained(
      model_id,
      quantization_config=bnb_config,
      device_map=device_map,
      torch_dtype=torch_dtype,
      attn_implementation="flash_attention_2"
  )
  # メモリ節約のため勾配チェックポイントを有効にする
  model.gradient_checkpointing_enable()
  # LoRA 用の設定（例：r=16、lora_alpha=16、ドロップアウト率0.05）
  lora_config = LoraConfig(
      r=16,
      lora_alpha=16,
      target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
      lora_dropout=0.05,
      bias="none",
      task_type="CAUSAL_LM",
      modules_to_save=["embed_tokens", "lm_head"]
  )
  # PEFT の LoRA アダプタを注入
  model = get_peft_model(model, lora_config)
  model.print_trainable_parameters()  # 更新対象パラメータ数の確認

  ####### Training #######
  # 学習引数の設定
  training_args = TrainingArguments(
      output_dir="gemma3-jicvqa-checkpoint",
      overwrite_output_dir=True,
      num_train_epochs=1,
      per_device_train_batch_size=1,
      gradient_accumulation_steps=4,
      learning_rate=2e-4,
      warmup_ratio=0.03,
      lr_scheduler_type="constant",
      optim="adamw_torch_4bit",   # bitsandbytes 用の4bit最適化オプティマイザ
      bf16=True,
      logging_steps=20,
      save_strategy="epoch",
      save_total_limit=1,
      report_to=["wandb"],
      run_name="Gemma3_4B_JICVQA_QLoRA",
      gradient_checkpointing=True,
      dataloader_num_workers=4,
      remove_unused_columns=False
  )

  # Trainer の初期化（検証用データセットを指定）
  trainer = Trainer(
      model=model,
      args=training_args,
      train_dataset=train_dataset,
      eval_dataset=val_dataset,
      data_collator=collate_fn,
      tokenizer=processor.tokenizer
  )

  # ファインチューニングの実施
  trainer.train()

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

  trainer = Trainer(
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


trainable params: 1,375,293,440 || all params: 5,675,372,912 || trainable%: 24.2327


The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `flash_attention_2`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
20,23.0518
40,0.4181
60,0.2278
80,0.1932
100,0.1916
120,0.1625
140,0.1511
160,0.1427
180,0.1566
200,0.1442




TrainOutput(global_step=1524, training_loss=0.42407513159544763, metrics={'train_runtime': 5820.162, 'train_samples_per_second': 1.048, 'train_steps_per_second': 0.262, 'total_flos': 4.949474244485184e+16, 'train_loss': 0.42407513159544763, 'epoch': 0.9996720236142997})

In [14]:
  ####### store #######
  print("Saving the fine-tuned adapter to 'gemma3-jicvqa-adapter' directory...")
  trainer.model.save_pretrained("gemma3-jicvqa-adapter", safe_serialization=True)
  print("Adapter saved successfully in 'gemma3-jicvqa-adapter'.")

  ####### GPU メモリ解放  #######
  # ファインチューニング後、GPU上のTrainerやモデルを削除してメモリを解放する
  print("Releasing GPU resources...")
  del trainer
  del model
  torch.cuda.empty_cache()
  print("GPU memory has been successfully released.")

  ####### restore #######

  # ベースモデル（float16 で CPU 上にロード）を再ロードし、LoRA アダプタを統合する
  print("Reloading the base model on CPU with float16 precision...")
  base_model = AutoModelForImageTextToText.from_pretrained(
      model_id, device_map={"": "cpu"}, torch_dtype=torch.float16
  )
  print(
      "Loading the saved LoRA adapter into the base model from 'gemma3-jicvqa-adapter'..."
  )
  base_model = PeftModel.from_pretrained(base_model, "gemma3-jicvqa-adapter")
  # LoRA 重みをマージして、従来の Hugging Face モデル形式に変換
  print("Successfully loaded the LoRA adapter.")
  print(
      "Merging LoRA weights into the base model to create a standard Hugging Face model format..."
  )
  merged_model = base_model.merge_and_unload()
  print("Merging completed successfully.")

  # Hugging Face Hub にアップロード可能な形式で保存（プロセッサも一緒に保存）
  print("Saving the merged model and processor for evaluation...")
  save_dir = "gemma3-jicvqa-finetuned"
  merged_model.save_pretrained(
      save_dir, safe_serialization=True, max_shard_size="2GB"
  )
  processor.save_pretrained(save_dir)
  print(f"The merged model and processor have been saved to '{save_dir}'.")

Saving the fine-tuned adapter to 'gemma3-jicvqa-adapter' directory...


NameError: name 'trainer' is not defined

In [17]:
# Hugging Face Hub にアップロード可能な形式で保存（プロセッサも一緒に保存）
print("Saving the merged model and processor for evaluation...")
save_dir = "./drive/MyDrive/gemma3-jicvqa-finetuned"
merged_model.save_pretrained(
    save_dir, safe_serialization=True, max_shard_size="2GB"
)
processor.save_pretrained(save_dir)
print(f"The merged model and processor have been saved to '{save_dir}'.")

Saving the merged model and processor for evaluation...
The merged model and processor have been saved to './drive/MyDrive/gemma3-jicvqa-finetuned'.
