In [None]:
import sys

from training.egh_vlm_hallusion_bench import hidden_size

sys.path.append('..')

import torch
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor

from egh_vlm.extract_feature import extract_features, batch_extract_features
from egh_vlm.utils import load_hallusion_bench_dataset

## Extract Features

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Qwen3VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen3-VL-2B-Instruct",
    dtype="auto",
    device_map=device
)
processor = AutoProcessor.from_pretrained(
    "Qwen/Qwen3-VL-2B-Instruct",
    min_pixels=256 * 28 * 28,
    max_pixels=1280 * 28 * 28)

In [None]:
dataset = load_hallusion_bench_dataset("../data/hallusion_bench", sample_size=2)

In [None]:
dataset

In [None]:
res = batch_extract_features(
    dataset, model, processor, device
)

In [None]:
res[0][0].shape

In [None]:
res[0][0].shape

In [None]:
res[0][0].shape

In [None]:
# Example
print("Shape of qa embedding:", res[0][0].shape)
print("Shape of qa gradient:", res[0][1].shape)
print("Shape of ia embedding:", res[0][2].shape)
print("Shape of ia gradient:", res[0][3].shape)

## Training

In [1]:
import sys
sys.path.append('..')

import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor

from egh_vlm.extract_feature import extract_features, batch_extract_features
from egh_vlm.hallucination_detector import DetectorModule
from egh_vlm.hallucination_dataset import hallucination_collate_fn
from egh_vlm.training import get_features, eval_detector, train
from egh_vlm.utils import load_hallusion_bench_dataset

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_dir_path", type=str, default="../data/hallusion_bench")
parser.add_argument("--features_file_path", type=str, default="../data/hallusion_bench/features.pt")
parser.add_argument("--detector_file_path", type=str, default="../data/hallusion_bench/detector.pt")
parser.add_argument("--model_name", type=str, default="Qwen/Qwen3-VL-2B-Instruct")
parser.add_argument("--train_ratio", type=float, default=0.7)
args, _ = parser.parse_known_args()

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Qwen3VLForConditionalGeneration.from_pretrained(
    args.model_name,
    dtype="auto",
    device_map="auto"
)
processor = AutoProcessor.from_pretrained(
    args.model_name,
    min_pixels=256 * 28 * 28,
    max_pixels=1280 * 28 * 28)

In [4]:
dataset, hidden_size = get_features( args.features_file_path,
    args.dataset_dir_path, args.detector_file_path, model, processor, device)
train_dataset = dataset
test_dataset = dataset
train_dataloader = DataLoader(
    train_dataset,
    batch_size=16,
    collate_fn=hallucination_collate_fn,
    shuffle=True,
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=16,
    collate_fn=hallucination_collate_fn,
    shuffle=True,
)

In [5]:
detector = DetectorModule(hidden_size, hidden_size, 1, 0.2)
epoch = 2
loss_function = nn.BCELoss()
optim = torch.optim.Adam(detector.parameters(), lr=1e-4)
best_acc = 0.0
best_f1 = 0.0

for i in range(epoch):
    total_loss = train(detector, loss_function, optim, train_dataloader)
    print(f'Epoch [{i + 1}/{epoch}], Loss: {total_loss / 2000:.4f}')
    acc, f1, pr_auc = eval_detector(detector, test_dataloader)
    print(f'Epoch [{i + 1}/{epoch}], ACC: {acc:.4f}, F1: {f1:.4f}, PR-AUC:{pr_auc:.4f}')

    if acc > best_acc:
        best_acc = acc
        print(f'Best ACC at epoch {epoch + 1}')
    if f1 > best_f1:
        best_f1 = f1
        print(f'Best F1 at epoch {epoch + 1}')

    if total_loss < 1e-5:
        break

print(f'Training Finished!')

Epoch [1/2], Loss: 0.0003
Epoch [1/2], ACC: 1.0000, F1: 1.0000, PR-AUC:1.0000
Best ACC at epoch 3
Best F1 at epoch 3
Epoch [2/2], Loss: 0.0002
Epoch [2/2], ACC: 1.0000, F1: 1.0000, PR-AUC:1.0000
Training Finished!


In [None]:
torch.save(model.state_dict(), args.detector_file_path)