In [1]:
import os
import json
import re

folder = r"D:\aircraft_images"
subfolder = ["02", "05", "06"]

In [2]:
data = []

# Iterate through each subfolder
for subf in subfolder:
    # Get all the files in each sub folder
    # subf/data/...
    files = os.listdir(os.path.join(folder, subf, "data"))
    # Sort the files
    # Each image and caption share the same ID
    files.sort()

    # Iterate through all the files in each sub folder
    for i, file in enumerate(files):
        if i % 2 == 0:
            # image file, process the image and txt file here
            # Get the file id
            file_id = file.split(".")[0]
            # Get the text file, increment i by 1 to get the next file in sequence which is the text file with the same id
            next_file = files[i+1]
            next_file_id = next_file.split(".")[0]

            # Check that file_id and next_file_id are the same
            if not file_id == next_file_id:
                raise Exception(f"{file} does not have the same id as {next_file}")

            # Get the caption
            with open(os.path.join(folder, subf, "data", next_file), "r", encoding='utf-8') as f:
                caption = f.read()

            # Split the caption by comma, caption has the format
            # <Year> <Airline> <Aircraft model>, <registration>, <description>
            split_caption = caption.split(",")

            # Some captions are incomplete, do not use those
            if len(split_caption) != 3:
                continue

            # Initialise the entry dict that does into the data list
            entry = {}
            entry.update({
                "id": file_id, # Add in the id
                "image": os.path.join(folder, subf, "data", file_id+".jpg"), # Add in the image path
                "conversations": [
                    {
                        "from": "human",
                        "value": "<image>\nWhat is model of the plane in this image?"
                    },
                    {
                        "from": "gpt",
                        "value": split_caption[0].strip()
                    },
                    {
                        "from": "human",
                        "value": "What is the registration number of the plane?"
                    },
                    {
                        "from": "gpt",
                        "value": split_caption[1].strip()
                    },
                    {
                        "from": "human",
                        "value": "Describe what the plane is doing."
                    },
                    {
                        "from": "gpt",
                        "value": split_caption[2].strip()
                    },
                ]
            })
            data.append(entry)

        else:
            # txt file, skip to get to the next image file
            continue
    break       

In [3]:
data

[{'id': '1186GMrpjP0eV7NYeK0',
  'image': 'D:\\aircraft_images\\02\\data\\1186GMrpjP0eV7NYeK0.jpg',
  'conversations': [{'from': 'human',
    'value': '<image>\nWhat is model of the plane in this image?'},
   {'from': 'gpt', 'value': '2009 LATAM Airlines Brasil Airbus A319-132'},
   {'from': 'human', 'value': 'What is the registration number of the plane?'},
   {'from': 'gpt', 'value': 'registration PT-TMA'},
   {'from': 'human', 'value': 'Describe what the plane is doing.'},
   {'from': 'gpt',
    'value': 'is parked at the gate with catering truck and ground crew'}]},
 {'id': '1186KWAYw9AJ45ApNya',
  'image': 'D:\\aircraft_images\\02\\data\\1186KWAYw9AJ45ApNya.jpg',
  'conversations': [{'from': 'human',
    'value': '<image>\nWhat is model of the plane in this image?'},
   {'from': 'gpt', 'value': '1993 United Airlines Boeing 767-322ER(WL)'},
   {'from': 'human', 'value': 'What is the registration number of the plane?'},
   {'from': 'gpt', 'value': 'registration N658UA'},
   {'from':

In [4]:
import torch
from transformers import AutoTokenizer, AutoProcessor, TrainingArguments, LlavaForConditionalGeneration
from trl import SFTTrainer
from peft import LoraConfig

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained(model_id,
                                                      torch_dtype=torch.float16)

`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 27.99it/s]


In [6]:
LLAVA_CHAT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for conversation in conversations %}{% if conversation['from'] == 'human' %}USER: {% else %}ASSISTANT: {% endif %}{% if '<image>' in conversation['value'] %}{{ conversation['value'] }}{% else %}{{ conversation['value'] }}{% endif %}{% if conversation['from'] == 'human' %} {% else %}{{eos_token}}{% endif %}{% endfor %}"""

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer = tokenizer

Fetching 2 files: 100%|██████████████████████████████████████████████████████████████████████████| 2/2 [00:00<?, ?it/s]
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [8]:
from PIL import Image

class LLavaDataCollator:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, examples):
        texts = []
        images_batch = []

        for example in examples:
            messages = []
            example_images = []

            for conv in example["conversations"]:
                role = "user" if conv["from"] == "human" else "assistant"
                content_items = []

                for line in conv["value"].splitlines():
                    if line.strip() == "<image>":
                        content_items.append({"type": "image"})
                        # append the actual image
                        img = Image.open(example["image"]).convert("RGB")
                        example_images.append(img)
                    elif line.strip():
                        content_items.append({"type": "text", "text": line.strip()})

                messages.append({"role": role, "content": content_items})

            text = self.processor.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=False
            )
            texts.append(text)
            images_batch.append(example_images)

        batch = self.processor(
            text=texts,
            images=images_batch,
            return_tensors="pt",
            padding=True
        )

        labels = batch["input_ids"].clone()
        if self.processor.tokenizer.pad_token_id is not None:
            labels[labels == self.processor.tokenizer.pad_token_id] = -100
        batch["labels"] = labels

        return batch


data_collator = LLavaDataCollator(processor)

In [9]:
training_args = TrainingArguments(
    output_dir="llava-hf/llava-1.5-7b-hf-fine_tuned-aircraft",
    report_to="tensorboard",
    learning_rate=1.4e-5,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    logging_steps=5,
    num_train_epochs=1,
    push_to_hub=False,
    gradient_checkpointing=True,
    remove_unused_columns=False,
    fp16=True,
    bf16=False
)

In [10]:
lora_config = LoraConfig(
    r=64,
    lora_alpha=16,
    target_modules="all-linear"
)

In [None]:
from huggingface_hub import login
login("")

In [12]:
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=data[:1200],
    eval_dataset=data[1200:],
    peft_config=lora_config,
    data_collator=data_collator,
)

Fetching 2 files: 100%|██████████████████████████████████████████████████████████████████████████| 2/2 [00:00<?, ?it/s]
  return t.to(


In [13]:
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 2, 'bos_token_id': 1}.


Step,Training Loss


KeyboardInterrupt: 