# VIS: Model \#2
## Siamese Fusion Discrimination Network for Video-Audio Matching

In [1]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd /content/drive/MyDrive/MIT6-8300_Computer_Vision/Visually-Indicated-Sounds/

/content/drive/MyDrive/MIT6-8300_Computer_Vision/Visually-Indicated-Sounds


In [3]:
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB
from torch.utils.data import DataLoader

from dataloader import VideoAudioDataset, get_random_segment

from constants import AUDIO_SAMPLE_RATE

# !! Put data file location in file `data_filepath`
# If file `data_filepath` does not exist, assume data is located in root
filepath = 'vis-data-256/vis-data-256/'

if os.path.isfile('data_filepath'):
    with open('data_filepath', 'r') as f:
        filepath = f.readline() + filepath

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Active device: ", device)

Active device:  cuda


## Model structure

In [4]:
class FusionVIS(nn.Module):
    def __init__(self):
        super(FusionVIS, self).__init__()

        # audio preprocessing
        self.audio_preprocess = nn.Sequential(
            MelSpectrogram(sample_rate=AUDIO_SAMPLE_RATE, n_fft=2048, hop_length=512, n_mels=128),
            AmplitudeToDB()
        )

        # resnet backbone
        self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.backbone.fc = nn.Identity()

        # define convolutional layers
        # self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
        # self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        # self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        # self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # video

        # define fully connected layers
        self.fc1 = nn.Linear(in_features=1024, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=1)

    def forward(self, video, audio):
        # video preprocessing
        batch_size, seq_len, c, h, w = video.size()
        video = video.view(batch_size * seq_len, c, h, w)

        # audio preprocessing
        spectrogram = self.audio_preprocess(audio)
        spec_3 = spectrogram.repeat(1, 3, 1, 1)

        # backbone
        video = self.backbone(video)
        audio_feat = self.backbone(spec_3)

        # video postprocessing
        video_feat = torch.max(video, dim=0)[0].unsqueeze(0)

        # concatenation
        # print(video_feat.shape)
        # print(audio_feat.shape)
        fusion = torch.cat([video_feat, audio_feat], dim=1)

        fusion = self.fc1(fusion)
        fusion = F.relu(fusion)
        fusion = self.fc2(fusion)
        fusion = F.relu(fusion)
        fusion = self.fc3(fusion)
        fusion = F.sigmoid(fusion)

        return fusion.squeeze()

In [5]:
fusion_model = FusionVIS().to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 368MB/s]


## Training

In [6]:
train_idxs = np.load('datasets/train_dataset.npy')

In [7]:
train_dataset = VideoAudioDataset(train_idxs, device, filepath_prefix=filepath, transform=get_random_segment)

In [8]:
n_epochs = 10
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [9]:
criterion = nn.BCELoss()
optimizer= optim.Adam(fusion_model.parameters(), lr=0.001)

for epoch in range(n_epochs):

  c_loss = 0.0

  for batch_idx, (video_feat, audio_feat, label) in enumerate(train_loader):
    optimizer.zero_grad()

    output = fusion_model(video_feat, audio_feat)
    loss = criterion(output, label.float())

    loss.backward()
    optimizer.step()

    c_loss += loss.item()

    if batch_idx % 8 == 0:
      print(f"Epoch {epoch+1}, batch {batch_idx+1}: loss={c_loss/10:.3f}")
      c_loss = 0.0

OutOfMemoryError: ignored