## EGH VLM

#### Extract features

In [1]:
import torch
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor

from egh_vlm.extract_feature import extract_features
from egh_vlm.utils import load_egh_dataset

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Qwen3VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen3-VL-2B-Instruct",
    dtype="auto",
    device_map=device
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-2B-Instruct")

In [3]:
dataset = load_egh_dataset("data/egh_vlm")

Successfully load the EHG dataset with: 10 samples.


In [4]:
res = {}

for data in  dataset:
    features = extract_features(
        data['query'],
        data['image_path'],
        data['answer'],
        model,
        processor,
        device
    )
    res[data['id']] = features

In [5]:
# Example
print("Shape of qa embedding:", res['001'][0].shape)
print("Shape of qa gradient:", res['001'][1].shape)
print("Shape of ia embedding:", res['001'][2].shape)
print("Shape of ia gradient:", res['001'][3].shape)

Shape of qa embedding: torch.Size([9, 2048])
Shape of qa gradient: torch.Size([9, 2048])
Shape of ia embedding: torch.Size([9, 2048])
Shape of ia gradient: torch.Size([9, 2048])


#### Training

In [None]:
import argparse
from egh_vlm.training import Trainer

parser = argparse.ArgumentParser()
parser.add_argument("--dataset_dir_path", type=str, default="data/egh_vlm")
parser.add_argument("--model_name", type=str, default="Qwen/Qwen3-VL-2B-Instruct")
parser.add_argument("--train_ratio", type=float, default=0.5)
parser.add_argument("--device", type=str, default="cuda")
args, _ = parser.parse_known_args()

trainer = Trainer(args)
trainer.run()