# DRIVE



In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/VQA

/content/drive/MyDrive/VQA


# DATA

## AOKVQA

In [None]:
!export AOKVQA_DIR="/content/drive/MyDrive/VQA/data/aokvqa/"

In [None]:
!echo $AOKVQA_DIR

/content/drive/MyDrive/VQA/data/aokvqa/


In [None]:
!mkdir -p $AOKVQA_DIR

In [None]:
!curl -fsSL https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz | tar xvz -C $AOKVQA_DIR

aokvqa_v1p0_train.json
aokvqa_v1p0_val.json
aokvqa_v1p0_test.json
large_vocab_train.csv
specialized_vocab_train.csv


In [None]:
!ls /content/drive/MyDrive/VQA/data/aokvqa/

aokvqa_v1p0_test.json	aokvqa_v1p0_val.json   specialized_vocab_train.csv
aokvqa_v1p0_train.json	large_vocab_train.csv


## F-VQA

# CODE


## Cmd


In [None]:
!pwd

/content/drive/MyDrive/VQA


In [None]:
!mkdir -p /content/drive/MyDrive/VQA/cfgs

In [None]:
!mkdir -p /content/drive/MyDrive/VQA/data

In [None]:
!mkdir -p /content/drive/MyDrive/VQA/code

In [None]:
!mkdir -p /content/drive/MyDrive/VQA/code/utils

In [None]:
!mkdir -p /content/drive/MyDrive/VQA/code/model

In [None]:
!mkdir -p /content/drive/MyDrive/VQA/code/model/fusion_net

In [None]:
!mkdir -p /content/drive/MyDrive/VQA/code/model/answer_net

In [None]:
!mkdir -p /content/drive/MyDrive/VQA/code/data

In [None]:
!mkdir -p /content/drive/MyDrive/VQA/code/bash_script

In [17]:
!mkdir -p /content/drive/MyDrive/VQA/code/torchlight

In [None]:
!ls

cfgs  code  data  kg  run.ipynb


## Data


###### preprocess.py

In [None]:
%%writefile /content/drive/MyDrive/VQA/code/data/preprocess.py
import os
import os.path as osp
import re
import random
import itertools
import h5py
import torch
import torch.utils.data as data
import pdb
from torch.utils.data.dataloader import default_collate
from collections import Counter
from PIL import Image
# this is used for normalizing questions
_special_chars = re.compile('[^a-z0-9 ]*')

# these try to emulate the original normalization scheme for answers
_period_strip = re.compile(r'(?!<=\d)(\.)(?!\d)')
_comma_strip = re.compile(r'(\d)(,)(\d)')
_punctuation_chars = re.escape(r';/[]"{}()=+\_-><@`,?!')
_punctuation = re.compile(r'([{}])'.format(re.escape(_punctuation_chars)))
_punctuation_with_a_space = re.compile(r'(?<= )([{0}])|([{0}])(?= )'.format(_punctuation_chars))


def invert_dict(d): return {v: k for k, v in d.items()}


def process_punctuation(s):
    # the original is somewhat broken, so things that look odd here might just be to mimic that behaviour
    # this version should be faster since we use re instead of repeated operations on str's
    original_s = s
    if _punctuation.search(s) is None:
        return s
    s = _punctuation_with_a_space.sub('', s)
    if re.search(_comma_strip, s) is not None:
        s = s.replace(',', '')
    s = _punctuation.sub(' ', s)
    s = _period_strip.sub('', s)
    if s.strip() == '':
        return original_s.strip()
    else:
        return s.strip()


def extract_vocab(iterable, top_k=None, start=0, input_vocab=None):
    """ Turns an iterable of list of tokens into a vocabulary.
        These tokens could be single answers or word tokens in questions.
    """
    all_tokens = itertools.chain.from_iterable(iterable)
    counter = Counter(all_tokens)
    if top_k:
        most_common = counter.most_common(top_k)
        most_common = (t for t, c in most_common)
    else:
        most_common = counter.keys()
    # descending in count, then lexicographical order
    tokens = sorted(most_common, key=lambda x: (counter[x], x), reverse=True)

    vocab = {t: i for i, t in enumerate(tokens, start=start)}
    return vocab


class CocoImages(data.Dataset):
    def __init__(self, path, transform=None):
        super(CocoImages, self).__init__()
        self.path = path
        self.id_to_filename = self._find_images()
        self.sorted_ids = sorted(self.id_to_filename.keys())  # used for deterministic iteration order
        print('found {} images in {}'.format(len(self), self.path))
        self.transform = transform

    def _find_images(self):
        id_to_filename = {}
        for filename in os.listdir(self.path):
            if not filename.endswith('.jpg'):
                continue
            id_and_extension = filename.split('_')[-1]
            id = int(id_and_extension.split('.')[0])
            id_to_filename[id] = filename
        return id_to_filename

    def __getitem__(self, item):
        id = self.sorted_ids[item]
        path = os.path.join(self.path, self.id_to_filename[id])
        img = Image.open(path).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        return id, img

    def __len__(self):
        return len(self.sorted_ids)


class Composite(data.Dataset):
    """ Dataset that is a composite of several Dataset objects. Useful for combining splits of a dataset. """

    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, item):
        current = self.datasets[0]
        for d in self.datasets:
            if item < len(d):
                return d[item]
            item -= len(d)
        else:
            raise IndexError('Index too large for composite dataset')

    def __len__(self):
        return sum(map(len, self.datasets))

    def _get_answer_vectors(self, answer_indices):
        return self.datasets[0]._get_answer_vectors(answer_indices)

    def _get_answer_sequences(self, answer_indices):
        return self.datasets[0]._get_answer_sequences(answer_indices)

    @property
    def vector(self):
        return self.datasets[0].vector

    @property
    def token_to_index(self):
        return self.datasets[0].token_to_index

    @property
    def answer_to_index(self):
        return self.datasets[0].answer_to_index

    @property
    def index_to_answer(self):
        return self.datasets[0].index_to_answer

    @property
    def num_tokens(self):
        return self.datasets[0].num_tokens

    @property
    def num_answer_tokens(self):
        return self.datasets[0].num_answer_tokens

    @property
    def vocab(self):
        return self.datasets[0].vocab


def eval_collate_fn(batch):
    # put question lengths in descending order so that we can use packed sequences later
    batch.sort(key=lambda x: x[-1], reverse=True)
    return data.dataloader.default_collate(batch)

Writing /content/drive/MyDrive/VQA/code/data/preprocess.py


###### base.py

In [None]:
%%writefile /content/drive/MyDrive/VQA/code/data/base.py
import json
import os
import os.path as osp
import nltk
import h5py
import torch
import torch.utils.data as data
import pdb
from nltk import word_tokenize, pos_tag
import re
import numpy as np
import sys
import pickle as pkl

################
from .preprocess import invert_dict


class VisualQA(data.Dataset):
    def __init__(self,
                 args,
                 vector):
        super(VisualQA, self).__init__()

        # vocab
        self.vector = vector
        self.args = args
        # process question
        # self.args.question_vocab_path = osp.join(project_root, 'data', 'question.vocab.json') # a joint question vocab across all dataset
        with open(self.args.question_vocab_path, 'r') as fd:
            question_vocab = json.load(fd)
        self.token_to_index = question_vocab['question']
        self._max_question_length = question_vocab['max_question_length']
        self.image_features_path = args.FVQA.feature_path
        self.index_to_token = invert_dict(self.token_to_index)

        answer_vocab_path = self.args.FVQA.answer_vocab_path
        fact_vocab_path = self.args.FVQA.fact_vocab_path
        relation_vocab_path = self.args.FVQA.relation_vocab_path

        if self.args.fact_map:
            with open(fact_vocab_path, 'r') as fd:
                answer_vocab = json.load(fd)
        elif self.args.relation_map:
            with open(relation_vocab_path, 'r') as fd:
                answer_vocab = json.load(fd)
        else:
            with open(answer_vocab_path, 'r') as fd:
                answer_vocab = json.load(fd)
        self.answer_to_index = answer_vocab['answer']
        self.index_to_answer = invert_dict(self.answer_to_index)

        self.cached_answers_g2v = {}  # 只编码KGE
        self.cached_answers_w2v = {}  # 只编码序列
        self.cached_answers_gae = {}
        self.cached_answers_bert = {}
        self.unk_vector = self.vector['UNK']
        if "KG" in self.args.method_choice:
            self._map_kg()
        if "GAE" in self.args.method_choice:
            # self._map_gae()
            self._map_bert()

    @property
    def max_question_length(self):
        return self._max_question_length

    @property
    def max_answer_length(self):
        assert hasattr(self, answers), 'Dataloader must have access to answers'
        if not hasattr(self, '_max_answer_length'):
            self._max_answer_length = max(map(len, self.answers))
        return self._max_answer_length

    @property
    def num_tokens(self):
        return len(self.token_to_index)

    @property
    def num_answers(self):
        return len(self.answer_to_index)

    def __len__(self):
        return len(self.questions)

    # Internal data utility---------------------------------------

    def _load_image(self, image_id):
        """ Load an image """
        # pdb.set_trace()
        index = self.image_id_to_index[image_id]
        spa = torch.zeros([1, 1])  # init

        if self.args.fusion_model == 'UD' or self.args.fusion_model == 'BAN':
            spatials = self.features_file['spatial_features']
            dataset = self.features_file['image_features']  # 直接读取特征文件
            spa = spatials[index].astype('float32')
            spa = torch.from_numpy(spa)
        else:
            dataset = self.features_file['features']  # 直接读取特征文件

        img = dataset[index].astype('float32')

        return torch.from_numpy(img), spa

    def _create_image_id_to_index(self):
        """ Create a mapping from a COCO image id into the corresponding index into the h5 file """
        if not hasattr(self, 'features_file'):
            # Loading the h5 file has to be done here and not in __init__ because when the DataLoader
            # forks for multiple works, every child would use the same file object and fail
            # Having multiple readers using different file objects is fine though, so we just init in here.
            self.features_file = h5py.File(self.image_features_path, 'r')

        if self.args.fusion_model == 'UD' or self.args.fusion_model == 'BAN':
            import _pickle as cPickle
            image_id_to_index = cPickle.load(open(self.args.FVQA.img_id2idx, "rb"))
            # pdb.set_trace()
            self.s_dim = self.features_file['spatial_features'].shape[2]
            self.v_dim = self.features_file['image_features'].shape[2]  # 直接读取特征文件

        else:
            with h5py.File(self.image_features_path, 'r') as features_file:
                image_ids = features_file['ids'][()]
            image_id_to_index = {id: i for i, id in enumerate(image_ids)}
        return image_id_to_index

    def _encode_question(self, question):
        """ Turn a question into a vector of indices and a question length """
        vec = torch.zeros(self.max_question_length).long()
        for i, token in enumerate(question):
            index = self.token_to_index.get(token, 0)
            vec[i] = index
        return vec, len(question)

    def _map_kg(self):
        if "KG" not in self.args.method_choice:
            return
        # print("using kg embedding")
        kg_path = self.args.FVQA.kg_path
        entity_path = self.args.FVQA.entity_path  # 来源中的词对应的向量
        relation_path = self.args.FVQA.relation_path  # 同上
        relation2id_path = self.args.FVQA.relation2id_path  # 搜寻候选答案的来源
        entity2id_path = self.args.FVQA.entity2id_path  # 搜寻候选答案的来源

        a = np.load(entity_path)
        b = np.load(relation_path)
        self.map_kg = np.vstack((a, b))

        # 随机得到一个矩阵，以模拟随机的情况
        # self.map_ran=torch.zeros(self.map_kg.shape)
        # self.map_ran = torch.rand(self.map_kg.shape)
        # self.map_ran = torch.randn(self.map_kg.shape)
        # self.map_kg = self.map_ran

        self.map_kg = torch.Tensor(self.map_kg).view(-1, 300)

        self.stoi_kg = {}
        with open(os.path.join(entity2id_path), "r") as f:
            while 1:
                line = f.readline()
                if not line:
                    break
                line = re.split('\t|\n', line)[:2]
                self.stoi_kg[line[0]] = int(line[1])
        sz = len(self.stoi_kg)
        with open(os.path.join(relation2id_path), "r") as f:
            while 1:
                line = f.readline()
                if not line:
                    break
                line = re.split('\t|\n', line)[:2]
                self.stoi_kg[line[0]] = int(line[1]) + sz

    def _map_gae(self):
        if "GAE" not in self.args.method_choice:
            return
        # print("using kg embedding")

        _gae_path = self.args.FVQA.gae_path
        gae_path = osp.join(_gae_path, str(self.args.FVQA.gae_node_num) + "_init_" + self.args.FVQA.gae_init + ".pkl")
        print("gae file:", gae_path)
        with open(gae_path, 'rb') as f:
            if sys.version_info > (3, 0):
                features = pkl.load(f, encoding='latin1')
            else:
                features = pkl.load(f)
        # 下标到gae向量的映射
        self.map_gae = torch.FloatTensor(np.array(features)).view(-1, 300)
        vertices_f = osp.join(_gae_path, "g_nodes_" + str(self.args.FVQA.gae_node_num) + ".json")
        self.stoi_gae = {}
        with open(vertices_f) as fp:
            vertices_list = json.load(fp)

        for i, vertex in enumerate(vertices_list):
            self.stoi_gae[vertex] = i
        # print("test map gae")
        # pdb.set_trace()

    def _map_bert(self):
        if "GAE" not in self.args.method_choice:
            return
        # print("using kg embedding")

        cache_path = osp.join(self.args.FVQA.bert_path, "map_bert.pt")
        if not osp.exists(cache_path):
            _bert_path = self.args.FVQA.bert_path

            bert_path = osp.join(_bert_path, "conceptnet_bert_embeddings.pt")
            print("bert file:", bert_path)
            _cache = torch.load(bert_path)  # torch.Size([78334, 1024])

            self.map_bert = torch.FloatTensor(self.args.FVQA.max_ans, self.args.bert_dim)
            # 下标到gae向量的映射
            all = []

            with open(osp.join(_bert_path, "cn_node_names_for_embeddings.txt"), 'r', encoding='utf-8') as f:
                while 1:
                    line = f.readline()
                    if not line:
                        break
                    line = re.split('\n', line)
                    all.append(line[0])

            self.stoi_bert = {}  # answer to vector文件的 id 下标
            for key, value in self.answer_to_index.items():
                self.stoi_bert[key] = value
                if key in all:
                    self.map_bert[value] = _cache[all.index(key), :]
                else:
                    cnt = 0.0
                    tmp = torch.zeros(1, self.args.bert_dim).cuda()
                    for i, j in enumerate(all):
                        if len(j) >= 4 and len(key) >= 3 and (key in j or j in key):
                            # pdb.set_trace()
                            tmp += _cache[i, :]  # 取平均
                            cnt += 1
                        if cnt >= 3:
                            break
                    if cnt == 0:
                        raise TypeError('cnt can not = 0 !!!')
                    self.map_bert[value] = tmp / (cnt + 1e-12)

            if (self.map_bert != self.map_bert).any():
                raise TypeError('cnt can not = 0 !!!')
            # pdb.set_trace()
            torch.save({'map_bert': self.map_bert, 'stoi_bert': self.stoi_bert}, cache_path)
        else:
            _cache = torch.load(cache_path)
            self.map_bert = _cache['map_bert']  # 词向量列表 + 长度
            self.stoi_bert = _cache['stoi_bert']  # 答案下标

        # print("test map gae")
        # pdb.set_trace()

    def _get_answer_vectors(self, ways, answer_indices):
        dim = self.vector.dim
        if ways == 'GAE':
            dim = self.args.bert_dim
            return self._encode_answer_vector(self._encode_answer_vector_bert, dim, answer_indices)
            # return self._encode_answer_vector(self._encode_answer_vector_gae, dim, answer_indices)
        elif ways == 'KG':
            return self._encode_answer_vector(self._encode_answer_vector_g2v, dim, answer_indices)
        elif ways == 'W2V':
            return self._encode_answer_vector(self._encode_answer_vector_w2v, dim, answer_indices)

    def _encode_answer_vector(self, encode_model, dim, answer_indices):
        if isinstance(answer_indices[0], list):
            N, C = len(answer_indices), len(answer_indices[0])
            vector = torch.zeros(N, C, dim)
            for i, answer_ids in enumerate(answer_indices):
                for j, answer_id in enumerate(answer_ids):
                    if answer_id != -1:
                        vector[i, j, :] = encode_model(self.index_to_answer[answer_id])
                    else:
                        vector[i, j, :] = self.unk_vector
        else:
            vector = torch.zeros(len(answer_indices), dim)
            for idx, answer_id in enumerate(answer_indices):

                if answer_id != -1:
                    if type(answer_id).__name__ == 'int':
                        vector[idx, :] = encode_model(self.index_to_answer[answer_id])
                    else:
                        vector[idx, :] = encode_model(self.index_to_answer[answer_id.item()])
                else:
                    vector[idx, :] = self.unk_vector
        return vector, []

    def _get_answer_sequences_w2v(self, answer_indices):
        seqs, lengths = [], []
        max_seq_length = 0
        if isinstance(answer_indices[0], list):
            N, C = len(answer_indices), len(answer_indices[0])
            for i, answer_ids in enumerate(answer_indices):
                _seqs = []
                for j, answer_id in enumerate(answer_ids):
                    if answer_id != -1:
                        _seqs.append(self._encode_answer_sequence_w2v(self.index_to_answer[answer_id]))
                    else:
                        _seqs.append([self.unk_vector])
                    if max_seq_length < len(_seqs[-1]):
                        max_seq_length = len(_seqs[-1])  # determing max length
                seqs.append(_seqs)

            vector = torch.zeros(N, C, max_seq_length, self.vector.dim)
            for i, _seqs in enumerate(seqs):
                for j, seq in enumerate(_seqs):
                    if len(seq) != 0:
                        vector[i, j, :len(seq), :] = torch.cat(seq, dim=0)
                    lengths.append(len(seq))
            assert len(lengths) == N * \
                C, 'Wrong lengths - length: {} vs N: {}, C: {} vs seqs: {}'.format(len(lengths), N, C, len(seqs))
        else:
            for idx, answer_id in enumerate(answer_indices):
                if answer_id != -1:
                    if type(answer_id).__name__ == 'int':
                        seqs.append(self._encode_answer_sequence_w2v(self.index_to_answer[answer_id]))
                    else:
                        seqs.append(self._encode_answer_sequence_w2v(self.index_to_answer[answer_id.item()]))
                else:
                    seqs.append([self.unk_vector])

                if max_seq_length < len(seqs[-1]):
                    max_seq_length = len(seqs[-1])  # determing max length

            vector = torch.zeros(len(answer_indices), max_seq_length, self.vector.dim)
            for idx, seq in enumerate(seqs):
                if len(seq) != 0:
                    vector[idx, :len(seq), :] = torch.cat(seq, dim=0)
                lengths.append(len(seq))

        return vector, lengths

    def _encode_answer_vector_bert(self, answer):  # 向量求平均

        if isinstance(self.cached_answers_bert.get(answer, -1), int):
            answer_vec = torch.zeros(1, self.args.bert_dim)
            idk = self.stoi_bert.get(answer, -1)
            if idk >= 0:
                answer_vec = self.map_bert[idk]
            self.cached_answers_bert[answer] = answer_vec
        return self.cached_answers_bert[answer]

    def _encode_answer_vector_gae(self, answer):  # 向量求平均
        if isinstance(self.cached_answers_gae.get(answer, -1), int):
            answer_vec = torch.zeros(1, self.vector.dim)
            idk = self.stoi_gae.get(answer, -1)
            if idk >= 0:
                answer_vec = self.map_gae[idk].reshape(1, 300)
            self.cached_answers_gae[answer] = answer_vec
        return self.cached_answers_gae[answer]

    def _encode_answer_vector_g2v(self, answer):  # 向量求平均
        if isinstance(self.cached_answers_g2v.get(answer, -1), int):
            answer_vec = torch.zeros(1, self.vector.dim)

            idk = self.stoi_kg.get(answer, -1)
            if idk >= 0:
                answer_vec = self.map_kg[idk].reshape(1, 300)
            self.cached_answers_g2v[answer] = answer_vec
        return self.cached_answers_g2v[answer]

    def _encode_answer_vector_w2v(self, answer):  # 向量求平均
        if isinstance(self.cached_answers_w2v.get(answer, -1), int):
            tokens = nltk.word_tokenize(answer)
            answer_vec = torch.zeros(1, self.vector.dim)
            cnt = 0
            for i, token in enumerate(tokens):
                if self.vector.check(token):
                    answer_vec += self.vector[token]
                    cnt += 1
            self.cached_answers_w2v[answer] = answer_vec / (cnt + 1e-12)
            # pdb.set_trace()
        return self.cached_answers_w2v[answer]

    def _encode_answer_sequence_w2v(self, answer):
        if isinstance(self.cached_answers_w2v.get(answer, -1), int):
            tokens = nltk.word_tokenize(answer)
            answer_seq = []
            for i, token in enumerate(tokens):
                if self.vector.check(token):
                    answer_seq.append(self.vector[token].view(1, self.vector.dim))
                else:
                    answer_seq.append(self.vector['<unk>'].view(1, self.vector.dim))
            self.cached_answers_w2v[answer] = answer_seq

        return self.cached_answers_w2v[answer]

    def _encode_multihot_labels(self, answers):
        """ Turn an answer into a vector """
        max_answer_index = self.args.TEST.max_answer_index
        answer_vec = torch.zeros(max_answer_index)
        for answer in answers:
            index = self.answer_to_index.get(answer)
            if index is not None:
                if index < max_answer_index:
                    answer_vec[index] += 1
        return answer_vec

    def evaluate(self, predictions):
        raise NotImplementedError

Writing /content/drive/MyDrive/VQA/code/data/base.py


###### fvqa.py

In [None]:
%%writefile /content/drive/MyDrive/VQA/code/data/fvqa.py
import json
import os
import os.path as osp
import nltk
from collections import Counter
import torch
import torch.utils.data as data
import pdb

################
from .base import VisualQA
from .preprocess import process_punctuation


def get_loader(args, vector, train=False, val=False):
    """ Returns a data loader for the desired split """
    assert train + val == 1, 'need to set exactly one of {train, val, test} to True'  # 必须有一个为真
    id = args.FVQA.data_choice
    if train:
        filepath = "train" + id
        print("use train data:", id)
        filepath = os.path.join(args.FVQA.train_data_path, filepath)
    else:
        filepath = "test" + id
        filepath = os.path.join(args.FVQA.test_data_path, filepath)

    split = FVQA(  # 定义每一次训练的VQA输入 # ok
        args,
        path_for(args, train=train, val=val, filepath=filepath),  # train的问题
        vector,  # 对应的词向量
        file_path=filepath
    )
    batch_size = args.TRAIN.batch_size
    if val:
        batch_size = args.TEST.batch_size
    loader = torch.utils.data.DataLoader(  # 定义传统的DataLoader
        split,
        batch_size=batch_size,
        shuffle=True,  # only shuffle the data in training
        pin_memory=True,
        num_workers=args.TRAIN.data_workers,
    )

    return loader


class FVQA(VisualQA):  # ok
    """ FVQA dataset, open-ended """

    def __init__(self, args, qa_path, vector, file_path=None):
        self.args = args
        answer_vocab_path = self.args.FVQA.answer_vocab_path
        super(FVQA, self).__init__(args, vector)
        # load annotation
        with open(qa_path, 'r') as fd:
            self.qa_json = json.load(fd)

        # print('extracting answers...')

        # 把问题变成id向量+长度的表示, 答案变成id向量
        if args.fact_map:
            #  得到对应的名字
            name = "fact"
            self.answers = list(prepare_fact(self.qa_json))  # 候选答案列表的列表 [[answer1,answer2,...][....]] 每个问题对应的答案. 单词表示
        elif args.relation_map:
            name = "relation"
            self.answers = list(prepare_relation(self.qa_json))  # 候选答案列表的列表 [[answer1,answer2,...][....]] 每个问题对应的答案. 单词表示
        else:
            name = "answer"
            self.answers = list(prepare_answers(self.qa_json))  # 候选答案列表的列表 [[answer1,answer2,...][....]] 每个问题对应的答案. 单词表示

        cache_filepath = self._get_cache_path(qa_path, file_path, name)

        # self.support_relation = list(prepare_relation(self.qa_json))
        self.questions, self.answer_indices = self._qa_id_represent(cache_filepath)
        # pdb.set_trace()
        # process images 处理图片

    def open_hdf5(self):
        self.image_features_path = self.args.FVQA.feature_path
        self.image_id_to_index = self._create_image_id_to_index()  # 得到图片编号到下标的表示
        # self.image_ids = [q['image_id'] for q in questions_json['questions']]
        self.image_ids = self._get_img_id()

    def __getitem__(self, item):  # ok
        if not hasattr(self, 'image_ids'):
            self.open_hdf5()
        # if item > len(self.answers):
        #     pdb.set_trace()

        question, question_length = self.questions[item]  # 问题向量列表
        # sample answers
        # self.answer_indices[item]：[1,2,3] or [-1, -1 ...]
        # answer_cands = Counter(self.answer_indices[item])  # 单个答案 返回类型：Counter({1: 1, 2: 1, 3: 1})
        # answer_indices = list(answer_cands.keys())  # 答案有哪几个（下标）[[1,2,3]]
        # counts = list(answer_cands.values())  # 这几个答案分别出现了多少次[10]

        label = self._encode_multihot_labels(self.answers[item])  # 答案的multihot表示 前百分之多少的答案
        image_id = self.image_ids[item]
        image, spa = self._load_image(image_id)  # 直接获得图片的特征
        # unique_answers, answer_vectors = self._generate_batch_answer(answer_indices, counts)
        # answer_vectors == label
        # assert answer_vectors == label
        # return image, spa, question, unique_answers, answer_vectors, label, item, question_length
        # pdb.set_trace()
        return image, spa, question, label, item, question_length

    def _get_cache_path(self, qa_path, file_path, name):
        w2v = ""
        if "KG" in self.args.method_choice:
            if "w2v" in self.args.FVQA.entity_path:
                w2v = "(w2vinit)_" + self.args.FVQA.entity_num + "_" + self.args.FVQA.KGE
            else:
                w2v = "_" + self.args.FVQA.entity_num + "_" + self.args.FVQA.KGE
        if "train" in qa_path:
            cache_filepath = osp.join(file_path, "fvqa_" + name + "_and_question_train_" +
                                      self.args.method_choice + w2v + "_" + str(self.args.FVQA.max_ans) + ".pt")
        else:
            cache_filepath = osp.join(file_path, "fvqa_" + name + "_and_question_test_" + self.args.method_choice + w2v + "_" + str(
                self.args.FVQA.max_ans) + ".pt")
        return cache_filepath

    def _qa_id_represent(self, cache_filepath):
        if not os.path.exists(cache_filepath):
            # print('encoding questions...')
            questions = list(prepare_questions(self.qa_json))  # 问题词列表的列表
            questions = [self._encode_question(q) for q in questions]  # 把问题变成id向量+长度的表示

            # 对于候选答案列表中的每一个问题对应的候选答案列表，转换成下标表示[[1,2,3],[2,3,4]......]  1——>一个答案
            answer_indices = [[self.answer_to_index.get(_a, -1) for _a in a] for a in self.answers]  # 如果没有匹配就是 -1
            torch.save({'questions': questions, 'answer_indices': answer_indices}, cache_filepath)

        else:
            # 已经有，对应这个训练/测试集 的问题w2v表，[train 和 test是不一样的]
            _cache = torch.load(cache_filepath)
            questions = _cache['questions']  # 词向量列表 + 长度
            answer_indices = _cache['answer_indices']  # 答案下标
            # self.answer_vectors = _cache['answer_vectors']  # 答案的向量表示[平均]

        return questions, answer_indices

    def _get_img_id(self):
        image_ids = []
        keys = list(self.qa_json.keys())
        for a in keys:
            filename = self.qa_json[a]["img_file"]
            id_and_extension = filename.split('_')[-1]
            id = int(id_and_extension.split('.')[0])
            if not filename.endswith('.jpg'):
                id += 1000000  # 把jpg和jpeg的分开
                # pdb.set_trace()
            image_ids.append(id)
        return image_ids

    # def _generate_batch_answer(self, indices, counts):  # 获得每一个batch的500个候选答案。
    #     unique_answers = list(range(0, self.args.FVQA.max_ans))
    #     # unique_answers = list(set( aid for aids in indices for aid in aids ))
    #     answer_dict = {k: i for i, k in enumerate(unique_answers)}
    #     answer_vector = torch.zeros(len(indices), len(unique_answers))  # 128,500
    #
    #     for i in range(len(counts)):  # 128
    #         for j, c in zip(indices[i], counts[i]):
    #             answer_vector[i, answer_dict[j]] = c  # 把出现的次数附上
    #
    #     return unique_answers, answer_vector


def path_for(args, train=False, val=False, filepath=""):
    # tra = "all_qs_dict_release_train_" + str(args.FVQA.max_ans) + ".json"
    # tes = "all_qs_dict_release_test_" + str(args.FVQA.max_ans) + ".json"
    tra = "all_qs_dict_release_train_500.json"
    tes = "all_qs_dict_release_test_500.json"
    if train == True:
        return os.path.join(args.FVQA.train_data_path, filepath, tra)
    else:
        return os.path.join(args.FVQA.test_data_path, filepath, tes)


def prepare_questions(questions_json):  # ok
    """ Tokenize and normalize questions from a given question json in the usual VQA format. """
    keys = list(questions_json.keys())
    questions = []
    for a in keys:
        questions.append(questions_json[a]['question'])  # question的list
    for question in questions:
        question = question.lower()[:-1]
        yield nltk.word_tokenize(process_punctuation(question))  # 得到一个词的list，例如['I', 'LOVE', 'YOU']


def prepare_answers(answers_json):  # ok
    """ Normalize answers from a given answer json in the usual VQA format. """
    keys = list(answers_json.keys())
    answers = []

    for a in keys:
        answer = answers_json[a]["answer"]
        answers.append([answer] * 10)  # 双层list，内层的list对应一个问题的答案序列
    for answer_list in answers:
        ret = list(map(process_punctuation, answer_list))  # 去除标点等操作
        yield ret


def prepare_fact(answers_json):  # ok
    """ Normalize answers from a given answer json in the usual VQA format. """
    keys = list(answers_json.keys())
    support_facts = []
    for a in keys:
        answer = answers_json[a]["answer"]
        facts = answers_json[a]["fact"]
        f1 = facts[0]
        f2 = facts[2]
        if answer != f1 and answer != f2:
            pdb.set_trace()
        assert (answer == f1 or answer == f2)
        if answer == f1:
            fact = f2
        else:
            fact = f1
        support_facts.append([fact] * 10)  # 双层list，内层的list对应一个问题的答案序列
    for support_facts_list in support_facts:
        ret = list(map(process_punctuation, support_facts_list))  # 去除标点等操作
        yield ret


def prepare_relation(answers_json):  # ok
    """ Normalize answers from a given answer json in the usual VQA format. """
    keys = list(answers_json.keys())
    relations = []
    for a in keys:
        facts = answers_json[a]["fact"]
        relation = facts[1]

        relations.append([relation] * 10)  # 双层list，内层的list对应一个问题的答案序列
    for relation_list in relations:
        ret = list(map(process_punctuation, relation_list))  # 去除标点等操作
        yield ret

Writing /content/drive/MyDrive/VQA/code/data/fvqa.py


###### aokvqa.py

In [None]:
%%writefile /content/drive/MyDrive/VQA/code/data/aokvqa.py
import json
import os
import os.path as osp
import nltk
from collections import Counter
import torch
import torch.utils.data as data
import pdb

################
from .base import VisualQA
from .preprocess import process_punctuation


def get_loader(args, vector, train=False, val=False):
    """ Returns a data loader for the desired split """
    assert train + val == 1, 'need to set exactly one of {train, val, test} to True'  # 必须有一个为真
    id = args.FVQA.data_choice
    if train:
        filepath = "train" + id
        print("use train data:", id)
        filepath = os.path.join(args.FVQA.train_data_path, filepath)
    else:
        filepath = "test" + id
        filepath = os.path.join(args.FVQA.test_data_path, filepath)

    split = FVQA(  # 定义每一次训练的VQA输入 # ok
        args,
        path_for(args, train=train, val=val, filepath=filepath),  # train的问题
        vector,  # 对应的词向量
        file_path=filepath
    )
    batch_size = args.TRAIN.batch_size
    if val:
        batch_size = args.TEST.batch_size
    loader = torch.utils.data.DataLoader(  # 定义传统的DataLoader
        split,
        batch_size=batch_size,
        shuffle=True,  # only shuffle the data in training
        pin_memory=True,
        num_workers=args.TRAIN.data_workers,
    )

    return loader


class FVQA(VisualQA):  # ok
    """ FVQA dataset, open-ended """

    def __init__(self, args, qa_path, vector, file_path=None):
        self.args = args
        answer_vocab_path = self.args.FVQA.answer_vocab_path
        super(FVQA, self).__init__(args, vector)
        # load annotation
        with open(qa_path, 'r') as fd:
            self.qa_json = json.load(fd)

        # print('extracting answers...')

        # 把问题变成id向量+长度的表示, 答案变成id向量
        if args.fact_map:
            #  得到对应的名字
            name = "fact"
            self.answers = list(prepare_fact(self.qa_json))  # 候选答案列表的列表 [[answer1,answer2,...][....]] 每个问题对应的答案. 单词表示
        elif args.relation_map:
            name = "relation"
            self.answers = list(prepare_relation(self.qa_json))  # 候选答案列表的列表 [[answer1,answer2,...][....]] 每个问题对应的答案. 单词表示
        else:
            name = "answer"
            self.answers = list(prepare_answers(self.qa_json))  # 候选答案列表的列表 [[answer1,answer2,...][....]] 每个问题对应的答案. 单词表示

        cache_filepath = self._get_cache_path(qa_path, file_path, name)

        # self.support_relation = list(prepare_relation(self.qa_json))
        self.questions, self.answer_indices = self._qa_id_represent(cache_filepath)
        # pdb.set_trace()
        # process images 处理图片

    def open_hdf5(self):
        self.image_features_path = self.args.FVQA.feature_path
        self.image_id_to_index = self._create_image_id_to_index()  # 得到图片编号到下标的表示
        # self.image_ids = [q['image_id'] for q in questions_json['questions']]
        self.image_ids = self._get_img_id()

    def __getitem__(self, item):  # ok
        if not hasattr(self, 'image_ids'):
            self.open_hdf5()
        # if item > len(self.answers):
        #     pdb.set_trace()

        question, question_length = self.questions[item]  # 问题向量列表
        # sample answers
        # self.answer_indices[item]：[1,2,3] or [-1, -1 ...]
        # answer_cands = Counter(self.answer_indices[item])  # 单个答案 返回类型：Counter({1: 1, 2: 1, 3: 1})
        # answer_indices = list(answer_cands.keys())  # 答案有哪几个（下标）[[1,2,3]]
        # counts = list(answer_cands.values())  # 这几个答案分别出现了多少次[10]

        label = self._encode_multihot_labels(self.answers[item])  # 答案的multihot表示 前百分之多少的答案
        image_id = self.image_ids[item]
        image, spa = self._load_image(image_id)  # 直接获得图片的特征
        # unique_answers, answer_vectors = self._generate_batch_answer(answer_indices, counts)
        # answer_vectors == label
        # assert answer_vectors == label
        # return image, spa, question, unique_answers, answer_vectors, label, item, question_length
        # pdb.set_trace()
        return image, spa, question, label, item, question_length

    def _get_cache_path(self, qa_path, file_path, name):
        w2v = ""
        if "KG" in self.args.method_choice:
            if "w2v" in self.args.FVQA.entity_path:
                w2v = "(w2vinit)_" + self.args.FVQA.entity_num + "_" + self.args.FVQA.KGE
            else:
                w2v = "_" + self.args.FVQA.entity_num + "_" + self.args.FVQA.KGE
        if "train" in qa_path:
            cache_filepath = osp.join(file_path, "fvqa_" + name + "_and_question_train_" +
                                      self.args.method_choice + w2v + "_" + str(self.args.FVQA.max_ans) + ".pt")
        else:
            cache_filepath = osp.join(file_path, "fvqa_" + name + "_and_question_test_" + self.args.method_choice + w2v + "_" + str(
                self.args.FVQA.max_ans) + ".pt")
        return cache_filepath

    def _qa_id_represent(self, cache_filepath):
        if not os.path.exists(cache_filepath):
            # print('encoding questions...')
            questions = list(prepare_questions(self.qa_json))  # 问题词列表的列表
            questions = [self._encode_question(q) for q in questions]  # 把问题变成id向量+长度的表示

            # 对于候选答案列表中的每一个问题对应的候选答案列表，转换成下标表示[[1,2,3],[2,3,4]......]  1——>一个答案
            answer_indices = [[self.answer_to_index.get(_a, -1) for _a in a] for a in self.answers]  # 如果没有匹配就是 -1
            torch.save({'questions': questions, 'answer_indices': answer_indices}, cache_filepath)

        else:
            # 已经有，对应这个训练/测试集 的问题w2v表，[train 和 test是不一样的]
            _cache = torch.load(cache_filepath)
            questions = _cache['questions']  # 词向量列表 + 长度
            answer_indices = _cache['answer_indices']  # 答案下标
            # self.answer_vectors = _cache['answer_vectors']  # 答案的向量表示[平均]

        return questions, answer_indices

    def _get_img_id(self):
        image_ids = []
        keys = list(self.qa_json.keys())
        for a in keys:
            filename = self.qa_json[a]["img_file"]
            id_and_extension = filename.split('_')[-1]
            id = int(id_and_extension.split('.')[0])
            if not filename.endswith('.jpg'):
                id += 1000000  # 把jpg和jpeg的分开
                # pdb.set_trace()
            image_ids.append(id)
        return image_ids

    # def _generate_batch_answer(self, indices, counts):  # 获得每一个batch的500个候选答案。
    #     unique_answers = list(range(0, self.args.FVQA.max_ans))
    #     # unique_answers = list(set( aid for aids in indices for aid in aids ))
    #     answer_dict = {k: i for i, k in enumerate(unique_answers)}
    #     answer_vector = torch.zeros(len(indices), len(unique_answers))  # 128,500
    #
    #     for i in range(len(counts)):  # 128
    #         for j, c in zip(indices[i], counts[i]):
    #             answer_vector[i, answer_dict[j]] = c  # 把出现的次数附上
    #
    #     return unique_answers, answer_vector


def path_for(args, train=False, val=False, filepath=""):
    # tra = "all_qs_dict_release_train_" + str(args.FVQA.max_ans) + ".json"
    # tes = "all_qs_dict_release_test_" + str(args.FVQA.max_ans) + ".json"
    tra = "all_qs_dict_release_train_500.json"
    tes = "all_qs_dict_release_test_500.json"
    if train == True:
        return os.path.join(args.FVQA.train_data_path, filepath, tra)
    else:
        return os.path.join(args.FVQA.test_data_path, filepath, tes)


def prepare_questions(questions_json):  # ok
    """ Tokenize and normalize questions from a given question json in the usual VQA format. """
    keys = list(questions_json.keys())
    questions = []
    for a in keys:
        questions.append(questions_json[a]['question'])  # question的list
    for question in questions:
        question = question.lower()[:-1]
        yield nltk.word_tokenize(process_punctuation(question))  # 得到一个词的list，例如['I', 'LOVE', 'YOU']


def prepare_answers(answers_json):  # ok
    """ Normalize answers from a given answer json in the usual VQA format. """
    keys = list(answers_json.keys())
    answers = []

    for a in keys:
        answer = answers_json[a]["answer"]
        answers.append([answer] * 10)  # 双层list，内层的list对应一个问题的答案序列
    for answer_list in answers:
        ret = list(map(process_punctuation, answer_list))  # 去除标点等操作
        yield ret


def prepare_fact(answers_json):  # ok
    """ Normalize answers from a given answer json in the usual VQA format. """
    keys = list(answers_json.keys())
    support_facts = []
    for a in keys:
        answer = answers_json[a]["answer"]
        facts = answers_json[a]["fact"]
        f1 = facts[0]
        f2 = facts[2]
        if answer != f1 and answer != f2:
            pdb.set_trace()
        assert (answer == f1 or answer == f2)
        if answer == f1:
            fact = f2
        else:
            fact = f1
        support_facts.append([fact] * 10)  # 双层list，内层的list对应一个问题的答案序列
    for support_facts_list in support_facts:
        ret = list(map(process_punctuation, support_facts_list))  # 去除标点等操作
        yield ret


def prepare_relation(answers_json):  # ok
    """ Normalize answers from a given answer json in the usual VQA format. """
    keys = list(answers_json.keys())
    relations = []
    for a in keys:
        facts = answers_json[a]["fact"]
        relation = facts[1]

        relations.append([relation] * 10)  # 双层list，内层的list对应一个问题的答案序列
    for relation_list in relations:
        ret = list(map(process_punctuation, relation_list))  # 去除标点等操作
        yield ret

Writing /content/drive/MyDrive/VQA/code/data/aokvqa.py


## Model


###### \_\_init__.py

In [None]:
%%writefile /content/drive/MyDrive/VQA/code/model/__init__.py
from .attention import BiAttention
from .classifier import SimpleClassifier
from .counting import Counter
from .fc import FCNet, BCNet

Writing /content/drive/MyDrive/VQA/code/model/__init__.py


###### fc.py

In [None]:
%%writefile /content/drive/MyDrive/VQA/code/model/fc.py
from __future__ import print_function
import torch.nn as nn
from torch.nn.utils.weight_norm import weight_norm
import torch


class FCNet(nn.Module):
    """Simple class for non-linear fully connect network
    """

    def __init__(self, dims, act='ReLU', dropout=0):
        super(FCNet, self).__init__()

        layers = []
        for i in range(len(dims) - 2):
            in_dim = dims[i]
            out_dim = dims[i + 1]
            if 0 < dropout:
                layers.append(nn.Dropout(dropout))
            layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
            if '' != act:
                layers.append(getattr(nn, act)())
        if 0 < dropout:
            layers.append(nn.Dropout(dropout))
        layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None))
        if '' != act:
            layers.append(getattr(nn, act)())

        self.main = nn.Sequential(*layers)

    def forward(self, x):
        return self.main(x)


class BCNet(nn.Module):
    """Simple class for non-linear bilinear connect network
    """

    def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=[.2, .5], k=3):
        super(BCNet, self).__init__()

        self.c = 32
        self.k = k
        self.v_dim = v_dim
        self.q_dim = q_dim
        self.h_dim = h_dim
        self.h_out = h_out

        self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout[0])
        self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout[0])
        self.dropout = nn.Dropout(dropout[1])  # attention
        if 1 < k:
            self.p_net = nn.AvgPool1d(self.k, stride=self.k)

        if None == h_out:
            pass
        elif h_out <= self.c:
            self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_())
            self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_())
        else:
            self.h_net = weight_norm(nn.Linear(h_dim, h_out), dim=None)

    def forward(self, v, q):
        if None == self.h_out:
            v_ = self.v_net(v).transpose(1, 2).unsqueeze(3)
            q_ = self.q_net(q).transpose(1, 2).unsqueeze(2)
            d_ = torch.matmul(v_, q_)  # b x h_dim x v x q
            logits = d_.transpose(1, 2).transpose(2, 3)  # b x v x q x h_dim
            return logits

        # broadcast Hadamard product, matrix-matrix production
        # fast computation but memory inefficient
        # epoch 1, time: 157.84
        elif self.h_out <= self.c:
            v_ = self.dropout(self.v_net(v)).unsqueeze(1)
            q_ = self.q_net(q)
            h_ = v_ * self.h_mat  # broadcast, b x h_out x v x h_dim
            logits = torch.matmul(h_, q_.unsqueeze(1).transpose(2, 3))  # b x h_out x v x q
            logits = logits + self.h_bias
            return logits  # b x h_out x v x q

        # batch outer product, linear projection
        # memory efficient but slow computation
        # epoch 1, time: 304.87
        else:
            v_ = self.dropout(self.v_net(v)).transpose(1, 2).unsqueeze(3)
            q_ = self.q_net(q).transpose(1, 2).unsqueeze(2)
            d_ = torch.matmul(v_, q_)  # b x h_dim x v x q
            logits = self.h_net(d_.transpose(1, 2).transpose(2, 3))  # b x v x q x h_out
            return logits.transpose(2, 3).transpose(1, 2)  # b x h_out x v x q

    def forward_with_weights(self, v, q, w):
        v_ = self.v_net(v).transpose(1, 2).unsqueeze(2)  # b x d x 1 x v
        q_ = self.q_net(q).transpose(1, 2).unsqueeze(3)  # b x d x q x 1
        logits = torch.matmul(torch.matmul(v_, w.unsqueeze(1)), q_)  # b x d x 1 x 1
        logits = logits.squeeze(3).squeeze(2)
        if 1 < self.k:
            logits = logits.unsqueeze(1)  # b x 1 x d
            logits = self.p_net(logits).squeeze(1) * self.k  # sum-pooling
        return logits


class GroupMLP(nn.Module):
    def __init__(self, in_features, mid_features, out_features, drop=0.5, groups=1):
        super(GroupMLP, self).__init__()

        self.conv1 = nn.Conv1d(in_features, mid_features, 1)
        self.drop = nn.Dropout(p=drop)
        self.relu = nn.LeakyReLU()
        self.conv2 = nn.Conv1d(mid_features, out_features, 1, groups=groups)

    def forward(self, a):
        N, C = a.size()
        h = self.relu(self.conv1(a.view(N, C, 1)))
        return self.conv2(self.drop(h)).view(N, -1)


class GroupMLP_1lay(nn.Module):
    def __init__(self, in_features, mid_features, out_features, drop=0.5, groups=1):
        super(GroupMLP_1lay, self).__init__()

        self.conv1 = nn.Conv1d(in_features, mid_features, 1)
        self.batch_norm_fusion = nn.BatchNorm1d(mid_features, affine=False)
        self.drop = nn.Dropout(p=drop)
        self.relu = nn.LeakyReLU()
        self.conv2 = nn.Conv1d(mid_features, out_features, 1, groups=groups)

    def forward(self, a):
        N, C = a.size()
        h = self.conv1(a.view(N, C, 1))
        h = self.batch_norm_fusion(h)
        h = self.relu(h)
        return self.conv2(self.drop(h)).view(N, -1)


class GroupMLP_2lay(nn.Module):
    def __init__(self, in_features, mid_features, out_features, drop=0.5, groups=1):
        super(GroupMLP_2lay, self).__init__()

        self.conv1 = nn.Conv1d(in_features, mid_features, 1)
        self.batch_norm_fusion = nn.BatchNorm1d(mid_features, affine=False)
        self.drop = nn.Dropout(p=drop)
        self.relu = nn.LeakyReLU()
        self.conv2 = nn.Conv1d(mid_features, mid_features, 1, groups=groups)
        self.conv3 = nn.Conv1d(mid_features, out_features, 1, groups=groups)

    def forward(self, a):
        N, C = a.size()
        h = self.conv1(a.view(N, C, 1))
        h = self.relu(h)
        h = self.conv2(h)
        h = self.batch_norm_fusion(h)
        h = self.relu(h)
        return self.conv3(self.drop(h)).view(N, -1)

Writing /content/drive/MyDrive/VQA/code/model/fc.py


###### vector.py

In [None]:
%%writefile /content/drive/MyDrive/VQA/code/model/vector.py
import array
import zipfile
from tqdm import tqdm
from six.moves.urllib.request import urlretrieve
import os
import os.path as osp
import torch
import io

class Vector(object):
    def __init__(self, cache_path,
                 vector_type='glove.840B', unk_init=torch.Tensor.zero_) -> object:
        urls = {
            'glove.42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip',
            'glove.840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip',
            'glove.6B': 'http://nlp.stanford.edu/data/glove.6B.zip',
        }
        url = urls[vector_type] if urls.get(vector_type, False) != False else None
        name = osp.splitext(osp.basename(url))[0] + '.txt'  # glove.840B.300d.txt

        self.unk_init = unk_init
        self.cache(name, cache_path, url=url)

    def __getitem__(self, token):
        if self.stoi.get(token, -1) != -1:
            return self.vectors[self.stoi[token]]
        else:
            return self.unk_init(torch.Tensor(1, self.dim))

    def _prepare(self, vocab):
        word2vec = torch.Tensor(len(vocab), self.dim)
        for token, idx in vocab.items():
            word2vec[idx, :] = self[token]

        return word2vec

    def check(self, token):
        if self.stoi.get(token, -1) != -1:
            return True
        else:
            return False

    def cache(self, name, cache_path, url=None):
        # cache_path='.vector_cache',
        #name= "glove.840B.300d.txt"
        #url = 'http://nlp.stanford.edu/data/glove.840B.300d.zip'

        path = osp.join(cache_path, name)
        path_pt = "{}.pt".format(path)

        if not osp.isfile(path_pt):
            # download vocab file if it does not exists
            if not osp.exists(path) and url:
                dest = osp.join(cache_path, os.path.basename(url))
                if not osp.exists(dest):
                    print('[-] Downloading vectors from {}'.format(url))
                    if not osp.exists(cache_path):
                        os.mkdir(cache_path)

                    with tqdm(unit='B', unit_scale=True, miniters=1, desc=dest) as t:
                        urlretrieve(url, dest, reporthook=reporthook(t))

                print('[-] Extracting vectors into {}'.format(path))
                ext = os.path.splitext(dest)[1][1:]
                if ext == 'zip':
                    with zipfile.ZipFile(dest, "r") as zf:
                        zf.extractall(cache_path)

            if not os.path.isfile(path):
                raise RuntimeError('no vectors found at {}'.format(path))

            # build vocab list
            itos, vectors, dim = [], array.array(str('d')), None

            # Try to read the whole file with utf-8 encoding.
            binary_lines = False
            try:
                with io.open(path, encoding="utf8") as f:
                    lines = [line for line in f]
            # If there are malformed lines, read in binary mode
            # and manually decode each word from utf-8
            except:
                print("[!] Could not read {} as UTF8 file, "
                      "reading file as bytes and skipping "
                      "words with malformed UTF8.".format(path))
                with open(path, 'rb') as f:
                    lines = [line for line in f]
                binary_lines = True

            print("[-] Loading vectors from {}".format(path))  # 读取vector
            for line in tqdm(lines, total=len(lines)):
                # Explicitly splitting on " " is important, so we don't
                # get rid of Unicode non-breaking spaces in the vectors.
                entries = line.rstrip().split(" ")
                word, entries = entries[0], entries[1:]
                if dim is None and len(entries) > 1:
                    dim = len(entries)
                elif len(entries) == 1:
                    print("Skipping token {} with 1-dimensional "
                          "vector {}; likely a header".format(word, entries))
                    continue
                elif dim != len(entries):
                    raise RuntimeError(
                        "Vector for token {} has {} dimensions, but previously "
                        "read vectors have {} dimensions. All vectors must have "
                        "the same number of dimensions.".format(word, len(entries), dim))

                vectors.extend(float(x) for x in entries)
                itos.append(word)

            self.itos = itos
            self.stoi = {word: i for i, word in enumerate(itos)}
            self.vectors = torch.Tensor(vectors).view(-1, dim)
            self.dim = dim
            print('* Caching vectors to {}'.format(path_pt))
            torch.save((self.itos, self.stoi, self.vectors, self.dim), path_pt)
        else:
            print('* Loading vectors to {}'.format(path_pt))
            self.itos, self.stoi, self.vectors, self.dim = torch.load(path_pt)

Overwriting /content/drive/MyDrive/VQA/code/model/vector.py


###### attetion.py

In [None]:
%%writefile /content/drive/MyDrive/VQA/code/model/attention.py
import torch
import torch.nn as nn
from torch.nn.utils.weight_norm import weight_norm
from .fc import FCNet, BCNet
import torch.nn.functional as F

class BaseAttention(nn.Module):
    def __init__(self, v_dim, q_dim, num_hid):
        super(BaseAttention, self).__init__()
        self.nonlinear = FCNet([v_dim + q_dim, num_hid])
        self.linear = weight_norm(nn.Linear(num_hid, 1), dim=None)

    def forward(self, v, q):
        """
        v: [batch, k, vdim]
        q: [batch, qdim]
        """
        logits = self.logits(v, q)
        w = nn.functional.softmax(logits, 1)
        return w

    def logits(self, v, q):
        num_objs = v.size(1)
        q = q.unsqueeze(1).repeat(1, num_objs, 1)
        vq = torch.cat((v, q), 2)
        joint_repr = self.nonlinear(vq)
        logits = self.linear(joint_repr)
        return logits


class UpDnAttention(nn.Module):
    def __init__(self, v_dim, q_dim, num_hid, dropout=0.2):
        super(UpDnAttention, self).__init__()

        self.v_proj = FCNet([v_dim, num_hid])
        self.q_proj = FCNet([q_dim, num_hid])
        self.dropout = nn.Dropout(dropout)
        self.linear = weight_norm(nn.Linear(num_hid, 1), dim=None)

    def forward(self, v, q):
        """
        v: [batch, k, vdim]
        q: [batch, qdim]
        """
        logits = self.logits(v, q)
        w = nn.functional.softmax(logits, 1)
        return w

    def logits(self, v, q):
        batch, k, _ = v.size()
        v_proj = self.v_proj(v)  # [batch, k, qdim]
        q_proj = self.q_proj(q).unsqueeze(1).repeat(1, k, 1)
        joint_repr = v_proj * q_proj
        joint_repr = self.dropout(joint_repr)
        logits = self.linear(joint_repr)
        return logits

class SanAttention(nn.Module):
  def __init__(self, v_features, q_features, mid_features, glimpses, drop=0.0):
    super(SanAttention, self).__init__()
    self.v_conv = nn.Conv2d(v_features, mid_features, 1, bias=False)  # let self.lin take care of bias
    self.q_lin = nn.Linear(q_features, mid_features)
    self.x_conv = nn.Conv2d(mid_features, glimpses, 1)

    self.drop = nn.Dropout(drop)
    self.relu = nn.LeakyReLU(inplace=True)

  def forward(self, v, q):
    v = self.v_conv(self.drop(v))
    q = self.q_lin(self.drop(q))
    q = tile_2d_over_nd(q, v)
    x = self.relu(v + q)
    x = self.x_conv(self.drop(x))
    return x

def tile_2d_over_nd(feature_vector, feature_map):
  """ Repeat the same feature vector over all spatial positions of a given feature map.
    The feature vector should have the same batch size and number of features as the feature map.
  """
  n, c = feature_vector.size()
  spatial_size = feature_map.dim() - 2
  tiled = feature_vector.view(n, c, *([1] * spatial_size)).expand_as(feature_map)
  return tiled

def apply_attention(input, attention):
  """ Apply any number of attention maps over the input.
    The attention map has to have the same size in all dimensions except dim=1.
  """
  # import pdb
  # pdb.set_trace()
  n, c = input.size()[:2]
  glimpses = attention.size(1)

  # flatten the spatial dims into the third dim, since we don't need to care about how they are arranged
  input = input.view(n, c, -1)
  attention = attention.view(n, glimpses, -1)
  s = input.size(2)

  # apply a softmax to each attention map separately
  # since softmax only takes 2d inputs, we have to collapse the first two dimensions together
  # so that each glimpse is normalized separately
  attention = attention.view(n * glimpses, -1)
  attention = F.softmax(attention)

  # apply the weighting by creating a new dim to tile both tensors over
  target_size = [n, glimpses, c, s]
  input = input.view(n, 1, c, s).expand(*target_size)
  attention = attention.view(n, glimpses, 1, s).expand(*target_size)
  weighted = input * attention
  # sum over only the spatial dimension
  weighted_mean = weighted.sum(dim=3)
  # the shape at this point is (n, glimpses, c, 1)
  return weighted_mean.view(n, -1)


class BiAttention(nn.Module):
    def __init__(self, x_dim, y_dim, z_dim, glimpse, dropout=[.2, .5]):
        super(BiAttention, self).__init__()

        self.glimpse = glimpse
        self.logits = weight_norm(BCNet(x_dim, y_dim, z_dim, glimpse, dropout=dropout, k=3),
                                  name='h_mat', dim=None)

    def forward(self, v, q, v_mask=True):
        """
        v: [batch, k, vdim]
        q: [batch, qdim]
        """
        p, logits = self.forward_all(v, q, v_mask)
        return p, logits

    def forward_all(self, v, q, v_mask=True):
        v_num = v.size(1)
        q_num = q.size(1)
        logits = self.logits(v, q)  # b x g x v x q

        if v_mask:
            mask = (0 == v.abs().sum(2)).unsqueeze(1).unsqueeze(3).expand(logits.size())
            logits.data.masked_fill_(mask.data, -float('inf'))

        p = nn.functional.softmax(logits.view(-1, self.glimpse, v_num * q_num), 2)
        return p.view(-1, self.glimpse, v_num, q_num), logits

Overwriting /content/drive/MyDrive/VQA/code/model/attention.py


###### language_model.py

In [None]:
%%writefile /content/drive/MyDrive/VQA/code/model/language_model.py
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from torch.nn.utils.rnn import pack_padded_sequence
import torch.nn.init as init
import pdb


class WordEmbedding(nn.Module):
    """Word Embedding

    The ntoken-th dim is used for padding_idx, which agrees *implicitly*
    with the definition in Dictionary.
    """

    def __init__(self, ntoken, emb_dim, dropout=0):
        super(WordEmbedding, self).__init__()
        self.emb = nn.Embedding(ntoken + 1, emb_dim, padding_idx=ntoken)
        self.dropout = nn.Dropout(dropout)
        self.ntoken = ntoken
        self.emb_dim = emb_dim

    def init_embedding(self, np_file):
        # weight_init = torch.from_numpy(np.load(np_file))
        weight_init = np_file
        assert weight_init.shape == (self.ntoken, self.emb_dim)
        self.emb.weight.data[:self.ntoken] = weight_init

    def forward(self, x):
        emb = self.emb(x)
        emb = self.dropout(emb)
        return emb


class UpDnQuestionEmbedding(nn.Module):
    def __init__(self, in_dim, num_hid, nlayers, bidirect, dropout=0, rnn_type='GRU'):
        """Module for question embedding
        """
        super(UpDnQuestionEmbedding, self).__init__()
        assert rnn_type == 'LSTM' or rnn_type == 'GRU'
        rnn_cls = nn.LSTM if rnn_type == 'LSTM' else nn.GRU

        self.rnn = rnn_cls(
            in_dim, num_hid, nlayers,
            bidirectional=bidirect,
            dropout=dropout,
            batch_first=True)

        self.in_dim = in_dim
        self.num_hid = num_hid
        self.nlayers = nlayers
        self.rnn_type = rnn_type
        self.ndirections = 1 + int(bidirect)

    def init_hidden(self, batch):
        # just to get the type of tensor
        weight = next(self.parameters()).data
        hid_shape = (self.nlayers * self.ndirections, batch, self.num_hid)
        if self.rnn_type == 'LSTM':
            return (Variable(weight.new(*hid_shape).zero_()),
                    Variable(weight.new(*hid_shape).zero_()))
        else:
            return Variable(weight.new(*hid_shape).zero_())

    def forward(self, x):
        # x: [batch, sequence, in_dim]
        batch = x.size(0)
        hidden = self.init_hidden(batch)
        self.rnn.flatten_parameters()
        output, hidden = self.rnn(x, hidden)

        if self.ndirections == 1:
            return output[:, -1]

        forward_ = output[:, -1, :self.num_hid]
        backward = output[:, 0, self.num_hid:]
        return torch.cat((forward_, backward), dim=1)

    def forward_all(self, x):
        # x: [batch, sequence, in_dim]
        batch = x.size(0)
        hidden = self.init_hidden(batch)
        self.rnn.flatten_parameters()
        output, hidden = self.rnn(x, hidden)
        return output


class QuestionEmbedding(nn.Module):
    def __init__(self, in_dim, num_hid, nlayers=1, bidirect=True, dropout=0, rnn_type='GRU', words_dropout=None,
                 dropout_before_rnn=None,
                 dropout_after_rnn=None):
        """Module for question embedding
        """
        super(QuestionEmbedding, self).__init__()
        assert rnn_type == 'LSTM' or rnn_type == 'GRU'
        rnn_cls = nn.LSTM if rnn_type == 'LSTM' else nn.GRU
        self.bidirect = bidirect
        self.ndirections = 1 + int(bidirect)
        if bidirect:
            num_hid = int(num_hid / 2)
        self.words_dropout = words_dropout
        if dropout_before_rnn is not None:
            self.dropout_before_rnn = nn.Dropout(p=dropout_before_rnn)
        else:
            self.dropout_before_rnn = None
        self.rnn = rnn_cls(
            in_dim, num_hid, nlayers,
            bidirectional=bidirect,
            dropout=dropout,
            batch_first=True)
        if dropout_after_rnn is not None:
            self.dropout_after_rnn = nn.Dropout(p=dropout_after_rnn)
        else:
            self.dropout_after_rnn = None

        self.in_dim = in_dim
        self.num_hid = num_hid
        self.nlayers = nlayers
        self.rnn_type = rnn_type

    def init_hidden(self, batch):
        # just to get the type of tensor
        weight = next(self.parameters()).data
        hid_shape = (self.nlayers * self.ndirections, batch, self.num_hid)
        if self.rnn_type == 'LSTM':
            return (Variable(weight.new(*hid_shape).zero_()),
                    Variable(weight.new(*hid_shape).zero_()))
        else:
            return Variable(weight.new(*hid_shape).zero_())

    def forward(self, x, qlen=None):
        # x: [batch, sequence, in_dim]
        batch = x.size(0)
        num_tokens = x.size(1)
        if self.words_dropout is not None and self.words_dropout > 0:
            num_dropout = int(self.words_dropout * num_tokens)
            rand_ixs = np.random.randint(0, num_tokens, (batch, num_dropout))
            for bix, token_ixs in enumerate(rand_ixs):
                x[bix, token_ixs] *= 0
        hidden = self.init_hidden(batch)
        self.rnn.flatten_parameters()
        if self.dropout_before_rnn is not None:
            x = self.dropout_before_rnn(x)

        q_words_emb, hidden = self.rnn(x, hidden)  # q_words_emb: B x num_words x gru_dim, hidden: 1 x B x gru_dim

        out = None
        if self.bidirect:
            forward_ = q_words_emb[:, -1, :self.num_hid]
            backward = q_words_emb[:, 0, self.num_hid:]
            hid = torch.cat((forward_, backward), dim=1)
            out = hid
            # return q_words_emb, hid
        else:
            out = q_words_emb[:, -1]
            # return q_words_emb, q_words_emb[:, -1]

        if self.dropout_after_rnn is not None:
            out = self.dropout_after_rnn(out)
        return out

class Seq2SeqRNN(nn.Module):
  def __init__(self, input_features, rnn_features, num_layers=1, drop=0.0,
               rnn_type='LSTM', rnn_bidirectional=False):
    super(Seq2SeqRNN, self).__init__()
    self.bidirectional = rnn_bidirectional

    if rnn_type == 'LSTM':
      self.rnn = nn.LSTM(input_size=input_features,
                hidden_size=rnn_features, dropout=drop,
                num_layers=num_layers, batch_first=True,
                bidirectional=rnn_bidirectional)
    elif rnn_type == 'GRU':
      self.rnn = nn.GRU(input_size=input_features,
                hidden_size=rnn_features, dropout=drop,
                num_layers=num_layers, batch_first=True,
                bidirectional=rnn_bidirectional)
    else:
      raise ValueError('Unsupported Type')

    self.init_weight(rnn_bidirectional, rnn_type)

  def init_weight(self, bidirectional, rnn_type):
    self._init_rnn(self.rnn.weight_ih_l0, rnn_type)
    self._init_rnn(self.rnn.weight_hh_l0, rnn_type)
    self.rnn.bias_ih_l0.data.zero_()
    self.rnn.bias_hh_l0.data.zero_()

    if bidirectional:
      self._init_rnn(self.rnn.weight_ih_l0_reverse, rnn_type)
      self._init_rnn(self.rnn.weight_hh_l0_reverse, rnn_type)
      self.rnn.bias_ih_l0_reverse.data.zero_()
      self.rnn.bias_hh_l0_reverse.data.zero_()

  def _init_rnn(self, weight, rnn_type):
    chunk_size = 4 if rnn_type == 'LSTM' else 3
    for w in weight.chunk(chunk_size, 0):
      init.xavier_uniform(w)

  def forward(self, q_emb, q_len):
    lengths = torch.LongTensor(q_len)
    lens, indices = torch.sort(lengths, 0, True)

    packed = pack_padded_sequence(q_emb[indices.cuda()], lens.tolist(), batch_first=True)
    if isinstance(self.rnn, nn.LSTM):
        # pdb.set_trace()
        _, ( outputs, _ ) = self.rnn(packed)
    elif isinstance(self.rnn, nn.GRU):
        _, outputs = self.rnn(packed)

    if self.bidirectional:
      outputs = torch.cat([ outputs[0, :, :], outputs[1, :, :] ], dim=1)
    else:
      outputs = outputs.squeeze(0)

    _, _indices = torch.sort(indices, 0)
    outputs = outputs[_indices.cuda()]

    return outputs

Overwriting /content/drive/MyDrive/VQA/code/model/language_model.py


#### Fusion Network


###### \_\_init__.py

In [3]:
%%writefile /content/drive/MyDrive/VQA/code/model/fusion_net/__init__.py
from .updn import UD
from .ban import BAN
from .san import SAN
from .mlp import MLP

Overwriting /content/drive/MyDrive/VQA/code/model/fusion_net/__init__.py


###### mlp.py

In [7]:
%%writefile /content/drive/MyDrive/VQA/code/model/fusion_net/mlp.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable

from ..fc import GroupMLP
from ..language_model import WordEmbedding

from utils import freeze_layer


class MLP(nn.Module):
    #args, self.train_loader.dataset, self.question_word2vec
    # def __init__(self, args, dataset, question_word2vec):
    def __init__(self, args, dataset, embedding_weights=None, rnn_bidirectional=True):
        super(MLP, self).__init__()
        embedding_requires_grad = not args.freeze_w2v  # freeze 则不需要grad
        question_features = 300
        vision_features = args.output_features  # 图片的

        # self.text = BagOfWordsMLPProcessor(
        self.text = BagOfWordsProcessor(
            embedding_tokens=embedding_weights.size(0) if embedding_weights is not None else dataset.num_tokens,
            embedding_weights=embedding_weights,
            embedding_features=300,
            embedding_requires_grad=embedding_requires_grad,
        )
        self.mlp = GroupMLP(
            in_features=vision_features + question_features,
            mid_features= 4 * args.hidden_size,
            out_features=args.embedding_size,
            drop=0.5,
            groups=64,
        )

        for m in self.modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
                init.xavier_uniform(m.weight)
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, v, b, q, q_len):
        q = F.normalize(self.text(q, list(q_len.data)), p=2, dim=1)  # 问题向量求平均值
        v = F.normalize(F.avg_pool2d(v, (v.size(2), v.size(3))).squeeze(), p=2, dim=1)

        combined = torch.cat([v, q], dim=1)
        embedding = self.mlp(combined)
        return embedding


class BagOfWordsProcessor(nn.Module):
    def __init__(self, embedding_tokens, embedding_features,
                 embedding_weights, embedding_requires_grad):
        super(BagOfWordsProcessor, self).__init__()
        self.embedding = nn.Embedding(embedding_tokens, embedding_features, padding_idx=0)
        self.embedding.weight.data = embedding_weights
        self.embedding.weight.requires_grad = embedding_requires_grad

    def forward(self, q, q_len):
        embedded = self.embedding(q)
        q_len = Variable(torch.Tensor(q_len).view(-1, 1) + 1e-12, requires_grad=False).cuda()

        return torch.div(torch.sum(embedded, 1), q_len)

Overwriting /content/drive/MyDrive/VQA/code/model/fusion_net/mlp.py


###### ban.py

In [6]:
%%writefile /content/drive/MyDrive/VQA/code/model/fusion_net/ban.py
"""
Bilinear Attention Networks
Jin-Hwa Kim, Jaehyun Jun, Byoung-Tak Zhang
https://arxiv.org/abs/1805.07932

This code is adapted from: https://github.com/jnhwkim/ban-vqa (written by Jin-Hwa Kim)
"""
import torch.nn as nn

from ..attention import BiAttention
from ..classifier import SimpleClassifier
from ..counting import Counter
from ..fc import FCNet, BCNet
from ..language_model import WordEmbedding, UpDnQuestionEmbedding

from utils import freeze_layer


class BAN(nn.Module):
    #args, self.train_loader.dataset, self.question_word2vec
    # def __init__(self, args, dataset, question_word2vec):
    def __init__(self, args, dataset, question_word2vec):
        super(BAN, self).__init__()
        self.args = args
        self.w_emb = WordEmbedding(question_word2vec.size(0), 300, .0)
        if args.freeze_w2v:
            self.w_emb.init_embedding(question_word2vec)
            freeze_layer(self.w_emb)
        self.q_emb = UpDnQuestionEmbedding(300, args.embedding_size, 1, False, .0)
        self.v_att = BiAttention(args.v_dim, self.q_emb.num_hid, self.q_emb.num_hid, args.glimpse)
        self.b_net = []
        self.q_prj = []
        self.c_prj = []
        self.objects = 10  # minimum number of boxes
        for i in range(args.glimpse):
            self.b_net.append(BCNet(args.v_dim, self.q_emb.num_hid, self.q_emb.num_hid, None, k=1))
            self.q_prj.append(FCNet([self.q_emb.num_hid, self.q_emb.num_hid], '', .2))
            self.c_prj.append(FCNet([self.objects + 1, self.q_emb.num_hid], 'ReLU', .0))

        self.b_net = nn.ModuleList(self.b_net)
        self.q_prj = nn.ModuleList(self.q_prj)
        self.c_prj = nn.ModuleList(self.c_prj)
        self.counter = Counter(self.objects)
        self.drop = nn.Dropout(.5)
        self.tanh = nn.Tanh()

    def forward(self, v, b, q, q_len):
        """Forward

        v: [batch, num_objs, obj_dim]
        b: [batch, num_objs, b_dim]
        q: [batch_size, seq_length]

        return: logits, not probs
        """
        w_emb = self.w_emb(q)
        q_emb = self.q_emb.forward_all(w_emb)  # [batch, q_len, q_dim]
        boxes = b[:, :, :4].transpose(1, 2)

        b_emb = [0] * self.args.glimpse
        att, logits = self.v_att.forward_all(v, q_emb)  # b x g x v x q

        for g in range(self.args.glimpse):
            b_emb[g] = self.b_net[g].forward_with_weights(v, q_emb, att[:, g, :, :])  # b x l x h

            atten, _ = logits[:, g, :, :].max(2)
            embed = self.counter(boxes, atten)

            q_emb = self.q_prj[g](b_emb[g].unsqueeze(1)) + q_emb
            q_emb = q_emb + self.c_prj[g](embed).unsqueeze(1)

        return q_emb.sum(1)

Overwriting /content/drive/MyDrive/VQA/code/model/fusion_net/ban.py


###### san.py

In [9]:
%%writefile /content/drive/MyDrive/VQA/code/model/fusion_net/san.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from torch.autograd import Variable

from ..attention import SanAttention, apply_attention
from ..fc import GroupMLP
from ..language_model import Seq2SeqRNN, WordEmbedding

import pdb
from utils import freeze_layer

class SAN(nn.Module):
    #args, self.train_loader.dataset, self.question_word2vec
    #def __init__(self, args, dataset, question_word2vec):
    def __init__(self, args, dataset,embedding_weights=None,rnn_bidirectional=True):
        super(SAN, self).__init__()
        embedding_requires_grad = not args.freeze_w2v
        question_features = 1024
        rnn_features = int(question_features // 2) if rnn_bidirectional else int(question_features)
        vision_features = args.output_features
        glimpses = 2

        # vocab_size = embedding_weights.size(0)
        # vector_dim = embedding_weights.size(1)
        # self.embedding = nn.Embedding(vocab_size, vector_dim, padding_idx=0)
        # self.embedding.weight.data = embedding_weights
        # self.embedding.weight.requires_grad = embedding_requires_grad
        self.w_emb = WordEmbedding(embedding_weights.size(0), 300, .0)
        if args.freeze_w2v:
            self.w_emb.init_embedding(embedding_weights)
            freeze_layer(self.w_emb)

        self.drop = nn.Dropout(0.5)
        self.text = Seq2SeqRNN(
            input_features=embedding_weights.size(1),
            rnn_features=int(rnn_features),
            rnn_type='LSTM',
            rnn_bidirectional=rnn_bidirectional,
        )
        self.attention = SanAttention(
            v_features=vision_features,
            q_features=question_features,
            mid_features=512,
            glimpses=2,
            drop=0.5,
        )
        self.mlp = GroupMLP(
            in_features=glimpses * vision_features + question_features,
            mid_features= 4 * args.hidden_size,
            out_features=args.embedding_size,
            drop=0.5,
            groups=64,
        )

        for m in self.modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
                init.xavier_uniform(m.weight)
                if m.bias is not None:
                    m.bias.data.zero_()



    def forward(self, v, b, q, q_len):
        # pdb.set_trace()
        q = self.text(self.drop(self.w_emb(q)), list(q_len.data))
        # q = self.text(self.embedding(q), list(q_len.data))

        v = F.normalize(v, p=2, dim=1)
        a = self.attention(v, q)
        v = apply_attention(v, a)

        combined = torch.cat([v, q], dim=1)
        embedding = self.mlp(combined)
        return embedding

Overwriting /content/drive/MyDrive/VQA/code/model/fusion_net/san.py


###### updn.py

In [10]:
%%writefile /content/drive/MyDrive/VQA/code/model/fusion_net/updn.py

import torch
import torch.nn as nn

from ..language_model import WordEmbedding, UpDnQuestionEmbedding
from ..attention import UpDnAttention
from ..classifier import SimpleClassifier
from ..fc import FCNet

from utils import freeze_layer

class UD(nn.Module):
    def __init__(self, args, dataset, question_word2vec):
        super(UD, self).__init__()
        self.w_emb = WordEmbedding(question_word2vec.size(0), 300, 0.0)
        if args.freeze_w2v:
            self.w_emb.init_embedding(question_word2vec)
            freeze_layer(self.w_emb)
            # self.w_emb.weight.requires_grad = False

        self.q_emb = UpDnQuestionEmbedding(300, args.embedding_size, 1, False, 0.0)
        self.v_att = UpDnAttention(args.v_dim, self.q_emb.num_hid, args.embedding_size)
        self.q_net = FCNet([self.q_emb.num_hid, args.embedding_size])
        self.v_net = FCNet([args.v_dim, args.embedding_size])
        # self.classifier = SimpleClassifier(
        #     args.embedding_size, args.embedding_size * 2, args.num_ans_candidates, 0.5)

    def forward(self, v, b, q, qlen):
        """Forward

        v: [batch, num_objs, obj_dim]
        b: [batch, num_objs, b_dim]
        q: [batch_size, seq_length]

        return: logits, not probs
        """
        # print("q = {}".format(q))
        w_emb = self.w_emb(q)
        # print("w_emb = {}".format(w_emb))
        q_emb = self.q_emb(w_emb)  # [batch, q_dim]

        att = self.v_att(v, q_emb) # [spa, 1]
        v_emb = (att * v).sum(1)  # [batch, v_dim]

        q_repr = self.q_net(q_emb)
        v_repr = self.v_net(v_emb)
        joint_repr = q_repr * v_repr
       # logits = self.classifier(joint_repr)
        return joint_repr

Overwriting /content/drive/MyDrive/VQA/code/model/fusion_net/updn.py


#### Answer Network

###### \_\_init__.py

In [4]:
%%writefile /content/drive/MyDrive/VQA/code/model/answer_net/__init__.py
from .mlp import MLP

Overwriting /content/drive/MyDrive/VQA/code/model/answer_net/__init__.py


###### mlp.py

In [5]:
%%writefile /content/drive/MyDrive/VQA/code/model/answer_net/mlp.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import pdb
import model.fc as FC
from ..fc import GroupMLP, GroupMLP_2lay, GroupMLP_1lay


class MLP(nn.Module):
    def __init__(self, args, dataset):
        super(MLP, self).__init__()
        ans_net_list = ["GroupMLP", "GroupMLP_1lay", "GroupMLP_2lay"]
        ans_net = ans_net_list[args.ans_net_lay]
        self.mlp = getattr(FC, ans_net)(
            in_features=args.ans_feature_len,  # fan
            mid_features=args.hidden_size,  # 2048
            out_features=args.embedding_size,  # fan
            drop=0.0,
            groups=64,  # 64
        )

        for m in self.modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
                init.xavier_uniform(m.weight)
                if m.bias is not None:
                    m.bias.data.zero_()

    def forward(self, a, a_len=None):
        # pdb.set_trace()
        return self.mlp(F.normalize(a, p=2))

Overwriting /content/drive/MyDrive/VQA/code/model/answer_net/mlp.py


## Content


###### deal_data.py

In [None]:
%%writefile /content/drive/MyDrive/VQA/code/data/deal_data.py
from config import cfg
import os.path as osp
import pickle
import json
import pdb
import re
from utils import dele_a, transfer, hand_remove, deal_fact
from collections import defaultdict
import tqdm
import Levenshtein
import wordninja
from data import fvqa, preprocess
import random
import numpy as np
from collections import defaultdict


class Runner:
    def __init__(self, args):
        self.args = args
        self.path = osp.join(args.data_root, "data", "FVQA/")
        self.data_path = osp.join(self.path, "new_dataset_release")
        self.split_path = osp.join(self.path, "Name_Lists")
        self.exp_data = osp.join(args.data_root, "fvqa", "exp_data")
        self.e1_list = []
        self.r_list = []
        self.e2_list = []
        self.entity_list = []
        # 这个entity 在哪些VQA pair中出现过。
        self.e1_show_key = defaultdict(list)
        self.e2_show_key = defaultdict(list)
        self.all_entity = []

    # 得到一个所有fact都直接被包含在内的 json文件（不需要跳转）
    def get_new_all_json(self):
        path = osp.join(self.data_path, "all_qs_dict_release_combine_all.json")

        if not osp.exists(path):
            with open(osp.join(self.data_path, "all_fact_triples_release.json"), "r", encoding='utf8') as ffp:
                dic_all = json.load(ffp)
                # pdb.set_trace()
                for i in dic_all.keys():
                    # fact_source = dic[i]["fact"][0]
                    fact = dic_all[i]
                    fact['e1'] = deal_fact(dic_all[i], fact['e1'])
                    fact['e2'] = deal_fact(dic_all[i], fact['e2'])
                    dic_all[i]["fact"] = []
                    dic_all[i]["fact"].append(fact['e1'])
                    dic_all[i]["fact"].append(fact['r'].split('/')[-1])
                    dic_all[i]["fact"].append(fact['e2'])
                    # pdb.set_trace()
                    del dic_all[i]['KB']
                    del dic_all[i]['e1_label']
                    # del dic_all[i]['uri']
                    del dic_all[i]['e2_label']
                    # del dic_all[i]['sources']
                    # del dic_all[i]['context']
                    del dic_all[i]['score']
        else:
            # 需要人工去噪
            with open(path, 'w') as fd:
                json.dump(dic_all, fd)
                print("get_new_json_combile done!（remember to do some human check !!!）")

    # 得到一个所有fact都直接被包含在内的 json文件（不需要跳转）
    def get_new_json(self):
        path = osp.join(self.data_path, "all_qs_dict_release_combine.json")
        if not osp.exists(path):
            with open(osp.join(self.data_path, "all_qs_dict_release_cp.json"), "r") as fp:
                dic = json.load(fp)
                with open(osp.join(self.data_path, "all_fact_triples_release.json"), "r", encoding='utf8') as ffp:
                    dic_all = json.load(ffp)
                    # pdb.set_trace()
                    for i in dic_all.keys():
                        fact_source = dic[i]["fact"][0]
                        fact = dic_all[fact_source]
                        fact['e1'] = deal_fact(dic[i], fact['e1'])
                        fact['e2'] = deal_fact(dic[i], fact['e2'])
                        dic[i]["fact"][0] = fact['e1']
                        dic[i]["fact"].append(fact['r'].split('/')[-1])
                        dic[i]["fact"].append(fact['e2'])
                        del dic[i]['ans_source']
                        del dic[i]['visual_concept']
            # 需要人工去噪
            with open(path, 'w') as fd:
                json.dump(dic, fd)
                print("get_new_json done!（remember to do some human check !!!）")

    def get_entity_filter(self):
        # 把头尾实体筛选一遍，并且储存

        with open(osp.join(self.data_path, "all_qs_dict_release_combine.json"), 'r') as fp:
            dic = json.load(fp)
            for i in dic.keys():
                for j in [0, 1, 2]:
                    dic[i]["fact"][j] = dic[i]["fact"][j].lower().replace("  ", " ")
                    if dic[i]["fact"][j][0] == " ":
                        dic[i]["fact"][j] = dic[i]["fact"][j][1:]
                    if len(dic[i]["fact"][j]) > 2 and dic[i]["fact"][j][-2] == "#":
                        dic[i]["fact"][j] = dic[i]["fact"][j][:-2]

                self.e1_list.append(dic[i]["fact"][0])
                self.r_list.append(dic[i]["fact"][1])
                self.e2_list.append(dic[i]["fact"][2])
                self.e1_show_key[dic[i]["fact"][0]].append(i)
                self.e2_show_key[dic[i]["fact"][2]].append(i)
            self.entity_list = set(self.e1_list + self.e2_list)
            self.entity_list = list(self.entity_list)
            self.r_list = list(set(self.r_list))
            # pdb.set_trace()
            print("get_entity_filter done!")

    def get_all_entity(self):
        path = osp.join(self.data_path, 'ids_new.data')
        if not osp.exists(path):

            # 得到所有的头尾实体，并且排序
            entity_list = []
            with open(osp.join(self.data_path, "FVQA_triple_new_2.txt"), 'r', encoding='utf-8') as f:
                # k = 0
                while 1:
                    line = f.readline()
                    if not line:
                        break
                    if line[:3] == '***':
                        continue
                    # k += 1
                    # if k % 1000 == 0:
                    #     print(k, len(lis))
                    line = re.split('\t|\n', line)
                    entity_list.append(line[0].lower().replace("-", " "))
                    entity_list.append(line[2].lower().replace("-", " "))
            entity_set = set(entity_list)

            def rule_4(a):
                return entity_list.count(a)

            entity_sort = list(set(entity_list))
            entity_sort.sort(key=rule_4, reverse=True)

            with open(path, 'wb') as f:
                pickle.dump(entity_sort, f)

        else:
            with open(path, 'rb') as f:  # 按出现数量排序过了的实体
                print("load ids_new.data")
                entity_sort = pickle.load(f)

        entity_sort.remove('y')
        entity_sort.remove('and')
        entity_sort.remove('yes')
        entity_sort.remove('no')

        path = osp.join(self.data_path, "all_qs_dict_release_combine_filter.json")

        if not osp.exists(path):
            with open(osp.join(self.data_path, "all_qs_dict_release_combine.json"), 'r') as fp:
                dic = json.load(fp)
                Noin = []
                for entity in tqdm.tqdm(self.entity_list):
                    entity_orig = entity
                    entity = entity.replace("_", " ").replace("-", " ")
                    entity = entity.replace("Category:", "").replace("category:", "")
                    entity = entity.replace("(", "").replace(")", "")
                    entity_list = [entity]

                    dele_a_list = dele_a(entity)
                    transfer_a = [transfer(entity)]
                    # entity_list.append(no_)
                    entity_list = entity_list + transfer_a  # 变形
                    entity_list = entity_list + dele_a_list  # 去冠词
                    entity_list = entity_list + dele_a(transfer_a[0])  # 变形后去冠词
                    for i in dele_a_list:
                        entity_list = entity_list + [transfer(i)]  # 去冠词后变形

                    entity_list = list(set(entity_list))
                    hand_list = []
                    for k in entity_list:
                        hand_list = hand_list + hand_remove(k)  # 手动去特殊形式
                    entity_list = entity_list + list(set(hand_list))
                    entity_list = list(set(entity_list))
                    flag = 0

                    # print("change entity...")
                    for key in entity_sort:
                        if key in entity_list:
                            flag = 1
                            self.all_entity.append(key)
                            for j in self.e1_show_key[entity_orig]:  # 答案是这个的编号
                                dic[j]['fact'][0] = key
                            for j in self.e2_show_key[entity_orig]:  # 答案是这个的编号
                                dic[j]['fact'][2] = key
                            break
                    if flag:
                        continue

                    Noin.append(entity_orig)
            print("all entity num:", len(list(set(self.all_entity))))
            print("no in :", Noin)
            print("no in num :", len(Noin))

            # entity 筛选过的。此时答案和entity 统一了

            with open(path, 'w') as fp:
                json.dump(dic, fp)
            print("get_all_entity filter done!")

    def fusion_answer_and_entity(self):
        # 把答案里面出现的entity对齐到entity中。
        # 使用编辑距离
        path = osp.join(self.data_path, "all_qs_dict_release_combine_filter_fusion.json")
        if not osp.exists(path):
            with open(osp.join(self.data_path, "all_qs_dict_release_combine_filter.json"), 'r') as fp:
                dic = json.load(fp)
                with open(osp.join(self.data_path, "ans_entity_map.txt"), 'w') as ffp:
                    for key in dic.keys():
                        strout = "not match: "
                        e1 = dic[key]["fact"][0]
                        e2 = dic[key]["fact"][2]
                        ans = dic[key]["answer"]
                        # 和头实体相似度大于尾实体
                        if Levenshtein.ratio(ans, e1) > Levenshtein.ratio(ans, e2):
                            strout += dic[key]["fact"][2]
                            strout += "\t\t\t\t\t  match: "
                            strout += dic[key]["fact"][0]
                            strout += " -> "
                            strout += ans
                            dic[key]["fact"][0] = ans

                        else:
                            strout += dic[key]["fact"][0]
                            strout += "\t\t\t\t\t  match: "
                            strout += dic[key]["fact"][2]
                            strout += " -> "
                            strout += ans
                            dic[key]["fact"][2] = ans
                        ffp.write(strout + "\n")

                    print("fusion_answer_and_entity done!")

                with open(osp.join(self.data_path, "all_qs_dict_release_combine_filter_fusion.json"), 'w') as fp:
                    json.dump(dic, fp)

    def statistics_of_ans_and_entity(self, name=None, path=None):
        # 数据统计
        if path == None:
            path = osp.join(self.data_path, name)

        with open(path, 'r') as fp:
            dic = json.load(fp)
            ans_set = set()
            entity_set = set()
            relation_set = set()
            dic_len = 0
            for key in dic.keys():
                dic_len += 1
                e1 = dic[key]["fact"][0]
                r = dic[key]["fact"][1]
                e2 = dic[key]["fact"][2]
                ans = dic[key]["answer"]
                ans_set.add(ans)
                entity_set.add(e1)
                entity_set.add(e2)
                relation_set.add(r)

            ans_or_entity = ans_set | entity_set
            ans_and_entity = ans_set & entity_set
            print("ans_set len:", len(ans_set))
            print("entity_set len:", len(entity_set))
            print("ans_or_entity len:", len(ans_or_entity))
            print("ans_and_entity len:", len(ans_and_entity))
            print("relation len:", len(relation_set))
            print("dic len:", dic_len)

    def filter_top500_IQA_pair(self):
        # read ans file
        # store the map from id to ans (with dic)
        # TODO: optimize the code with matrix
        path = osp.join(self.data_path, "all_qs_dict_release_combine_filter_fusion_500.json")
        if not osp.exists(path):
            ans_2_id = {}
            with open(osp.join(self.data_path, "ans.txt"), 'r', encoding='utf-8') as f:
                while 1:
                    line = f.readline()
                    if not line:
                        break
                    line = re.split('-|\n', line)
                    ans_2_id[line[1]] = int(line[0])
            print(len(ans_2_id))
            with open(osp.join(self.data_path, "all_qs_dict_release_combine_filter_fusion.json"), 'r') as fp:
                dic = json.load(fp)
                dic_500 = {key: value for key, value in dic.items() if
                           dic[key]["answer"] in ans_2_id.keys() and ans_2_id[dic[key]["answer"]] <= 500}

            with open(path, 'w') as fp:
                json.dump(dic_500, fp)
                print("filter_top500_IQA_pair done!")

    def deal_relation(self):
        path = osp.join(self.data_path, "all_qs_dict_release_combine_filter_fusion_500.json")
        if not osp.exists(path):
            with open(osp.join(self.data_path, "all_qs_dict_release_combine_filter_fusion_500.json"), 'r') as fp:
                dic = json.load(fp)
            relation_set = set()
            for key in dic.keys():
                relation_set.add(dic[key]["fact"][1])

            print("relation len:", len(relation_set))
            relation_map = {}
            for relation in list(relation_set):
                relation_orig = relation
                # 是否需要把关系去掉？
                if relation[-2] == "#":
                    relation = relation[:-2]

                relation_split = wordninja.split(relation)
                for i in range(len(relation_split)):
                    relation_split[i] = relation_split[i].lower()

                if relation == "transnbhd":
                    relation_map[relation_orig] = "belong to"
                else:
                    relation_map[relation_orig] = ' '.join(relation_split)

            print(relation_map)
            for key in dic.keys():
                tmp = dic[key]["fact"][1]
                dic[key]["fact"][1] = relation_map[tmp]

            with open(path, 'w') as fp:
                json.dump(dic, fp)
                print("deal_relation done!")

    def split_data(self):
        # 把数据集划分出来
        for i in range(0, 5):
            num = str(i)
            train_name = osp.join(self.args.FVQA.train_data_path, "train" + num, "all_qs_dict_release_train_500.json")
            test_name = osp.join(self.args.FVQA.test_data_path, "test" + num, "all_qs_dict_release_test_500.json")

            if osp.exists(train_name) and osp.exists(test_name):
                continue

            img_train = []
            img_test = []

            with open(osp.join(self.split_path, "train_list_" + num + ".txt"), "r") as f:
                while 1:
                    line = f.readline()
                    if not line:
                        break
                    line = re.split('\n', line)
                    img_train.append(line[0])

            with open(osp.join(self.split_path, "test_list_" + num + ".txt"), "r") as f:
                while 1:
                    line = f.readline()
                    if not line:
                        break
                    line = re.split('\n', line)
                    img_test.append(line[0])

            with open(osp.join(self.data_path, "all_qs_dict_release_combine_filter_fusion_500.json"), 'r') as fp:
                dic = json.load(fp)

                dic_train = {key: value for key, value in dic.items() if dic[key]["img_file"] in img_train}
                dic_test = {key: value for key, value in dic.items() if dic[key]["img_file"] in img_test}

            # train_name = osp.join(cfg.FVQA.train_data_path, "train" + num, "all_qs_dict_release_train_500.json")
            # test_name = osp.join(cfg.FVQA.test_data_path, "test" + num, "all_qs_dict_release_test_500.json")
            ans_train = []
            ans_test = []
            q_train = []
            q_test = []
            for key, value in dic_train.items():
                ans_train.append(dic_train[key]["answer"])
                q_train.append(dic_train[key]["question"])
            for key, value in dic_test.items():
                ans_test.append(dic_test[key]["answer"])
                q_test.append(dic_test[key]["question"])

            ans_train_set = set(ans_train)
            q_train_set = set(q_train)
            ans_test_set = set(ans_test)
            q_test_set = set(q_test)

            with open(train_name, "w") as ff:
                json.dump(dic_train, ff)
                print("save to:", train_name)

            with open(test_name, "w") as ff:
                json.dump(dic_test, ff)
                print("save to:", test_name)
            # ans_set len: 387
            # entity_set len: 1842
            # ans_or_entity len: 1842
            # ans_and_entity len: 387
            # relation len: 71
            # dic len: 2669
            print(num, " train :", len(dic_train), len(ans_train_set), len(q_train_set))
            self.statistics_of_ans_and_entity(train_name)

            # ans_set len: 403
            # entity_set len: 1958
            # ans_or_entity len: 1958
            # ans_and_entity len: 403
            # relation len: 87
            # dic len: 2823
            print(num, " test :", len(dic_test), len(ans_test_set), len(q_test_set))
            self.statistics_of_ans_and_entity(test_name)
            print("dataset " + num + " done!")

    def preprocess_answer(self):
        pass

    def preprocess_fact(self):
        output_format = osp.join(self.args.FVQA.common_data_path, "answer.vocab.fvqa.fact.500.json")
        if not osp.exists(output_format):
            num = 2
            vqa_train_questions = osp.join(self.args.FVQA.train_data_path, "train" + str(num), "all_qs_dict_release_train_500.json")
            vqa_val_questions = osp.join(self.args.FVQA.test_data_path, "test" + str(num), "all_qs_dict_release_test_500.json")
            with open(vqa_train_questions, 'r') as fd:
                qaq1 = json.load(fd)
            with open(vqa_val_questions, 'r') as fd:
                qaq2 = json.load(fd)

            annotations = {**qaq1, **qaq2}
            # word2vec = Vector()
            facts = fvqa.prepare_fact(annotations)
            fact_vocab = preprocess.extract_vocab(facts, top_k=None)
            vocabs = {'answer': fact_vocab}
            print('* Dump output vocab to: {}'.format(output_format))
            with open(output_format, 'w') as fd:
                json.dump(vocabs, fd)
        print("preprocess_fact done!")

    def preprocess_relation(self):
        output_format = osp.join(self.args.FVQA.common_data_path, "answer.vocab.fvqa.relation.500.json")
        if not osp.exists(output_format):
            num = 2
            vqa_train_questions = osp.join(self.args.FVQA.train_data_path, "train" + str(num), "all_qs_dict_release_train_500.json")
            vqa_val_questions = osp.join(self.args.FVQA.test_data_path, "test" + str(num), "all_qs_dict_release_test_500.json")
            with open(vqa_train_questions, 'r') as fd:
                qaq1 = json.load(fd)
            with open(vqa_val_questions, 'r') as fd:
                qaq2 = json.load(fd)

            annotations = {**qaq1, **qaq2}
            # word2vec = Vector()
            relations = fvqa.prepare_relation(annotations)
            relation_vocab = preprocess.extract_vocab(relations, top_k=None)
            vocabs = {'answer': relation_vocab}
            print('* Dump output vocab to: {}'.format(output_format))
            with open(output_format, 'w') as fd:
                json.dump(vocabs, fd)
        print("preprocess_relation done!")

    def split_unseen_data(self):
        ans_2_id = {}

        with open(osp.join(self.data_path, "ans.txt"), 'r', encoding='utf-8') as f:
            while 1:
                line = f.readline()
                if not line:
                    break
                line = re.split('-|\n', line)
                ans_2_id[line[1]] = int(line[0])
        print(len(ans_2_id))
        num = "0"
        img = []
        with open(osp.join(self.split_path, "train_list_" + num + ".txt"), "r") as f:
            while 1:
                line = f.readline()
                if not line:
                    break
                line = re.split('\n', line)
                img.append(line[0])

        with open(osp.join(self.split_path, "test_list_" + num + ".txt"), "r") as f:
            while 1:
                line = f.readline()
                if not line:
                    break
                line = re.split('\n', line)
                img.append(line[0])

        with open(osp.join(self.data_path, "all_qs_dict_release_combine_filter_fusion_500.json"), 'r') as fp:
            dic = json.load(fp)

        dic_all = {key: value for key, value in dic.items() if dic[key]["img_file"] in img}
        ans_id = list(range(1, 501))

        # split_unseen_data
        for i in range(5):
            num = str(i)
            train_name = osp.join(self.args.FVQA.seen_train_data_path, "train" + num, "all_qs_dict_release_train_500.json")
            test_name = osp.join(self.args.FVQA.unseen_test_data_path, "test" + num, "all_qs_dict_release_test_500.json")

            if not(osp.exists(train_name) and osp.exists(train_name)):
                ans_seen = random.sample(ans_id, 250)
                ans_unseen = list(set(ans_id) - set(ans_seen))

                dic_seen = {key: value for key, value in dic_all.items() if ans_2_id[dic[key]["answer"]] in ans_seen}
                dic_unseen = {key: value for key, value in dic_all.items() if ans_2_id[dic[key]["answer"]] in ans_unseen}

                with open(train_name, "w") as ff:
                    json.dump(dic_seen, ff)
                with open(test_name, "w") as ff:
                    json.dump(dic_unseen, ff)

                self.statistics_of_ans_and_entity(train_name)

                self.statistics_of_ans_and_entity(test_name)

                print("dataset " + num + " done!")

    def get_fact_relation_matrix(self):
        if not osp.exists(self.args.FVQA.fact_relation_to_ans_path):

            answer_vocab_path = self.args.FVQA.answer_vocab_path
            fact_vocab_path = self.args.FVQA.fact_vocab_path
            relation_vocab_path = self.args.FVQA.relation_vocab_path

            with open(fact_vocab_path, 'r') as fd:
                fact_vocab = json.load(fd)

            with open(relation_vocab_path, 'r') as fd:
                relation_vocab = json.load(fd)

            with open(answer_vocab_path, 'r') as fd:
                answer_vocab = json.load(fd)

            self.answer_to_index = answer_vocab['answer']
            self.index_to_answer = preprocess.invert_dict(self.answer_to_index)
            self.fact_to_index = fact_vocab['answer']
            self.index_to_fact = preprocess.invert_dict(self.fact_to_index)
            self.relation_to_index = relation_vocab['answer']
            self.index_to_relation = preprocess.invert_dict(self.relation_to_index)

            output_format = osp.join(self.args.FVQA.common_data_path, "fact_relation_dict.data")
            vqa_train_questions = osp.join(self.args.FVQA.train_data_path, "train2", "all_qs_dict_release_train_500.json")
            vqa_val_questions = osp.join(self.args.FVQA.test_data_path, "test2", "all_qs_dict_release_test_500.json")
            with open(vqa_train_questions, 'r') as fd:
                qaq1 = json.load(fd)
            with open(vqa_val_questions, 'r') as fd:
                qaq2 = json.load(fd)

            annotations = {**qaq1, **qaq2}

            fact_num = len(self.fact_to_index)
            ans_num = len(self.answer_to_index)
            rel_num = len(self.relation_to_index)

            # fact_relation_matrix = - np.ones([fact_num,rel_num ], dtype = int)
            fact_relation_to_ans = defaultdict(list)

            keys = list(annotations.keys())

            for a in keys:
                answer = annotations[a]["answer"]
                facts = annotations[a]["fact"]
                f1 = facts[0]
                rel = facts[1]
                f2 = facts[2]
                assert (answer == f1 or answer == f2)
                if answer == f1:
                    fact = f2
                else:
                    fact = f1

                fact = preprocess.process_punctuation(fact)
                rel = preprocess.process_punctuation(rel)
                name = str(self.fact_to_index[fact]) + "-" + str(self.relation_to_index[rel])
                fact_relation_to_ans[name].append(self.answer_to_index[answer])

            with open(output_format, 'w') as fd:
                json.dump(fact_relation_to_ans, fd)
                print("dump done!")

        with open(self.args.FVQA.fact_relation_to_ans_path, 'r') as fd:
            fact_relation_to_ans = json.load(fd)

    def preprocess_json_in_order(self):
        num = "3"

        data_path = osp.join(self.exp_data, "test_data", "test" + num, "all_qs_dict_release_test_500.json")
        output_format = osp.join(self.exp_data, "test_data", "test" + num, "all_qs_dict_release_test_500_inorder.json")

        if not osp.exists(output_format):
            with open(data_path, 'r') as fd:
                annotations = json.load(fd)
            keys = list(annotations.keys())
            tmp = 0
            new_annotations = {}
            for a in keys:
                new_annotations[str(tmp)] = annotations[a]
                tmp += 1

            with open(output_format, 'w') as fd:
                json.dump(new_annotations, fd)
                print("dump done!")

    def disjoint_judge(self):
        fact_id_path = osp.join(self.args.FVQA.common_data_path, "answer.vocab.fvqa.fact.500.json")
        answer_id_path = osp.join(self.args.FVQA.common_data_path, "answer.vocab.fvqa.500.json")
        with open(fact_id_path, 'r') as fd:
            self.fact_id = json.load(fd)
            self.fact_id = self.fact_id['answer']
            list_fact = list(self.fact_id)
        with open(answer_id_path, 'r') as fd:
            self.answer_id = json.load(fd)
            self.answer_id = self.answer_id['answer']
            list_ans = list(self.answer_id)
        all = 0
        for i in list_ans:
            if i in list_fact:
                all += 1
        print(all)

    def data_analysis(self, name):
        if name == "zsl":
            testpath = "test_unseen_data"
            trainpath = "train_seen_data"
        else:
            testpath = "test_data"
            trainpath = "train_data"

        train_triplet_num = 0
        test_triplet_num = 0
        and_answer_num = 0
        and_entity_num = 0
        and_question_num = 0
        and_image_num = 0
        and_answer_class = 0
        and_entity_class = 0
        and_question_class = 0
        and_image_class = 0
        train_question_class = 0
        test_question_class = 0
        train_answer_class = 0
        test_answer_class = 0
        train_entity_class = 0
        test_entity_class = 0
        train_image_class = 0
        test_image_class = 0

        for num in range(5):
            test_question = []
            train_question = []
            test_answer = []
            test_image = []
            train_answer = []
            test_entity = []
            train_entity = []
            train_image = []

            num = str(num)
            datapath_test = osp.join(self.exp_data, testpath, "test" + num, "all_qs_dict_release_test_500.json")
            datapath_train = osp.join(self.exp_data, trainpath, "train" + num, "all_qs_dict_release_train_500.json")

            with open(datapath_test, 'r') as fd:
                test_data = json.load(fd)
            with open(datapath_train, 'r') as fd:
                train_data = json.load(fd)

            test_data_keys = list(test_data.keys())
            train_data_keys = list(train_data.keys())

            for key in test_data_keys:
                test_question.append(test_data[key]["question"])
                test_answer.append(test_data[key]["answer"])
                test_image.append(test_data[key]["img_file"])
                e1 = test_data[key]["fact"][0]
                e2 = test_data[key]["fact"][2]
                ans = test_data[key]["answer"]
                # 和头实体相似度大于尾实体
                if Levenshtein.ratio(ans, e1) > Levenshtein.ratio(ans, e2):
                    test_entity.append(e2)
                else:
                    test_entity.append(e1)

            for key in train_data_keys:
                train_question.append(train_data[key]["question"])
                train_answer.append(train_data[key]["answer"])
                train_image.append(train_data[key]["img_file"])
                e1 = train_data[key]["fact"][0]
                e2 = train_data[key]["fact"][2]
                ans = train_data[key]["answer"]
                # 和头实体相似度大于尾实体
                if Levenshtein.ratio(ans, e1) > Levenshtein.ratio(ans, e2):
                    train_entity.append(e2)
                else:
                    train_entity.append(e1)

            # 求question/answer/entity 的数量
            train_triplet_num += len(train_question)
            test_triplet_num += len(test_question)

            # overlap of quetsion/ans/entity
            q_and = [val for val in train_question if val in test_question]
            e_and = [val for val in train_entity if val in test_entity]
            a_and = [val for val in train_answer if val in test_answer]
            i_and = [val for val in train_image if val in test_image]

            and_answer_num += len(a_and)
            and_entity_num += len(q_and)
            and_question_num += len(e_and)
            and_image_num += len(i_and)

            and_answer_class += len(set(a_and))
            and_entity_class += len(set(q_and))
            and_question_class += len(set(e_and))
            and_image_class += len(set(i_and))

            train_question_class += len(set(train_question))
            test_question_class += len(set(test_question))
            train_answer_class += len(set(train_answer))
            test_answer_class += len(set(test_answer))
            train_entity_class += len(set(train_entity))
            test_entity_class += len(set(test_entity))
            train_image_class += len(set(train_image))
            test_image_class += len(set(test_image))

        train_triplet_num = train_triplet_num / 5.
        test_triplet_num = test_triplet_num / 5.
        and_answer_num = and_answer_num / 5.
        and_entity_num = and_entity_num / 5.
        and_question_num = and_question_num / 5.
        and_image_num = and_image_num / 5.
        and_answer_class = and_answer_class / 5.
        and_entity_class = and_entity_class / 5.
        and_question_class = and_question_class / 5.
        and_image_class = and_image_class / 5.
        train_question_class = train_question_class / 5.
        test_question_class = test_question_class / 5.
        train_answer_class = train_answer_class / 5.
        test_answer_class = test_answer_class / 5.
        train_entity_class = train_entity_class / 5.
        test_entity_class = test_entity_class / 5.
        train_image_class = train_image_class / 5.
        test_image_class = test_image_class / 5.

        print(name + "_train_triplet_num:", train_triplet_num)
        print(name + "_test_triplet_num:", test_triplet_num)
        print(name + "_and_answer_num:", and_answer_num)
        print(name + "_and_entity_num:", and_entity_num)
        print(name + "_and_question_num:", and_question_num)
        print(name + "_and_image_num:", and_image_num)
        print(name + "_and_answer_class:", and_answer_class)
        print(name + "_and_entity_class:", and_entity_class)
        print(name + "_and_question_class:", and_question_class)
        print(name + "_and_image_class:", and_image_class)
        print(name + "_train_question_class:", train_question_class)
        print(name + "_test_question_class:", test_question_class)
        print(name + "_train_answer_class:", train_answer_class)
        print(name + "_test_answer_class:", test_answer_class)
        print(name + "_train_entity_class:", train_entity_class)
        print(name + "_test_entity_class:", test_entity_class)
        print(name + "_train_image_class:", train_image_class)
        print(name + "_test_image_class:", test_image_class)

    def data_analysis_zsl_and_general(self):
        # self.data_analysis("zsl")
        self.data_analysis("general")


if __name__ == '__main__':
    cfg = cfg()
    args = cfg.get_args()
    cfg.update_train_configs(args)
    runner = Runner(cfg)

    runner.get_new_json()
    # runner.get_new_all_json()

    runner.get_entity_filter()
    runner.get_all_entity()
    runner.fusion_answer_and_entity()

    # 此时得到的文件：all_qs_dict_release_combine_filter.json 是过滤好了的。
    # 包含三元组 5826

    # ans_set len: 833
    # entity_set len: 3294
    # ans_or_entity len: 3294
    # ans_and_entity len: 833
    # name = "all_qs_dict_release_combine_filter_fusion.json"

    # ans_set len: 500
    # entity_set len: 3027
    # ans_or_entity len: 3027
    # ans_and_entity len: 500
    # relation len: 108

    name = "all_qs_dict_release_combine_filter_fusion_500.json"

    runner.filter_top500_IQA_pair()
    runner.statistics_of_ans_and_entity(name=name)

    runner.filter_top500_IQA_pair()
    runner.deal_relation()
    runner.split_data()

    runner.preprocess_relation()
    runner.preprocess_fact()

    # runner.split_unseen_data()

    runner.get_fact_relation_matrix()

    runner.preprocess_json_in_order()

    # runner.data_analysis_zsl_and_general()

###### main.py


In [None]:
%%writefile /content/drive/MyDrive/VQA/code/main.py
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import pdb
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataloader import default_collate
import warnings
from pprint import pprint

# self-defined
# import model.fusion_net as fusion_net
# import model.answer_net as answer_net
# from model import Vector, SimpleClassifier
# from config import cfg
# from torchlight import initialize_exp, set_seed, snapshot, get_dump_path, show_params
# from utils import unseen_mask, freeze_layer, cosine_sim, Metrics, instance_bce_with_logits
# from data import fvqa
# import copy
# torch.multiprocessing.set_start_method('spawn')

warnings.filterwarnings('ignore')


class Runner:
    def __init__(self, args):
        # prepare for: data , model, loss fuction, optimizer

        self.log_dir = get_dump_path(args)
        self.model_dir = os.path.join(self.log_dir, 'model')

        self.word2vec = Vector(args.FVQA.common_data_path)
        # data load
        self.train_loader = fvqa.get_loader(args, self.word2vec, train=True)
        self.val_loader = fvqa.get_loader(args, self.word2vec, val=True)

        self.avocab = default_collate(list(range(0, args.FVQA.max_ans)))

        # question_word2vec: get the word vector (for each word in question )
        # the id of which could map to the vector of corresponding token
        self.question_word2vec = self.word2vec._prepare(self.train_loader.dataset.token_to_index)

        # get the fusion_model and answer_net
        self._model_choice(args)

        # get the mask from zsl
        self.negtive_mux = unseen_mask(args, self.val_loader)

        # optimizer
        params_for_optimization = list(self.fusion_model.parameters()) + list(self.answer_net.parameters())
        self.optimizer = optim.Adam([p for p in params_for_optimization if p.requires_grad], lr=args.TRAIN.lr)

        # loss fuction
        self.log_softmax = nn.LogSoftmax(dim=1).cuda()

        # Recorder
        self.max_acc = [0, 0, 0, 0]
        self.max_zsl_acc = [0, 0, 0, 0]
        self.best_epoch = 0
        self.correspond_loss = 1e20

        self.early_stop = 0

        print("fusion_model:")
        pprint(self.fusion_model)
        print("Answer Model:")
        pprint(self.answer_net)

        self.args = args

        # test stage:
        if self.args.now_test:
            print("begin test! ...")
            print("loading model  ...")
            self._load_model(self.fusion_model, "fusion")
            self._load_model(self.answer_net, "embedding")

    def run(self):
        # 1. define the parameters which are out the epoch
        # 2. Update statistical indicator
        # 3. concate of answer embedding

        # Answer embedding :
        # choices belong to: ['CLS', 'W2V', 'KG', 'GAE', 'KG_W2V', 'KG_GAE', 'GAE_W2V', 'KG_GAE_W2V']
        # well, we recommend only use the parameter : 'CLS' or 'W2V'.
        # since that the resource of other choices need extra training.
        if args.method_choice != 'CLS':
            previous_var = None
            for method_choice in self.method_list:
                # get the corresponding choice embedding
                answer_var, answer_len = self.train_loader.dataset._get_answer_vectors(method_choice, self.avocab)

                # normalize in row and then concate then
                answer_var = F.normalize(answer_var, p=2, dim=1)
                if previous_var is not None:
                    previous_var = torch.cat([previous_var, answer_var], dim=1)
                else:
                    previous_var = answer_var
            self.answer_var = Variable(previous_var.float()).cuda()

        # warm up (ref: ramen)
        self.gradual_warmup_steps = [i * self.args.TRAIN.lr for i in torch.linspace(0.5, 2.0, 7)]
        self.lr_decay_epochs = range(14, 47, self.args.TRAIN.lr_decay_step)

        # if test:
        if self.args.now_test:
            self.args.TRAIN.epochs = 2

        for epoch in range(self.args.TRAIN.epochs):

            self.early_stop += 1
            if self.args.patience < self.early_stop:
                # early stop
                break
            # warm up
            if epoch < len(self.gradual_warmup_steps):
                self.optimizer.param_groups[0]['lr'] = self.gradual_warmup_steps[epoch]
            elif epoch in self.lr_decay_epochs:
                self.optimizer.param_groups[0]['lr'] *= self.args.TRAIN.lr_decay_rate

            self.train_metrics = Metrics()
            self.val_metrics = Metrics()
            self.zsl_metrics = Metrics()
            # use TOP50 metrics for fact mapping:
            if self.args.fact_map == 1:
                self.train_metrics = Metrics(topnum=50)
                self.val_metrics = Metrics(topnum=50)
                self.zsl_metrics = Metrics(topnum=50)

            # train
            if not self.args.now_test:
                ######## begin training!! #######
                self.train(epoch)
                #################################
                lr = self.optimizer.param_groups[0]['lr']
                # recode:
                logger.info(
                    f'Train Epoch {epoch}: LOSS={self.train_metrics.total_loss: .5f}, lr={lr: .6f}, acc1={self.train_metrics.acc_1: .2f},acc3={self.train_metrics.acc_3: .2f},acc10={self.train_metrics.acc_10: .2f}')
            # eval
            if epoch % 1 == 0 and epoch > 0:
                ######## begin evaling!! #######
                self.eval(epoch)
                #################################
                logger.info('#################################################################################################################')
                logger.info(f'Test Epoch {epoch}: LOSS={self.val_metrics.total_loss: .5f}, acc1={self.val_metrics.acc_1: .2f}, acc3={self.val_metrics.acc_3: .2f}, acc10={self.val_metrics.acc_10: .2f}')
                if args.ZSL and not self.args.fact_map and not args.relation_map:
                    logger.info(f'Zsl Epoch {epoch}: LOSS={self.zsl_metrics.total_loss: .5f}, acc1={self.zsl_metrics.acc_1: .2f}, acc3={self.zsl_metrics.acc_3: .2f}, acc10={self.zsl_metrics.acc_10: .2f}')
                logger.info('#################################################################################################################')

                # add 0.1 accuracy punishment, avoid for too much attention on hit@10 acc
                # 添加0.1的精读惩罚, 防止模型过多的关注hit@10 acc
                if self.val_metrics.total_loss < (self.correspond_loss - 1) or self.val_metrics.acc_all > (self.max_acc[3] + 0.2):
                    # reset early_stop and updata
                    self.early_stop = 0
                    self.best_epoch = epoch
                    self.correspond_loss = self.val_metrics.total_loss
                    self._updata_best_result(self.max_acc, self.val_metrics)

                    self.best_fusion_model = copy.deepcopy(self.fusion_model)
                    self.best_answer_net = copy.deepcopy(self.answer_net)

                    # ZSL result
                    if args.ZSL and not self.args.fact_map and not args.relation_map:
                        self._updata_best_result(self.max_zsl_acc, self.zsl_metrics)

                if not args.no_tensorboard and not self.args.now_test:
                    writer.add_scalar('loss', self.val_metrics.total_loss, epoch)
                    writer.add_scalar('acc1', self.val_metrics.acc_1, epoch)
                    writer.add_scalar('acc3', self.val_metrics.acc_3, epoch)
                    writer.add_scalar('acc10', self.val_metrics.acc_10, epoch)

        # save the model
        if not self.args.now_test and self.args.save_model:
            self.fusion_model_path = self._save_model(self.best_fusion_model, "fusion")
            self.answer_net_path = self._save_model(self.best_answer_net, "embedding")

    def train(self, epoch):
        self.fusion_model.train()
        self.answer_net.train()
        prefix = "train"
        tq = tqdm(self.train_loader, desc='{} E{:03d}'.format(prefix, epoch), ncols=0)

        for visual_features, boxes, question_features, answers, idx, q_len in tq:
            visual_features = Variable(visual_features.float()).cuda()
            boxes = Variable(boxes.float()).cuda()
            question_features = Variable(question_features).cuda()
            answers = Variable(answers).cuda()
            q_len = Variable(q_len).cuda()
            fusion_embedading = self.fusion_model(visual_features, boxes, question_features, q_len)

            # Classifier-based methods
            if args.method_choice == 'CLS':
                # TODO: Normalization?
                predicts = self.answer_net(fusion_embedading)
                loss = instance_bce_with_logits(predicts, answers / 10)
            # Mapping-based methods
            else:
                answer_embedding = self.answer_net(self.answer_var)
                # notice the temperature (correspoding to specific score)
                predicts = cosine_sim(fusion_embedading, answer_embedding) / self.args.loss_temperature
                predicts = predicts.to(torch.float64)
                nll = -self.log_softmax(predicts).to(torch.float64)
                # loss = (nll * answers[0] / answers[0].sum(1, keepdim=True)).sum(dim=1).mean()
                loss = (nll * answers / answers.sum(1, keepdim=True)).sum(dim=1).mean()

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.train_metrics.update_per_batch(loss, answers.data, predicts.data)
        self.train_metrics.update_per_epoch()

    def eval(self, epoch):
        self.fusion_model.eval()
        self.answer_net.eval()
        prefix = "eval"
        tq = tqdm(self.val_loader, desc='{} E{:03d}'.format(prefix, epoch), ncols=0)

        for visual_features, boxes, question_features, answers, idx, q_len in tq:
            with torch.no_grad():
                visual_features = Variable(visual_features.float()).cuda()
                boxes = Variable(boxes.float()).cuda()
                question_features = Variable(question_features).cuda()
                answers = Variable(answers).cuda()
                q_len = Variable(q_len).cuda()
                fusion_embedading = self.fusion_model(visual_features, boxes, question_features, q_len)

                if args.method_choice == 'CLS':
                    predicts = self.answer_net(fusion_embedading)
                    loss = instance_bce_with_logits(predicts, answers / 10)

                else:
                    answer_embedding = self.answer_net(self.answer_var)
                    predicts = cosine_sim(fusion_embedading, answer_embedding) / self.args.loss_temperature
                    predicts = predicts.to(torch.float64)
                    nll = -self.log_softmax(predicts).to(torch.float64)
                    loss = (nll * answers / answers.sum(1, keepdim=True)).sum(dim=1).mean()

                if args.ZSL == 1 and not self.args.fact_map and not args.relation_map:
                    # if predicts.shape[0] != self.negtive_mux.shape[0]:
                    #     pdb.set_trace()
                    zsl_predicts = predicts + self.negtive_mux[:predicts.shape[0], :]

            self.val_metrics.update_per_batch(loss, answers.data, predicts.data)
            if args.ZSL == 1 and not self.args.fact_map and not args.relation_map:
                self.zsl_metrics.update_per_batch(loss, answers.data, zsl_predicts.data)

        self.val_metrics.update_per_epoch()
        if args.ZSL == 1 and not self.args.fact_map and not args.relation_map:
            self.zsl_metrics.update_per_epoch()

    def _model_choice(self, args):
        assert args.fusion_model in ['SAN', 'MLP', 'BAN', 'UD']
        # models api
        self.fusion_model = getattr(fusion_net, args.fusion_model)(args, self.train_loader.dataset,
                                                                   self.question_word2vec).cuda()
        # freeze word embedding
        if args.freeze_w2v and args.fusion_model != 'MLP':
            freeze_layer(self.fusion_model.w_emb)

        # answer models
        assert args.method_choice in ['CLS', 'W2V', 'KG', 'GAE', 'KG_W2V', 'KG_GAE', 'GAE_W2V', 'KG_GAE_W2V']
        ans_len_table = {'W2V': 300, 'KG': 300, 'GAE': 1024, 'CLS': 0}
        self.method_list = args.method_choice.split('_')
        self.method_list.sort()
        for i in self.method_list:
            args.ans_feature_len += ans_len_table[i]
        # Mapping-based methods
        if args.method_choice != 'CLS':
            assert args.answer_embedding in ['MLP']
            self.answer_net = getattr(answer_net, args.answer_embedding)(args, self.train_loader.dataset).cuda()
        else:
            # Classifier-based methods
            self.answer_net = SimpleClassifier(args.embedding_size, 2 * args.hidden_size, args.FVQA.max_ans, 0.5).cuda()

    def _updata_best_result(self, max_acc, metrics):
        max_acc[3] = metrics.acc_all
        max_acc[2] = metrics.acc_10
        max_acc[1] = metrics.acc_3
        max_acc[0] = metrics.acc_1

    def _load_model(self, model, function):
        assert function == "fusion" or function == "embedding"
        # support entity mapping
        if self.args.fact_map:
            target = "fact"
        # relation mapping
        elif self.args.relation_map:
            target = "relation"
        else:
            target = "answer"
        model_name = type(model).__name__
        if not self.args.ZSL:
            target = "general_" + target
        save_path = os.path.join(self.args.FVQA.model_save_path, function)
        save_path = os.path.join(save_path, f'{target}_{model_name}_{self.args.FVQA.data_choice}.pkl')

        model.load_state_dict(torch.load(save_path))
        print(f"loading {function} model done!")

    def _save_model(self, model, function):
        assert function == "fusion" or function == "embedding"
        if self.args.fact_map:
            target = "fact"
        elif self.args.relation_map:
            target = "relation"
        else:
            target = "answer"
        model_name = type(model).__name__
        if not self.args.ZSL:
            target = "general_" + target
        save_path = os.path.join(self.args.FVQA.model_save_path, function)
        os.makedirs(save_path, exist_ok=True)
        save_path = os.path.join(save_path, f'{target}_{model_name}_{self.args.FVQA.data_choice}.pkl')

        torch.save(model.state_dict(), save_path)
        return save_path


if __name__ == '__main__':
    # Config loading...
    cfg = cfg()
    args = cfg.get_args()
    cfg.update_train_configs(args)
    set_seed(cfg.random_seed)

    # Environment initialization...
    logger = initialize_exp(cfg)
    logger_path = get_dump_path(cfg)
    if not cfg.no_tensorboard:
        writer = SummaryWriter(log_dir=os.path.join(logger_path, 'tensorboard'))

    torch.cuda.set_device(cfg.gpu_id)

    # Run...
    runner = Runner(cfg)
    runner.run()

    #  information output:
    logger.info(f"best performance = {runner.max_acc[0]: .2f},{runner.max_acc[1]: .2f},{runner.max_acc[2]: .2f}. best epoch = {runner.best_epoch}, correspond_loss={runner.correspond_loss: .4f}")
    if args.ZSL == 1 and not args.fact_map and not args.relation_map:
        logger.info(f" zsl performance = {runner.max_zsl_acc[0]: .2f},{runner.max_zsl_acc[1]: .2f},{runner.max_zsl_acc[2]: .2f}")
    if not cfg.now_test:
        logger.info(f" fusion_model_path = {runner.fusion_model_path}")
        logger.info(f" answer_net_path = {runner.answer_net_path}")
    if not cfg.no_tensorboard:
        writer.close()

Overwriting /content/drive/MyDrive/VQA/code/main.py


In [None]:
!ls

cfgs  code  data  kg  run.ipynb


###### config.py


In [None]:
%%writefile /content/drive/MyDrive/VQA/code/config.py
import os.path as osp
import numpy as np
import random
import torch
from easydict import EasyDict as edict
import argparse
import pdb


class cfg():
    def __init__(self):

        self.fusion_model_path = ""
        self.answer_net_path = ""

        self.joint_test_way = 0

        self.this_dir = osp.dirname(__file__)
        self.data_root = osp.abspath(osp.join(self.this_dir, '..', '..', 'data', 'KG_VQA'))
        self.this_dir = osp.dirname(__file__)
        self.project_root = osp.abspath(osp.join(self.this_dir, '..'))
        self.method_choice = "KG"
        self.ans_fusion = 'RNN_concate'
        self.fusion_model = ''
        self.requires_grad = 1
        self.bert_dim = 1024
        self.KGE = "TransE"
        self.KGE_init = None  # none or w2v
        self.glimpse = 4
        self.ans_feature_len = 0
        self.patience = 30
        self.v_dim = 2048

        self.FVQA = edict()

        # FVQA params

        self.FVQA.max_ans = 500
        self.FVQA.data_choice = "0"

        self.FVQA.entity_num = "all"
        self.FVQA.data_path = osp.join(self.data_root, "fvqa")

        self.FVQA.exp_data_path = osp.join(self.FVQA.data_path, "exp_data")
        self.FVQA.common_data_path = osp.join(self.FVQA.exp_data_path, "common_data")
        self.FVQA.test_data_path = osp.join(self.FVQA.exp_data_path, "test_data")
        self.FVQA.train_data_path = osp.join(self.FVQA.exp_data_path, "train_data")
        self.FVQA.seen_train_data_path = osp.join(self.FVQA.exp_data_path, "train_seen_data")
        self.FVQA.unseen_test_data_path = osp.join(self.FVQA.exp_data_path, "test_unseen_data")
        self.FVQA.seen_test_data_path = osp.join(self.FVQA.exp_data_path, "test_seen_data")
        self.FVQA.model_save_path = osp.join(self.FVQA.data_path, "model_save")
        self.FVQA.runs_path = osp.join(self.FVQA.data_path, "model_save")

        self.FVQA.qa_path = self.FVQA.exp_data_path
        self.FVQA.feature_path = osp.join(self.FVQA.common_data_path, 'fvqa-resnet-14x14.h5')
        self.FVQA.answer_vocab_path = osp.join(
            self.FVQA.common_data_path, 'answer.vocab.fvqa.' + str(self.FVQA.max_ans) + '.json')
        self.FVQA.fact_vocab_path = osp.join(self.FVQA.common_data_path, 'answer.vocab.fvqa.fact.500.json')
        self.FVQA.relation_vocab_path = osp.join(self.FVQA.common_data_path, 'answer.vocab.fvqa.relation.500.json')

        self.FVQA.fact_relation_to_ans_path = osp.join(self.FVQA.common_data_path, "fact_relation_dict.data")
        self.FVQA.img_path = osp.join(self.FVQA.qa_path, 'images')

        self.FVQA.kg_path = osp.join(self.FVQA.common_data_path, "KG_embedding")
        self.FVQA.gae_path = osp.join(self.FVQA.common_data_path, "GAE_embedding")
        self.FVQA.bert_path = osp.join(self.FVQA.common_data_path, "BERT_embedding")

        self.FVQA.gae_node_num = 3463
        self.FVQA.gae_init = "w2v"  # or w2v
        # 有问题
        # self.FVQA.qa = 'train2014'
        # self.FVQA.task = 'OpenEnded'
        # self.FVQA.dataset = 'mscoco'

        # self.dataset = self.FVQA

        self.cache_path = osp.join(self.data_root, '.cache')
        self.output_path = self.FVQA.model_save_path
        self.embedding_size = 1024  # embedding dimensionality
        self.hidden_size = 2 * self.embedding_size  # hidden embedding
        # a joint question vocab across all dataset
        self.question_vocab_path = osp.join(self.FVQA.common_data_path, 'question.vocab.json')  # 修改这里之后所有的预存文件（pt）都要删除

        # preprocess config
        self.image_size = 448
        self.output_size = self.image_size // 32
        self.preprocess_batch_size = 100  # 64
        self.output_features = 2048
        self.central_fraction = 0.875

        # Train params
        self.TRAIN = edict()
        self.TRAIN.epochs = 600
        self.TRAIN.batch_size = 128  # 128
        self.TRAIN.lr = 5e-4  # default Adam lr 1e-3
        self.TRAIN.lr_decay_step = 3
        self.TRAIN.lr_decay_rate = .70

        # self.TRAIN.data_workers = 20
        self.TRAIN.data_workers = 8  # 10
        self.TRAIN.answer_batch_size = self.FVQA.max_ans  # batch size for answer network
        self.TRAIN.max_negative_answer = self.FVQA.max_ans  # max negative answers to sample

        # Test params
        self.TEST = edict()
        self.TEST.batch_size = 128
        self.TEST.max_answer_index = self.FVQA.max_ans  # max answer index for computing acc   853

    def get_args(self):
        parser = argparse.ArgumentParser()
        parser.add_argument('--gpu_id', default=1, type=int)
        parser.add_argument('--finetune', action='store_true')
        parser.add_argument('--batch_size', default=128, type=int)
        parser.add_argument('--max_ans', default=500, type=int)  # 3000 300##
        parser.add_argument('--loss_temperature', default=0.01, type=float)
        # parser.add_argument('--pretrained_model', default=None, type=str)
        parser.add_argument('--answer_embedding', default='MLP', choices=['RNN', 'MLP'])  # 答案编码：MLP or RNN##
        # parser.add_argument('--context_embedding', default='BoW', choices=['SAN', 'BoW'])  # Q I 内容编码：SAN or MLP
        parser.add_argument('--embedding_size', default=1024, choices=[1024, 300, 512], type=int)  # 答案编码：MLP or RNN##
        parser.add_argument('--epoch', default=800, type=int)  # 答案编码：MLP or RNN ##
        # choice model
        parser.add_argument('--fusion_model', default='SAN', choices=['MLP', 'SAN', 'UD', 'MUTAN', 'BAN', 'ViLBERT'])
        parser.add_argument('--requires_grad', default=0, type=int, choices=[0, 1])
        # choice class
        parser.add_argument('--method_choice', default='W2V',
                            choices=['CLS', 'W2V', 'KG', 'GAE', 'KG_W2V', 'KG_GAE', 'GAE_W2V', 'KG_GAE_W2V'])
        parser.add_argument('--ans_fusion', default='Simple_concate',
                            choices=['RNN_concate', 'GATE_attention', 'GATE', 'RNN_GATE_attention', 'Simple_concate'])
        # KG situation
        parser.add_argument('--KGE', default='TransE',
                            choices=['TransE', 'ComplEx', "TransR", "DistMult"])  # 答案编码：MLP or RNN ##
        parser.add_argument('--KGE_init', default="w2v")  # None  # none or w2v ##
        parser.add_argument('--GAE_init', default="random")  # None  # random or w2v ##
        parser.add_argument('--ZSL', type=int, default=0)  # None  # random or w2v ##
        parser.add_argument('--entity_num', default="all", choices=['all', '4302'])  # todo: 完成不同子图情况的... ##

        parser.add_argument('--data_choice', default='0', choices=['0', '1', '2', '3', '4'])
        parser.add_argument('--name', default=None, type=str)  # 定义名字后缀

        parser.add_argument("--no-tensorboard", default=False, action="store_true")
        parser.add_argument("--exp_name", default="", type=str, required=True, help="Experiment name")
        parser.add_argument("--dump_path", default="dump/", type=str, help="Experiment dump path")
        parser.add_argument("--exp_id", default="", type=str, help="Experiment ID")
        parser.add_argument("--random_seed", default=4567, type=int)
        parser.add_argument("--freeze_w2v", default=1, type=int, choices=[0, 1])
        parser.add_argument("--ans_net_lay", default=0, type=int, choices=[0, 1, 2])
        parser.add_argument("--fact_map", default=0, type=int, choices=[0, 1])
        parser.add_argument("--relation_map", default=0, type=int, choices=[0, 1])

        parser.add_argument("--now_test", default=0, type=int, choices=[0, 1])
        parser.add_argument("--save_model", default=0, type=int, choices=[0, 1])

        parser.add_argument("--joint_test_way", default=0, type=int, choices=[0, 1])
        parser.add_argument("--top_rel", default=10, type=int)
        parser.add_argument("--top_fact", default=100, type=int)
        parser.add_argument("--soft_score", default=10, type=int)  # 10 or 10000
        parser.add_argument("--mrr", default=0, type=int)
        args = parser.parse_args()
        return args

    def update_train_configs(self, args):
        self.gpu_id = args.gpu_id
        self.finetune = args.finetune
        self.answer_embedding = args.answer_embedding
        self.name = args.name
        self.no_tensorboard = args.no_tensorboard
        self.exp_name = args.exp_name
        self.dump_path = args.dump_path
        self.exp_id = args.exp_id
        self.random_seed = args.random_seed
        self.freeze_w2v = args.freeze_w2v
        self.loss_temperature = args.loss_temperature
        self.ZSL = args.ZSL
        self.ans_net_lay = args.ans_net_lay
        self.fact_map = args.fact_map
        self.relation_map = args.relation_map
        self.now_test = args.now_test
        self.save_model = args.save_model
        self.joint_test_way = args.joint_test_way
        self.top_rel = args.top_rel
        self.top_fact = args.top_fact
        self.soft_score = args.soft_score
        self.mrr = args.mrr

        if args.ZSL == 1:
            print("ZSL setting...")
            self.FVQA.test_data_path = self.FVQA.unseen_test_data_path
            self.FVQA.train_data_path = self.FVQA.seen_train_data_path

        if args.fusion_model == 'UD' or args.fusion_model == 'BAN':
            self.FVQA.feature_path = osp.join(self.FVQA.common_data_path, 'fvqa_36.hdf5')
            self.FVQA.img_id2idx = osp.join(self.FVQA.common_data_path, 'fvqa36_imgid2idx.pkl')
        self.requires_grad = True if args.requires_grad == 1 else False
        self.fusion_model = args.fusion_model
        self.TRAIN.batch_size = args.batch_size
        # self.TRAIN.answer_batch_size = args.answer_batch_size
        self.method_choice = args.method_choice
        self.ans_fusion = args.ans_fusion
        self.embedding_size = args.embedding_size
        self.FVQA.data_choice = args.data_choice
        self.FVQA.max_ans = args.max_ans
        self.TRAIN.epochs = args.epoch
        self.FVQA.KGE = args.KGE
        self.FVQA.KGE_init = args.KGE_init
        self.FVQA.gae_init = args.GAE_init
        self.FVQA.entity_num = args.entity_num

        if self.fact_map:
            self.FVQA.max_ans = 2791
        if self.relation_map:
            self.FVQA.max_ans = 103

        self.TEST.max_answer_index = self.FVQA.max_ans
        self.TRAIN.answer_batch_size = self.FVQA.max_ans  # batch size for answer network
        self.TRAIN.max_negative_answer = self.FVQA.max_ans

        self.FVQA.answer_vocab_path = osp.join(
            self.FVQA.common_data_path, 'answer.vocab.fvqa.' + str(self.FVQA.max_ans) + '.json')

        if "KG" in self.method_choice:
            self.FVQA.relation2id_path = osp.join(self.FVQA.kg_path, "relations_" + self.FVQA.entity_num + ".tsv")
            self.FVQA.entity2id_path = osp.join(self.FVQA.kg_path, "entities_" + self.FVQA.entity_num + ".tsv")
            if self.KGE_init != "w2v":
                self.FVQA.entity_path = osp.join(self.FVQA.kg_path, "fvqa_" +
                                                 self.FVQA.entity_num + "_" + self.KGE + "_entity.npy")
                self.FVQA.relation_path = osp.join(self.FVQA.kg_path, "fvqa_" +
                                                   self.FVQA.entity_num + "_" + self.KGE + "_relation.npy")
            else:
                self.FVQA.entity_path = osp.join(self.FVQA.kg_path, "fvqa_" +
                                                 self.FVQA.entity_num + "_w2v_" + self.KGE + "_entity.npy")
                self.FVQA.relation_path = osp.join(self.FVQA.kg_path, "fvqa_" +
                                                   self.FVQA.entity_num + "_w2v_" + self.KGE + "_relation.npy")

Overwriting /content/drive/MyDrive/VQA/code/config.py


## Torch Light

###### \_\_init__.py

In [18]:
%%writefile /content/drive/MyDrive/VQA/code/torchlight/__init__.py
from .logger import initialize_exp, get_dump_path
from .metric import Metric, CategoricalAccuracy, PRMetric
from .module import LSTM4VarLenSeq
from .vocab import (PAD_TOKEN, UNK_TOKEN, BOS_TOKEN, EOS_TOKEN,
                    DefaultLookupDict,
                    Vocabulary)
from .utils import (personal_display_settings,
                    set_seed,
                    normalize,
                    snapshot,
                    show_params,
                    longest_substring,
                    pad,
                    to_cuda,
                    get_code_version,
                    cat_ragged_tensors,
                    topk_accuracy,
                    get_total_trainable_params)

Writing /content/drive/MyDrive/VQA/code/torchlight/__init__.py


###### logger.py

In [19]:
%%writefile /content/drive/MyDrive/VQA/code/torchlight/logger.py
import os
import re
import sys
import time
import json
import torch
import pickle
import random
import getpass
import logging
import argparse
import subprocess
import numpy as np
from datetime import timedelta, date
from .utils import get_code_version


class LogFormatter():

    def __init__(self):
        self.start_time = time.time()

    def format(self, record):
        elapsed_seconds = round(record.created - self.start_time)

        prefix = "%s - %s - %s" % (
            record.levelname,
            time.strftime('%x %X'),
            timedelta(seconds=elapsed_seconds)
        )
        message = record.getMessage()
        message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3))
        return "%s - %s" % (prefix, message) if message else ''


def create_logger(filepath, rank):
    """
    Create a logger.
    Use a different log file for each process.
    """
    # create log formatter
    log_formatter = LogFormatter()

    # create file handler and set level to debug
    if filepath is not None:
        if rank > 0:
            filepath = '%s-%i' % (filepath, rank)
        file_handler = logging.FileHandler(filepath, "a", encoding='utf-8')
        file_handler.setLevel(logging.DEBUG)
        file_handler.setFormatter(log_formatter)

    # create console handler and set level to info
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(log_formatter)

    # create logger and set level to debug
    logger = logging.getLogger()
    logger.handlers = []
    logger.setLevel(logging.DEBUG)
    logger.propagate = False
    if filepath is not None:
        logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    # reset logger elapsed time
    def reset_time():
        log_formatter.start_time = time.time()
    logger.reset_time = reset_time

    return logger


def initialize_exp(params):
    """
    Initialize the experiment:
    - dump parameters
    - create a logger
    """
    # dump parameters
    exp_folder = get_dump_path(params)
    json.dump(vars(params), open(os.path.join(exp_folder, 'params.pkl'), 'w'), indent=4)

    # get running command
    command = ["python", sys.argv[0]]
    for x in sys.argv[1:]:
        if x.startswith('--'):
            assert '"' not in x and "'" not in x
            command.append(x)
        else:
            assert "'" not in x
            if re.match('^[a-zA-Z0-9_]+$', x):
                command.append("%s" % x)
            else:
                command.append("'%s'" % x)
    command = ' '.join(command)
    params.command = command + ' --exp_id "%s"' % params.exp_id

    # check experiment name
    assert len(params.exp_name.strip()) > 0

    # create a logger
    logger = create_logger(os.path.join(exp_folder, 'train.log'), rank=getattr(params, 'global_rank', 0))
    logger.info("============ Initialized logger ============")
    # logger.info("\n".join("%s: %s" % (k, str(v))
    #                       for k, v in sorted(dict(vars(params)).items())))
    # text = f'# Git Version: {get_code_version()} #'
    # logger.info("\n".join(['=' * 24, text, '=' * 24]))
    logger.info("The experiment will be stored in %s\n" % exp_folder)
    logger.info("Running command: %s" % command)
    logger.info("")
    return logger


def get_dump_path(params):
    """
    Create a directory to store the experiment.
    """
    assert len(params.exp_name) > 0
    assert not params.dump_path in ('', None), \
            'Please choose your favorite destination for dump.'
    dump_path = params.dump_path

    # create the sweep path if it does not exist
    when = date.today().strftime('%m%d-')
    sweep_path = os.path.join(dump_path, when + params.exp_name)
    if not os.path.exists(sweep_path):
        subprocess.Popen("mkdir -p %s" % sweep_path, shell=True).wait()

    # create an random ID for the job if it is not given in the parameters.
    if params.exp_id == '':
        chars = 'abcdefghijklmnopqrstuvwxyz0123456789'
        while True:
            exp_id = ''.join(random.choice(chars) for _ in range(10))
            if not os.path.isdir(os.path.join(sweep_path, exp_id)):
                break
        params.exp_id = exp_id

    # create the dump folder / update parameters
    exp_folder = os.path.join(sweep_path, params.exp_id)
    if not os.path.isdir(exp_folder):
        subprocess.Popen("mkdir -p %s" % exp_folder, shell=True).wait()
    return exp_folder


if __name__ == '__main__':
    pass

Writing /content/drive/MyDrive/VQA/code/torchlight/logger.py


###### metric.py

In [20]:
%%writefile /content/drive/MyDrive/VQA/code/torchlight/metric.py
# from abc import ABC, ABCMeta, abstractclassmethod
import torch
import numpy as np
from abc import ABC, abstractmethod, ABCMeta


class Metric(metaclass=ABCMeta):
    """
    Abstract Base class (ABC) for all Metrics.
    Taken from https://github.com/pytorch/ignite/metrics/metric.py
        and modify a bit.
    Often, data is truncated into batches. In such scenario, we call
    -   reset() in the begining of every epoch.
    -   update() after every batch
    -   compute() whenever you want to log the training/testing performance.
    """

    @abstractmethod
    def __init__(self):
        pass

    @abstractmethod
    def reset(self):
        """
        Resets the metric to to it's initial state.
        This is called at the start of each epoch.
        """
        pass

    @abstractmethod
    def update(self, output):
        """
        Updates the metric's state using the passed batch output.
        This is called once for each batch.
        Args:
            output: the is the output from the engine's process function
        """
        pass

    @abstractmethod
    def compute(self):
        """
        Computes the metric based on it's accumulated state.
        This is called at the end of each epoch.
        Returns:
            Any: the actual quantity of interest
        Raises:
            NotComputableError: raised when the metric cannot be computed
        """
        pass


class CategoricalAccuracy(Metric):
    """
    Calculates the categorical accuracy.
    - `update` must receive output of the form `(y_pred, y)`.
    - `y_pred` must be in the following shape (batch_size, num_categories, ...)
    - `y` must be in the following shape (batch_size, ...)
    """

    def __init__(self):
        super().__init__()
        self._num_examples = 0
        self._num_correct = 0

    def reset(self):
        self._num_examples = 0
        self._num_correct = 0

    def update(self, output):
        y_pred, y = output
        _, indices = torch.max(y_pred, 1)
        correct = torch.eq(indices, y).view(-1)
        self._num_correct += torch.sum(correct).item()
        self._num_examples += correct.shape[0]

    def compute(self):
        if self._num_examples == 0:
            raise ZeroDivisionError('CategoricalAccuracy must have at least'
                                    ' one example before it can be computed')
        return self._num_correct / self._num_examples


class PRMetric(Metric):
    """
    Calculates the precision and recall.
    - `update` must receive output of the form `(y_pred, y)`.
    - `y_pred` must be in the following shape (batch_size, num_categories, ...)
    - `y` must be in the following shape (batch_size, ...)
    """

    def __init__(self, num_class=2):
        """
        precision = tp / tp + fp
        recall = tp / tp + fn
        """
        super().__init__()
        self.num_class = num_class
        self.confusion_matrix = np.zeros((self.num_class, self.num_class),
                                         dtype=np.float32)

    def reset(self):
        self.confusion_matrix = np.zeros((self.num_class, self.num_class),
                                         dtype=np.float32)

    def update(self, output):
        y_pred, y = output
        _, indices = torch.max(y_pred, 1)
        self.confusion_matrix[indices.cpu().numpy(), y.cpu().numpy()] += 1

    def compute(self):
        tp = np.diag(self.confusion_matrix)
        total_pred = np.sum(self.confusion_matrix, axis=1)  # (-1, 1)
        total_gold = np.sum(self.confusion_matrix, axis=0)  # (1, -1)
        # tn don't care
        p = tp / total_pred
        r = tp / total_gold
        return p, r


if __name__ == '__main__':
    # unit test
    pr = PRMetric(3)
    pr.confusion_matrix = np.array([[2, 0, 2],
                                    [0, 1, 0],
                                    [1, 0, 0]])
    print(pr.compute())

Writing /content/drive/MyDrive/VQA/code/torchlight/metric.py


###### module.py

In [21]:
%%writefile /content/drive/MyDrive/VQA/code/torchlight/module.py
import math
from typing import Sequence, Union, Callable
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

torch.manual_seed(10086)
# typing, everything in Python is Object.
tensor_activation = Callable[[torch.Tensor], torch.Tensor]


class LSTM4VarLenSeq(nn.Module):
    def __init__(self, input_size, hidden_size,
                 num_layers=1, bias=True, bidirectional=False, init='orthogonal', take_last=True):
        """
        no dropout support
        batch_first support deprecated, the input and output tensors are
        provided as (batch, seq_len, feature).

        Args:
            input_size:
            hidden_size:
            num_layers:
            bias:
            bidirectional:
            init: ways to init the torch.nn.LSTM parameters,
                supports 'orthogonal' and 'uniform'
            take_last: 'True' if you only want the final hidden state
                otherwise 'False'
        """
        super(LSTM4VarLenSeq, self).__init__()
        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=hidden_size,
                            num_layers=num_layers,
                            bias=bias,
                            bidirectional=bidirectional)
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bias = bias
        self.bidirectional = bidirectional
        self.init = init
        self.take_last = take_last
        self.batch_first = True  # Please don't modify this

        self.init_parameters()

    def init_parameters(self):
        """orthogonal init yields generally good results than uniform init"""
        if self.init == 'orthogonal':
            gain = 1  # use default value
            for nth in range(self.num_layers * self.bidirectional):
                # w_ih, (4 * hidden_size x input_size)
                nn.init.orthogonal_(self.lstm.all_weights[nth][0], gain=gain)
                # w_hh, (4 * hidden_size x hidden_size)
                nn.init.orthogonal_(self.lstm.all_weights[nth][1], gain=gain)
                # b_ih, (4 * hidden_size)
                nn.init.zeros_(self.lstm.all_weights[nth][2])
                # b_hh, (4 * hidden_size)
                nn.init.zeros_(self.lstm.all_weights[nth][3])
        elif self.init == 'uniform':
            k = math.sqrt(1 / self.hidden_size)
            for nth in range(self.num_layers * self.bidirectional):
                nn.init.uniform_(self.lstm.all_weights[nth][0], -k, k)
                nn.init.uniform_(self.lstm.all_weights[nth][1], -k, k)
                nn.init.zeros_(self.lstm.all_weights[nth][2])
                nn.init.zeros_(self.lstm.all_weights[nth][3])
        else:
            raise NotImplemented('Unsupported Initialization')

    def forward(self, x, x_len, hx=None):
        # 1. Sort x and its corresponding length
        sorted_x_len, sorted_x_idx = torch.sort(x_len, descending=True)
        sorted_x = x[sorted_x_idx]
        # 2. Ready to unsort after LSTM forward pass
        # Note that PyTorch 0.4 has no argsort, but PyTorch 1.0 does.
        _, unsort_x_idx = torch.sort(sorted_x_idx, descending=False)

        # 3. Pack the sorted version of x and x_len, as required by the API.
        x_emb = pack_padded_sequence(sorted_x, sorted_x_len,
                                     batch_first=self.batch_first)

        # 4. Forward lstm
        # output_packed.data.shape is (valid_seq, num_directions * hidden_dim).
        # See doc of torch.nn.LSTM for details.
        out_packed, (hn, cn) = self.lstm(x_emb)

        # 5. unsort h
        # (num_layers * num_directions, batch, hidden_size) -> (batch, ...)
        hn = hn.permute(1, 0, 2)[unsort_x_idx]  # swap the first two dim
        hn = hn.permute(1, 0, 2)  # swap the first two again to recover
        if self.take_last:
            return hn.squeeze(0)
        else:
            # unpack: out
            # (batch, max_seq_len, num_directions * hidden_size)
            out, _ = pad_packed_sequence(out_packed,
                                         batch_first=self.batch_first)
            out = out[unsort_x_idx]
            # unpack: c
            # (num_layers * num_directions, batch, hidden_size) -> (batch, ...)
            cn = cn.permute(1, 0, 2)[unsort_x_idx]  # swap the first two dim
            cn = cn.permute(1, 0, 2)  # swap the first two again to recover
            return out, (hn, cn)


if __name__ == '__main__':
    # Note that in the future we will import unittest
    # and port the following examples to test folder.

    # Unit test for LSTM variable length sequences
    # ================
    net = LSTM4VarLenSeq(200, 100,
                         num_layers=3, bias=True, bidirectional=True, init='orthogonal', take_last=False)

    inputs = torch.tensor([[1, 2, 3, 0],
                           [2, 3, 0, 0],
                           [2, 4, 3, 0],
                           [1, 4, 3, 0],
                           [1, 2, 3, 4]])
    embedding = nn.Embedding(num_embeddings=5, embedding_dim=200, padding_idx=0)
    lens = torch.LongTensor([3, 2, 3, 3, 4])

    input_embed = embedding(inputs)
    output, (h, c) = net(input_embed, lens)
    # 5, 4, 200, batch, seq length, hidden_size * 2 (only last layer)
    print(output.shape)
    # 6, 5, 100, num_layers * num_directions, batch, hidden_size
    print(h.shape)
    # 6, 5, 100, num_layers * num_directions, batch, hidden_size
    print(c.shape)

Writing /content/drive/MyDrive/VQA/code/torchlight/module.py


###### utils.py

In [22]:
%%writefile /content/drive/MyDrive/VQA/code/torchlight/utils.py
"""
Utilizations for common usages.
"""
import os
import random
import torch
import numpy as np
from difflib import SequenceMatcher
from unidecode import unidecode
from datetime import datetime
from torch.nn.parallel import DataParallel, DistributedDataParallel


def personal_display_settings():
    """
    Pandas Doc
    https://pandas.pydata.org/pandas-docs/stable/generated/pandas.set_option.html
    NumPy Doc
        -
    """
    from pandas import set_option
    set_option('display.max_rows', 500)
    set_option('display.max_columns', 500)
    set_option('display.width', 2000)
    set_option('display.max_colwidth', 1000)
    from numpy import set_printoptions
    set_printoptions(suppress=True)


def set_seed(seed):
    """
    Freeze every seed for reproducibility.
    torch.cuda.manual_seed_all is useful when using random generation on GPUs.
    e.g. torch.cuda.FloatTensor(100).uniform_()
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def normalize(s):
    """
    German and Frence have different vowels than English.
    This utilization removes all the non-unicode characters.
    Example:
        āáǎà  -->  aaaa
        ōóǒò  -->  oooo
        ēéěè  -->  eeee
        īíǐì  -->  iiii
        ūúǔù  -->  uuuu
        ǖǘǚǜ  -->  uuuu

    :param s: unicode string
    :return:  unicode string with regular English characters.
    """
    s = s.strip().lower()
    s = unidecode(s)
    return s


def snapshot(model, epoch, save_path):
    """
    Saving models w/ its params.
        Get rid of the ONNX Protocal.
    F-string feature new in Python 3.6+ is used.
    """
    os.makedirs(save_path, exist_ok=True)
    # timestamp = datetime.now().strftime('%m%d_%H%M')
    save_path = os.path.join(save_path, f'{type(model).__name__}_{epoch}_epoch.pkl')
    if isinstance(model, (DataParallel, DistributedDataParallel)):
        torch.save(model.module.state_dict(), save_path)
    else:
        torch.save(model.state_dict(), save_path)
    return save_path


def save_checkpoint(model, optimizer, epoch, path):
    torch.save({
        'epoch': epoch,
        'models': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }, path)


def load_checkpoint(path, map_location):
    checkpoint = torch.load(path, map_location=map_location)
    return checkpoint


def show_params(model):
    """
    Show models parameters for logging.
    """
    for name, param in model.named_parameters():
        print('%-16s' % name, param.size())


def longest_substring(str1, str2):
    # initialize SequenceMatcher object with input string
    seqMatch = SequenceMatcher(None, str1, str2)

    # find match of longest sub-string
    # output will be like Match(a=0, b=0, size=5)
    match = seqMatch.find_longest_match(0, len(str1), 0, len(str2))

    # print longest substring
    return str1[match.a: match.a + match.size] if match.size != 0 else ""


def pad(sent, max_len):
    """
    syntax "[0] * int" only works properly for Python 3.5+
    Note that in testing time, the length of a sentence
    might exceed the pre-defined max_len (of training data).
    """
    length = len(sent)
    return (sent + [0] * (max_len - length))[:max_len] if length < max_len else sent[:max_len]


def to_cuda(*args, device=None):
    """
    Move Tensors to CUDA.
    If no device provided, default to the first card in CUDA_VISIBLE_DEVICES.
    """
    assert all(torch.is_tensor(t) for t in args), \
        'Only support for tensors, please check if any nn.Module exists.'
    if device is None:
        device = torch.device('cuda:0')
    return [None if x is None else x.to(device) for x in args]


def get_code_version(short_sha=True):
    from subprocess import check_output, STDOUT, CalledProcessError
    try:
        sha = check_output('git rev-parse HEAD', stderr=STDOUT,
                           shell=True, encoding='utf-8')
        if short_sha:
            sha = sha[:7]
        return sha
    except CalledProcessError:
        # There was an error - command exited with non-zero code
        pwd = check_output('pwd', stderr=STDOUT, shell=True, encoding='utf-8')
        pwd = os.path.abspath(pwd).strip()
        print(f'Working dir {pwd} is not a git repo.')


def cat_ragged_tensors(left, right):
    assert left.size(0) == right.size(0)
    batch_size = left.size(0)
    max_len = left.size(1) + right.size(1)

    len_left = (left != 0).sum(dim=1)
    len_right = (right != 0).sum(dim=1)

    left_seq = left.unbind()
    right_seq = right.unbind()
    # handle zero padding
    output = torch.zeros((batch_size, max_len), dtype=torch.long, device=left.device)
    for i, row_left, row_right, l1, l2 in zip(range(batch_size),
                                              left_seq, right_seq,
                                              len_left, len_right):
        l1 = l1.item()
        l2 = l2.item()
        j = l1 + l2
        # concatenate rows of ragged tensors
        row_cat = torch.cat((row_left[:l1], row_right[:l2]))
        # copy to empty tensor
        output[i, :j] = row_cat
    return output


def topk_accuracy(inputs, labels, k=1, largest=True):
    assert len(inputs.size()) == 2
    assert len(labels.size()) == 2
    _, indices = inputs.topk(k=k, largest=largest)
    result = indices - labels  # boardcast
    nonzero_count = (result != 0).sum(dim=1, keepdim=True)
    num_correct = (nonzero_count != result.size(1)).sum().item()
    num_example = inputs.size(0)
    return num_correct, num_example


def get_total_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


if __name__ == '__main__':
    print(normalize('ǖǘǚǜ'))

Writing /content/drive/MyDrive/VQA/code/torchlight/utils.py


###### vocab.py

In [23]:
%%writefile /content/drive/MyDrive/VQA/code/torchlight/vocab.py
# coding: utf-8
"""
Every NLP task needs a Vocabulary
Every Vocabulary is built from Instances
Every Instance is a collection of Fields
"""

__all__ = ['DefaultLookupDict', 'Vocabulary']

PAD_TOKEN = '<pad>'
UNK_TOKEN = '<unk>'
BOS_TOKEN = '<bos>'
EOS_TOKEN = '<eos>'
PAD_IDX = 0
UNK_IDX = 1


class DefaultLookupDict(dict):
    def __init__(self, default):
        super(DefaultLookupDict, self).__init__()
        self._default = default

    def __getitem__(self, item):
        return self.get(item, self._default)


class Vocabulary:
    """
    Define a vocabulary object that will be used to numericalize a field.
    Attributes:
        token2id: A collections.defaultdict instance mapping token strings to
            numerical identifiers.
        id2token: A list of token strings indexed by their numerical
        identifiers.
        embedding: pretrained vectors.

    Examples:
    >>> from torchlight.vocab import Vocabulary
    >>> from collections import Counter
    >>> text_data = ['hello', 'world', 'hello', 'nice', 'world', 'hi', 'world']
    >>> vocab = Vocabulary(Counter(text_data))
    """
    def __init__(self, counter, max_size=None, min_freq=1, specials=None):
        """
        Create a Vocabulary given Counter.
        Args:
            counter: collections.Counter object holding the frequencies of
                each value found in the data.
            max_size: The maximum size of the vocabulary, or None for no
                maximum. Default: None.
            min_freq: The minimum frequency needed to include a token in the
                vocabulary. Values less than 1 will be set to 1. Default: 1.
            specials: The list of special tokens except ['<pad>', '<unk>'].
                Possible choices: [CLS] [MASK] [SEP] in BERT or <bos> <eos>
                in Machine Translation.
        """
        min_freq = max(min_freq, 1)  # must be positive

        if specials is None:
            self.specials = [PAD_TOKEN, UNK_TOKEN]
        else:
            assert isinstance(specials, list), "'specials' is of type list"
            self.specials = [PAD_TOKEN, UNK_TOKEN] + specials

        assert len(set(self.specials)) == len(self.specials), \
            "specials can not contain duplicates."

        if max_size is not None:
            max_size = len(self.specials) + max_size

        self.id2token = self.specials[:]
        self.token2id = DefaultLookupDict(UNK_IDX)
        self.token2id.update({tok: i for i, tok in enumerate(self.id2token)})

        # sort by frequency, then alphabetically
        token_freqs = sorted(counter.items(), key=lambda tup: tup[0])
        token_freqs.sort(key=lambda tup: tup[1], reverse=True)

        for token, freq in token_freqs:
            if freq < min_freq or len(self.id2token) == max_size:
                break
            if token not in self.specials:
                self.id2token.append(token)
                self.token2id[token] = len(self.id2token) - 1

        # TODO
        self.embedding = None

    def __len__(self):
        return len(self.id2token)

    def __repr__(self):
        return 'Vocab(size={}, specials="{}")'.format(len(self), self.specials)

    def __getitem__(self, tokens):
        """Looks up indices of text tokens according to the vocabulary.
        If `unknown_token` of the vocabulary is None, looking up unknown tokens
        results in KeyError.
        Parameters
        ----------
        tokens : str or list of strs
            A source token or tokens to be converted.
        Returns
        -------
        int or list of ints
            A token index or a list of token indices according to the vocabulary.
        """

        if not isinstance(tokens, (list, tuple)):
            return self.token2id[tokens]
        else:
            return [self.token2id[token] for token in tokens]

    def __call__(self, tokens):
        """Looks up indices of text tokens according to the vocabulary.
        Parameters
        ----------
        tokens : str or list of strs
            A source token or tokens to be converted.
        Returns
        -------
        int or list of ints
            A token index or a list of token indices according to the
            vocabulary.
        """

        return self[tokens]

    @classmethod
    def from_json(cls, json_str):
        pass

    def to_json(self):
        pass

    def set_embedding(self):
        pass

Writing /content/drive/MyDrive/VQA/code/torchlight/vocab.py


## Utils

###### \_\_init__.py

In [16]:
%%writefile /content/drive/MyDrive/VQA/code/utils/__init__.py
from .tool import unseen_mask, freeze_layer, cosine_sim, batch_accuracy, instance_bce_with_logits, dele_a, transfer, hand_remove, deal_fact
from .metrics import Metrics

Overwriting /content/drive/MyDrive/VQA/code/utils/__init__.py


###### metrics.py

In [14]:
%%writefile /content/drive/MyDrive/VQA/code/utils/metrics.py
import os
import json

import torch
import torch.nn as nn
import time
import pdb


class Metrics:
    """
    Stores accuracy (score), loss and timing info
    """

    def __init__(self, topnum=10):

        self.topnum = topnum
        self.total_loss = 0

        self.correct_1 = 0
        self.correct_3 = 0
        self.correct_10 = 0
        self.acc_all = 0
        self.acc_1 = 0
        self.acc_3 = 0
        self.acc_10 = 0
        self.num_examples = 0
        self.num_epoch = 0

        self.mrr = 0
        self.mr = 0
        self.mrr_all = 0
        self.mr_all = 0

    def update_per_batch(self, loss, answers, pred):
        self.total_loss += loss
        self.num_epoch += 1
        if self.topnum == 10:
            top1, top3, top10 = self.batch_accuracy_10(pred, answers)
        elif self.topnum == 50:
            top1, top3, top10 = self.batch_accuracy_50(pred, answers)
        elif self.topnum == 200:
            top1, top3, top10 = self.batch_accuracy_200(pred, answers)
        self.num_examples += top1.shape[0]

        self.correct_1 += top1.sum().item()
        self.correct_3 += top3.sum().item()
        self.correct_10 += top10.sum().item()

        #
        mrr_tmp, mr_tmp = self.batch_mr_mrr(pred, answers)
        self.mrr_all += mrr_tmp.sum().item()
        self.mr_all += mr_tmp.sum().item()

    def update_per_epoch(self):
        self.acc_1 = 100 * (self.correct_1 / self.num_examples)
        self.acc_3 = 100 * (self.correct_3 / self.num_examples)
        self.acc_10 = 100 * (self.correct_10 / self.num_examples)

        self.mr = self.mr_all / self.num_examples
        self.mrr = self.mrr_all / self.num_examples

        self.total_loss = self.total_loss / self.num_epoch
        self.acc_all = self.acc_1 + self.acc_3 + self.acc_10

    def batch_accuracy_10(self, predicted, true):
        """ Compute the accuracies for a batch of predictions and answers """
        # (Pdb) predicted.shape
        # torch.Size([128, 500])
        # (Pdb) true.shape
        # torch.Size([128, 500])
        if len(true.shape) == 3:
            true = true[0]
        _, ok = predicted.topk(10, dim=1)
        agreeing_all = torch.zeros([predicted.shape[0], 1], dtype=torch.float).cuda()
        for i in range(10):
            tmp = ok[:, i].reshape(-1, 1)
            agreeing_all += true.gather(dim=1, index=tmp)
            if i == 0:
                top1 = (agreeing_all * 0.3).clamp(max=1)
            if i == 2:
                top3 = (agreeing_all * 0.3).clamp(max=1)
            if i == 9:
                top10 = (agreeing_all * 0.3).clamp(max=1)
        return top1, top3, top10

    def batch_accuracy_50(self, predicted, true):
        """ Compute the accuracies for a batch of predictions and answers """
        if len(true.shape) == 3:
            true = true[0]
        _, ok = predicted.topk(50, dim=1)
        agreeing_all = torch.zeros([predicted.shape[0], 1], dtype=torch.float).cuda()
        for i in range(50):
            tmp = ok[:, i].reshape(-1, 1)
            agreeing_all += true.gather(dim=1, index=tmp)
            if i == 9:
                top10 = (agreeing_all * 0.3).clamp(max=1)
            if i == 29:
                top30 = (agreeing_all * 0.3).clamp(max=1)
            if i == 49:
                top50 = (agreeing_all * 0.3).clamp(max=1)

        return top10, top30, top50

    def batch_accuracy_200(self, predicted, true):
        """ Compute the accuracies for a batch of predictions and answers """
        if len(true.shape) == 3:
            true = true[0]
        _, ok = predicted.topk(200, dim=1)
        agreeing_all = torch.zeros([predicted.shape[0], 1], dtype=torch.float).cuda()
        for i in range(200):
            tmp = ok[:, i].reshape(-1, 1)
            agreeing_all += true.gather(dim=1, index=tmp)
            if i == 79:
                top10 = (agreeing_all * 0.3).clamp(max=1)
            if i == 149:
                top30 = (agreeing_all * 0.3).clamp(max=1)
            if i == 199:
                top50 = (agreeing_all * 0.3).clamp(max=1)

        return top10, top30, top50

    def batch_mr_mrr(self, predicted, true):
        if len(true.shape) == 3:
            true = true[0]

        # 计算
        top_rank = predicted.shape[1]
        batch_size = predicted.shape[0]
        _, predict_ans_rank = predicted.topk(top_rank, dim=1)  # 答案排名的坐标 batchsize * 500
        _, real_ans = true.topk(1, dim=1)  # 真正的答案：batchsize * 1

        # 扩充维度
        real_ans = real_ans.expand(batch_size, top_rank)
        ans_different = torch.abs(predict_ans_rank - real_ans)
        # 此时为0的位置就是预测正确的位置
        _, real_ans_list = ans_different.topk(top_rank, dim=1)  # 此时最后一位的数值就是正确答案在预测答案里面的位置,为 0
        real_ans_list = real_ans_list + 1.0
        mr = real_ans_list[:, -1].reshape(-1, 1).to(torch.float64)
        mrr = 1.0 / mr
        # pdb.set_trace()

        return mrr, mr

    # def print(self, epoch):
    #     print("Epoch {} Score {:.2f} Loss {}".format(epoch, 100 * self.raw_score / self.num_examples,
    #                                                  self.loss / self.num_examples))


# def accumulate_metrics(epoch, train_metrics, val_metrics, val_per_type_metric,
#                        best_val_score,
#                        best_val_epoch, save_val_metrics=True):
#     stats = {
#         "epoch": epoch,

#         "train_loss": float(train_metrics.loss),
#         "train_raw_score": float(train_metrics.raw_score),
#         "train_normalized_score": float(train_metrics.normalized_score),
#         "train_upper_bound": float(train_metrics.upper_bound),
#         "train_score": float(train_metrics.score),
#         "train_num_examples": train_metrics.num_examples,

#         "train_time": train_metrics.end_time - train_metrics.start_time,
#         "val_time": val_metrics.end_time - val_metrics.start_time
#     }
#     if save_val_metrics:
#         stats["val_raw_score"] = float(val_metrics.raw_score)
#         stats["val_normalized_score"] = float(val_metrics.normalized_score)
#         stats["val_upper_bound"] = float(val_metrics.upper_bound)
#         stats["val_loss"] = float(val_metrics.loss)
#         stats["val_score"] = float(val_metrics.score)
#         stats["val_num_examples"] = val_metrics.num_examples
#         stats["val_per_type_metric"] = val_per_type_metric.get_json()

#         stats["best_val_score"] = float(best_val_score)
#         stats["best_epoch"] = best_val_epoch

#     print(json.dumps(stats, indent=4))
#     return stats

Overwriting /content/drive/MyDrive/VQA/code/utils/metrics.py


###### tool.py

In [15]:
%%writefile /content/drive/MyDrive/VQA/code/utils/tool.py
import os
import json

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from nltk import word_tokenize, pos_tag
from nltk.corpus import wordnet
from nltk.stem import WordNetLemmatizer
import pdb


def instance_bce_with_logits(logits, labels):
    assert logits.dim() == 2

    loss = nn.functional.binary_cross_entropy_with_logits(logits, labels)
    loss *= labels.size(1)
    return loss


def freeze_layer(layer):
    for param in layer.parameters():
        param.requires_grad = False


def unseen_mask(args, val_loader):
    negtive_mux = None
    # zsl
    if args.ZSL == 1:
        negtive_mux = torch.ones(args.TEST.batch_size, args.FVQA.max_ans)
        indices = val_loader.dataset.answer_indices
        all_ans = set(aid for aids in indices for aid in aids)

        # unseen 类置0
        for i in all_ans:
            for j in range(args.TRAIN.batch_size):
                negtive_mux[j, i] = 0
        negtive_mux = negtive_mux * (-1e13)
        negtive_mux = negtive_mux.cuda()
        # pdb.set_trace()
    return negtive_mux


def cosine_sim(im, s):
    return im.mm(s.t())


def batch_mc_acc(predicted):
    """ Compute the accuracies for a batch of predictions and answers """
    N, C = predicted.squeeze().size()
    _, predicted_index = predicted.max(dim=1, keepdim=True)
    return (predicted_index == C - 1).float()


def batch_top1(predicted, true):
    """ Compute the accuracies for a batch of predictions and answers """
    _, predicted_index = predicted.max(dim=1, keepdim=True)
    return true.gather(dim=1, index=predicted_index).clamp(max=1)


def batch_accuracy(predicted, true):
    """ Compute the accuracies for a batch of predictions and answers """
    # import pdb
    # pdb.set_trace()
    # _, predicted_index = predicted.max(dim=1, keepdim=True)
    # agreeing = true[0].gather(dim=1, index=predicted_index)
    # return (agreeing * 0.3).clamp(max=1)
    if len(true.shape) == 3:
        true = true[0]
    _, ok = predicted.topk(10, dim=1)
    agreeing_all = torch.zeros([predicted.shape[0], 1], dtype=torch.float).cuda()
    for i in range(10):
        tmp = ok[:, i].reshape(-1, 1)
        agreeing_all += true.gather(dim=1, index=tmp)
        if i == 0:
            top1 = (agreeing_all * 0.3).clamp(max=1)
        if i == 2:
            top3 = (agreeing_all * 0.3).clamp(max=1)
        if i == 9:
            top10 = (agreeing_all * 0.3).clamp(max=1)

    top1 = top1.sum().item() / top1.shape[0]
    top3 = top3.sum().item() / top3.shape[0]
    top10 = top10.sum().item() / top10.shape[0]
    return top1, top3, top10


# def update_learning_rate(optimizer, epoch):
#     learning_rate = cfg.TRAIN.base_lr * 0.5 ** (float(epoch) / cfg.TRAIN.lr_decay)
#     for param_group in optimizer.param_groups:
#         param_group['lr'] = learning_rate
#
#     return learning_rate


class Tracker:
    """ Keep track of results over time, while having access to monitors to display information about them. """

    def __init__(self):
        self.data = {}

    def track(self, name, *monitors):
        """ Track a set of results with given monitors under some name (e.g. 'val_acc').
            When appending to the returned list storage, use the monitors to retrieve useful information.
        """
        l = Tracker.ListStorage(monitors)
        self.data.setdefault(name, []).append(l)
        return l

    def to_dict(self):
        # turn list storages into regular lists
        return {k: list(map(list, v)) for k, v in self.data.items()}

    class ListStorage:
        """ Storage of data points that updates the given monitors """

        def __init__(self, monitors=[]):
            self.data = []
            self.monitors = monitors
            for monitor in self.monitors:
                setattr(self, monitor.name, monitor)

        def append(self, item):
            for monitor in self.monitors:
                monitor.update(item)
            self.data.append(item)

        def __iter__(self):
            return iter(self.data)

    class MeanMonitor:
        """ Take the mean over the given values """
        name = 'mean'

        def __init__(self):
            self.n = 0
            self.total = 0

        def update(self, value):
            self.total += value
            self.n += 1

        @property
        def value(self):
            return self.total / self.n

    class MovingMeanMonitor:
        """ Take an exponentially moving mean over the given values """
        name = 'mean'

        def __init__(self, momentum=0.9):
            self.momentum = momentum
            self.first = True
            self.value = None

        def update(self, value):
            if self.first:
                self.value = value
                self.first = False
            else:
                m = self.momentum
                self.value = m * self.value + (1 - m) * value


class data_prefetcher():
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.preload()

    def preload(self):
        try:
            self.next_features, self.next_targets, _ = next(self.loader)
        except StopIteration:
            self.next_features = None
            self.next_targets = None
            return
        # self.next_features_gpu = []
        # self.next_targets_gpu = {}
        # for xaf in self.next_features:
        #     self.next_features_gpu.append(torch.empty_like(xaf, device='cuda'))
        # for key in self.next_targets.keys():
        #     self.next_targets_gpu[key] = torch.empty_like(self.next_targets[key], device='cuda')
        # self.stream.wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(self.stream):
            self.next_features = [single_feature.cuda(non_blocking=True) for single_feature in self.next_features]
            if isinstance(self.next_targets, dict):
                for key in self.next_targets.keys():
                    self.next_targets[key] = self.next_targets[key].cuda(non_blocking=True)
            else:
                self.next_targets = [single_target.cuda(non_blocking=True) for single_target in self.next_targets]
            # for index in range(len(self.next_features_gpu)):
            #     self.next_features_gpu[index].copy_(self.next_features[index], non_blocking=True)
            # for key in self.next_targets_gpu.keys():
            #     self.next_targets_gpu[key].copy_(self.next_targets[key], non_blocking=True)

    def next(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        # features = self.next_features_gpu
        # targets = self.next_targets_gpu
        features = self.next_features
        targets = self.next_targets
        if features is not None:
            features = [xaf.record_stream(torch.cuda.current_stream()) for xaf in features]
        if targets is not None:
            targets = [targets[xaf].record_stream(torch.cuda.current_stream()) for xaf in targets.keys()]
        self.preload()
        return features, targets


def get_transform(target_size, central_fraction=1.0):
    return transforms.Compose([
        transforms.Scale(int(target_size / central_fraction)),
        transforms.CenterCrop(target_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])


def dele_a(answer):  # 去冠词
    answer_t = answer.replace('.', '')
    answer_tt = answer_t.replace(',', '')
    answer_ttt = answer_tt.replace("the ", "")
    answer_tttt = answer_ttt.replace("an ", "")
    answer_ttttt = answer_tttt.replace("a ", "")
    ans_list = [answer_t, answer_tt, answer_ttt, answer_tttt, answer_ttttt]

    return list(set(ans_list))


def transfer(answer):  # 单复数转换
    tokens = word_tokenize(answer)
    tagged_sent = pos_tag(tokens)
    wnl = WordNetLemmatizer()

    new = []
    for tag in tagged_sent:
        if tag[0] == "as":
            new.append("as")
            continue
        elif tag[0] == "grazing" or tag[0] == "timing" or tag[0] == "bicycling":
            kk = tag[0].replace("ing", "") + "e"
            new.append(kk)
            continue
        wordnet_pos = get_wordnet_pos(tag[1]) or wordnet.NOUN
        tmp = wnl.lemmatize(tag[0], pos=wordnet_pos)
        if tmp == "ax":
            tmp = "axe"
        elif tmp == "people":
            tmp = "person"
        elif tmp == "teeth":
            tmp = "tooth"
        elif tmp == "worn":
            tmp = "wear"
        new.append(tmp)  # 词形还原
    string = ' '
    key = string.join(new)
    return key


def hand_remove(answer):  # 手动去ing，s，es
    _ing = answer.replace("ing", "")
    __ing = answer.replace("ing ", " ")
    _s = answer.replace("s", "")
    __s = answer.replace("s ", " ")
    _es = answer.replace("es", "")
    __es = answer.replace("es ", " ")
    _er = answer.replace("er", "")
    __er = answer.replace("er ", " ")
    return list(set([_ing, _s, _es, _er, __ing, __s, __es, __er]))


def get_wordnet_pos(tag):
    if tag.startswith('J'):
        return wordnet.ADJ
    elif tag.startswith('V'):
        return wordnet.VERB
    elif tag.startswith('N'):
        return wordnet.NOUN
    elif tag.startswith('R'):
        return wordnet.ADV
    else:
        return None


def deal_fact(dic, fact):
    fact = fact.split('/')
    if fact[-1] == "n" or fact[-1] == "v":
        ans = fact[-2]
    else:
        ans = fact[-1]

    ans = ans.split(':')
    if ans[0] == "Category":
        ans = ans[1]
    else:
        ans = ans[0]

    # if ans[-1] == ")":
    #     # ans = ans.split("(")[0]
    #     pdb.set_trace()
    #     ans = dic["answer"]
    return ans

Overwriting /content/drive/MyDrive/VQA/code/utils/tool.py


## Bash Script

###### run_FVQA_train.sh

In [None]:
%%writefile /content/drive/MyDrive/VQA/code/bash_script/run_FVQA_train.sh
data=3;
python code/main.py --gpu_id 1 --exp_name knowledge_space --exp_id W2V --fusion_model SAN --data_choice 3 --method_choice W2V --save_model 1
python code/main.py --gpu_id 1 --exp_name semantic_space --exp_id W2V --fusion_model SAN --data_choice 3 --method_choice W2V  --save_model 1 --relation_map 1
python code/main.py --gpu_id 1 --exp_name object_space --exp_id W2V --fusion_model SAN --data_choice 3 --method_choice W2V  --save_model 1 --fact_map 1

###### run_FVQA.sh

In [None]:
%%writefile /content/drive/MyDrive/VQA/code/bash_script/run_FVQA.sh
data=3;
ke=10;
kr=3;
score=10;
zsl=0;
python joint_test.py --gpu_id 1 --exp_name fusion_prediction --ZSL "${zsl}" --exp_id rel"${kr}"_fact"${ke}"data_"${data}"score_"${score}" --data_choice "${data}" --top_rel "${kr}" --top_fact "${ke}" --soft_score "${score}"  --mrr 1

# RUN

In [12]:
!python code/main.py --gpu_id 1 --exp_name knowledge_space --exp_id W2V --fusion_model SAN --data_choice 3 --method_choice W2V --save_model 1

2026-01-30 02:26:02.366887: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2026-01-30 02:26:02.953390: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2026-01-30 02:26:03.255784: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769739963.659626    5312 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769739963.748395    5312 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769739963.785135    5312 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linkin

# GITHUB


In [None]:
# %%bash
# TOKEN=$(cat /content/token.txt)
# export MY_ENV_VAR_NAME_1="$TOKEN"
# echo "Environment variable MY_ENV_VAR_NAME set to: $MY_ENV_VAR_NAME_1"

In [None]:
# !TOKEN=$(cat /content/token.txt)

In [None]:
# !export TOKEN="$TOKEN"

In [None]:
# !echo "Environment variable MY_ENV_VAR_NAME set to: $TOKEN"

Environment variable MY_ENV_VAR_NAME set to: 


In [None]:
# !more /content/token.txt | gh auth login --with-token

In [None]:
# !value=$(cat config.txt)

In [None]:
%%writefile /content/drive/MyDrive/VQA/.gitignore
kg/
data/
run.ipynb

Overwriting /content/drive/MyDrive/VQA/.gitignore


In [None]:
!git config --global --unset user.name

In [None]:
!git config user.name "LTBach"

In [None]:
!git config --global --unset user.email

In [None]:
!git config user.email "bach1346790852@gmail.com"

In [None]:
!git remote add origin https://<USERNAME>:<PASSWORD>@github.com/<USERNAME>/reponame.git

In [None]:
!git add .

In [None]:
!git stage

Nothing specified, nothing added.
[33mhint: Maybe you wanted to say 'git add .'?[m
[33mhint: Turn this message off by running[m
[33mhint: "git config advice.addEmptyPathspec false"[m


In [None]:
!git commit -m "Add code"

[main 120de67] Add code
 2 files changed, 2 insertions(+), 1 deletion(-)


In [None]:
!git push -u origin main

Enumerating objects: 25, done.
Counting objects:   4% (1/25)Counting objects:   8% (2/25)Counting objects:  12% (3/25)Counting objects:  16% (4/25)Counting objects:  20% (5/25)Counting objects:  24% (6/25)Counting objects:  28% (7/25)Counting objects:  32% (8/25)Counting objects:  36% (9/25)Counting objects:  40% (10/25)Counting objects:  44% (11/25)Counting objects:  48% (12/25)Counting objects:  52% (13/25)Counting objects:  56% (14/25)Counting objects:  60% (15/25)Counting objects:  64% (16/25)Counting objects:  68% (17/25)Counting objects:  72% (18/25)Counting objects:  76% (19/25)Counting objects:  80% (20/25)Counting objects:  84% (21/25)Counting objects:  88% (22/25)Counting objects:  92% (23/25)Counting objects:  96% (24/25)Counting objects: 100% (25/25)Counting objects: 100% (25/25), done.
Delta compression using up to 2 threads
Compressing objects: 100% (21/21), done.
Writing objects: 100% (25/25), 39.52 KiB | 554.00 KiB/s, done.
Total 25 (delta 4), r