In [1]:
import sys
sys.path.append("..")

import argparse
import os
from collections import Counter
import random
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor

from egh_vlm.extract_feature import batch_extract_features
from egh_vlm.hallucination_dataset import HallucinationDataset, load_features, split_stratified, hallucination_collate_fn
from egh_vlm.hallucination_detector import DetectorModule
from egh_vlm.training import train_detector, eval_detector
from egh_vlm.utils import ModelBundle, load_egh_dataset, load_phd_dataset

## Extract Features

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Qwen3VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen3-VL-2B-Instruct",
    dtype="auto",
    device_map=device
)
processor = AutoProcessor.from_pretrained(
    "Qwen/Qwen3-VL-2B-Instruct",
    max_pixels=1280 * 720)
model_bundle = ModelBundle(model, processor, device)

In [None]:
dataset = load_egh_dataset(folder_path = "../data/egh_vlm", file_name="egh_vlm.json", sample_size=5)

In [None]:
res = batch_extract_features(
    dataset, model_bundle, mask_mode=None
)
if len(res) > 0:
    print("Shape of embedding:", res.embs[0].shape)
    print("Shape of gradient:", res.grads[0].shape)

## Training

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, default="../data/phd/phd_sampled_qwen3_vl_2b_full_balanced.json")
parser.add_argument("--img_folder_path", type=str, default="../data/phd/images")
parser.add_argument("--features_file_path", type=str, default="../data/phd/features.pt")
parser.add_argument("--detector_file_path", type=str, default="../data/phd/detector.pt")
parser.add_argument("--model_name", type=str, default="Qwen/Qwen3-VL-2B-Instruct")
parser.add_argument("--train_ratio", type=float, default=0.7)
args = parser.parse_args("")
config = vars(args)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device used: {device}")
model = Qwen3VLForConditionalGeneration.from_pretrained(
    config["model_name"],
    dtype="auto",
    device_map=device
)
processor = AutoProcessor.from_pretrained(
    config["model_name"],
    max_pixels=1024 * 768)

model_bundle = ModelBundle(model, processor, device)

Device used: cuda


In [4]:
def get_balanced_raw_dataset(dataset):
    label_counts = Counter(item["label"] for item in dataset)
    print("Label distribution before balancing:\n", label_counts)

    min_count = min(label_counts.values())
    print(f"Max balanced count per label: {min_count}")

    balanced_dataset = []

    for label in label_counts:
        label_samples = [item for item in dataset if item["label"] == label]
        balanced_dataset.extend(random.sample(label_samples, min_count))
    print(f"Balanced dataset size: {len(balanced_dataset)}")
    print("Label distribution after balancing:\n", Counter(item["label"] for item in balanced_dataset))

def get_balanced_hallucination_dataset():
    indices_by_label = {label: [i for i, l in enumerate(dataset.labels) if l == label] 
                        for label in label_counts}
    label_counts = Counter(dataset.labels)
    min_count = min(label_counts.values())
    balanced_indices = []

    for label, indices in indices_by_label.items():
        sample_size = min(min_count, len(indices))
        balanced_indices.extend(random.sample(indices, sample_size))

    balanced_ids = [dataset.ids[i] for i in balanced_indices]
    balanced_embs = [dataset.embs[i] for i in balanced_indices]
    balanced_grads = [dataset.grads[i] for i in balanced_indices]
    balanced_labels = [dataset.labels[i] for i in balanced_indices]

    print(f"Balanced dataset: {len(balanced_indices)} samples ({min_count} per {len(label_counts)} classes)")
    return HallucinationDataset(balanced_ids, balanced_embs, balanced_grads, balanced_labels)

def get_features(save_path, dataset_path, img_folder_path, model_bundle: ModelBundle, mask_mode=None, sample_size=None):
    dataset = load_phd_dataset(dataset_path, img_folder_path, sample_size=sample_size)

    if os.path.isfile(save_path):
        processed_features = load_features(save_path)
    else:
        processed_features = None

    features = batch_extract_features(dataset, model_bundle, processed_features, mask_mode, save_path)
    return features, features.embs[0].size(-1) if len(features) > 0 else 0
        

In [8]:
dataset, hidden_size = get_features(
    save_path=args.features_file_path,
    dataset_path=args.dataset_path,
    img_folder_path=args.img_folder_path,
    model_bundle=model_bundle,
    mask_mode=None,
    sample_size=40,
)

print("Length of dataset:", len(dataset))
print("Hidden size:", hidden_size)

Successfully load the PhD dataset with: 40 samples.


Extract features:: 100%|██████████| 40/40 [04:44<00:00,  7.12s/it]

Length of dataset: 45
Hidden size: 2048





In [9]:
train_dataset, val_dataset = split_stratified(dataset, train_ratio=args.train_ratio)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=16,
    collate_fn=hallucination_collate_fn,
    shuffle=True,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=16,
    collate_fn=hallucination_collate_fn,
    shuffle=True,
)

In [10]:
detector = DetectorModule(hidden_size, hidden_size, 1, 0.2)
epoch = 20
loss_function = nn.BCELoss()
optim = torch.optim.Adam(detector.parameters(), lr=1e-4)
best_acc = [0.0, -1]
best_f1 = [0.0, -1]
best_pr_auc = [0.0, -1]

for i in range(epoch):
    total_loss = train_detector(detector, loss_function, optim, train_dataloader)
    print(f"Epoch [{i + 1}/{epoch}], Loss: {total_loss / 2000:.4f}")
    acc, f1, pr_auc = eval_detector(detector, val_dataloader)
    print(f"Epoch [{i + 1}/{epoch}], ACC: {acc:.4f}, F1: {f1:.4f}, PR-AUC:{pr_auc:.4f}")

    if acc > best_acc[0]:
        best_acc = [acc, i + 1]
    if f1 > best_f1[0]:
        best_f1 = [f1, i + 1]
    if pr_auc > best_pr_auc[0]:
        best_pr_auc = [pr_auc, i + 1]
    if total_loss < 1e-3:
        break

print(f"Eval ACC: {best_acc[0]:.4f} at epoch {best_acc[1]}")
print(f"Eval F1: {best_f1[0]:.4f} at epoch {best_f1[1]}")
print(f"Eval PR-AUC: {best_pr_auc[0]:.4f} at epoch {best_pr_auc[1]}")

torch.save(detector.state_dict(), args.detector_file_path)

Epoch [1/20], Loss: 0.0008
Epoch [1/20], ACC: 0.8571, F1: 0.8000, PR-AUC:0.9048
Epoch [2/20], Loss: 0.0005
Epoch [2/20], ACC: 0.7857, F1: 0.8000, PR-AUC:0.8333
Epoch [3/20], Loss: 0.0005
Epoch [3/20], ACC: 0.8571, F1: 0.8000, PR-AUC:0.9048
Epoch [4/20], Loss: 0.0004
Epoch [4/20], ACC: 0.7143, F1: 0.5000, PR-AUC:0.8095
Epoch [5/20], Loss: 0.0003
Epoch [5/20], ACC: 0.7857, F1: 0.6667, PR-AUC:0.8571
Epoch [6/20], Loss: 0.0002
Epoch [6/20], ACC: 0.7857, F1: 0.7692, PR-AUC:0.8095
Epoch [7/20], Loss: 0.0002
Epoch [7/20], ACC: 0.8571, F1: 0.8333, PR-AUC:0.8690
Epoch [8/20], Loss: 0.0001
Epoch [8/20], ACC: 0.7143, F1: 0.5000, PR-AUC:0.8095
Epoch [9/20], Loss: 0.0001
Epoch [9/20], ACC: 0.7143, F1: 0.5000, PR-AUC:0.8095
Epoch [10/20], Loss: 0.0001
Epoch [10/20], ACC: 0.7143, F1: 0.5000, PR-AUC:0.8095
Epoch [11/20], Loss: 0.0000
Epoch [11/20], ACC: 0.7857, F1: 0.7273, PR-AUC:0.8048
Epoch [12/20], Loss: 0.0000
Epoch [12/20], ACC: 0.8571, F1: 0.8333, PR-AUC:0.8690
Epoch [13/20], Loss: 0.0000
Epoch 