In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import io
import re
import glob
from pprint import pprint, pformat

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from IPython.display import display, Markdown, Latex
from tqdm import tqdm

from scripts.load_audio_v1 import  AudioFeatureDataset
from scripts.load_text import TextDataset
from scripts.load_video import VideoDataset
from scripts.load_multimodal_data import MultimodalDataset
from scripts.position_encoder import PositionalEncoding
from scripts.encoders import AudioEncoder, VideoEncoder, DotProductAttention
from scripts.decoder import MultimodalDecoder

In [3]:
# path_how2 = "/Volumes/LaCie/vision/data/" # Jeremy
path_how2 = "/Volumes/T7/data/" # Romain

video_path = os.path.join(path_how2, "resnext101-action-avgpool-300h", "train.npy")

texts_path = os.path.join(path_how2,"how2-300h-v1/data/train", "text.en")
embeddings_path = os.path.join(path_how2, "how2-release/word_embedding/","cmu_partition.train.vec")

In [4]:
video_dataset = VideoDataset(video_path)
audio_dataset = AudioFeatureDataset(path_how2,"train")
text_dataset = TextDataset(texts_path, embeddings_path)

print("Len Video: ", len(video_dataset))
print("Len Audio: ", len(audio_dataset))
print("Len Text: ", len(text_dataset))

multimodal_dataset = MultimodalDataset(video_dataset, audio_dataset, text_dataset)

print("\nLen Multimodal: ", len(multimodal_dataset))

  return  np.array(text), np.array(splitted), largest_split


Len Video:  184949
Len Audio:  184949
Len Text:  184949

Len Multimodal:  184949


In [5]:
class Net(nn.Module):
    def __init__(self, vocab_size, text_size=225):
        super().__init__()
        
        d_model = 480
        d_feedforward = 1920
        dropout = 0.2
        nhead = 6
        nlayer_audio = 6
        nlayer_video = 1

        video_dim = 2048
        audio_size = 10810
        audio_dim = 43
        tied = 48
        down_sampling_factor = 10

        self.audio_encoder = AudioEncoder(
            audio_dim,
            audio_size,
            tied,
            nhead,
            nlayer_audio,
            d_model,
            d_feedforward,
            dropout,
            down_sampling_factor
        )

        self.video_encoder = VideoEncoder(
            video_dim,
            nhead,
            nlayer_video,
            d_model,
            d_feedforward,
            dropout
        )

        self.fusion = DotProductAttention(d_model, 480, 480)
        
        text_dim = 100
        self.text_size = text_size
#         vocab_size = len(vocab_id_dict)
        n_layer = 4
        d_model = 480
        d_feedforward = 1920
        dropout = 0.2
        nhead = 6

        self.decoder = MultimodalDecoder(
            text_dim, self.text_size, vocab_size, nhead,  n_layer, d_model, d_feedforward, dropout
        )
        
    def forward(self, video, audio, text):
        audio_encoding = self.audio_encoder(audio)
        video_encoding = self.video_encoder(video)
        merge_encoding = self.fusion(audio_encoding, video_encoding, video_encoding)
        mask = self.decoder.generate_square_subsequent_mask(self.text_size)
        decoded = self.decoder(text, merge_encoding, tgt_mask=mask)
        return decoded

In [6]:
word2id= text_dataset.vocab_id_dict
id2word = text_dataset.id_vocab_dict

vocab_size = len(word2id)
net = Net(vocab_size)

In [7]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")  # you can continue going on here, like cuda:1 cuda:2....etc. 
    print("Running on the GPU")
else:
    device = torch.device("cpu")
    print("Running on the CPU")

Running on the CPU


In [10]:
def process_batch(batch):
    audio = batch["audio"].float()
    video = batch["video"]["video"].unsqueeze(1).float()
    text_emb = batch["text"]["embedding"].float()
    text_id = batch["text"]["id_embedding"].long()
    
    return video, audio, text_emb, text_id

optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss() # includes softmax

def train(net):
    BATCH_SIZE = 10
    EPOCHS = 3
    
    dataloader = DataLoader(multimodal_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    
    loss_hist = []
    
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch}")
        
        loss_hist_epoch = []
        
        for batch in tqdm(dataloader):
            
            video, audio, text_emb, text_id = process_batch(batch)
            
            audio = audio.to(device)
            video = video.to(device)
            text_emb = text_emb.to(device)
            text_id = text_id.to(device)
            
            net.zero_grad()
            
            output = net(video, audio, text_emb)
             # out is of shape [225, 4, 36756] => now of shape [4, 225, 36756]
            output = torch.transpose(output, 0, 1).reshape(-1, vocab_size)
            
            target = text_id.view(-1).long()
            loss = criterion(output,target)

            loss.backward()
            optimizer.step()
            
            loss_hist_epoch.append(loss.item())
            print(loss_hist_epoch[-1])
            
        loss_hist.append(loss_hist_epoch)
        break

In [11]:
train(net)

  0%|          | 0/18495 [00:00<?, ?it/s]

Epoch 0


  0%|          | 1/18495 [00:45<235:54:50, 45.92s/it]

2.2635603


  0%|          | 2/18495 [01:29<231:32:24, 45.07s/it]

1.456046


  0%|          | 2/18495 [01:36<247:15:01, 48.13s/it]


KeyboardInterrupt: 