In [None]:
from huggingface_hub import login

login(token="")

In [None]:
def format_data(sample):
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": sample["image"]},
                {"type": "text", "text": sample["query"]},
            ],
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": sample["label"]}],
        },
    ]
    return {
        "messages": messages,
        "images": [sample["image"]],
    }

In [None]:
import os
from PIL import Image
from datasets import Dataset, Features, Value, Image

prefix = 'real_clocks' # distorted  thin_hands standard
folder_name = 'Train' #Train Train_1000 Train_5000

png_folder_path = f'Data/{prefix}/{folder_name}'
def get_png_files(path):
    return [f for f in os.listdir(path) if f.endswith('.png')]

features = Features({
    "image": Image(decode=True),  # 这样在访问时返回 PIL.Image
    "query": Value("string"),
    "label": Value("string")
})

images_list = []
queries_list = []
labels_list = []
png_files = get_png_files(png_folder_path)
prompt = '''What time is shown on the clock in the given image?'''
for file_name in png_files:
    image_path = os.path.join(png_folder_path, file_name)
    # image = Image.open(image_path)
    # image = image.convert("RGB")
    images_list.append(image_path)
    queries_list.append(prompt)
    correct_answer = file_name.replace('.png','').replace('_',':')
    labels_list.append(f'The time shown on the clock is **{correct_answer}**.')

dataset = Dataset.from_dict({
    "image": images_list,
    "query": queries_list,
    "label": labels_list
}, features=features)

total_len = len(dataset)
train_size = int(total_len * 15 / 16)
split_dataset = dataset.train_test_split(
    train_size=train_size, shuffle=True, seed=42
)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]

In [None]:
train_dataset = [format_data(sample) for sample in train_dataset]
eval_dataset = [format_data(sample) for sample in eval_dataset]

In [None]:
import requests
import torch
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor

model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"

model = MllamaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_id)

In [None]:
def generate_text_from_sample(model, processor, sample, max_new_tokens=1024, device="cuda"):

    image = sample[0:1][0]['content'][0]['image']
    # print(prompt)

    messages = [
        {"role": "user", "content": [
            {"type": "image"},
            {"type": "text", "text": sample[0:1][0]['content'][1]['text']}
        ]}
    ]
    input_text = processor.apply_chat_template(sample[0:1], add_generation_prompt=True)
    inputs = processor(
        image,
        input_text,
        add_special_tokens=False,
        return_tensors="pt"
    ).to(model.device)
    
    output = model.generate(**inputs, max_new_tokens=500,do_sample = False,top_p=None,temperature=None)
    
    extracted_answer = processor.decode(output[0]).split('assistant<|end_header_id|>')[-1].strip().replace('<|eot_id|>','')

    return extracted_answer

In [None]:
train_dataset[0]

In [None]:
import gc
import time


def clear_memory():
    # Delete variables if they exist in the current global scope
    if "inputs" in globals():
        del globals()["inputs"]
    if "model" in globals():
        del globals()["model"]
    if "processor" in globals():
        del globals()["processor"]
    if "trainer" in globals():
        del globals()["trainer"]
    if "peft_model" in globals():
        del globals()["peft_model"]
    if "bnb_config" in globals():
        del globals()["bnb_config"]
    time.sleep(2)

    # Garbage collection and clearing CUDA memory
    gc.collect()
    time.sleep(2)
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    time.sleep(2)
    gc.collect()
    time.sleep(2)

    print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")


clear_memory()

In [None]:
#################################
# Fine-Tune the Model using TRL
#################################

In [None]:
from transformers import BitsAndBytesConfig

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)


model = MllamaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_id)

In [None]:
from peft import LoraConfig, get_peft_model

# Configure LoRA
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=8,
    bias="none",
    target_modules=["q_proj", "v_proj"],
    task_type="CAUSAL_LM",
)

# # Apply PEFT model adaptation
# peft_model = get_peft_model(model, peft_config)

# # Print trainable parameters
# peft_model.print_trainable_parameters()

In [None]:
from trl import SFTConfig

# Configure training arguments
training_args = SFTConfig(
    output_dir=f"Finetuned_models/{prefix}/LLaMa3.2_{folder_name}",  # Directory to save the model
    num_train_epochs=3,  # Number of training epochs
    per_device_train_batch_size=1,  # Batch size for training
    per_device_eval_batch_size=1,  # Batch size for evaluation
    gradient_accumulation_steps=8,  # Steps to accumulate gradients
    gradient_checkpointing=True,  # Enable gradient checkpointing for memory efficiency
    # Optimizer and scheduler settings
    optim="adamw_torch_fused",  # Optimizer type
    learning_rate=2e-4,  # Learning rate for training
    lr_scheduler_type="constant",  # Type of learning rate scheduler
    # Logging and evaluation
    logging_steps=10,  # Steps interval for logging
    eval_steps=80,  # Steps interval for evaluation
    eval_strategy="steps",  # Strategy for evaluation
    save_strategy="steps",  # Strategy for saving the model
    save_steps=160,  # Steps interval for saving
    metric_for_best_model="eval_loss",  # Metric to evaluate the best model
    greater_is_better=False,  # Whether higher metric values are better
    load_best_model_at_end=True,  # Load the best model after training
    # Mixed precision and gradient settings
    bf16=True,  # Use bfloat16 precision
    tf32=True,  # Use TensorFloat-32 precision
    max_grad_norm=0.3,  # Maximum norm for gradient clipping
    warmup_ratio=0.03,  # Ratio of total steps for warmup
    # Hub and reporting
    push_to_hub=False,  # Whether to push model to Hugging Face Hub
    report_to="wandb",  # Reporting tool for tracking metrics
    # Gradient checkpointing settings
    gradient_checkpointing_kwargs={"use_reentrant": False},  # Options for gradient checkpointing
    # Dataset configuration
    dataset_text_field="",  # Text field in dataset
    dataset_kwargs={"skip_prepare_dataset": True},  # Additional dataset options
    # max_seq_length=1024  # Maximum sequence length for input
)

training_args.remove_unused_columns = False  # Keep unused columns in dataset

In [None]:
import wandb

wandb.init(
    project=f"{prefix}_Gemma3-{folder_name}",  # change this
    name=f"{prefix}_Gemma3-{folder_name}",  # change this
    config=training_args,
)

In [None]:
from transformers import AutoModelForVision2Seq, AutoProcessor, LlavaForConditionalGeneration


def collate_fn(examples):
    # Get the texts and images, and apply the chat template
    texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
    images = [example["images"] for example in examples]
    if isinstance(model, LlavaForConditionalGeneration):
        # LLava1.5 does not support multiple images
        images = [image[0] for image in images]

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100  #
    # Ignore the image token index in the loss computation (model specific)
    image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
    labels[labels == image_token_id] = -100
    batch["labels"] = labels

    return batch

In [None]:
# train_dataset[0]

In [None]:
# collate_fn([train_dataset[0]])

In [None]:
from trl import SFTTrainer
# from unsloth.trainer import UnslothVisionDataCollator


trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=collate_fn,
    peft_config=peft_config,
    processing_class=processor.tokenizer,
)

In [None]:
trainer.train()

In [None]:
trainer.save_model(training_args.output_dir)

In [None]:
clear_memory()