In [None]:
import sys
sys.path.append('..')

import argparse
import os
from collections import Counter
import random
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor

from egh_vlm.extract_feature import batch_extract_features
from egh_vlm.hallucination_dataset import HallucinationDataset, load_features, split_stratified, hallucination_collate_fn
from egh_vlm.hallucination_detector import DetectorModule
from egh_vlm.training import train_detector, eval_detector
from egh_vlm.utils import ModelBundle, load_egh_dataset, load_phd_dataset

## Extract Features

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Qwen3VLForConditionalGeneration.from_pretrained(
    'Qwen/Qwen3-VL-2B-Instruct',
    dtype='auto',
    device_map=device
)
processor = AutoProcessor.from_pretrained(
    'Qwen/Qwen3-VL-2B-Instruct',
    max_pixels=1280 * 720)
model_bundle = ModelBundle(model, processor, device)

In [None]:
dataset = load_egh_dataset(folder_path = '../data/egh_vlm', file_name='egh_vlm.json', sample_size=5)

In [None]:
res = batch_extract_features(
    dataset, model_bundle, mask_mode=None
)
if len(res) > 0:
    print('Shape of embedding:', res.embs[0].shape)
    print('Shape of gradient:', res.grads[0].shape)

## Training

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', type=str, default='../data/phd/phd_sampled_qwen3_vl_2b_balanced.json')
parser.add_argument('--img_folder_path', type=str, default='../data/phd/images')
parser.add_argument('--features_file_path', type=str, default='../data/phd/features_question_only.pt')
parser.add_argument('--detector_file_path', type=str, default='../data/phd/detector_question_only.pt')
parser.add_argument('--model_name', type=str, default='Qwen/Qwen3-VL-2B-Instruct')
parser.add_argument('--train_ratio', type=float, default=0.7)
args = parser.parse_args('')
config = vars(args)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device used: {device}')
model = Qwen3VLForConditionalGeneration.from_pretrained(
    config['model_name'],
    dtype='auto',
    device_map=device
)
processor = AutoProcessor.from_pretrained(
    config['model_name'],
    max_pixels=1024 * 768)

model_bundle = ModelBundle(model, processor, device)

Device used: cuda


In [None]:
def get_balanced_raw_dataset(dataset):
    label_counts = Counter(item['label'] for item in dataset)
    print('Label distribution before balancing:\n', label_counts)

    min_count = min(label_counts.values())
    print(f'Max balanced count per label: {min_count}')

    balanced_dataset = []

    for label in label_counts:
        label_samples = [item for item in dataset if item['label'] == label]
        balanced_dataset.extend(random.sample(label_samples, min_count))
    print(f'Balanced dataset size: {len(balanced_dataset)}')
    print('Label distribution after balancing:\n', Counter(item['label'] for item in balanced_dataset))

def get_balanced_hallucination_dataset():
    indices_by_label = {label: [i for i, l in enumerate(dataset.labels) if l == label] 
                        for label in label_counts}
    label_counts = Counter(dataset.labels)
    min_count = min(label_counts.values())
    balanced_indices = []

    for label, indices in indices_by_label.items():
        sample_size = min(min_count, len(indices))
        balanced_indices.extend(random.sample(indices, sample_size))

    balanced_ids = [dataset.ids[i] for i in balanced_indices]
    balanced_embs = [dataset.embs[i] for i in balanced_indices]
    balanced_grads = [dataset.grads[i] for i in balanced_indices]
    balanced_labels = [dataset.labels[i] for i in balanced_indices]

    print(f'Balanced dataset: {len(balanced_indices)} samples ({min_count} per {len(label_counts)} classes)')
    return HallucinationDataset(balanced_ids, balanced_embs, balanced_grads, balanced_labels)

def get_features(processed_path, save_path, dataset_path, img_folder_path, model_bundle: ModelBundle, mask_mode=None, sample_size=None):
    dataset = load_phd_dataset(dataset_path, img_folder_path, sample_size=sample_size)

    if os.path.isfile(processed_path):
        processed_features = load_features(processed_path)
    else:
        processed_features = None

    print(f'Length: {len(processed_features.ids)}')
    features = batch_extract_features(dataset, model_bundle, processed_features, mask_mode, save_path)
    return features, features.embs[0].size(-1) if len(features) > 0 else 0
        

In [None]:
dataset, hidden_size = get_features(
    processed_path=args.features_file_path,
    save_path='../data/phd/features_question_only_test.pt',
    dataset_path=args.dataset_path,
    img_folder_path=args.img_folder_path,
    model_bundle=model_bundle,
    mask_mode='image',
)

print('Length of dataset:', len(dataset))
print('Hidden size:', hidden_size)

Successfully load the PhD dataset with: 3770 samples.
Length: 3770


Extract features::  11%|â–ˆ         | 399/3770 [01:13<07:06,  7.91it/s]

In [None]:
train_dataset, val_dataset = split_stratified(dataset, train_ratio=args.train_ratio, random_state=42)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=16,
    collate_fn=hallucination_collate_fn,
    shuffle=True,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=16,
    collate_fn=hallucination_collate_fn,
    shuffle=True,
)

In [None]:
detector = DetectorModule(hidden_size, hidden_size, 1, 0.2)
epoch = 20
loss_function = nn.BCELoss()
optim = torch.optim.Adam(detector.parameters(), lr=1e-4)
best_acc = [0.0, -1]
best_f1 = [0.0, -1]
best_pr_auc = [0.0, -1]

for i in range(epoch):
    total_loss = train_detector(detector, loss_function, optim, train_dataloader)
    print(f'Epoch [{i + 1}/{epoch}], Loss: {total_loss / 2000:.4f}')
    acc, f1, pr_auc = eval_detector(detector, val_dataloader)
    print(f'Epoch [{i + 1}/{epoch}], ACC: {acc:.4f}, F1: {f1:.4f}, PR-AUC:{pr_auc:.4f}')

    if acc > best_acc[0]:
        best_acc = [acc, i + 1]
    if f1 > best_f1[0]:
        best_f1 = [f1, i + 1]
    if pr_auc > best_pr_auc[0]:
        best_pr_auc = [pr_auc, i + 1]
    if total_loss < 1e-3:
        break

print(f'Eval ACC: {best_acc[0]:.4f} at epoch {best_acc[1]}')
print(f'Eval F1: {best_f1[0]:.4f} at epoch {best_f1[1]}')
print(f'Eval PR-AUC: {best_pr_auc[0]:.4f} at epoch {best_pr_auc[1]}')

torch.save(detector.state_dict(), args.detector_file_path)