In [1]:
import sys
sys.path.append('..')

import argparse
import os
from collections import Counter
import logging
import json
import random
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor

from egh_vlm.hallucination_dataset import HallucinationDataset, load_features, split_stratified, hallucination_collate_fn
from egh_vlm.hallucination_detector import DetectorModule
from egh_vlm.training import train_detector, eval_detector
from egh_vlm.utils import ModelBundle

## Training

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', type=str, default='../data/phd/prototype/phd_sample_qwen3_vl_2b_balanced.json')
parser.add_argument('--img_folder_path', type=str, default='../data/phd/images')
parser.add_argument('--features_file_path', type=str, default='../data/phd/prototype/features_image_only_new.pt')
parser.add_argument('--processed_features_file_path', type=str, default='../data/phd/prototype/features_image_only.pt')
parser.add_argument('--detector_file_path', type=str, default='../data/phd/prototype/detector_image_only.pt')
parser.add_argument('--model_name', type=str, default='Qwen/Qwen3-VL-2B-Instruct')
parser.add_argument('--train_ratio', type=float, default=0.7)
args = parser.parse_args('')
config = vars(args)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device used: {device}')
model = Qwen3VLForConditionalGeneration.from_pretrained(
    config['model_name'],
    dtype='auto',
    device_map=device
)
processor = AutoProcessor.from_pretrained(
    config['model_name'],
    max_pixels=1024 * 768)

model_bundle = ModelBundle(model, processor, device)

Device used: cuda


In [4]:
def get_balanced_raw_dataset(dataset):
    label_counts = Counter(item['label'] for item in dataset)
    print('Label distribution before balancing:\n', label_counts)

    min_count = min(label_counts.values())
    print(f'Max balanced count per label: {min_count}')

    balanced_dataset = []

    for label in label_counts:
        label_samples = [item for item in dataset if item['label'] == label]
        balanced_dataset.extend(random.sample(label_samples, min_count))
    print(f'Balanced dataset size: {len(balanced_dataset)}')
    print('Label distribution after balancing:\n', Counter(item['label'] for item in balanced_dataset))

def get_balanced_hallucination_dataset(dataset):
    indices_by_label = {label: [i for i, l in enumerate(dataset.labels) if l == label] 
                        for label in label_counts}
    label_counts = Counter(dataset.labels)
    min_count = min(label_counts.values())
    balanced_indices = []

    for label, indices in indices_by_label.items():
        sample_size = min(min_count, len(indices))
        balanced_indices.extend(random.sample(indices, sample_size))

    balanced_ids = [dataset.ids[i] for i in balanced_indices]
    balanced_embs = [dataset.embs[i] for i in balanced_indices]
    balanced_grads = [dataset.grads[i] for i in balanced_indices]
    balanced_labels = [dataset.labels[i] for i in balanced_indices]

    print(f'Balanced dataset: {len(balanced_indices)} samples ({min_count} per {len(label_counts)} classes)')
    return HallucinationDataset(balanced_ids, balanced_embs, balanced_grads, balanced_labels)

In [5]:
def get_features(processed_path, save_path, dataset_path, img_folder_path, model_bundle: ModelBundle, mask_mode=None, sample_size=None):
    if os.path.isfile(processed_path):
        processed_features = load_features(processed_path)
    else:
        processed_features = None
    print(f'Length of processed features: {len(processed_features.ids) if processed_features else 0}')

    # dataset = load_phd_dataset(dataset_path, img_folder_path, sample_size=sample_size)
    # features = batch_extract_features(dataset, model_bundle, processed_features, mask_mode, save_path)
    features = processed_features
    return features, features.embs[0].size(-1) if len(features) > 0 else 0

dataset, hidden_size = get_features(
    processed_path=args.processed_features_file_path,
    save_path=args.features_file_path,
    dataset_path=args.dataset_path,
    img_folder_path=args.img_folder_path,
    model_bundle=model_bundle,
    mask_mode=None,
)

print('Length of dataset:', len(dataset))
print('Hidden size:', hidden_size)

Length of processed features: 3774
Length of dataset: 3774
Hidden size: 2048


In [6]:
train_dataset, val_dataset = split_stratified(dataset, train_ratio=args.train_ratio, random_state=42)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=16,
    collate_fn=hallucination_collate_fn,
    shuffle=True,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=16,
    collate_fn=hallucination_collate_fn,
    shuffle=True,
)

In [7]:
def training_pipeline(weight, epoch, lr, save_path='log.txt'):
    logging.basicConfig(filename=save_path, level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

    result = {
        'w': weight,
        'train_ratio': args.train_ratio,
        'epoch': epoch,
        'lr': lr,
        'acc': {
            'value': 0.0,
            'epoch': -1
        },
        'f1': {
            'value': 0.0,
            'epoch': -1
        },
        'pr_auc': {
            'value': 0.0,
            'epoch': -1
        }
    }

    # Init
    detector = DetectorModule(hidden_size, hidden_size, 1, weight)
    loss_fn = nn.BCELoss()
    optim = torch.optim.Adam(detector.parameters(), lr=lr)
    
    logging.debug(f'Training w/ weight: {weight}')
    for i in range(epoch):
        total_loss = train_detector(detector, loss_fn, optim, train_dataloader)
        acc, f1, pr_auc = eval_detector(detector, val_dataloader)
        
        logging.debug(f'Epoch [{i+1}/{epoch}], Loss: {total_loss:.4f}')
        logging.debug(f'Epoch [{i + 1}/{epoch}], ACC: {acc:.4f}, F1: {f1:.4f}, PR-AUC:{pr_auc:.4f}\n')
        if acc > result['acc']['value']:
            result['acc']['value'] = acc
            result['acc']['epoch'] = i + 1
        if f1 > result['f1']['value']:
            result['f1']['value'] = f1
            result['f1']['epoch'] = i + 1
        if pr_auc > result['pr_auc']['value']:
            result['pr_auc']['value'] = pr_auc
            result['pr_auc']['epoch'] = i + 1
        if total_loss < 1e-4:
            break
    logging.debug(f'Eval ACC: {result["acc"]["value"]:.4f} at epoch {result["acc"]["epoch"]}')
    logging.debug(f'Eval F1: {result["f1"]["value"]:.4f} at epoch {result["f1"]["epoch"]}')
    logging.debug(f'Eval PR-AUC: {result["pr_auc"]["value"]:.4f} at epoch {result["pr_auc"]["epoch"]}')

    # Clean up logger handlers
    logger = logging.getLogger()
    for handler in logger.handlers[:]:
        handler.close()
        logger.removeHandler(handler)

    return result

In [8]:
weights = [round(0.1 * i, 2) for i in range(1, 10)]
weights_test = [0.3, 0.5, 0.7]
epoch = 30
lr = 1e-4

records = []

for weight in weights:
    result = training_pipeline(weight, epoch, lr, save_path=f'../data/logs/egh_vlm_image_only_{weight}_log.txt')
    records.append(result)

In [9]:
with open('../data/phd/egh_vlm_image_only_eval.json', 'w') as f:
    json.dump(records, f, indent=4)