# Graph-Based Audio-Visual Question Answering

CS224W (23/24 Fall) Project from Zhengyang Wei, Tianyuan Dai, & Haoyi Duan

**Before you get started:**
- This Colab includes our PyG implementation of the paper ***Graph-Based Video-Language Learning with Multi-Grained
Audio-Visual Alignment***. The paper author hasn't shared any open-source code, so we've implemented the code ourselves.

    - Link to the pdf of this paper: https://dl.acm.org/doi/pdf/10.1145/3581783.3612132

- Make sure to **sequentially run all the cells in each section**, so that the intermediate variables / packages will carry over to the next cell.
- The data is stored in [Google Drive](https://drive.google.com/drive/folders/175T6bEFoC2X8qww7wfuxQu_yMSpik-yS?usp=drive_link). Feel free to make a copy to your drive!

## 1. Colab Tutorial Introduction

*Mount Google Drive for Loading Subset of MUSIC-AVQA dataset.*

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


## 2. Environment Configuration

### 2.1 Install Dependencis

In [None]:
!pip install torch==1.13.0
!pip install torchaudio==0.13.0
!pip install torchvision==0.14.0
!pip install ffmpeg==1.4
!pip install numpy==1.21.5
!pip install tensorboardX
!pip install spacy
!pip install SceneGraphParser
!pip install ftfy
!pip install regex

Collecting torch==1.13.0
  Downloading torch-1.13.0-cp310-cp310-manylinux1_x86_64.whl (890.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m890.1/890.1 MB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-runtime-cu11==11.7.99 (from torch==1.13.0)
  Downloading nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl (849 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m849.3/849.3 kB[0m [31m59.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cudnn-cu11==8.5.0.96 (from torch==1.13.0)
  Downloading nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl (557.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m557.1/557.1 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cublas-cu11==11.10.3.66 (from torch==1.13.0)
  Downloading nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl (317.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.1/317.1 M

Collecting tensorboardX
  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/101.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━[0m [32m92.2/101.7 kB[0m [31m2.8 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.7/101.7 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.6.2.2
Collecting SceneGraphParser
  Downloading SceneGraphParser-0.1.0-py3-none-any.whl (19 kB)
Installing collected packages: SceneGraphParser
Successfully installed SceneGraphParser-0.1.0
Collecting ftfy
  Downloading ftfy-6.1.3-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.4/53.4 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: ftfy
Successfully installed ftfy-6.1.3


### 2.2 Install Pytorch Geometric

- This verifies the torch and cuda version

In [None]:
import os
import torch
import json
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)

torch:  1.13 ; cuda:  cu117


- This installs the Pytorch Geometric library

In [None]:
# Install torch geometric
import os
import torch
torch_version = str(torch.__version__)
scatter_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
sparse_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
!pip install torch-scatter -f $scatter_src
!pip install torch-sparse -f $sparse_src
!pip install torch-geometric

Looking in links: https://pytorch-geometric.com/whl/torch-1.13.0+cu117.html
Collecting torch-scatter
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: torch-scatter
  Building wheel for torch-scatter (setup.py) ... [?25l[?25hdone
  Created wheel for torch-scatter: filename=torch_scatter-2.1.2-cp310-cp310-linux_x86_64.whl size=3522473 sha256=ffadc85ab6e802b8581141378dcfc923c8dd6676897dfd1fa1cf79e9ef621881
  Stored in directory: /root/.cache/pip/wheels/92/f1/2b/3b46d54b134259f58c8363568569053248040859b1a145b3ce
Successfully built torch-scatter
Installing collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2
Looking in links: https://pytorch-geometric.com/whl/torch-1.13.0+cu117.html
Collecting torch-sparse
  Downloading torch_sparse-0.6.18.tar.gz 

In [None]:
import torch_geometric
torch_geometric.__version__

'2.4.0'

## 3. Data Preprocessing

*You can **skip** this whole section since the processed data is already stored in [Google Drive](https://drive.google.com/drive/folders/175T6bEFoC2X8qww7wfuxQu_yMSpik-yS?usp=drive_link).*

- Load clip model

In [None]:
import sys
sys.path = ["/content/drive/MyDrive/AVQA-GNN"] + sys.path
import clip_net.clip
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip_net.clip.load("ViT-B/32", device=device)

### 3.1 Question Feature extraction

In [None]:
import os
import torch
from torchvision import transforms, utils
from PIL import Image
import numpy as np
import glob
import json
import ast
import csv
from tqdm import tqdm

def qst_feat_extract(qst):

    text = clip_net.clip.tokenize(qst).to(device)

    with torch.no_grad():
        text_features = model.encode_text(text)

    return text_features


def QstCLIP_feat(json_path, dst_qst_path):

    samples = json.load(open(json_path, 'r'))

    ques_vocab = ['<pad>']

    i = 0
    for sample in tqdm(samples):
        i += 1
        question = sample['question_content'].rstrip().split(' ')
        question[-1] = question[-1][:-1]

        question_id = sample['question_id']
        # print("\n")
        # print("question id: ", question_id)

        save_file = os.path.join(dst_qst_path, str(question_id) + '.npy')

        if os.path.exists(save_file):
            print(question_id, " is already exist!")
            continue

        p = 0
        for pos in range(len(question)):
            if '<' in question[pos]:
                question[pos] = ast.literal_eval(sample['templ_values'])[p]
                p += 1
        for wd in question:
            if wd not in ques_vocab:
                ques_vocab.append(wd)

        question = ' '.join(question)

        qst_feat = qst_feat_extract(question)

        qst_features = qst_feat.float().cpu().numpy()
        # print(qst_features.shape)
        np.save(save_file, qst_features)

In [None]:
json_path = "/content/drive/MyDrive/AVQA-GNN/dataset/split_que_id/music_avqa.json"
dst_qst_path = "/content/drive/MyDrive/AVQA-GNN/data/MUSIC-AVQA/clip_qst/"
os.makedirs(dst_qst_path, exist_ok=True)

QstCLIP_feat(json_path, dst_qst_path)

100%|██████████| 45624/45624 [15:32<00:00, 48.94it/s]


### 3.2 Visual Feature Extraction

In [None]:
import os
import torch
from torchvision import transforms, utils
from PIL import Image
import numpy as np
import glob

def clip_feat_extract(img):

    image = preprocess(Image.open(img)).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image)
    return image_features


def ImageClIP_Patch_feat_extract(dir_fps_path, dst_clip_path):

    video_list = os.listdir(dir_fps_path)
    video_idx = 0
    total_nums = len(video_list)

    for video in video_list:

        video_idx = video_idx + 1
        print("\n--> ", video_idx, video)

        save_file = os.path.join(dst_clip_path, video + '.npy')
        if os.path.exists(save_file):
            print(video + '.npy', "is already processed!")
            continue

        video_img_list = sorted(glob.glob(os.path.join(dir_fps_path, video, '*.jpg')))

        params_frames = len(video_img_list)
        samples = np.round(np.linspace(0, params_frames-1, params_frames))

        img_list  = [video_img_list[int(sample)] for sample in samples]
        img_features = torch.zeros(len(img_list), 50, 512)

        idx = 0
        for img_cont in img_list:
            img_idx_feat = clip_feat_extract(img_cont)
            img_features[idx] = img_idx_feat
            idx += 1

        img_features = img_features.float().cpu().numpy()
        np.save(save_file, img_features)

        print("Process: ", video_idx, " / ", total_nums, " ----- video id: ", video_idx, " ----- save shape: ", img_features.shape)


def ImageClIP_feat_extract(dir_fps_path, dst_clip_path):

    video_list = os.listdir(dir_fps_path)
    video_idx = 0
    total_nums = len(video_list)

    for video in video_list:

        video_idx = video_idx + 1
        print("\n--> ", video_idx, video)

        save_file = os.path.join(dst_clip_path, video + '.npy')
        if os.path.exists(save_file):
            print(video + '.npy', "is already processed!")
            continue

        video_img_list = sorted(glob.glob(os.path.join(dir_fps_path, video, '*.jpg')))

        params_frames = len(video_img_list)
        samples = np.round(np.linspace(0, params_frames-1, params_frames))

        img_list  = [video_img_list[int(sample)] for sample in samples]
        img_features = torch.zeros(len(img_list), 512)

        idx = 0
        for img_cont in img_list:
            img_idx_feat = clip_feat_extract(img_cont)
            img_features[idx] = img_idx_feat
            idx += 1

        img_features = img_features.float().cpu().numpy()
        np.save(save_file, img_features)

        print("Process: ", video_idx, " / ", total_nums, " ----- video id: ", video_idx, " ----- save shape: ", img_features.shape)


In [None]:
dir_fps_path = '/content/drive/MyDrive/AVQA-GNN/data/MUSIC-AVQA/frames'
dst_clip_path = '/content/drive/MyDrive/AVQA-GNN/data/MUSIC-AVQA/clip_vit_b32'
os.makedirs(dst_clip_path, exist_ok=True)

ImageClIP_feat_extract(dir_fps_path, dst_clip_path)

### 3.3 Scene Graph Generation

In [None]:
!pip install sng_parser
!python -m spacy download en

In [None]:
API_TOKEN = "Your Hugging-face API token"
API_URL = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-base"
headers = {"Authorization": f"Bearer {API_TOKEN}"}


def text_encoder(text):
    text = clip_net.clip.tokenize(text).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text)

    return text_features

In [None]:
import os
import requests
import glob
import numpy as np
import sng_parser
import json
from tqdm import tqdm
import torch
import sys

def query(filename):
    with open(filename, "rb") as f:
        data = f.read()
    response = requests.post(API_URL, headers=headers, data=data)
    return response.json()


def scene_graph_parsing(dir_fps_path, dst_scenegraph_path):

    video_list = os.listdir(dir_fps_path)
    video_idx = 0
    total_nums = len(video_list)

    scene_graphs = {}
    name2idx = {}
    id = 0
    for video in video_list:

        video_idx = video_idx + 1
        print("\n--> ", video_idx, video)

        video_img_list = sorted(glob.glob(os.path.join(dir_fps_path, video, '*.jpg')))

        params_frames = len(video_img_list)
        samples = np.round(np.linspace(0, params_frames-1, params_frames))

        img_list = [video_img_list[int(sample)] for sample in samples]
        img_captions = []

        for img_count in img_list:
            while True:
                try:
                    output = query(img_count)
                    _ = output[0]['generated_text']
                    break
                except Exception as e:
                    print(output, str(e))
            img_caption = output[0]['generated_text']
            img_captions.append(img_caption)

        scene_graph = {}
        for caption_id, img_caption in enumerate(img_captions):
            graph = sng_parser.parse(img_caption)
            scene_graph[caption_id] = graph

        data_new = {}
        for key, value in scene_graph.items():
            for i, entity in enumerate(value['entities']):
                value['entities'][i]['span_embedding'] = text_encoder(entity['span'])
            for i, relation in enumerate(value['relations']):
                value['relations'][i]['relation_embedding'] = text_encoder(relation['relation'])
            data_new[int(key)] = value

        name = video
        scene_graphs[id] = data_new
        name2idx[name] = id
        id += 1

    np.save(os.path.join(dst_scenegraph_path, 'scene_graphs.npy'), scene_graphs)
    with open(os.path.join(dst_scenegraph_path, 'name2idx.json'), 'w') as file:
        json.dump(name2idx, file, indent=4)

In [None]:
dir_fps_path = "/content/drive/MyDrive/AVQA-GNN/data/MUSIC-AVQA/frames"
dst_scenegraph_path = "/content/drive/MyDrive/AVQA-GNN/data/MUSIC-AVQA/scene_graphs_npy"
os.makedirs(dst_scenegraph_path, exist_ok=True)

scene_graph_parsing(dir_fps_path, dst_scenegraph_path)

### 3.4 Query Graph Generation

In [None]:
import os
import requests
import numpy as np
import sng_parser
import json
from tqdm import tqdm
import ast
import torch
import sys

def query(filename):
    with open(filename, "rb") as f:
        data = f.read()
    response = requests.post(API_URL, headers=headers, data=data)
    return response.json()


def query_graph_parsing(json_path, dst_qst_path):

    samples = json.load(open(json_path, 'r'))

    ques_vocab = ['<pad>']

    query_graphs = {}
    name2idx = {}
    id = 0
    i = 0
    for sample in tqdm(samples):
        i += 1
        question = sample['question_content'].rstrip().split(' ')
        question[-1] = question[-1][:-1]

        question_id = sample['question_id']

        p = 0
        for pos in range(len(question)):
            if '<' in question[pos]:
                question[pos] = ast.literal_eval(sample['templ_values'])[p]
                p += 1
        for wd in question:
            if wd not in ques_vocab:
                ques_vocab.append(wd)

        question = ' '.join(question)

        # parsing
        data = sng_parser.parse(question)

        for i, entity in enumerate(data['entities']):
            data['entities'][i]['span_embedding'] = text_encoder(entity['span'])
        for i, relation in enumerate(data['relations']):
            data['relations'][i]['relation_embedding'] = text_encoder(relation['relation'])

        name = str(question_id)
        query_graphs[id] = data
        name2idx[name] = id
        id += 1

    np.save(os.path.join(dst_qst_path, 'query_graphs.npy'), query_graphs)
    with open(os.path.join(dst_qst_path, 'name2idx.json'), 'w') as file:
        json.dump(name2idx, file, indent=4)

In [None]:
json_path = "/content/drive/MyDrive/AVQA-GNN/dataset/split_que_id/music_avqa.json"
dst_qst_path = "/content/drive/MyDrive/AVQA-GNN/data/MUSIC-AVQA/query_graphs_npy"
os.makedirs(dst_qst_path, exist_ok=True)

query_graph_parsing(json_path, dst_qst_path)

100%|██████████| 45624/45624 [40:48<00:00, 18.63it/s]


## 4. Model

### 4.1 GAT

In [None]:
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import softmax
import torch_scatter
from typing import Union, Tuple, Optional

class GAT(MessagePassing):
    def __init__(self,
                 in_channels: Union[int, Tuple[int, int]],
                 out_channels: int,
                 edge_in_channels: int,
                 heads: int = 1,
                 negative_slope: float = 0.2,
                 dropout: float = 0.0,
                 add_self_loops: bool = True,
                 **kwargs):
        kwargs.setdefault('aggr', 'add')
        super(GAT, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.add_self_loops = add_self_loops

        if isinstance(in_channels, int):
            self.lin_l = nn.Linear(in_channels, heads * out_channels, bias=False)
            self.lin_r = self.lin_l
        else:
            self.lin_l = nn.Linear(in_channels[0], heads * out_channels, bias=False)
            self.lin_r = nn.Linear(in_channels[1], heads * out_channels, bias=False)

        self.att_l = Parameter(torch.zeros(heads, out_channels))
        self.att_r = Parameter(torch.zeros(heads, out_channels))

        self.lin_e = nn.Linear(edge_in_channels, heads * out_channels, bias=False)
        self.att_e = Parameter(torch.zeros(heads, out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.lin_l.weight)
        nn.init.xavier_uniform_(self.lin_r.weight)
        nn.init.xavier_uniform_(self.lin_e.weight)
        nn.init.xavier_uniform_(self.att_l)
        nn.init.xavier_uniform_(self.att_r)
        nn.init.xavier_uniform_(self.att_e)

    def forward(self, x, edge_index, edge_attr, size = None):

        H, C = self.heads, self.out_channels

        x_l = self.lin_l(x).view(-1, H, C)
        x_r = self.lin_r(x).view(-1, H, C)
        alpha_l = self.att_l.unsqueeze(0) * x_l
        alpha_r = self.att_r.unsqueeze(0) * x_r

        e = self.lin_e(edge_attr).view(-1, H, C)
        alpha_e = self.att_e.unsqueeze(0) * e

        out = self.propagate(edge_index=edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), alpha_e=alpha_e, size=size).view(-1, H*C)

        return out

    def message(self, x_j, alpha_j, alpha_i, alpha_e, index, ptr, size_i):

        alpha = alpha_i + alpha_j + alpha_e
        alpha = F.leaky_relu(alpha, self.negative_slope)
        alpha = softmax(alpha, index, ptr, size_i)
        alpha = F.dropout(alpha, p=self.dropout)
        out = alpha * x_j

        return out


class GNNStack(nn.Module):
    def __init__(self, in_channels, out_channels, edge_attr_dim, num_layers, heads=4, dropout=0., negative_slope=0.2):
        super(GNNStack, self).__init__()

        assert (num_layers >= 1), 'Number of layers is not >= 1'

        self.convs = nn.ModuleList()
        self.convs.append(GAT(in_channels=in_channels, out_channels=out_channels,
                    edge_in_channels=edge_attr_dim, heads=heads, negative_slope=negative_slope, dropout=dropout))
        assert num_layers >= 1, 'Number of layers is not >= 1'
        for l in range(num_layers - 1):
            self.convs.append(GAT(in_channels=heads * in_channels, out_channels=out_channels,
                    edge_in_channels=edge_attr_dim, heads=heads, negative_slope=negative_slope, dropout=dropout))

        self.bns = nn.ModuleList([nn.BatchNorm1d(heads * out_channels) for _ in range(num_layers - 1)])

        self.post_mp = nn.Sequential(
                    nn.Linear(heads * in_channels, out_channels), nn.Dropout(dropout),
                    nn.Linear(out_channels, out_channels))
        self.num_layers = num_layers
        self.dropout = dropout

    def forward(self, x, edge_index, edge_attr, batch):

        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index, edge_attr)
            if i != self.num_layers - 1:
                x = self.bns[i](x)
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout)

        x = self.post_mp(x)
        return x

### 4.2 LayerNorm

In [None]:
from torch_geometric.typing import OptTensor

import torch
from torch.nn import Parameter
from torch import Tensor
from torch_scatter import scatter
from torch_geometric.utils import degree

from torch_geometric.nn.inits import ones, zeros

class LayerNorm(torch.nn.Module):
    r"""Applies layer normalization over each individual example in a batch
    of node features as described in the `"Instance Normalization: The Missing
    Ingredient for Fast Stylization" <https://arxiv.org/abs/1607.08022>`_
    paper

    .. math::
        \mathbf{x}^{\prime}_i = \frac{\mathbf{x} -
        \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}}
        \odot \gamma + \beta

    The mean and standard-deviation are calculated across all nodes and all
    node channels separately for each object in a mini-batch.

    Args:
        in_channels (int): Size of each input sample.
        eps (float, optional): A value added to the denominator for numerical
            stability. (default: :obj:`1e-5`)
        affine (bool, optional): If set to :obj:`True`, this module has
            learnable affine parameters :math:`\gamma` and :math:`\beta`.
            (default: :obj:`True`)
    """
    def __init__(self, in_channels, eps=1e-5, affine=True):
        super(LayerNorm, self).__init__()

        self.in_channels = in_channels
        self.eps = eps

        if affine:
            self.weight = Parameter(torch.Tensor([in_channels]))
            self.bias = Parameter(torch.Tensor([in_channels]))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        ones(self.weight)
        zeros(self.bias)

    def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor:
        """"""
        if batch is None:
            x = x - x.mean()
            out = x / (x.std(unbiased=False) + self.eps)

        else:
            batch_size = int(batch.max()) + 1

            norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1)
            norm = norm.mul_(x.size(-1)).view(-1, 1)

            mean = scatter(x, batch, dim=0, dim_size=batch_size,
                           reduce='add').sum(dim=-1, keepdim=True) / norm

            x = x - mean[batch]

            var = scatter(x * x, batch, dim=0, dim_size=batch_size,
                          reduce='add').sum(dim=-1, keepdim=True)
            var = var / norm

            out = x / (var.sqrt()[batch] + self.eps)

        if self.weight is not None and self.bias is not None:
            out = out * self.weight + self.bias

        return out

    def __repr__(self):
        return f'{self.__class__.__name__}({self.in_channels})'
    r"""Applies layer normalization over each individual example in a batch
    of node features as described in the `"Instance Normalization: The Missing
    Ingredient for Fast Stylization" <https://arxiv.org/abs/1607.08022>`_
    paper

    .. math::
        \mathbf{x}^{\prime}_i = \frac{\mathbf{x} -
        \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}}
        \odot \gamma + \beta

    The mean and standard-deviation are calculated across all nodes and all
    node channels separately for each object in a mini-batch.

    Args:
        in_channels (int): Size of each input sample.
        eps (float, optional): A value added to the denominator for numerical
            stability. (default: :obj:`1e-5`)
        affine (bool, optional): If set to :obj:`True`, this module has
            learnable affine parameters :math:`\gamma` and :math:`\beta`.
            (default: :obj:`True`)
    """
    def __init__(self, in_channels, eps=1e-5, affine=True):
        super(LayerNorm, self).__init__()

        self.in_channels = in_channels
        self.eps = eps

        if affine:
            self.weight = Parameter(torch.Tensor([in_channels]))
            self.bias = Parameter(torch.Tensor([in_channels]))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        ones(self.weight)
        zeros(self.bias)


    def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor:
        """"""
        if batch is None:
            x = x - x.mean()
            out = x / (x.std(unbiased=False) + self.eps)

        else:
            batch_size = int(batch.max()) + 1

            norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1)
            norm = norm.mul_(x.size(-1)).view(-1, 1)

            mean = scatter(x, batch, dim=0, dim_size=batch_size,
                           reduce='add').sum(dim=-1, keepdim=True) / norm

            x = x - mean[batch]

            var = scatter(x * x, batch, dim=0, dim_size=batch_size,
                          reduce='add').sum(dim=-1, keepdim=True)
            var = var / norm

            out = x / (var.sqrt()[batch] + self.eps)

        if self.weight is not None and self.bias is not None:
            out = out * self.weight + self.bias

        return out


    def __repr__(self):
        return f'{self.__class__.__name__}({self.in_channels})'

### 4.3 Main Model

In [None]:
import torch
import torch.nn as nn
from torch.nn import Sequential, Linear, ReLU, Bilinear
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import numpy as np
import copy
import torch_geometric
from torch_scatter import scatter_mean, scatter_add


def get_gt_scene_graph_encoding_layer(num_node_features, num_edge_features):

    class EdgeModel(torch.nn.Module):
        def __init__(self):
            super(EdgeModel, self).__init__()
            self.edge_mlp = Sequential(
                Linear(2 * num_node_features + num_edge_features, num_edge_features),
                ReLU(),
                Linear(num_edge_features, num_edge_features)
            )

        def forward(self, src, dest, edge_attr, u, batch):
            out = torch.cat([src, dest, edge_attr], dim=1)
            return self.edge_mlp(out)

    class NodeModel(torch.nn.Module):
        def __init__(self):
            super(NodeModel, self).__init__()
            self.node_mlp_1 = Sequential(
                Linear(num_node_features + num_edge_features, num_node_features),
                ReLU(),
                Linear(num_node_features, num_node_features)
            )
            self.node_mlp_2 = Sequential(
                Linear(2 * num_node_features, num_node_features),
                ReLU(),
                Linear(num_node_features, num_node_features)
            )

        def forward(self, x, edge_index, edge_attr, u, batch):
            row, col = edge_index
            out = torch.cat([x[row], edge_attr], dim=1)
            out = self.node_mlp_1(out)
            out = scatter_mean(out, col, dim=0, dim_size=x.size(0))
            out = torch.cat([x, out], dim=1)
            return self.node_mlp_2(out)

    op = torch_geometric.nn.MetaLayer(EdgeModel(), NodeModel())
    return op


class HierarchicalMatch(nn.Module):
    def __init__(self, N=3, dim=512):
        super(HierarchicalMatch, self).__init__()
        self.N = N
        self.dim = dim

    def forward(self, joint, f, B):
        N = self.N
        dim = self.dim
        b = torch.zeros(N, N, B).to('cuda')
        for i, f_i in enumerate(f):
            for j, f_ij in enumerate(f_i):
                f_ij_tmp = torch.mean(f_ij, dim=1) # [B, 512]
                b[i][j] = torch.bmm(joint.unsqueeze(1), f_ij_tmp.unsqueeze(-1)).squeeze() / torch.stack([torch.bmm(joint.unsqueeze(1), f[i][r].mean(dim=1).unsqueeze(-1)).squeeze() for r in range(N)]).sum(dim=0) # [B]

        f_ii = []
        for i, f_i in enumerate(f):
            f_ii.append(torch.stack([b[i][j][:, None, None] * f_ij for j, f_ij in enumerate(f_i)]).sum(dim=0))

        lambda_i = torch.zeros(N, B).to('cuda')
        for i, f_i in enumerate(f_ii):
            f_i_tmp = torch.mean(f_i, dim=1) # [B, 512]
            lambda_i[i] = torch.bmm(joint.unsqueeze(1), f_i_tmp.unsqueeze(-1)).squeeze() / torch.stack([torch.bmm(joint.unsqueeze(1), f_ii[r].mean(dim=1).unsqueeze(-1)).squeeze() for r in range(N)]).sum(dim=0)

        return torch.stack([lambda_i[i][:, None] * f_i.mean(dim=1) for i, f_i in enumerate(f_ii)]).sum(dim=0) # [B, 512]


class CrossAttention(nn.Module):
    def __init__(self, dim=512):
        super(CrossAttention, self).__init__()
        self.w = Linear(dim, dim)
        nn.init.xavier_uniform_(self.w.weight)

    def forward(self, audio_conv_list, video_conv_list):
        (B, T, C) = audio_conv_list[0].shape
        f_v = [[], [], []]
        f_a = [[], [], []]
        for i, video_conv in enumerate(video_conv_list):
            for j, audio_conv in enumerate(audio_conv_list):
                a_ij = F.softmax(torch.bmm(self.w(video_conv), audio_conv.permute(0, 2, 1)) / torch.sqrt(torch.tensor(video_conv.shape[-1]))) # [4, 10, 10]
                f_v[i].append(torch.bmm(a_ij, audio_conv))
                f_a[j].append(torch.bmm(a_ij.permute(0, 2, 1), video_conv))

        return f_v, f_a


class MgA(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MgA, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, 1)
        self.conv2 = nn.Conv1d(in_channels, out_channels, 3)
        self.conv3 = nn.Conv1d(in_channels, out_channels, 5)

    def forward(self, input):
        conv_1 = self.conv1(input).permute(0, 2, 1) # [B, 10, 512]
        conv_2 = self.conv2(input).permute(0, 2, 1) # [B, 8, 512]
        conv_3 = self.conv3(input).permute(0, 2, 1) # [B, 6, 512]

        return conv_1, conv_2, conv_3


class AVQA_GNN(nn.Module):

    def __init__(self, args, num_node_features=512, num_edge_features=512):
        super(AVQA_GNN, self).__init__()
        self.scene_graph_encoding_layer = get_gt_scene_graph_encoding_layer(num_node_features=num_node_features, num_edge_features=num_edge_features)
        self.query_graph_encoding_layer = get_gt_scene_graph_encoding_layer(num_node_features=num_node_features, num_edge_features=num_edge_features)
        self.scene_graph_layernorm = LayerNorm(num_node_features)
        self.query_graph_layernorm = LayerNorm(num_edge_features)
        out_channels = 512
        self.video_gat = GNNStack(in_channels=num_node_features,
                                  out_channels=out_channels,
                                  edge_attr_dim=num_edge_features,
                                  num_layers=5,
                                  heads=4,
                                  dropout=0.1,
                                  negative_slope=0.2)

        self.query_gat = GNNStack(in_channels=num_node_features,
                                  out_channels=out_channels,
                                  edge_attr_dim=num_edge_features,
                                  num_layers=5,
                                  heads=4,
                                  dropout=0.1,
                                  negative_slope=0.2)
        self.lin_v = Sequential(
                Linear(out_channels, out_channels),
                ReLU()
            )
        self.lin_q = Sequential(
                Linear(out_channels, out_channels),
                ReLU()
            )
        self.wv = Linear(out_channels, out_channels)
        self.wq = Parameter(torch.zeros(1, 10))

        self.joint_linear = Bilinear(out_channels, out_channels, out_channels)
        self.lin_a = Sequential(
                Linear(128, 512),
                ReLU(),
                Linear(512, 512)
            )

        self.mga_v = MgA(out_channels, out_channels)
        self.mga_a = MgA(out_channels, out_channels)
        self.cross_attn = CrossAttention(out_channels)
        self.match_v = HierarchicalMatch(N=3, dim=out_channels)
        self.match_a = HierarchicalMatch(N=3, dim=out_channels)
        self.tanh_avq = nn.Tanh()
        self.fc_answer_pred = nn.Linear(512, 42)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.lin_v[0].weight)
        nn.init.xavier_uniform_(self.lin_q[0].weight)
        nn.init.xavier_uniform_(self.wv.weight)
        nn.init.xavier_uniform_(self.wq)
        nn.init.xavier_uniform_(self.joint_linear.weight)
        nn.init.xavier_uniform_(self.lin_a[0].weight)
        nn.init.xavier_uniform_(self.lin_a[-1].weight)
        nn.init.xavier_uniform_(self.fc_answer_pred.weight)

    def forward(self, audio_feat, visual_feat, question_feat, sg_data, qg_data):
        B = len(qg_data)

        s_x_encoded_list, s_edge_attr_encoded_list = [], []
        for i in range(len(sg_data)):
            s_x_encoded, s_edge_attr_encoded, _ = self.scene_graph_encoding_layer(
                x=sg_data[i].x,
                edge_index=sg_data[i].edge_index,
                edge_attr=sg_data[i].edge_attr,
                u=None,
                batch=sg_data[i].batch
            )
            s_x_encoded = self.scene_graph_layernorm(s_x_encoded, sg_data[i].batch)
            s_x_encoded_list.append(s_x_encoded)
            s_edge_attr_encoded_list.append(s_edge_attr_encoded)

        q_x_encoded, q_edge_attr_encoded, _ = self.query_graph_encoding_layer(
            x=qg_data.x,
            edge_index=qg_data.edge_index,
            edge_attr=qg_data.edge_attr,
            u=None,
            batch=qg_data.batch
        )
        q_x_encoded = self.query_graph_layernorm(q_x_encoded, qg_data.batch)

        x_executed_list = []
        for i, (s_x_encoded, s_edge_attr_encoded) in enumerate(zip(s_x_encoded_list, s_edge_attr_encoded_list)):
            x_executed = self.video_gat(x=s_x_encoded, edge_index=sg_data[i].edge_index, edge_attr=s_edge_attr_encoded, batch=sg_data[i].batch)
            x_executed_batch_list = []
            for j in range(B):
                x_executed_batch_list.append(x_executed[sg_data[i].batch==j].sum(dim=0))
            x_executed = torch.stack(x_executed_batch_list) # [B, 512]
            x_executed_list.append(x_executed)
        x_executed_list = torch.stack(x_executed_list).permute(1, 0, 2) # [B, 10, 512]

        q_executed = self.query_gat(x=q_x_encoded, edge_index=qg_data.edge_index, edge_attr=q_edge_attr_encoded, batch=qg_data.batch)
        q_executed_batch_list = []
        for i in range(B):
            q_executed_batch_list.append(q_executed[qg_data.batch==i])
        max_q = max(q.shape[0] for q in q_executed_batch_list)
        for i, q in enumerate(q_executed_batch_list):
            pad_size = max_q - q.shape[0]
            if pad_size > 0:
                q_executed_batch_list[i] = F.pad(q, (0, 0, 0, pad_size))
        q_executed = torch.stack(q_executed_batch_list) # [B, q, 512]

        video = self.lin_v(x_executed_list)
        query = self.lin_q(q_executed)

        sim = torch.bmm(video, query.permute(0, 2, 1)) # [B, n:10, q]

        temporature = 1.0
        v_joint = torch.bmm((sim/temporature).permute(0, 2, 1), video) # [B, q, 512]
        v_joint = self.wv(v_joint).mean(dim=1) # [B, 512]

        q_joint = torch.bmm((sim/temporature), query) # [B, 10, 512]
        q_joint = (self.wq.unsqueeze(-1) * q_joint).sum(dim=1) # [B, 512]

        vq_joint = self.joint_linear(v_joint, q_joint) # [B, 512]

        audio_feat = self.lin_a(audio_feat)
        (B, T, C) = audio_feat.shape

        audio_conv_1, audio_conv_2, audio_conv_3 = self.mga_a(audio_feat.permute(0, 2, 1)) # [B, l_a, 512]
        video_conv_1, video_conv_2, video_conv_3 = self.mga_v(visual_feat.permute(0, 2, 1)) # [B, l_v, 512]

        audio_conv_list = [audio_conv_1, audio_conv_2, audio_conv_3]
        video_conv_list = [video_conv_1, video_conv_2, video_conv_3]

        f_v, f_a = self.cross_attn(audio_conv_list, video_conv_list) # [3, 3, B, 10, 512]

        f_v = self.match_v(vq_joint, f_v, B) # [B, 512]
        f_a = self.match_a(vq_joint, f_a, B)

        z_v = F.sigmoid(vq_joint * f_v) # [B, 512]
        z_a = F.sigmoid(vq_joint * f_a)

        f_m = z_v * f_v + z_a * f_a # [B, 512]

        avq_feat = f_m * question_feat.squeeze(1) # [B, 512]
        avq_feat = self.tanh_avq(avq_feat)

        answer_pred = self.fc_answer_pred(avq_feat)

        return answer_pred

## 5. Dataloader

In [None]:
import numpy as np
import torch
import os
# from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import DataLoader
from torch.utils.data import Dataset

from torchvision import transforms
import pandas as pd
import ast
import json
from PIL import Image
import torch_geometric


def TransformImage(img):

    transform_list = []
    mean = [0.43216, 0.394666, 0.37645]
    std = [0.22803, 0.22145, 0.216989]

    transform_list.append(transforms.Resize([256,256]))
    transform_list.append(transforms.ToTensor())
    transform_list.append(transforms.Normalize(mean, std))
    trans = transforms.Compose(transform_list)
    frame_tensor = trans(img)

    return frame_tensor

def TransformImage_Resize(img):

    transform_list = []
    mean = [0.43216, 0.394666, 0.37645]
    std = [0.22803, 0.22145, 0.216989]

    transform_list.append(transforms.Resize([256,256]))
    # transform_list.append(transforms.ToTensor())
    # transform_list.append(transforms.Normalize(mean, std))
    trans = transforms.Compose(transform_list)
    frame_org = trans(img)

    return frame_org


def load_frame_info(img_path):

    img = Image.open(img_path).convert('RGB')
    # img2 = TransformImage_Resize(img)   # visualization
    frame_tensor = TransformImage(img)

    # return img2, frame_tensor    # visualization
    return frame_tensor


def image_info(frame_path):

    img_list = os.listdir(frame_path)
    img_list.sort()

    select_img = []

    for frame_idx in range(len(img_list)):
        if frame_idx < 60:
            video_frames_path = os.path.join(frame_path, str(frame_idx+1).zfill(6)+".jpg")
            frame_tensor_info = load_frame_info(video_frames_path)
            select_img.append(frame_tensor_info.cpu().numpy())


    select_img = np.array(select_img)

    # return org_img, select_img
    return select_img



def ids_to_multinomial(id, categories):
    """ label encoding
    Returns:
      1d array, multimonial representation, e.g. [1,0,1,0,0,...]
    """
    id_to_idx = {id: index for index, id in enumerate(categories)}

    return id_to_idx[id]


class AVQA_dataset(Dataset):

    def __init__(self, args, label, audios_feat_dir,
                       clip_vit_b32_dir, clip_qst_dir):

        self.args = args

        samples = json.load(open(args['label_train'], 'r'))

        # Question
        ques_vocab = ['<pad>']
        ans_vocab = []
        i = 0
        for sample in samples:
            i += 1
            question = sample['question_content'].rstrip().split(' ')
            question[-1] = question[-1][:-1]

            p = 0
            for pos in range(len(question)):
                if '<' in question[pos]:
                    question[pos] = ast.literal_eval(sample['templ_values'])[p]
                    p += 1
            for wd in question:
                if wd not in ques_vocab:
                    ques_vocab.append(wd)
            if sample['anser'] not in ans_vocab:
                ans_vocab.append(sample['anser'])
        # ques_vocab.append('fifth')

        self.ques_vocab = ques_vocab
        self.ans_vocab = ans_vocab
        self.word_to_ix = {word: i for i, word in enumerate(self.ques_vocab)}

        self.samples = json.load(open(label, 'r'))
        self.max_len = 14    # question length

        self.audios_feat_dir = audios_feat_dir

        self.clip_vit_b32_dir = clip_vit_b32_dir
        self.clip_qst_dir = clip_qst_dir
        self.scene_graph_dir = args['scene_graph_dir']
        self.query_graph_dir = args['query_graph_dir']
        self.scene_graphs = np.load(os.path.join(self.scene_graph_dir, "scene_graphs.npy"), allow_pickle=True).item()
        self.query_graphs = np.load(os.path.join(self.query_graph_dir, "query_graphs.npy"), allow_pickle=True).item()
        self.scene_name2idx = json.load(open(os.path.join(self.scene_graph_dir, "name2idx.json"), 'r'))
        self.query_name2idx = json.load(open(os.path.join(self.query_graph_dir, "name2idx.json"), 'r'))


    def __len__(self):
        return len(self.samples)

    def get_lstm_embeddings(self, question_input, sample):

        question = sample['question_content'].rstrip().split(' ')
        question[-1] = question[-1][:-1]

        p = 0
        for pos in range(len(question)):
            if '<' in question[pos]:
                question[pos] = ast.literal_eval(sample['templ_values'])[p]
                p += 1
        if len(question) < self.max_len:
            n = self.max_len - len(question)
            for i in range(n):
                question.append('<pad>')

        idxs = [self.word_to_ix[w] for w in question]
        ques = torch.tensor(idxs, dtype=torch.long)

        return ques

    def get_frames_spatial(self, video_name):

        frames_path = os.path.join(self.frames_dir, video_name)
        frames_spatial = image_info(frames_path)    # [T, 3, 224, 224]

        return frames_spatial

    def convert_to_pyg_graph(self, entities, relations):
        x = torch.zeros(len(entities), 512)
        edge_features = torch.zeros(len(relations), 512)
        edge_topology = torch.zeros(len(relations), 2).long()

        for i, entity in enumerate(entities):
            x_idx = entity['span_embedding']
            x[i] = x_idx

        for i, relation in enumerate(relations):
            edge_feature_idx = relation['relation_embedding']
            edge_features[i] = edge_feature_idx
            edge_topology[i] = torch.tensor((relation['subject'], relation['object']))

        data = torch_geometric.data.Data(x=x, edge_index=edge_topology.t().contiguous(), edge_attr=edge_features)
        return data

    def convert_to_pyg_graphs(self, sg_this):
        pyg_datas = []
        for i in range(10):
            sg = sg_this[i]
            pyg_data = self.convert_to_pyg_graph(sg['entities'], sg['relations'])
            pyg_datas.append(pyg_data)

        return pyg_datas

    def __getitem__(self, idx):

        sample = self.samples[idx]
        name = sample['video_id']
        question_id = sample['question_id']

        audio_feat = np.load(os.path.join(self.audios_feat_dir, name + '.npy'))
        audio_feat = audio_feat[::6, :]

        question_feat = np.load(os.path.join(self.clip_qst_dir, str(question_id) + '.npy'))

        visual_CLIP_feat = np.load(os.path.join(self.clip_vit_b32_dir, name + '.npy'))
        visual_feat = visual_CLIP_feat[::6, 0, :]

        # visual_CLIP_feat = np.load(os.path.join(self.clip_vit_b32_dir, name + '.npy'))
        # patch_feat = visual_CLIP_feat[:60, 1:, :]

        #########################################################################
        # one json for all scene_graph
        id = self.scene_name2idx[name]
        scene_graphs = self.scene_graphs[id]
        sg_data = self.convert_to_pyg_graphs(scene_graphs) # list[pyg_data]

        id = self.query_name2idx[str(question_id)]
        query_graph = self.query_graphs[id]
        qg_data = self.convert_to_pyg_graph(query_graph['entities'], query_graph['relations'])
        #########################################################################

        ### answer
        answer = sample['anser']
        answer_label = ids_to_multinomial(answer, self.ans_vocab)
        answer_label = torch.from_numpy(np.array(answer_label)).long()

        return (name, torch.from_numpy(audio_feat), torch.from_numpy(visual_feat), torch.from_numpy(question_feat), answer_label, sg_data, qg_data, question_id)


def AVQA_dataset_collate_fn(data):

    name, audio_feat, visual_feat, question_feat, answer_label, sg_data, qg_data, question_id = zip(*data)

    audio_feat = torch.stack(audio_feat)
    visual_feat = torch.stack(visual_feat)
    question_feat = torch.stack(question_feat)
    answer_label = torch.stack(answer_label)

    sg_data_list = {i:[] for i in range(len(sg_data[0]))}
    for sg_this in sg_data:
        for i in range(len(sg_this)):
            sg_data_list[i].append(sg_this[i])

    sg_data_out = []
    for i in range(len(sg_data[0])):
        sg_data_out.append(torch_geometric.data.Batch.from_data_list(sg_data_list[i]))

    qg_data = torch_geometric.data.Batch.from_data_list(qg_data)

    return (name, audio_feat, visual_feat, question_feat, answer_label, sg_data_out, qg_data, question_id)


## 6. Training

### 6.1 Arguments

In [None]:
import os

args = {}
root_path = '/content/drive/MyDrive/AVQA-GNN/data/MUSIC-AVQA'
dataset_path = '/content/drive/MyDrive/AVQA-GNN/dataset/split_que_id'

### ======================== Dataset Configs ==========================
args["audios_feat_dir"] = os.path.join(root_path, 'vggish')
args["clip_vit_b32_dir"] = os.path.join(root_path, 'clip_vit_b32')
args["clip_qst_dir"] = os.path.join(root_path, 'clip_qst')
args["clip_word_dir"] = os.path.join(root_path, 'clip_word')
args["frames_dir"] = os.path.join(root_path, 'frames')
args["scene_graph_dir"] = os.path.join(root_path, 'scene_graphs_npy')
args["query_graph_dir"] = os.path.join(root_path, 'query_graphs_npy')

### ======================== Label Configs ==========================
args["label_train"] = os.path.join(dataset_path, "music_avqa_subset_train.json")
args["label_val"] = os.path.join(dataset_path, "music_avqa_subset_val.json")
args["label_test"] = os.path.join(dataset_path, "music_avqa_subset_test.json")

### ======================== Learning Configs ==========================
args['batch_size'] =4
args['epochs'] = 30
args['lr'] = 1.2e-4
args['seed'] = 1

### ======================== Save Configs ==========================
args["checkpoint"] = 'AVQA_GNN_Net'
args["model_save_dir"] = '/content/drive/MyDrive/AVQA-GNN/models_avqa_gnn'
args["mode"] = 'train'

### ======================== Runtime Configs ==========================
args['log_interval'] = 5
args['num_workers'] = 2
args['gpu'] ='0'

### 6.2 Training Setup

In [None]:
import argparse
import torch
import torch.nn as nn
import torch.optim as optim

import json
import numpy as np

from tqdm import tqdm

from datetime import datetime
TIMESTAMP = "{0:%Y-%m-%d-%H-%M-%S/}".format(datetime.now())


print("\n--------------- AVQA-GNN Training --------------- \n")


def train(args, model, train_loader, optimizer, criterion, epoch):

    model.train()
    print("-------- Training ... --------")
    for batch_idx, sample in enumerate(train_loader):
        name, audio_feat, visual_feat, question_feat, target, sg_data, qg_data, _ = sample
        qg_data = qg_data.to('cuda')
        for i in range(len(sg_data)):
            sg_data[i] = sg_data[i].to('cuda')
        audio_feat = audio_feat.to('cuda')
        visual_feat = visual_feat.to('cuda')
        question_feat = question_feat.to('cuda')
        target = target.to('cuda')

        optimizer.zero_grad()
        output_qa = model(audio_feat, visual_feat, question_feat, sg_data, qg_data)
        loss = criterion(output_qa, target)

        loss.backward()
        optimizer.step()

        if batch_idx % args['log_interval'] == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                  epoch, batch_idx * len(audio_feat), len(train_loader.dataset),
                  100. * batch_idx / len(train_loader), loss.item()))


def eval(model, val_loader, epoch):

    model.eval()
    total_qa = 0
    correct_qa = 0
    print("-------- Validating ... --------")
    with torch.no_grad():
        for batch_idx, sample in enumerate(val_loader):
            name, audio_feat, visual_feat, question_feat, target, sg_data, qg_data, _ = sample
            if sg_data[0].x.shape == (0, 512):
                print(batch_idx)
            qg_data = qg_data.to('cuda')
            for i in range(len(sg_data)):
                sg_data[i] = sg_data[i].to('cuda')
            audio_feat = audio_feat.to('cuda')
            visual_feat = visual_feat.to('cuda')
            question_feat = question_feat.to('cuda')
            target = target.to('cuda')

            preds_qa = model(audio_feat, visual_feat, question_feat, sg_data, qg_data)

            _, predicted = torch.max(preds_qa.data, 1)
            total_qa += preds_qa.size(0)
            correct_qa += (predicted == target).sum().item()

    print('Current Acc: %.2f %%' % (100 * correct_qa / total_qa))

    return 100 * correct_qa / total_qa



def main():
    os.environ['CUDA_VISIBLE_DEVICES'] = args['gpu']
    torch.manual_seed(args['seed'])
    os.makedirs(args['model_save_dir'], exist_ok=True)

    tensorboard_name = args['checkpoint']

    model = AVQA_GNN(args)
    model = nn.DataParallel(model).to('cuda')

    train_dataset = AVQA_dataset(label = args['label_train'],
                                 args = args,
                                 audios_feat_dir = args['audios_feat_dir'],
                                 clip_vit_b32_dir = args['clip_vit_b32_dir'],
                                 clip_qst_dir = args['clip_qst_dir'])
    train_loader = DataLoader(train_dataset, batch_size=args['batch_size'], shuffle=True, num_workers=args['num_workers'], collate_fn=AVQA_dataset_collate_fn, drop_last=True)

    val_dataset = AVQA_dataset(label = args['label_val'],
                               args = args,
                               audios_feat_dir = args['audios_feat_dir'],
                               clip_vit_b32_dir = args['clip_vit_b32_dir'],
                               clip_qst_dir = args['clip_qst_dir'])
    val_loader = DataLoader(val_dataset, batch_size=args['batch_size'], shuffle=False, num_workers=args['num_workers'], collate_fn=AVQA_dataset_collate_fn, drop_last=True)


    optimizer = optim.Adam(model.parameters(), lr=args['lr'])
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.1)
    criterion = nn.CrossEntropyLoss()


    best_acc = 0
    best_epoch = 0
    for epoch in range(1, args['epochs'] + 1):

        # train for one epoch
        train(args, model, train_loader, optimizer, criterion, epoch=epoch)

        # evaluate on validation set
        scheduler.step(epoch)
        current_acc = eval(model, val_loader, epoch)
        if current_acc >= best_acc:
            best_acc = current_acc
            best_epoch = epoch
            torch.save(model.state_dict(), os.path.join(args['model_save_dir'], args['checkpoint'] + ".pt"))

        print("Best Acc: %.2f %%"%best_acc)
        print("Best Epoch: ", best_epoch)
        print("*"*20)


--------------- AVQA-GNN Training --------------- 



In [None]:
main()



-------- Training ... --------


  a_ij = F.softmax(torch.bmm(self.w(video_conv), audio_conv.permute(0, 2, 1)) / torch.sqrt(torch.tensor(video_conv.shape[-1]))) # [4, 10, 10]


-------- Validating ... --------




Current Acc: 25.64 %
Best Acc: 25.64 %
Best Epoch:  1
********************
-------- Training ... --------
-------- Validating ... --------
Current Acc: 35.90 %
Best Acc: 35.90 %
Best Epoch:  2
********************
-------- Training ... --------
-------- Validating ... --------
Current Acc: 41.67 %
Best Acc: 41.67 %
Best Epoch:  3
********************
-------- Training ... --------
-------- Validating ... --------
Current Acc: 42.31 %
Best Acc: 42.31 %
Best Epoch:  4
********************
-------- Training ... --------
-------- Validating ... --------
Current Acc: 48.72 %
Best Acc: 48.72 %
Best Epoch:  5
********************
-------- Training ... --------
-------- Validating ... --------
Current Acc: 51.28 %
Best Acc: 51.28 %
Best Epoch:  6
********************
-------- Training ... --------
-------- Validating ... --------
Current Acc: 48.08 %
Best Acc: 51.28 %
Best Epoch:  6
********************
-------- Training ... --------
-------- Validating ... --------
Current Acc: 51.28 %
Best A

## 7. Testing

In [None]:
import argparse
import torch
import torch.nn as nn
import torch.optim as optim

import ast
import json
import numpy as np


print("\n--------------- Spatio-temporal Reasoning Network(PSTP-Net) --------------- \n")

def test(model, val_loader):

    model.eval()

    total = 0
    correct = 0
    samples = json.load(open(args['label_test'], 'r'))

    # prediction save
    A_count = []
    A_compt = []
    V_count = []
    V_local = []
    AV_exist = []
    AV_count = []
    AV_local = []
    AV_compt = []
    AV_templ = []

    # results save
    que_id = []
    pred_results =[]
    grd_target = []
    pred_label = []


    with torch.no_grad():
        for batch_idx, sample in enumerate(val_loader):

            name, audio_feat, visual_feat, question_feat, target, sg_data, qg_data, question_id = sample
            if sg_data[0].x.shape == (0, 512):
                print(batch_idx)
            qg_data = qg_data.to('cuda')
            for i in range(len(sg_data)):
                sg_data[i] = sg_data[i].to('cuda')
            audio_feat = audio_feat.to('cuda')
            visual_feat = visual_feat.to('cuda')
            question_feat = question_feat.to('cuda')
            target = target.to('cuda')

            preds_qa = model(audio_feat, visual_feat, question_feat, sg_data, qg_data)

            preds = preds_qa

            _, predicted = torch.max(preds.data, 1)
            # print(preds.data, predicted, target)

            total += preds.size(0)
            correct += (predicted == target).sum().item()

            # result
            grd_target.append(target.cpu().item())
            pred_label.append(predicted.cpu().item())

            pred_bool = predicted == target
            for index in range(len(pred_bool)):
                pred_results.append(pred_bool[index].cpu().item())
                que_id.append(question_id[index].item())


            x = samples[batch_idx]
            type =ast.literal_eval(x['type'])
            if type[0] == 'Audio':
                if type[1] == 'Counting':
                    A_count.append((predicted == target).sum().item())
                elif type[1] == 'Comparative':
                    A_compt.append((predicted == target).sum().item())
            elif type[0] == 'Visual':
                if type[1] == 'Counting':
                    V_count.append((predicted == target).sum().item())
                elif type[1] == 'Location':
                    V_local.append((predicted == target).sum().item())
            elif type[0] == 'Audio-Visual':
                if type[1] == 'Existential':
                    AV_exist.append((predicted == target).sum().item())
                elif type[1] == 'Counting':
                    AV_count.append((predicted == target).sum().item())
                elif type[1] == 'Location':
                    AV_local.append((predicted == target).sum().item())
                elif type[1] == 'Comparative':
                    AV_compt.append((predicted == target).sum().item())
                elif type[1] == 'Temporal':
                    AV_templ.append((predicted == target).sum().item())

    print('\nAudio Count Acc: %.2f %%' % (100 * sum(A_count)/len(A_count)))
    print('Audio Compt Acc: %.2f %%' % (100 * sum(A_compt) / len(A_compt)))
    print('Audio Averg Acc: %.2f %%' % (100 * (sum(A_count) + sum(A_compt)) / (len(A_count) + len(A_compt))))

    print('\nVisual Count Acc: %.2f %%' % (100 * sum(V_count) / len(V_count)))
    print('Visual Local Acc: %.2f %%' % (100 * sum(V_local) / len(V_local)))
    print('Visual Averg Acc: %.2f %%' % (100 * (sum(V_count) + sum(V_local)) / (len(V_count) + len(V_local))))

    print('\nAudio-Visual Exist Acc: %.2f %%' % (100 * sum(AV_exist) / len(AV_exist)))
    print('Audio-Visual Count Acc: %.2f %%' % (100 * sum(AV_count) / len(AV_count)))
    print('Audio-Visual Local Acc: %.2f %%' % (100 * sum(AV_local) / len(AV_local)))
    print('Audio-Visual Compt Acc: %.2f %%' % (100 * sum(AV_compt) / len(AV_compt)))
    print('Audio-Visual Templ Acc: %.2f %%' % (100 * sum(AV_templ) / len(AV_templ)))
    print('Audio-Visual Averg Acc: %.2f %%' % (100 * (sum(AV_count) + sum(AV_local) + sum(AV_exist) + sum(AV_templ) + sum(AV_compt)) /
                                                     (len(AV_count) + len(AV_local) + len(AV_exist) + len(AV_templ) + len(AV_compt))))

    print('\n---->Overall Accuracy: %.2f %%' % (100 * correct / total), "\n")

    os.makedirs("results", exist_ok=True)
    with open("results/AVQA-GNN_Net.txt", 'w') as f:
        for index in range(len(que_id)):
            # print(que_id[index],' \t ',pred_results[index],' \t ',grd_target[index],' \t ',pred_label[index])
            f.write(str(que_id[index])+' \t '+str(pred_results[index])+' \t '+str(grd_target[index])+' \t '+str(pred_label[index])+'\n')

    return 100 * correct / total



def main_test():

    os.environ['CUDA_VISIBLE_DEVICES'] = args['gpu']

    torch.manual_seed(args['seed'])

    model = AVQA_GNN(args)
    model = nn.DataParallel(model)
    model = model.to('cuda')

    test_dataset = AVQA_dataset(args = args,
                                label = args['label_test'],
                                audios_feat_dir = args['audios_feat_dir'],
                                clip_vit_b32_dir = args['clip_vit_b32_dir'],
                                clip_qst_dir = args['clip_qst_dir'])

    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=args['num_workers'], collate_fn=AVQA_dataset_collate_fn)

    model.load_state_dict(torch.load(os.path.join(args['model_save_dir'], args['checkpoint'] + ".pt")))
    test(model, test_loader)


--------------- Spatio-temporal Reasoning Network(PSTP-Net) --------------- 



In [None]:
main_test()

  a_ij = F.softmax(torch.bmm(self.w(video_conv), audio_conv.permute(0, 2, 1)) / torch.sqrt(torch.tensor(video_conv.shape[-1]))) # [4, 10, 10]



Audio Count Acc: 54.17 %
Audio Compt Acc: 47.22 %
Audio Averg Acc: 50.00 %

Visual Count Acc: 65.52 %
Visual Local Acc: 30.00 %
Visual Averg Acc: 47.46 %

Audio-Visual Exist Acc: 59.46 %
Audio-Visual Count Acc: 50.00 %
Audio-Visual Local Acc: 35.00 %
Audio-Visual Compt Acc: 56.76 %
Audio-Visual Templ Acc: 20.69 %
Audio-Visual Averg Acc: 45.64 %

---->Overall Accuracy: 46.82 % 

