In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import io
import re
import glob
from pprint import pprint, pformat

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from IPython.display import display, Markdown, Latex
from tqdm import tqdm

from load_audio import  AudioFeatureDataset
from load_text import TextDataset
from load_video import VideoDataset
from load_multimodal_data import MultimodalDataset
from position_encoder import PositionalEncoding
from encoders import AudioEncoder, VideoEncoder, DotProductAttention
from decoder import MultimodalDecoder

In [3]:
# path_how2 = "/Volumes/LaCie/vision/data/" # Jeremy
path_how2 = "/Volumes/T7/data/" # Romain

video_path = os.path.join(path_how2, "resnext101-action-avgpool-300h", "train.npy")

texts_path = os.path.join(path_how2,"how2-300h-v1/data/train", "text.en")
embeddings_path = os.path.join(path_how2, "how2-release/word_embedding/","cmu_partition.train.vec")

In [6]:
video_dataset = VideoDataset(video_path)
audio_dataset = AudioFeatureDataset(path_how2,"train")
text_dataset = TextDataset(texts_path, embeddings_path)

print("Len Video: ", len(video_dataset))
print("Len Audio: ", len(audio_dataset))
print("Len Text: ", len(text_dataset))

multimodal_dataset = MultimodalDataset(video_dataset, audio_dataset, text_dataset)

print("\nLen Multimodal: ", len(multimodal_dataset))

Len Video:  184949
Len Audio:  184949
Len Text:  184949

Len Multimodal:  184949


In [33]:
class Net(nn.Module):
    def __init__(self, vocab_size, text_size=225):
        super().__init__()
        
        d_model = 480
        d_feedforward = 1920
        dropout = 0.2
        nhead = 6
        nlayer_audio = 6
        nlayer_video = 1

        video_dim = 2048
        audio_size = 10810
        audio_dim = 43
        tied = 48
        down_sampling_factor = 10

        self.audio_encoder = AudioEncoder(
            audio_dim,
            audio_size,
            tied,
            nhead,
            nlayer_audio,
            d_model,
            d_feedforward,
            dropout,
            down_sampling_factor
        )

        self.video_encoder = VideoEncoder(
            video_dim,
            nhead,
            nlayer_video,
            d_model,
            d_feedforward,
            dropout
        )

        self.fusion = DotProductAttention(d_model, 480, 480)
        
        text_dim = 100
        self.text_size = text_size
#         vocab_size = len(vocab_id_dict)
        n_layer = 4
        d_model = 480
        d_feedforward = 1920
        dropout = 0.2
        nhead = 6

        self.decoder = MultimodalDecoder(
            text_dim, self.text_size, vocab_size, nhead,  n_layer, d_model, d_feedforward, dropout
        )
        
    def forward(self, video, audio, text):
        audio_encoding = self.audio_encoder(audio)
        video_encoding = self.video_encoder(video)
        merge_encoding = self.fusion(audio_encoding, video_encoding, video_encoding)
        mask = self.decoder.generate_square_subsequent_mask(self.text_size)
        decoded = self.decoder(text, merge_encoding, tgt_mask=mask)
        return decoded

In [30]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")  # you can continue going on here, like cuda:1 cuda:2....etc. 
    print("Running on the GPU")
else:
    device = torch.device("cpu")
    print("Running on the CPU")

Running on the CPU


In [69]:
def process_batch(batch):
    audio = batch["audio"].float()
    video = batch["video"]["video"].unsqueeze(1).float()
    text_emb = batch["text"]["embedding"].float()
    text_id = batch["text"]["id_embedding"].long()
    
    return video, audio, text_emb, text_id

word2id= text_dataset.vocab_id_dict
id2word = text_dataset.id_vocab_dict

optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss() # includes softmax

def train(net):
    BATCH_SIZE = 10
    EPOCHS = 3
    
    dataloader = DataLoader(multimodal_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    
    loss_hist = []
    
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch}")
        
        loss_hist_epoch = []
        
        for batch in tqdm(dataloader):
            
            video, audio, text_emb, text_id = process_batch(batch)
            
            audio = audio.to(device)
            video = video.to(device)
            text_emb = text_emb.to(device)
            text_id = text_id.to(device)
            
            net.zero_grad()
            
            output = net(video, audio, text_emb)
             # out is of shape [225, 4, 36756] => now of shape [4, 225, 36756]
            output = torch.transpose(output, 0, 1).reshape(-1, vocab_size)
            
            target = text_id.view(-1).long()
            loss = criterion(output,target)

            loss.backward()
            optimizer.step()
            
            loss_hist_epoch.append(loss)
            print(loss)
            
        loss_hist.append(loss_hist_epoch)
        break

In [70]:
vocab_size = len(word2id)
net = Net(vocab_size)

train(net)

  0%|          | 0/18495 [00:00<?, ?it/s]

Epoch 0


  0%|          | 1/18495 [00:25<131:16:40, 25.55s/it]

tensor(9.8445, grad_fn=<NllLossBackward>)


  0%|          | 2/18495 [00:51<131:18:08, 25.56s/it]

tensor(9.8333, grad_fn=<NllLossBackward>)


  0%|          | 3/18495 [01:16<131:19:20, 25.57s/it]

tensor(9.7865, grad_fn=<NllLossBackward>)


  0%|          | 4/18495 [01:42<131:01:41, 25.51s/it]

tensor(9.7953, grad_fn=<NllLossBackward>)


  0%|          | 5/18495 [02:07<131:12:00, 25.54s/it]

tensor(9.7647, grad_fn=<NllLossBackward>)


  0%|          | 6/18495 [02:33<131:54:16, 25.68s/it]

tensor(9.7525, grad_fn=<NllLossBackward>)


  0%|          | 7/18495 [02:59<132:06:49, 25.73s/it]

tensor(9.7162, grad_fn=<NllLossBackward>)


  0%|          | 8/18495 [03:26<133:34:30, 26.01s/it]

tensor(9.7050, grad_fn=<NllLossBackward>)


  0%|          | 9/18495 [03:52<134:15:50, 26.15s/it]

tensor(9.7530, grad_fn=<NllLossBackward>)


  0%|          | 10/18495 [04:18<134:12:23, 26.14s/it]

tensor(9.7994, grad_fn=<NllLossBackward>)


  0%|          | 11/18495 [04:44<133:46:03, 26.05s/it]

tensor(9.7612, grad_fn=<NllLossBackward>)


  0%|          | 12/18495 [05:10<133:29:45, 26.00s/it]

tensor(9.7150, grad_fn=<NllLossBackward>)


  0%|          | 13/18495 [05:36<133:45:35, 26.05s/it]

tensor(9.7163, grad_fn=<NllLossBackward>)


  0%|          | 14/18495 [06:03<134:44:35, 26.25s/it]

tensor(9.7668, grad_fn=<NllLossBackward>)


  0%|          | 15/18495 [06:29<134:21:14, 26.17s/it]

tensor(9.7646, grad_fn=<NllLossBackward>)


  0%|          | 16/18495 [06:55<134:03:35, 26.12s/it]

tensor(9.7504, grad_fn=<NllLossBackward>)


  0%|          | 17/18495 [07:21<133:39:41, 26.04s/it]

tensor(9.7912, grad_fn=<NllLossBackward>)


  0%|          | 18/18495 [07:47<133:29:38, 26.01s/it]

tensor(9.7486, grad_fn=<NllLossBackward>)


  0%|          | 19/18495 [08:13<133:16:51, 25.97s/it]

tensor(9.7669, grad_fn=<NllLossBackward>)


  0%|          | 20/18495 [08:39<134:24:07, 26.19s/it]

tensor(9.7733, grad_fn=<NllLossBackward>)


  0%|          | 21/18495 [09:06<134:59:20, 26.31s/it]

tensor(9.7426, grad_fn=<NllLossBackward>)


  0%|          | 22/18495 [09:32<134:33:58, 26.22s/it]

tensor(9.7497, grad_fn=<NllLossBackward>)


  0%|          | 23/18495 [09:58<133:47:59, 26.08s/it]

tensor(9.7345, grad_fn=<NllLossBackward>)


  0%|          | 24/18495 [10:24<134:41:24, 26.25s/it]

tensor(9.8353, grad_fn=<NllLossBackward>)


  0%|          | 25/18495 [10:50<134:21:09, 26.19s/it]

tensor(9.7842, grad_fn=<NllLossBackward>)


  0%|          | 26/18495 [11:17<134:50:27, 26.28s/it]

tensor(9.7974, grad_fn=<NllLossBackward>)


  0%|          | 27/18495 [11:43<134:15:30, 26.17s/it]

tensor(9.8029, grad_fn=<NllLossBackward>)


  0%|          | 28/18495 [12:09<134:52:15, 26.29s/it]

tensor(9.7414, grad_fn=<NllLossBackward>)


  0%|          | 29/18495 [12:36<134:49:54, 26.29s/it]

tensor(9.7929, grad_fn=<NllLossBackward>)


  0%|          | 30/18495 [13:01<134:13:39, 26.17s/it]

tensor(9.7768, grad_fn=<NllLossBackward>)


  0%|          | 31/18495 [13:27<133:49:44, 26.09s/it]

tensor(9.8137, grad_fn=<NllLossBackward>)


  0%|          | 32/18495 [13:53<133:23:53, 26.01s/it]

tensor(9.7461, grad_fn=<NllLossBackward>)


  0%|          | 33/18495 [14:19<133:25:35, 26.02s/it]

tensor(9.7357, grad_fn=<NllLossBackward>)


  0%|          | 34/18495 [14:45<133:32:08, 26.04s/it]

tensor(9.7470, grad_fn=<NllLossBackward>)


  0%|          | 35/18495 [15:11<133:01:30, 25.94s/it]

tensor(9.7690, grad_fn=<NllLossBackward>)


  0%|          | 36/18495 [15:37<132:42:05, 25.88s/it]

tensor(9.6890, grad_fn=<NllLossBackward>)


  0%|          | 37/18495 [16:03<132:47:33, 25.90s/it]

tensor(9.7741, grad_fn=<NllLossBackward>)


  0%|          | 37/18495 [16:28<137:02:42, 26.73s/it]


KeyboardInterrupt: 

tensor(10.9977, grad_fn=<NllLossBackward>)