In [1]:
import os
import json
import torch
import evaluate
import open_clip
import collections
import numpy as np
import torch.nn as nn
from PIL import Image
from tqdm.auto import tqdm
import torch.nn.functional as F
from collections import OrderedDict
from prototype import MultiThreadMemory
from datasets import Dataset, load_metric
from transformers import default_data_collator
from open_clip import tokenizer as clip_tokenizer
from torch.utils.data.dataloader import default_collate
from transformers import BertTokenizerFast, BertForQuestionAnswering, TrainingArguments, Trainer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import subprocess
result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

https://huggingface.co/docs/transformers/v4.41.2/en/model_doc/bert#transformers.BertForQuestionAnswering

In [3]:
tokenizer = BertTokenizerFast.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
qa_bert = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

Some weights of the model checkpoint at bert-large-uncased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
max_length = 384 # The maximum length of a feature (question and context)
doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.

In [5]:
tokenizer("What is your name?", "My name is Sylvain.")

{'input_ids': [101, 2054, 2003, 2115, 2171, 1029, 102, 2026, 2171, 2003, 25353, 22144, 2378, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [6]:
# 判断并选择设备
def select_device():
       if torch.cuda.is_available():
              device = torch.device("cuda")  # 优先使用CUDA（NVIDIA GPU）
              print("Using CUDA (GPU)")
       elif torch.backends.mps.is_available():
              device = torch.device("mps")  # 如果CUDA不可用但MPS可用，使用MPS（Apple Silicon）
              print("Using MPS (Apple Silicon)")
       else:
              device = torch.device("cpu")  # 如果都不可用，使用CPU
              print("Using CPU")
       return device

device = select_device()
# device = torch.device("cpu")

Using CUDA (GPU)


# 自定义数据集

In [7]:
class VQADataset:
    def __init__(self, key, questions, contexts, answers, img_names):
        questions = questions
        contexts = contexts
        answers = answers
        img_names = img_names
        key = key

    def __getitem__(self, idx):
        return {
            'key': key[idx],
            'question': questions[idx],
            'context': contexts[idx],
            'answers': answers[idx],
            'image_name': img_names[idx]
        }

    def __len__(self):
        return len(questions)
 
 
def load_datasets(data_dir, sub_folders, img_dir):
    questions = []
    contexts = []
    answers = []
    img_names = []
    keys = []
    for folder in sub_folders:
        json_file_name = 'all_qs_dict_release_train_500.json' if 'train' in folder else 'all_qs_dict_release_test_500.json'
        json_file_path = os.path.join(data_dir, folder, json_file_name)
        with open(json_file_path) as f:
            data = json.load(f)
            for key in data.keys():
                keys.append(key)
                questions.append(data[key]['question'])
                contexts.append(data[key]['fact_surface'].replace("[[", "").replace("]]", ""))
                # 处理答案格式
                answer_text = data[key]['answer']
                answer_start = contexts[-1].find(answer_text)  # 通过查找答案在上下文中的位置
                answers.append({
                    'answer_start': [answer_start if answer_start != -1 else 0],
                    'text': [answer_text]
                })
                img_names.append(os.path.join(img_dir, data[key]['img_file']))
    
    # 创建Hugging Face datasets对象
    dataset = Dataset.from_dict({
        'key': keys,
        'question': questions,
        'context': contexts,
        'answers': answers,
        'image_name': img_names
    })
    
    return dataset

In [8]:
project_root = os.getcwd()
train_data_dir = os.path.join(project_root, 'data/KG_VQA/fvqa/exp_data/train_seen_data')
test_data_dir = os.path.join(project_root, 'data/KG_VQA/fvqa/exp_data/test_unseen_data')
img_dir = os.path.join(project_root, "data/KG_VQA/fvqa/exp_data/images/images")
sub_folders_train = ['train0', 'train1', 'train2', 'train3', 'train4']
sub_folders_test = ['test0', 'test1', 'test2', 'test3', 'test4']

train_dataset = load_datasets(train_data_dir, sub_folders_train, img_dir)
# test_dataset_all = load_datasets(test_data_dir, sub_folders_test, img_dir)
validation_dataset = load_datasets(test_data_dir, sub_folders_test, img_dir) 
# split_datasets = test_dataset_all.train_test_split(test_size=0.2)  
# test_dataset = split_datasets['train']
# validation_dataset = split_datasets['test']

print(train_dataset[0])
print('训练集大小：', len(train_dataset))
# print('测试集大小：', len(test_dataset))
print('验证集大小：', len(validation_dataset))

{'key': '270', 'question': 'Which object can be found in a jazz club', 'context': 'You are likely to find a trumpet in a jazz club', 'answers': {'answer_start': [25], 'text': ['trumpet']}, 'image_name': '/root/autodl-tmp/vqa/VQA-with-XProNet/data/KG_VQA/fvqa/exp_data/images/images/ILSVRC2012_test_00050748.JPEG'}
训练集大小： 13662
验证集大小： 13798


# 数据处理

In [9]:
pad_on_right = tokenizer.padding_side == "right"

def prepare_train_features(examples):
    # Some of the questions have lots of whitespace on the left, which is not useful and will make the
    # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
    # left whitespace
    examples["question"] = [q.lstrip() for q in examples["question"]]

    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=256,
       #  stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

In [10]:
features = prepare_train_features(train_dataset[:5])
features

{'input_ids': [[101, 2029, 4874, 2064, 2022, 2179, 1999, 1037, 4166, 2252, 102, 2017, 2024, 3497, 2000, 2424, 1037, 9368, 1999, 1037, 4166, 2252, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 2029, 4874, 1999, 2023, 3746, 2038, 1037, 5725, 102, 1037, 20497, 2038, 1037, 5725, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [10]:
tokenized_train_dataset = train_dataset.map(prepare_train_features, batched=True, remove_columns=train_dataset.column_names)

Map: 100%|██████████| 13662/13662 [00:04<00:00, 3360.43 examples/s]


In [23]:
# tokenized_test_dataset = test_dataset.map(prepare_train_features, batched=True, remove_columns=test_dataset.column_names)

In [11]:
tokenized_val_dataset = validation_dataset.map(prepare_train_features, batched=True, remove_columns=validation_dataset.column_names)

Map:   0%|          | 0/13798 [00:00<?, ? examples/s]

Map: 100%|██████████| 13798/13798 [00:04<00:00, 3209.36 examples/s]


In [25]:
tokenized_train_dataset

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'],
    num_rows: 13662
})

# QA Train Example

In [6]:
# 设置训练参数
model_name = "qa-bert"
batch_size = 16

args = TrainingArguments(
    output_dir=f"{model_name}-finetuned-squad",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    push_to_hub=False,
)



In [16]:
data_collator = default_data_collator

# 初始化Trainer
trainer = Trainer(
    model=qa_bert,
    args=args,
    train_dataset=tokenized_train_dataset,
    # eval_dataset 仅用于评估模型的性能，不会用于训练。它的主要作用是在训练过程中定期评估模型，
    # 以监控模型的性能并防止过拟合。可以在训练后继续使用 eval_dataset 进行评估和计算各种指标。
    eval_dataset=tokenized_val_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [19]:
# 开始训练
trainer.train()

Epoch,Training Loss,Validation Loss
1,0.1869,0.0762
2,0.0411,0.064588
3,0.0136,0.046178


TrainOutput(global_step=2562, training_loss=0.0630003849828178, metrics={'train_runtime': 1793.4681, 'train_samples_per_second': 22.853, 'train_steps_per_second': 1.429, 'total_flos': 1.9031992389679104e+16, 'train_loss': 0.0630003849828178, 'epoch': 3.0})

In [20]:
trainer.save_model("test-squad-trained")

## Evaluation

In [26]:
# 重新加载模型
trained_model = BertForQuestionAnswering.from_pretrained("test-squad-trained")
# 重新加载分词器
bert_tokenizer = BertTokenizerFast.from_pretrained("test-squad-trained")

In [27]:
for key, question, context, answers in zip(validation_dataset['key'], validation_dataset['question'], validation_dataset['context'], validation_dataset['answers']):
    # 预测答案
    inputs = bert_tokenizer(question, context, return_tensors="pt")
    print(inputs)
    with torch.no_grad():
       outputs = trained_model(**inputs)
       
    answer_start_index = outputs.start_logits.argmax()
    answer_end_index = outputs.end_logits.argmax()
    
    predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
    predict_answer = tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
    print(f"predict answer: {predict_answer}")
    print(f"true answer: {answers}")
    break 


{'input_ids': tensor([[  101,  2425,  2033,  1996,  2171,  1997,  1996, 25381,  3491,  1999,
          2023,  3746,  1029,   102, 22359,  7460,  2000,  1996,  4696,  1997,
         25381,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
predict answer: lipstick
true answer: {'answer_start': [0], 'text': ['lipstick']}


In [19]:
for batch in trainer.get_eval_dataloader():
    break
batch = {k: v.to(trainer.args.device) for k, v in batch.items()}
with torch.no_grad():
    output = trainer.model(**batch)
output.keys()

odict_keys(['loss', 'start_logits', 'end_logits'])

In [20]:
def prepare_validation_features(examples):
    # Some of the questions have lots of whitespace on the left, which is not useful and will make the
    # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
    # left whitespace
    examples["question"] = [q.lstrip() for q in examples["question"]]

    # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

    # We keep the example_id that gave us this feature and we will store the offset mappings.
    tokenized_examples["example_id"] = []

    for i in range(len(tokenized_examples["input_ids"])):
        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 1 if pad_on_right else 0

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        tokenized_examples["example_id"].append(examples["key"][sample_index])

        # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
        # position is part of the context or not.
        tokenized_examples["offset_mapping"][i] = [
            (o if sequence_ids[k] == context_index else None)
            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
        ]

    return tokenized_examples

In [21]:
validation_features = validation_dataset.map(
    prepare_validation_features,
    batched=True,
    remove_columns=validation_dataset.column_names
)

Map: 100%|██████████| 13798/13798 [00:04<00:00, 2978.07 examples/s]


In [22]:
raw_predictions = trainer.predict(validation_features) 

In [23]:
validation_features.set_format(type=validation_features.format["type"], columns=list(validation_features.features.keys()))

In [24]:
def postprocess_qa_predictions(examples, features, raw_predictions, squad_v2=True, n_best_size = 20, max_answer_length = 30):
    """
    该函数可以由tokenizer.decode 完成输出最终的答案文本。
    """
    all_start_logits, all_end_logits = raw_predictions
    # Build a map example to its corresponding features.
    example_id_to_index = {k: i for i, k in enumerate(examples["key"])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature["example_id"]]].append(i)

    # The dictionaries we have to fill.
    predictions = collections.OrderedDict()

    # Logging.
    print(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")

    # Let's loop over all the examples!
    for example_index, example in enumerate(tqdm(examples)):
        # Those are the indices of the features associated to the current example.
        feature_indices = features_per_example[example_index]

        min_null_score = None # Only used if squad_v2 is True.
        valid_answers = []
        
        context = example["context"]
        # Looping through all the features associated to the current example.
        for feature_index in feature_indices:
            # We grab the predictions of the model for this feature.
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            # This is what will allow us to map some the positions in our logits to span of texts in the original
            # context.
            offset_mapping = features[feature_index]["offset_mapping"]

            # Update minimum null prediction.
            cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
            feature_null_score = start_logits[cls_index] + end_logits[cls_index]
            if min_null_score is None or min_null_score < feature_null_score:
                min_null_score = feature_null_score

            # Go through all possibilities for the `n_best_size` greater start and end logits.
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                    # to part of the input_ids that are not in the context.
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or offset_mapping[end_index] is None
                    ):
                        continue
                    # Don't consider answers with a length that is either < 0 or > max_answer_length.
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                        continue

                    start_char = offset_mapping[start_index][0]
                    end_char = offset_mapping[end_index][1]
                    valid_answers.append(
                        {
                            "score": start_logits[start_index] + end_logits[end_index],
                            "text": context[start_char: end_char]
                        }
                    )
        
        if len(valid_answers) > 0:
            best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
        else:
            # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
            # failure.
            best_answer = {"text": "", "score": 0.0}
        
        # Let's pick our final answer: the best one or the null answer (only for squad_v2)
        if not squad_v2:
            predictions[example["key"]] = best_answer["text"]
        else:
            answer = best_answer["text"] if best_answer["score"] > min_null_score else ""
            predictions[example["key"]] = answer

    return predictions

In [25]:
# squad_v2=True允许模型在无法从context中找到answer时,预测"不可能回答",从而处理这种情况。
squad_v2 = False
final_predictions = postprocess_qa_predictions(validation_dataset, validation_features, raw_predictions.predictions, squad_v2=squad_v2)

Post-processing 13798 example predictions split into 13798 features.


100%|██████████| 13798/13798 [00:28<00:00, 477.46it/s]


In [30]:
final_predictions

OrderedDict([('271', 'lipstick'),
             ('3519', 'bus'),
             ('3518', 'pizza'),
             ('3513', 'lizard'),
             ('3517', 'pizza'),
             ('3516', '*Something you find on a pizza is cheese'),
             ('3515', 'dog'),
             ('3514', 'temporary residence'),
             ('98', 'luggage'),
             ('91', 'glove'),
             ('93', 'pineapple'),
             ('92', 'pineapple'),
             ('1176', 'carrot'),
             ('877', 'bus'),
             ('3432', 'refrigerator'),
             ('623', 'laptop'),
             ('875', 'motorcycle'),
             ('622', 'An ipod is for listening'),
             ('872', 'a bicycle'),
             ('3435', 'Dogs'),
             ('627', 'bus'),
             ('4593', 'Baseball'),
             ('624', 'bus'),
             ('4591', 'baseball'),
             ('2745', 'cat'),
             ('2747', 'cat'),
             ('2749', 'cat'),
             ('4845', 'goldfish'),
             ('4735', 'pizza

In [26]:
# 自定义 hit@k 评估函数
def hit_at_k(predictions, references, k):
    hits = 0
    for pred, ref in zip(predictions, references):
        if ref in pred[:k]:
            hits += 1
    return hits / len(references)

In [29]:
# 提取预测和参考答案
predictions = []
references = []
for data in validation_dataset:
    key = data['key']
    if key in final_predictions:
        predictions.append(final_predictions[key].split())  # 将预测结果分割成单词列表
        references.append(data['answers']['text'][0])
        
# 计算 hit@5 和 hit@10
hit1 = hit_at_k(predictions, references, 1)
hit5 = hit_at_k(predictions, references, 5)
hit10 = hit_at_k(predictions, references, 10)

print(f"hit@1: {hit1:.2%}")
print(f"hit@5: {hit5:.2%}")
print(f"hit@10: {hit10:.2%}")

hit@1: 43.22%
hit@5: 66.31%
hit@10: 66.57%


其他获取答案的方法

In [45]:
# 提取预测和参考答案
predictions = []
references = []

for key, question, context, answers in zip(validation_dataset['key'], validation_dataset['question'], validation_dataset['context'], validation_dataset['answers']):
    # 预测答案
    inputs = tokenizer(question, context, return_tensors="pt")
    with torch.no_grad():
       outputs = trained_model(**inputs)
       
    answer_start_index = outputs.start_logits.argmax()
    answer_end_index = outputs.end_logits.argmax()
    
    predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
    predict_answer = tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
    predictions.append(predict_answer.split())
    references.append(answers['text'][0])

# 计算 hit@5 和 hit@10
hit1 = hit_at_k(predictions, references, 1)
hit5 = hit_at_k(predictions, references, 5)
hit10 = hit_at_k(predictions, references, 10)

print(f"hit@1: {hit1:.2%}")
print(f"hit@5: {hit5:.2%}")
print(f"hit@10: {hit10:.2%}")

hit@1: 72.68%
hit@5: 73.49%
hit@10: 74.71%


In [46]:

# 自定义 hit@k 评估函数
def hit_at_k(predictions, references, k):
    hits = 0
    for pred, ref in zip(predictions, references):
        if ref in pred[:k]:
            hits += 1
    return hits / len(references)

# 计算 hit@5 和 hit@10
hit1 = hit_at_k(predictions, references, 1)
hit5 = hit_at_k(predictions, references, 5)
hit10 = hit_at_k(predictions, references, 10)

print(f"hit@1: {hit1:.2%}")
print(f"hit@5: {hit5:.2%}")
print(f"hit@10: {hit10:.2%}")

hit@1: 72.68%
hit@5: 73.49%
hit@10: 74.71%


# Openclip Example

In [10]:
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k', device='cuda')
# vit_tokenizer = open_clip.get_tokenizer('ViT-B-32')

In [18]:
train_dataset[0]

{'key': '270',
 'question': 'Which object can be found in a jazz club',
 'context': 'You are likely to find a trumpet in a jazz club',
 'answers': {'answer_start': [25], 'text': ['trumpet']},
 'image_name': '/root/autodl-tmp/vqa/VQA-with-XProNet/data/KG_VQA/fvqa/exp_data/images/images/ILSVRC2012_test_00050748.JPEG'}

In [19]:
image = Image.open(train_dataset[0]['image_name']).convert("RGB")
image_input = clip_preprocess(image).unsqueeze(0).to(device)  # Unsqueeze 添加一个批次维度
text_tokens = clip_tokenizer.tokenize(train_dataset[0]['context']).to(device)

with torch.no_grad():
    image_features = clip_model.encode_image(image_input).float()
    text_features = clip_model.encode_text(text_tokens).float()
    
image_features.shape, text_features.shape

(torch.Size([1, 512]), torch.Size([1, 512]))

## 获取 bert 嵌入层

In [13]:
bert_tokenizer = BertTokenizerFast.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
qa_bert = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

Some weights of the model checkpoint at bert-large-uncased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.

In [116]:
# 准备输入数据
input_ids = tokenizer.encode(train_dataset['question'][0],  train_dataset['context'][0], add_special_tokens=False, return_tensors="pt")
# 获取嵌入层输出
with torch.no_grad():
    # 调用模型，确保请求隐藏状态
    outputs = qa_bert(input_ids, output_hidden_states=True)
    # 获取 last_hidden_state
    last_hidden_state = outputs.hidden_states[-1]

In [40]:
len(outputs.hidden_states), input_ids

(25,
 tensor([[2029, 4874, 2064, 2022, 2179, 1999, 1037, 4166, 2252, 2017, 2024, 3497,
          2000, 2424, 1037, 9368, 1999, 1037, 4166, 2252]]))

In [33]:
last_hidden_state.shape, last_hidden_state

(torch.Size([1, 20, 1024]),
 tensor([[[ 0.2163, -0.2789, -1.2007,  ...,  0.0475,  0.8559, -0.8639],
          [-0.7926, -0.3257, -0.0169,  ...,  0.9857,  1.0101, -1.0077],
          [-1.0194,  0.0947,  0.0371,  ...,  1.0660,  0.6618, -0.1563],
          ...,
          [-1.4010, -0.9902, -0.1381,  ...,  0.6080,  0.8213, -0.9767],
          [-1.3012, -0.9542, -0.0553,  ...,  0.3923,  0.6864, -1.2293],
          [-1.3286, -0.7271, -0.7457,  ...,  0.6030,  1.1488, -0.9676]]]))

## 拼接

In [43]:
# 扩展 image_features 和 text_features 到 [1, 20, 512]
image_features_expanded = image_features.repeat(1, 20, 1).to(device)
text_features_expanded = text_features.repeat(1, 20, 1).to(device)

# 拼接这些特征
combined_features = torch.cat([last_hidden_state.to(device), image_features_expanded, text_features_expanded], dim=2)  # 拼接在特征维度
combined_features.shape

torch.Size([1, 20, 2048])

## 输入嵌入向量到 qabert

inputs_embeds (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional) — Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert input_ids indices into associated vectors than the model’s internal embedding lookup matrix.

In [51]:
# 创建一个线性层，将特征从 2048 降维到 1024
fc = nn.Linear(combined_features.shape[-1], 1024).to(device)
# 应用线性层
final_features = fc(combined_features).to(device)
qa_bert.to(device)
qa_bert(inputs_embeds=final_features)

QuestionAnsweringModelOutput(loss=None, start_logits=tensor([[-2.7434, -1.8820, -3.2943, -3.1643, -3.0274, -3.2030, -3.0887, -3.0251,
         -1.3093, -3.1568, -3.4538, -3.1751, -3.3870, -3.2753, -3.1460, -2.0125,
         -2.8449, -2.3112, -2.6830, -2.0979]], device='cuda:0',
       grad_fn=<CloneBackward0>), end_logits=tensor([[-2.1292, -1.1948, -2.5703, -2.5366, -2.4104, -2.6203, -2.4670, -2.2586,
         -0.5762, -2.2948, -2.5812, -2.4727, -2.5212, -2.4964, -2.3861, -1.4723,
         -2.2931, -1.6980, -1.9327, -1.5011]], device='cuda:0',
       grad_fn=<CloneBackward0>), hidden_states=None, attentions=None)

# VQA train Example

In [11]:
tokenizer = BertTokenizerFast.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
qa_bert = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
qa_bert = qa_bert.to(device)
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained='openai', device=device)
# clip_model, _, clip_preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k', device=device)
# clip_model, _, clip_preprocess = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k_augreg', device=device)

Some weights of the model checkpoint at bert-large-uncased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
# image = Image.open(train_dataset[0]['image_name']).convert("RGB")
# image_input = clip_preprocess(image).unsqueeze(0).to(device)  # Unsqueeze 添加一个批次维度
# text_tokens = clip_tokenizer.tokenize(train_dataset[0]['context']).to(device)

# with torch.no_grad():
#     image_features = clip_model.encode_image(image_input).float()
#     text_features = clip_model.encode_text(text_tokens).float()
    
# image_features.shape, text_features.shape

In [12]:
def process_combine_data(batch):
    # 处理一批图像数据
    images = [Image.open(path).convert("RGB") for path in batch['image_name']]
    image_inputs = torch.stack([clip_preprocess(image) for image in images]).to(device)

    # 处理一批文本数据
    text_tokens = clip_tokenizer.tokenize([context for context in batch['context']]).to(device)

    with torch.no_grad():
        image_features = clip_model.encode_image(image_inputs).float()
        # print(image_features.shape)
        text_features = clip_model.encode_text(text_tokens).float()
        # print(text_features.shape)

    # 准备输入数据
    # input_ids = tokenizer.encode(batch['question'], batch['context'], add_special_tokens=False, return_tensors="pt")

    # 处理一批问题和上下文数据
    # input_ids = [tokenizer.encode(q, c, add_special_tokens=True, return_tensors="pt") for q, c in zip(batch['question'], batch['context'])]
    input_ids = [tokenizer.encode(q, c, add_special_tokens=True, return_tensors="pt", max_length=max_length, padding='max_length', truncation=True) for q, c in zip(batch['question'], batch['context'])]
    input_ids = torch.cat(input_ids, dim=0).to(device)  # 合并批次数据

    with torch.no_grad():
        outputs = qa_bert(input_ids, output_hidden_states=True)
        last_hidden_states = outputs.hidden_states[-1]
        # print(last_hidden_states.shape)

    # 扩展图像特征和文本特征
    image_features_expanded = image_features.unsqueeze(1).repeat(1, last_hidden_states.shape[1], 1).to(device)
    # print(image_features_expanded.shape)
    text_features_expanded = text_features.unsqueeze(1).repeat(1, last_hidden_states.shape[1], 1).to(device)
    # print(text_features_expanded.shape)

    # 拼接特征
    combined_features = torch.cat([last_hidden_states.to(device), image_features_expanded, text_features_expanded], dim=2)

    return {"combined_features": combined_features}
    # return {"combined_features": combined_features.cpu().numpy()}


In [14]:
features = process_combine_data(train_dataset[:2])
features 

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


{'combined_features': tensor([[[ 1.4703, -1.1017,  0.0225,  ..., -0.3972, -0.2359,  0.1485],
          [-0.1402, -1.5653,  0.3490,  ..., -0.3972, -0.2359,  0.1485],
          [-0.3292, -1.4199,  0.1754,  ..., -0.3972, -0.2359,  0.1485],
          ...,
          [ 1.2099, -0.9790,  0.1319,  ..., -0.3972, -0.2359,  0.1485],
          [ 1.1496, -1.0239,  0.1406,  ..., -0.3972, -0.2359,  0.1485],
          [ 1.1948, -0.9850,  0.0982,  ..., -0.3972, -0.2359,  0.1485]],
 
         [[ 1.3351, -1.0875, -0.0421,  ..., -0.2256,  0.3205, -0.0110],
          [ 0.0961, -1.3786,  0.3776,  ..., -0.2256,  0.3205, -0.0110],
          [ 0.2673, -0.9926, -0.0607,  ..., -0.2256,  0.3205, -0.0110],
          ...,
          [ 1.2144, -0.9850,  0.0387,  ..., -0.2256,  0.3205, -0.0110],
          [ 1.1847, -1.0080,  0.0383,  ..., -0.2256,  0.3205, -0.0110],
          [ 1.1952, -0.9801,  0.0129,  ..., -0.2256,  0.3205, -0.0110]]],
        device='cuda:0')}

In [14]:
# 随机生成100个唯一的索引
indices = np.random.permutation(len(train_dataset))[:100]
# 使用选定的索引切割数据集
small_train_dataset = train_dataset.select(indices)

indices = np.random.permutation(len(validation_dataset))[:20]
small_val_dataset = validation_dataset.select(indices)

In [15]:
# 应用函数到数据集的每一批数据
# NOTE：先使用少量数据进行测试
processed_train_dataset = small_train_dataset.map(process_combine_data, batched=True, batch_size=32)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
Map: 100%|██████████| 100/100 [00:04<00:00, 20.64 examples/s]


In [16]:
processed_val_dataset = small_val_dataset.map(process_combine_data, batched=True, batch_size=32)
processed_val_dataset

Map: 100%|██████████| 20/20 [00:00<00:00, 22.44 examples/s]


Dataset({
    features: ['key', 'question', 'context', 'answers', 'image_name', 'combined_features'],
    num_rows: 20
})

In [18]:
# 处理数据集
tokenized_combine_train_dataset = processed_train_dataset.map(prepare_train_features, batched=True, remove_columns=['key', 'question', 'context', 'answers', 'image_name'])
tokenized_combine_val_dataset = processed_val_dataset.map(prepare_train_features, batched=True, remove_columns=['key', 'question', 'context', 'answers', 'image_name'])


Map: 100%|██████████| 100/100 [00:00<00:00, 220.41 examples/s]
Map: 100%|██████████| 20/20 [00:00<00:00, 392.69 examples/s]


In [19]:
tokenized_combine_val_dataset

Dataset({
    features: ['combined_features', 'input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'],
    num_rows: 20
})

In [80]:
class CustomModel(nn.Module):
    def __init__(self, bert_model):
        super(CustomModel, self).__init__()
        self.bert = bert_model
        self.fc = nn.Linear(2048, 1024)  # 拼接后的特征维度为 2048,将其降维到 1024

    def forward(self, combined_features, start_positions, end_positions):
        reduced_features = self.fc(combined_features)
        outputs = self.bert(inputs_embeds=reduced_features, start_positions=start_positions, end_positions=end_positions)
        return outputs.loss, outputs.start_logits, outputs.end_logits

In [40]:
vqa_model = CustomModel(qa_bert).to(device)  # 确保模型也在正确的设备上
# 自定义数据收集器
# data_collator = MyDataCollator()

# 设置训练参数
model_name = "vqa-bert"
batch_size = 16

args = TrainingArguments(
    output_dir=f"{model_name}-finetuned-squad",
    evaluation_strategy="epoch",
    learning_rate=0.001,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
    weight_decay=0.01,
    logging_steps=5
)

trainer = Trainer(
    model=vqa_model,
    args=args,
    train_dataset=tokenized_combine_train_dataset,
    eval_dataset=tokenized_combine_val_dataset,
    data_collator=default_data_collator,  
    tokenizer=tokenizer
)

trainer.train()



Epoch,Training Loss,Validation Loss
1,6.3749,No log


TrainOutput(global_step=7, training_loss=6.259158406938825, metrics={'train_runtime': 56.5143, 'train_samples_per_second': 1.769, 'train_steps_per_second': 0.124, 'total_flos': 0.0, 'train_loss': 6.259158406938825, 'epoch': 1.0})

# VQA + Prototype

## Prototype Matrix Initialization

In [20]:
import numpy as np
from sklearn.cluster import KMeans

# 假设 processed_train_dataset 是已经加载的数据集
features = processed_train_dataset['combined_features']

# 将特征列表转换为 NumPy 数组，以便用于 K-Means 算法
features_matrix = np.array(features)
print(f"features_matrix shape: {features_matrix.shape}")

# features_matrix = features_matrix.reshape(features_matrix.shape[1], -1)
features_matrix = features_matrix.mean(axis=1)
print(f"features_matrix shape: {features_matrix.shape}")

# 设置 K-Means 算法，聚类数为 10
kmeans = KMeans(n_clusters=10, random_state=0).fit(features_matrix)

# 获取每个簇的中心点，这些中心点即为原型向量
prototype_vectors = kmeans.cluster_centers_
print(prototype_vectors.shape)
# 打印原型向量，查看结果
print("Prototype vectors:\n", prototype_vectors)
prototype_vectors = torch.tensor(prototype_vectors, dtype=torch.float32)
prototype_vectors = prototype_vectors.unsqueeze(0).to(device)

features_matrix shape: (100, 384, 2048)
features_matrix shape: (100, 2048)
(10, 2048)
Prototype vectors:
 [[ 1.19071023e+00 -9.64432258e-01  1.30624041e-04 ...  4.99812750e-02
  -8.89559196e-02  7.53895827e-02]
 [-6.18201035e-01 -8.81037543e-01 -2.01249824e-01 ... -2.08062132e-02
  -1.51775781e-01  7.27241635e-02]
 [ 1.18580058e+00 -9.63028422e-01 -1.37879116e-03 ...  4.00370925e-02
   2.01106527e-02  2.19146271e-01]
 ...
 [ 1.14892395e+00 -1.00487485e+00  4.84245156e-02 ... -4.68075089e-03
  -4.69022592e-02 -1.44994088e-02]
 [ 1.18890511e+00 -9.72748384e-01 -1.28681172e-02 ...  2.22443674e-01
  -3.89508198e-03  2.30963608e-02]
 [ 1.17891174e+00 -9.66156564e-01 -6.33837884e-03 ...  1.41137085e-01
  -7.55139544e-02  1.41927724e-01]]


## Prototype Querying And Responsing

In [49]:
# class Prototype(nn.Module):
#     def __init__(self, prototype_vectors, feature_dim=2048, num_prototypes=3):
#         """
#         Args:
#             prototype_vectors (np.ndarray): 原型向量（n_clusters， 2048）
#             feature_dim (int): 转换目标特征维度
#             num_prototypes (int): 每个特征选择的 topk 相似原型数量
#         """
#         super(Prototype, self).__init__()
#         self.prototypes = nn.Parameter(torch.tensor(prototype_vectors, dtype=torch.float32), requires_grad=False)
#         self.dim_reduction = nn.Linear(prototype_vectors.shape[1], feature_dim) # 用于降维原型向量到与combined_features相同的维度
#         self.num_prototypes = num_prototypes

#     def forward(self, combined_features):
#         # print(f"combined_features: {combined_features.shape}")
#         # 线性投影
#         projected_features = self.dim_reduction(self.prototypes)
#         # print(f"projected_features: {projected_features.shape}")
        
#         projected_features = projected_features.unsqueeze(1)
#         print(f"projected_features: {projected_features.shape}")
        
#         # combined_features = combined_features.unsqueeze(0)
#         # print(f"combined_features: {combined_features.shape}")
        
#         # 相似度计算（余弦相似度）
#         similarity = F.cosine_similarity(projected_features, combined_features, dim=2)
#         similarity = similarity.unsqueeze(0)
#         print(f"similarity: {similarity.shape}")
#         # 原型选择：选择每个特征最相关的前 num_prototypes 个原型
#         topk_vals, topk_indices = torch.topk(similarity, self.num_prototypes, dim=1)
#         # 权重计算：这里使用 softmax 来归一化权重
#         weights = F.softmax(topk_vals, dim=2)
#         # print(f"weights: {weights.shape}")
#         # 响应生成：加权求和原型向量
#         selected_prototypes = self.prototypes[topk_indices.squeeze(0)] 
#         # print(f"selected_prototypes: {selected_prototypes.shape}")
#         response = torch.sum(weights.unsqueeze(-1) * selected_prototypes, dim=1)
#         return response

# class Prototype(nn.Module):
#     def __init__(self, prototype_vectors, feature_dim=2048, num_prototypes=3):
#         super(Prototype, self).__init__()
#         self.prototypes = nn.Parameter(torch.tensor(prototype_vectors, dtype=torch.float32), requires_grad=False)
#         self.dim_reduction = nn.Linear(prototype_vectors.shape[1], feature_dim)
#         self.num_prototypes = num_prototypes

#     def forward(self, combined_features):
#         # 线性投影
#         projected_features = self.dim_reduction(self.prototypes)  # [num_prototypes, feature_dim]
#         projected_features = projected_features.unsqueeze(0)  # [1, num_prototypes, feature_dim]
        
#         # 扩展批处理尺寸
#         projected_features = projected_features.expand(combined_features.size(0), -1, -1)  # [batch_size, num_prototypes, feature_dim]
        
#         # 相似度计算（余弦相似度）
#         similarity = F.cosine_similarity(projected_features, combined_features.unsqueeze(1), dim=2)  # [batch_size, num_prototypes, num_features]
        
#         # 原型选择：选择每个特征最相关的前 num_prototypes 个原型
#         topk_vals, topk_indices = torch.topk(similarity, self.num_prototypes, dim=1)
        
#         # 权重计算：这里使用 softmax 来归一化权重
#         weights = F.softmax(topk_vals, dim=1)  # [batch_size, num_prototypes, num_features]

#         # 响应生成：加权求和原型向量
#         selected_prototypes = self.prototypes[topk_indices]  # [batch_size, num_features, num_prototypes, feature_dim]
#         response = torch.sum(weights.unsqueeze(-1) * selected_prototypes, dim=2)  # [batch_size, num_features, feature_dim]
#         return response


In [17]:
# 假设有一个输入特征
combined_features = torch.tensor(processed_train_dataset[0]['combined_features'])
prototype_vectors = kmeans.cluster_centers_
prototype_vectors = torch.tensor(prototype_vectors, dtype=torch.float32)

combined_features = combined_features.unsqueeze(0).to(device).float() 
prototype_vectors = prototype_vectors.unsqueeze(0).to(device).float()  
print(combined_features.shape, prototype_vectors.shape)
mtm = MultiThreadMemory(h=16, d_model=2048, topk=5, dropout=0.1)

In [18]:
mtm.to(device)

In [19]:
output = mtm(combined_features, prototype_vectors, prototype_vectors)
output.shape, output

## Training

In [21]:
# tokenizer = BertTokenizerFast.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
# qa_bert = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
# qa_bert = qa_bert.to(device)
# clip_model, _, clip_preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k', device=device)

prototype_vectors = kmeans.cluster_centers_
prototype_vectors = torch.tensor(prototype_vectors, dtype=torch.float32)
prototype_vectors = prototype_vectors.unsqueeze(0).to(device)

In [22]:
class VqaPrototypeModel(nn.Module):
    def __init__(self, prototype_vectors, device, batch_size=16, seq_length=38):
        super(VqaPrototypeModel, self).__init__()
        self.device = device
        self.bert = qa_bert.to(self.device)
        self.prototype = MultiThreadMemory(h=16, d_model=2048, topk=3, dropout=0.1, device=self.device).to(self.device)
        self.prototype_vectors = prototype_vectors.repeat(batch_size, seq_length, 1).to(self.device)
        # print(f"prototype_vectors shape: {prototype_vectors.shape}")
        self.fc = nn.Linear(4096, 1024).to(self.device)  # 注意，如果原始特征和响应被拼接，这里的输入维度应为 2048 + feature_dim

    def forward(self, combined_features, start_positions, end_positions):
        combined_features = torch.tensor(combined_features).to(self.device)
        # print(f"combined_features shape: {combined_features.shape}")

        response = self.prototype(combined_features, self.prototype_vectors, self.prototype_vectors, device=self.device)
        # print("Shape of Response:", response.shape)
        
        # 拼接原型响应和 BERT 输出
        final_combined_features = torch.cat([combined_features, response], dim=2)
        # print("Shape of final_combined_features:", final_combined_features.shape)
        
        reduced_features = self.fc(final_combined_features)
        # print("Shape of reduced_features:", reduced_features.shape)
        
        outputs = self.bert(inputs_embeds=reduced_features, start_positions=start_positions, end_positions=end_positions)
        
        return outputs.loss, outputs.start_logits, outputs.end_logits

In [21]:
# vqa_model = VqaPrototypeModel(prototype_vectors).to(device) 
# test_input = torch.tensor(tokenized_combine_train_dataset[0]['combined_features'], dtype=torch.float32)
# test_input = test_input.unsqueeze(0).to(device)
# print(test_input.shape)

# output = vqa_model(test_input, torch.tensor([0]).to(device), torch.tensor([0]).to(device))

In [24]:
torch.cuda.empty_cache()
vqa_model = VqaPrototypeModel(prototype_vectors, device=device).to(device) 

In [25]:
# 设置训练参数
model_name = "vqa-bert"
batch_size = 16

args = TrainingArguments(
    output_dir=f"{model_name}-finetuned-squad",
    evaluation_strategy="epoch",
    learning_rate=0.01,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
    weight_decay=0.01,
    logging_steps=5,
    gradient_accumulation_steps=4,
)

trainer = Trainer(
    model=vqa_model,
    args=args,
    train_dataset=tokenized_combine_train_dataset,
    eval_dataset=tokenized_combine_val_dataset,
    data_collator=default_data_collator,  
    tokenizer=tokenizer
)

trainer.train()



  combined_features = torch.tensor(combined_features).to(self.device)


OutOfMemoryError: CUDA out of memory. Tried to allocate 96.00 MiB. GPU 