In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Dataset
from sklearn.metrics import f1_score, accuracy_score, precision_recall_curve, auc
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor

import argparse
import json
from tqdm import tqdm
import gc

In [None]:
!unzip

In [None]:
def get_mean(input_list):
    temp = [torch.mean(x, dim=0).squeeze(0) for x in input_list]
    return torch.stack(temp).to(temp+[0].device)

class DetectorModule(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_dim,
        output_dim,
        w_qa_embedding=0.2,
        w_qa_gradient=0.2,
        w_ia_embedding=0.2,
        w_ia_gradient=0.2,
    ):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

        self.w_qa_embedding = w_qa_embedding
        self.w_qa_gradient = w_qa_gradient
        self.w_ia_embedding = w_ia_embedding
        self.w_ia_gradient = w_ia_gradient

    def forward(self, qa_embedding, qa_gradient, ia_embedding, ia_gradient):
        qa_embedding = get_mean(qa_embedding)
        qa_gradient = get_mean(qa_gradient)
        ia_embedding = get_mean(ia_embedding)
        ia_gradient = get_mean(ia_gradient)

        x = (
            self.w_qa_embedding * qa_embedding
            + self.w_qa_gradient * qa_gradient
            + self.w_ia_embedding * ia_embedding
            + self.w_ia_gradient * ia_gradient
        )
        x1 = F.relu(self.fc1(x))
        x2 = F.relu(self.fc2(x1))
        x3 = self.fc3(x2)
        return torch.sigmoid(x3)


In [None]:
def extract_features(
        query: str, image_path: str, answer: str, model, processor, device
):
    #region Get inputs
    # image + answer
    i_messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image_path},
            ],
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": answer}],
        }
    ]
    i_inputs = processor.apply_chat_template(
        i_messages,
        tokenize=True,
        add_generation_prompt=False,
        return_dict=True,
        return_tensors="pt"
    )
    i_inputs = {k: v.to(device) for k, v in i_inputs.items()}

    # query + answer
    q_messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": query},
            ],
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": answer}],
        }
    ]
    q_inputs = processor.apply_chat_template(
        q_messages,
        tokenize=True,
        add_generation_prompt=False,
        return_dict=True,
        return_tensors="pt"
    )
    q_inputs = {k: v.to(device) for k, v in q_inputs.items()}


    # answer
    a_messages = [
        {
            "role": "assistant",
            "content": [{"type": "text", "text": answer}],
        }
    ]
    a_inputs = processor.apply_chat_template(
        a_messages,
        tokenize=True,
        add_generation_prompt=False,
        return_dict=True,
        return_tensors="pt"
    )
    a_inputs = {k: v.to(device) for k, v in a_inputs.items()}
    # endregion

    with torch.set_grad_enabled(True):
        model.eval()

        q_output = model(**q_inputs, output_hidden_states=True)
        i_output = model(**i_inputs, output_hidden_states=True)
        a_output = model(**a_inputs, output_hidden_states=True)

        q_length = q_inputs["input_ids"].shape[1]
        i_length = i_inputs["input_ids"].shape[1]
        a_length = a_inputs["input_ids"].shape[1]

        # Extract answer probs (slice after context)
        q_prob = q_output.logits[0, q_length - (a_length - 1):, :]
        i_prob = i_output.logits[0, i_length - (a_length - 1):, :]
        a_prob = a_output.logits[0, 1:, :]

        # Extract last hidden states (embeddings)
        q_vector = q_output.hidden_states[-1]
        i_vector = i_output.hidden_states[-1]
        a_vector = a_output.hidden_states[-1]

        a_embedding = a_vector[0, 1:, :]

        # Question+answer embedding & gradient
        qa_kl_divergence = torch.sum(
            a_prob.softmax(dim=-1) * (a_prob.softmax(dim=-1).log() - torch.log_softmax(q_prob, dim=-1))
        )
        qa_gradient = torch.autograd.grad(
            outputs=qa_kl_divergence, inputs=a_vector, create_graph=False, retain_graph=True,
        )[0][0, 1:, :]

        q_embedding = q_vector[0, q_length - (a_length - 1):, :]
        qa_embedding = q_embedding - a_embedding

        # Image+answer embedding & gradient
        ia_kl_divergence = torch.sum(
            a_prob.softmax(dim=-1) * (a_prob.softmax(dim=-1).log() - torch.log_softmax(i_prob, dim=-1))
        )
        ia_gradient = torch.autograd.grad(
            outputs=ia_kl_divergence, inputs=a_vector, create_graph=False, retain_graph=True,
        )[0][0, 1:, :]

        i_embedding = i_vector[0, i_length - (a_length - 1):, :]
        ia_embedding = i_embedding - a_embedding
    return (
        qa_embedding.detach().float().to("cpu"),
        qa_gradient.detach().float().to("cpu"),
        ia_embedding.detach().float().to("cpu"),
        ia_gradient.detach().float().to("cpu"))

def batch_extract_features(data_list, model, processor, device):
    qa_embeddings = []
    qa_gradients = []
    ia_embeddings = []
    ia_gradients = []

    for data in tqdm(data_list, desc="Extract features:"):
        qa_embedding, qa_gradient, ia_embedding, ia_gradient = extract_features(
            query = data['query'],
            image_path= data['image_path'],
            answer = data['answer'],
            model=model,
            processor=processor,
            device=device
        )
        qa_embeddings.append(qa_embedding)
        qa_gradients.append(qa_gradient)
        ia_embeddings.append(ia_embedding)
        ia_gradients.append(ia_gradient)

        # Clear CUDA cache periodically
        if len(qa_embeddings) % 10 == 0:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    return qa_embeddings, qa_gradients, ia_embeddings, ia_gradients

In [None]:
class HallucinationDataset(Dataset):
    def __init__(self, qa_embeddings, qa_gradients, ia_embeddings, ia_gradients, labels):
        self.qa_embeddings = qa_embeddings
        self.qa_gradients = qa_gradients
        self.ia_embeddings = ia_embeddings
        self.ia_gradients = ia_gradients
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.qa_embeddings[idx], self.qa_gradients[idx], self.ia_embeddings[idx], self.ia_gradients[idx], self.labels[idx]


def hallucination_collate_fn(batch):
    qa_embeddings = []
    qa_gradients = []
    ia_embeddings = []
    ia_gradients = []
    labels = []

    for sample in batch:
        qa_embeddings.append(sample[0])
        qa_gradients.append(sample[1])
        ia_embeddings.append(sample[2])
        ia_gradients.append(sample[3])
        labels.append(sample[4])
    return qa_embeddings, qa_gradients, ia_embeddings, ia_gradients, torch.tensor(labels)

def load_egh_dataset(folder_path) -> list:
    dataset_path = folder_path + '/dataset.json'
    images_path = folder_path + '/images/'
    dataset = []

    with open(dataset_path, 'r') as f:
        data_list = json.load(f)
    for data in data_list:
        dataset.append({
            "id": data['id'],
            "image_path": images_path + data['image_id'],
            "query": data['query'],
            "answer": data['answer'],
            "label": data['label']
        })
    print(f"Successfully load the EHG dataset with: {len(dataset)} samples.")
    return dataset

In [None]:
def train(detector, loss_function, optimizer, data_loader):
    total_loss = 0

    for _, batch in enumerate(data_loader):
        optimizer.zero_grad()
        qa_embedding, qa_gradient, ia_embedding, ia_gradient, label = batch
        label = label.float()
        output = detector(qa_embedding, qa_gradient, ia_embedding, ia_gradient).squeeze()
        loss = loss_function(output, label)
        loss.backward()
        optimizer.step()
        total_loss += loss
    return total_loss

def eval_detector(detector, data_loader):
    total_label, total_pred, total_out = [], [], []

    with torch.no_grad():
        for _, batch in enumerate(data_loader):
            qa_embedding, qa_gradient, ia_embedding, ia_gradient, label = batch

            output = detector(qa_embedding, qa_gradient, ia_embedding, ia_gradient).squeeze()
            total_out += output.tolist()
            total_label += label.tolist()
            pred = list(map(lambda x: round(x), output.tolist()))
            total_pred += pred
        f1 = f1_score(total_label, total_pred)
        acc = accuracy_score(total_label, total_pred)
        precision, recall, cm = precision_recall_curve(total_label, total_pred)
        pr_auc = auc(recall, precision)
    return acc, f1, pr_auc

def load_dataset(dir_path, model, processor, device):
    data_list = load_egh_dataset(dir_path)
    qa_embedding, qa_gradient, ia_embedding, ia_gradient = (
        batch_extract_features(data_list, model, processor, device))
    labels = [data['label'] for data in data_list]
    hidden_size = qa_embedding[0].size(-1)
    return HallucinationDataset(
        qa_embedding,
        qa_gradient,
        ia_embedding,
        ia_gradient,
        labels,
    ), hidden_size

#### Training

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_dir_path", type=str, default="data/egh_vlm")
parser.add_argument("--model_name", type=str, default="Qwen/Qwen3-VL-2B-Instruct")
parser.add_argument("--train_ratio", type=float, default=0.5)
args, _ = parser.parse_known_args()

# VLM model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Qwen3VLForConditionalGeneration.from_pretrained(
    args.model_name,
    dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto"
)
processor = AutoProcessor.from_pretrained(args.model_name)

In [None]:
# Dataset
dataset, hidden_size = load_dataset(args.dataset_dir_path, model, processor, device)
train_dataset = dataset
test_dataset = dataset
train_dataloader = DataLoader(
    train_dataset,
    batch_size=16,
    collate_fn=hallucination_collate_fn,
    shuffle=True,
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=16,
    collate_fn=hallucination_collate_fn,
    shuffle=True,
)

In [None]:
# Training
detector = DetectorModule(hidden_size, hidden_size, 1)
epoch = 2
loss_function = nn.BCELoss()
optim = torch.optim.Adam(detector.parameters(), lr=1e-4)
best_acc = 0.0
best_f1 = 0.0

for i in range(epoch):
    total_loss = train(loss_function, optim, train_dataloader)
    print(f'Epoch [{i + 1}/{epoch}], Loss: {total_loss / 2000:.4f}')
    acc, f1, pr_auc = eval_detector(detector, test_dataloader)
    print(f'Epoch [{i + 1}/{epoch}], ACC: {acc:.4f}, F1: {f1:.4f}, PR-AUC:{pr_auc:.4f}')

    if acc > best_acc:
        best_acc = acc
        print(f'Best ACC at epoch {epoch + 1}')
    if f1 > best_f1:
        best_f1 = f1
        print(f'Best F1 at epoch {epoch + 1}')

    if total_loss < 1e-5:
        break

print(f'Training Finished!')