In [1]:
import sys
sys.path.append('..')

import argparse
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor

from egh_vlm.extract_feature import batch_extract_features
from egh_vlm.hallucination_dataset import load_features, save_features, split_stratified, hallucination_collate_fn, paired_hallucination_collate_fn
from egh_vlm.hallucination_detector import DetectorModule, PairedDetectorModule
from egh_vlm.training import train_detector, train_paired_detector, eval_detector, eval_paired_detector
from egh_vlm.utils import ModelBundle, load_egh_dataset, load_hallusion_bench_dataset

## Extract Features

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Qwen3VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen3-VL-2B-Instruct",
    dtype="auto",
    device_map=device
)
processor = AutoProcessor.from_pretrained(
    "Qwen/Qwen3-VL-2B-Instruct",
    max_pixels=1280 * 720)
model_bundle = ModelBundle(model, processor, device)

In [3]:
dataset = load_egh_dataset(dir_path = "../data/egh_vlm", sample_size=10)

Successfully load the EHG dataset with: 10 samples.


In [4]:
res = batch_extract_features(
    dataset, model_bundle, mask_mode=None
)
if len(res) > 0:
    print("Shape of embedding:", res.embs[0].shape)
    print("Shape of gradient:", res.grads[0].shape)

Extract features:: 100%|██████████| 10/10 [02:43<00:00, 16.32s/it]

Shape of embedding: torch.Size([9, 2048])
Shape of gradient: torch.Size([9, 2048])





## Training

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_dir_path", type=str, default="../data/hallusion_bench")
parser.add_argument("--features_path", type=str, default="../data/hallusion_bench/features.pt")
parser.add_argument("--detector_path", type=str, default="../data/hallusion_bench/detector.pt")
parser.add_argument("--model_name", type=str, default="Qwen/Qwen3-VL-2B-Instruct")
parser.add_argument("--train_ratio", type=float, default=0.7)
args = parser.parse_args('')
config = vars(args)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Qwen3VLForConditionalGeneration.from_pretrained(
    config['model_name'],
    dtype="auto",
    device_map=device
)
processor = AutoProcessor.from_pretrained(
    config['model_name'],
    max_pixels=1280 * 720)
model_bundle = ModelBundle(model, processor, device)

In [None]:
def get_features(save_path, dataset_dir, model_bundle: ModelBundle, mask_mode=None, sample_size=None):
    if os.path.isfile(save_path):
        features = load_features(save_path)
        return features, features.embs[0].size(-1) if len(features) > 0 else 0
    else:
        dataset = load_hallusion_bench_dataset(dataset_dir, sample_size=sample_size)
        features = batch_extract_features(dataset, model_bundle, mask_mode, save_path)
        save_features(features, save_path)
        return features, features.embs[0].size(-1) if len(features) > 0 else 0

In [None]:
dataset, hidden_size = get_features(
    save_path=args.features_path,
    dataset_dir=args.dataset_dir_path,
    model_bundle=model_bundle,
    mask_mode=None,
    sample_size=5
)

train_dataset, val_dataset = split_stratified(dataset, train_ratio=args.train_ratio)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=16,
    collate_fn=hallucination_collate_fn,
    shuffle=True,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=16,
    collate_fn=hallucination_collate_fn,
    shuffle=True,
)

In [None]:
detector = DetectorModule(hidden_size, hidden_size, 1, 0.2)
epoch = 20
loss_function = nn.BCELoss()
optim = torch.optim.Adam(detector.parameters(), lr=1e-4)
best_acc = 0.0
best_f1 = 0.0

for i in range(epoch):
    total_loss = train_detector(detector, loss_function, optim, train_dataloader)
    print(f'Epoch [{i + 1}/{epoch}], Loss: {total_loss / 2000:.4f}')
    acc, f1, pr_auc = eval_detector(detector, val_dataloader)
    print(f'Epoch [{i + 1}/{epoch}], ACC: {acc:.4f}, F1: {f1:.4f}, PR-AUC:{pr_auc:.4f}')

    if acc > best_acc:
        best_acc = acc
        print(f'Best ACC at epoch {i + 1}')
    if f1 > best_f1:
        best_f1 = f1
        print(f'Best F1 at epoch {i + 1}')
    if total_loss < 1e-3:
        break

print(f"Eval accuracy: {best_acc:.4f}, F1: {best_f1:.4f}")
print(f'Finished!')

torch.save(detector.state_dict(), args.detector_file_path)