In [1]:
import sys
sys.path.append('..')

import argparse
import logging
import os
import json
from collections import Counter
import random
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor

from egh_vlm.hallucination_dataset import HallucinationDataset, PairedHallucinationDataset, load_features, split_stratified, paired_hallucination_collate_fn
from egh_vlm.hallucination_detector import PairedDetectorModule
from egh_vlm.training import train_paired_detector, eval_paired_detector
from egh_vlm.utils import ModelBundle

## Training

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', type=str, default='../data/phd/prototype/phd_sample_qwen3_vl_2b_balanced.json')
parser.add_argument('--img_folder_path', type=str, default='../data/phd/images')
parser.add_argument('--features_question_only_path', type=str, default='../data/phd/prototype/features_question_only.pt')
parser.add_argument('--features_image_only_path', type=str, default='../data/phd/prototype/features_image_only.pt')
parser.add_argument('--features_full_path', type=str, default='../data/phd/prototype/features_full.pt')
parser.add_argument('--detector_path', type=str, default='../data/phd/prototype/detector_paired.pt')
parser.add_argument('--model_name', type=str, default='Qwen/Qwen3-VL-2B-Instruct')
parser.add_argument('--train_ratio', type=float, default=0.7)
args = parser.parse_args('')
config = vars(args)

In [3]:
def get_features(save_path, sample_size=None):
    if os.path.isfile(save_path):
        features = load_features(save_path)
        return features, features.embs[0].size(-1) if len(features) > 0 else 0
    else:
        return HallucinationDataset(), 0

features_question_only, hidden_size_question_only = get_features(
    save_path=args.features_question_only_path)
features_image_only, hidden_size_image_only = get_features(
    save_path=args.features_image_only_path)
features_full, hidden_size_full = get_features(
    save_path=args.features_full_path)

assert hidden_size_question_only == hidden_size_image_only, "Hidden sizes of question-only and image-only features must match"
hidden_size = hidden_size_question_only

print('Length of question only features:', len(features_question_only))
print('Length of image only features:', len(features_image_only))
print('Length of full features:', len(features_full))
print('Hidden size:', hidden_size)

Length of question only features: 3770
Length of image only features: 3774
Length of full features: 3774
Hidden size: 2048


In [4]:
features_paired = PairedHallucinationDataset()

for id in features_question_only.ids:
    feature_image_only = features_image_only.get_by_id(id)
    feature_question_only = features_question_only.get_by_id(id)
    assert feature_image_only[3] == feature_question_only[3]
    features_paired.add_item(
        id,
        [
            [feature_image_only[1], feature_image_only[2]],
            [feature_question_only[1], feature_question_only[2]],
        ],
        feature_image_only[3]
    )
print('Length of paired features:', len(features_paired))

Length of paired features: 3770


In [5]:
train_dataset, val_dataset = split_stratified(features_paired, train_ratio=args.train_ratio, random_state=42)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=16,
    collate_fn=paired_hallucination_collate_fn,
    shuffle=True,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=16,
    collate_fn=paired_hallucination_collate_fn,
    shuffle=True,
)

In [6]:
epoch = 30
lr = 1e-4

log_save_path = f'../data/logs/egh_vlm_paired_log.txt'
logging.basicConfig(filename=log_save_path, level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

result = {
    'train_ratio': args.train_ratio,
    'epoch': epoch,
    'lr': lr,
    'acc': {
        'value': 0.0,
        'epoch': -1
    },
    'f1': {
        'value': 0.0,
        'epoch': -1
    },
    'pr_auc': {
        'value': 0.0,
        'epoch': -1
    }
}

detector = PairedDetectorModule(hidden_size, hidden_size, 1)
loss_fn = nn.BCELoss()
optim = torch.optim.Adam(detector.parameters(), lr=lr)

for i in range(epoch):
    total_loss = train_paired_detector(detector, loss_fn, optim, train_dataloader)
    acc, f1, pr_auc = eval_paired_detector(detector, val_dataloader)
    
    logging.debug(f'Epoch [{i+1}/{epoch}], Loss: {total_loss:.4f}')
    logging.debug(f'Epoch [{i+1}/{epoch}], ACC: {acc:.4f}, F1: {f1:.4f}, PR-AUC:{pr_auc:.4f}\n')
    if acc > result['acc']['value']:
        result['acc']['value'] = acc
        result['acc']['epoch'] = i + 1
    if f1 > result['f1']['value']:
        result['f1']['value'] = f1
        result['f1']['epoch'] = i + 1
    if pr_auc > result['pr_auc']['value']:
        result['pr_auc']['value'] = pr_auc
        result['pr_auc']['epoch'] = i + 1
    if total_loss < 1e-4:
        break
logging.debug(f'Eval ACC: {result["acc"]["value"]:.4f} at epoch {result["acc"]["epoch"]}')
logging.debug(f'Eval F1: {result["f1"]["value"]:.4f} at epoch {result["f1"]["epoch"]}')
logging.debug(f'Eval PR-AUC: {result["pr_auc"]["value"]:.4f} at epoch {result["pr_auc"]["epoch"]}')

# Clean up logger handlers
logger = logging.getLogger()
for handler in logger.handlers[:]:
    handler.close()
    logger.removeHandler(handler)