<a href="https://colab.research.google.com/github/JXqwq/dl_lecture_VQA/blob/main/FinalProject_VQA_V2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import re
import random
import time
from statistics import mode

from PIL import Image
import numpy as np
import pandas
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from transformers import BertModel, BertTokenizer

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip -o "/content/drive/My Drive/VQA/train.zip" -d "/content/drive/My Drive/data/train"

In [None]:
!unzip -o "/content/drive/My Drive/VQA/valid.zip" -d "/content/drive/My Drive/data/valid"

In [4]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def process_text(text):
    # lowercase
    text = text.lower()

    # 数詞を数字に変換
    num_word_to_digit = {
        'zero': '0', 'one': '1', 'two': '2', 'three': '3', 'four': '4',
        'five': '5', 'six': '6', 'seven': '7', 'eight': '8', 'nine': '9',
        'ten': '10'
    }
    for word, digit in num_word_to_digit.items():
        text = text.replace(word, digit)

    # 小数点のピリオドを削除
    text = re.sub(r'(?<!\d)\.(?!\d)', '', text)

    # 冠詞の削除
    text = re.sub(r'\b(a|an|the)\b', '', text)

    # 短縮形のカンマの追加
    contractions = {
        "dont": "don't", "isnt": "isn't", "arent": "aren't", "wont": "won't",
        "cant": "can't", "wouldnt": "wouldn't", "couldnt": "couldn't"
    }
    for contraction, correct in contractions.items():
        text = text.replace(contraction, correct)

    # 句読点をスペースに変換
    text = re.sub(r"[^\w\s':]", ' ', text)

    # 句読点をスペースに変換
    text = re.sub(r'\s+,', ',', text)

    # 連続するスペースを1つに変換
    text = re.sub(r'\s+', ' ', text).strip()

    return text


# 1. データローダーの作成
class VQADataset(torch.utils.data.Dataset):
    def __init__(self, df_path, image_dir, transform=None, answer=True):
        self.transform = transform  # 画像の前処理
        self.image_dir = image_dir  # 画像ファイルのディレクトリ
        self.df = pandas.read_json(df_path)  # 画像ファイルのパス，question, answerを持つDataFrame
        self.answer = answer

        # question / answerの辞書を作成
        #self.question2idx = {}
        self.answer2idx = {}
        #self.idx2question = {}
        self.idx2answer = {}

        # 質問文に含まれる単語を辞書に追加
        #for question in self.df["question"]:
        #   question = process_text(question)
        #    words = question.split(" ")
        #    for word in words:
        #        if word not in self.question2idx:
        #            self.question2idx[word] = len(self.question2idx)
        #self.idx2question = {v: k for k, v in self.question2idx.items()}  # 逆変換用の辞書(question)

        if self.answer:
            # 回答に含まれる単語を辞書に追加
            for answers in self.df["answers"]:
                for answer in answers:
                    word = answer["answer"]
                    word = process_text(word)
                    if word not in self.answer2idx:
                        self.answer2idx[word] = len(self.answer2idx)
            self.idx2answer = {v: k for k, v in self.answer2idx.items()}  # 逆変換用の辞書(answer)

    def update_dict(self, dataset):
        """
        検証用データ，テストデータの辞書を訓練データの辞書に更新する．

        Parameters
        ----------
        dataset : Dataset
            訓練データのDataset
        """
        #self.question2idx = dataset.question2idx
        self.answer2idx = dataset.answer2idx
        #self.idx2question = dataset.idx2question
        self.idx2answer = dataset.idx2answer

    def __getitem__(self, idx):
        """
        対応するidxのデータ（画像，質問，回答）を取得．

        Parameters
        ----------
        idx : int
            取得するデータのインデックス

        Returns
        -------
        image : torch.Tensor  (C, H, W)
            画像データ
        question : torch.Tensor  (vocab_size)
            質問文をone-hot表現に変換したもの
        answers : torch.Tensor  (n_answer)
            10人の回答者の回答のid
        mode_answer_idx : torch.Tensor  (1)
            10人の回答者の回答の中で最頻値の回答のid
        """
        image = Image.open(f"{self.image_dir}/{self.df['image'][idx]}")
        image = self.transform(image)

        #question = np.zeros(len(self.idx2question) + 1)  # 未知語用の要素を追加
        #question_words = self.df["question"][idx].split(" ")
        #for word in question_words:
        #    try:
        #        question[self.question2idx[word]] = 1  # one-hot表現に変換
        #    except KeyError:
        #        question[-1] = 1  # 未知語
        question = self.df["question"][idx]

        if self.answer:
            answers = [self.answer2idx[process_text(answer["answer"])] for answer in self.df["answers"][idx]]
            mode_answer_idx = mode(answers)  # 最頻値を取得（正解ラベル）

            return image, question, torch.Tensor(answers), int(mode_answer_idx)

        else:
            return image, question

    def __len__(self):
        return len(self.df)

In [5]:
# 2. 評価指標の実装
# 簡単にするならBCEを利用する
def VQA_criterion(batch_pred: torch.Tensor, batch_answers: torch.Tensor):
    total_acc = 0.

    for pred, answers in zip(batch_pred, batch_answers):
        acc = 0.
        for i in range(len(answers)):
            num_match = 0
            for j in range(len(answers)):
                if i == j:
                    continue
                if pred == answers[j]:
                    num_match += 1
            acc += min(num_match / 3, 1)
        total_acc += acc / 10

    return total_acc / len(batch_pred)

In [6]:
# 3. モデルのの実装
# ResNetを利用できるようにしておく
# Modified: Use pretrained ResNet models & BERT for text preprocessing
# Reference: https://farooqsk.medium.com/introduction-to-visual-question-answering-in-pytorch-7b5cc61c86d, https://towardsdatascience.com/feature-extraction-with-bert-for-text-classification-533dde44dc2f

class VQAModel(nn.Module):
    def __init__(self, n_answer: int):
        super().__init__()

        # Fine tune the pretrained ResNet50Image Network
        self.resnet = torchvision.models.resnet50(pretrained=True)
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
          nn.Dropout(0.5),
          nn.Linear(num_ftrs, 512),
          nn.Tanh(),
          nn.Linear(512, 128),
          nn.Tanh(),
          nn.Linear(128, 32)
        )

        # Question Network
        self.text_encoder = BertModel.from_pretrained("distilbert-base-uncased")
        self.tokenizer = BertTokenizer.from_pretrained("distilbert-base-uncased")


        # Merge image and question output
        self.fc2 = nn.Linear(800, 400) # image features + BERT features = 800
        self.fc3 = nn.Linear(400, n_answer)

        self.dropout = nn.Dropout(0.5)

        #self.fc = nn.Sequential(
        #    nn.Linear(1024, 512),
        #    nn.ReLU(inplace=True),
        #    nn.Linear(512, n_answer)
        #)

    def forward(self, image, question):
        image_feature = self.resnet(image)  # 画像の特徴量

        question_tokens = self.tokenizer(question, padding = True, truncation = True, return_tensors="pt")
        question_tokens = {k: v.to(image.device) for k, v in question_tokens.items()}
        outputs = self.text_encoder(**question_tokens)
        question_feature = outputs.last_hidden_state[:,0,:] # テキストの特徴量

        act = nn.ReLU(inplace=True)

        x = torch.cat([image_feature, question_feature], dim=1)
        x = self.fc2(x)
        x = act(x)
        x = self.dropout(x)
        x = self.fc3(x)


        return x

In [7]:
# 4. 学習の実装
def train(model, dataloader, optimizer, criterion, device):
    model.train()

    total_loss = 0
    total_acc = 0
    simple_acc = 0

    start = time.time()
    for image, question, answers, mode_answer in dataloader:
        image, answer, mode_answer = \
            image.to(device), answers.to(device), mode_answer.to(device)

        pred = model(image, question)
        loss = criterion(pred, mode_answer.squeeze())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_acc += VQA_criterion(pred.argmax(1), answers)  # VQA accuracy
        simple_acc += (pred.argmax(1) == mode_answer).float().mean().item()  # simple accuracy

    return total_loss / len(dataloader), total_acc / len(dataloader), simple_acc / len(dataloader), time.time() - start


def eval(model, dataloader, optimizer, criterion, device):
    model.eval()

    total_loss = 0
    total_acc = 0
    simple_acc = 0

    start = time.time()
    for image, question, answers, mode_answer in dataloader:
        image, answer, mode_answer = \
            image.to(device), answers.to(device), mode_answer.to(device)

        pred = model(image, question)
        loss = criterion(pred, mode_answer.squeeze())

        total_loss += loss.item()
        total_acc += VQA_criterion(pred.argmax(1), answers)  # VQA accuracy
        simple_acc += (pred.argmax(1) == mode_answer).mean().item()  # simple accuracy

    return total_loss / len(dataloader), total_acc / len(dataloader), simple_acc / len(dataloader), time.time() - start

In [None]:
# deviceの設定
set_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"

# dataloader / model
transform = transforms.Compose([
    transforms.Resize((224, 224)), # Data augmentation: transforms.RandomHorizontalFlip()
    transforms.ToTensor()
      ])


# For colab
train_dataset = VQADataset(df_path="/content/drive/My Drive/VQA/train.json", image_dir="/content/drive/My Drive/data/train/train", transform=transform)
test_dataset = VQADataset(df_path="/content/drive/My Drive/VQA/valid.json", image_dir="/content/drive/My Drive/data/valid/valid", transform=transform, answer=False)
test_dataset.update_dict(train_dataset)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

model = VQAModel(n_answer=len(train_dataset.answer2idx)).to(device)


In [None]:
# optimizer / criterion
num_epoch = 5
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

# train model
for epoch in range(num_epoch):
    train_loss, train_acc, train_simple_acc, train_time = train(model, train_loader, optimizer, criterion, device)
    print(f"【{epoch + 1}/{num_epoch}】\n"
          f"train time: {train_time:.2f} [s]\n"
          f"train loss: {train_loss:.4f}\n"
          f"train acc: {train_acc:.4f}\n"
          f"train simple acc: {train_simple_acc:.4f}")
    path_name = f"/content/drive/My Drive/data/VQA_epoch{epoch + 1}.pth"
    torch.save(model, path_name)

In [None]:
model = torch.load("/content/drive/My Drive/data/VQA_epoch4.pth")
# 提出用ファイルの作成
model.eval()
submission = []
for image, question in test_loader:
    image, question = image.to(device), question
    pred = model(image, question)
    pred = pred.argmax(1).cpu().item()
    submission.append(pred)

submission = [train_dataset.idx2answer[id] for id in submission]
submission = np.array(submission)
#torch.save(model.state_dict(), "model.pth")
np.save("/content/drive/My Drive/submission.npy", submission)