In [None]:
!pip install bitsandbytes peft trl

In [None]:
import os
os.environ['WANDB_DISABLED'] = 'true'


from datasets import load_dataset
import torch
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor, Qwen3VLProcessor, BitsAndBytesConfig

from peft import LoraConfig, get_peft_model
from trl import SFTConfig, SFTTrainer

import warnings
warnings.filterwarnings('ignore')

In [None]:
system_message = """You are a very advanced agent that is specialiezed on analyzing and interpreting images and text. Your task is to process images to understand if the content provided to you is safe or not for individuals. Please keep in mind the safety of others while categorizing images as safe or unsafe"""

def format_data(sample):
    return [
        {
            'role': 'system',
            'content': [{'type': 'text', 'text': system_message}],
        },
        {
            'role': 'user',
            'content': [
                {
                    'type': 'image',
                    'image': sample['image'],
                },
                {
                    'type': 'text',
                    'text': sample['text'],
                }
            ]
        },
        {
            'role': 'assistant',
            'content': [{'type': 'text', 'text': sample['safety_label']}]
        }
    ]

In [None]:
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()

login(user_secrets.get_secret("hf_key"))

dataset_id = "yiting/UnsafeBench"
train_dataset, test_dataset = load_dataset(dataset_id, split=['train[:1%]', 'test[:2%]'])

In [None]:
print(len(train_dataset))
print(len(test_dataset))

In [None]:
print(train_dataset)
print(test_dataset)

In [None]:
from IPython.display import display

print(train_dataset[0])
display(train_dataset[0]['image'])

In [None]:
print(test_dataset[0])
display(test_dataset[0]['image'])

In [None]:
train_data = [format_data(sample) for sample in train_dataset]
test_data = [format_data(sample) for sample in test_dataset]

In [None]:
assert len(train_data) == len(train_dataset)
assert len(test_data) == len(test_dataset)

In [None]:
print(train_data[0])
print('-' * 80)
print(train_data[0][0]['content'][0]['text'])
print('-' * 80)
print(train_data[0][1]['content'][1]['text'])
print('-' * 80)
display(train_data[0][1]['content'][0]['image'])

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'device: {device}')

MODEL_ID = "Qwen/Qwen3-VL-2B-Instruct"
EPOCHS = 1
BATCH_SIZE = 1
GRADIENT_CHECKPOINTING = True
USE_REENTRANT = False
OPTIM = 'paged_adamw_32bit'
LEARNING_RATE = 2e-5
LOGGING_STEPS = 100
EVAL_STEPS = 0
SAVE_STEPS = EVAL_STEPS
EVAL_STRATEGY = 'no'
SAVE_STRATEGY = 'steps'
METRIC_FOR_BEST_MODEL = 'eval_loss'
LOAD_BEST_MODEL_AT_END = False
MAX_GRAD_NORM = 1
WARMUP_STEPS = 0
GRADIENT_ACCUMULATION_STEPS=64
DATASET_KWARGS = {'skip_prepare_dataset': True}
REMOVE_UNUSED_COLUMNS = False
MAX_SEQ_LEN = 128
NUM_STEPS = (len(train_data) // BATCH_SIZE) / EPOCHS
print(f'NUM_STEPS: {NUM_STEPS}')


In [None]:
if device == 'cuda':
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type='nf4',
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    model = Qwen3VLForConditionalGeneration.from_pretrained(
        MODEL_ID,
        device_map='auto',
        quantization_config=bnb_config
    )
else:
    model = Qwen3VLForConditionalGeneration.from_pretrained(
        MODEL_ID
    )

processor = Qwen3VLProcessor.from_pretrained(MODEL_ID)
processor.tokenizer.padding_size = 'right'


In [None]:
test_data[0][:2]

In [None]:
def text_generator(sample_data):
    text = processor.apply_chat_template(
        sample_data[:2], tokenize=False, add_generation_prompt=True
    )

    image_inputs = sample_data[1]['content'][0]['image']
    inputs = processor(
        text=[text],
        images= image_inputs,
        return_tensors='pt'
    )
    inputs = inputs.to(device)

    generated_ids = model.generate(**inputs, max_new_tokens=MAX_SEQ_LEN)

    output_text = processor.batch_decode(
        generated_ids, skip_special_tokens=True
    )
    del inputs
    actual_answer = sample_data[2]['content'][0]['text']
    return output_text[0], actual_answer

#gen_answer, answer = text_generator(test_data[0])

In [None]:
#print(gen_answer)
#print(answer)

In [None]:
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=8,
    bias='none',
    target_modules=['q_proj', 'v_proj'],
    task_type='CAUSAL_LM'
)

print(f'before adding adapter parameters: {model.num_parameters()}')
peft_model = get_peft_model(model, peft_config)
peft_model.print_trainable_parameters()

In [None]:
collate_sample = [train_data[0], train_data[1]]

def collate_fn(examples):
    texts = [processor.apply_chat_template(example, tokenize=False) for example in examples]
    image_inputs = [example[1]['content'][0]['image'] for example in examples]

    batch = processor(
        text=texts, images=image_inputs, return_tensors='pt', padding=True
    )
    labels = batch['input_ids'].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100
    batch['labels'] = batch['input_ids']

    return batch

collated_data = collate_fn(collate_sample)
print(collated_data.keys())

In [None]:
training_args = SFTConfig(
    output_dir=".",
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_checkpointing=GRADIENT_CHECKPOINTING,
    learning_rate=LEARNING_RATE,
    logging_steps=LOGGING_STEPS,
    eval_steps=EVAL_STEPS,
    eval_strategy=EVAL_STRATEGY,
    save_strategy=SAVE_STRATEGY,
    save_steps=SAVE_STEPS,
    metric_for_best_model=METRIC_FOR_BEST_MODEL,
    load_best_model_at_end=LOAD_BEST_MODEL_AT_END,
    max_grad_norm=MAX_GRAD_NORM,
    warmup_steps=WARMUP_STEPS,
    dataset_kwargs=DATASET_KWARGS,
    max_length=MAX_SEQ_LEN,
    remove_unused_columns=REMOVE_UNUSED_COLUMNS,
    optim=OPTIM,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS
)

In [None]:
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=test_data,
    data_collator=collate_fn,
    peft_config=peft_config,
    processing_class=processor.tokenizer
)

In [None]:
small_eval_data = test_data[:10]

print('-' * 80)
print('initial evaluation stage')
metric = trainer.evaluate(eval_dataset=small_eval_data)
print(metric)
print('-' * 80)

print('training model')
trainer.train()
print('-' * 80)


In [None]:
trainer.save_model(training_args.output_dir)