In [1]:
%load_ext autoreload
%autoreload 2
import torchtext
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import transformers
from transformers import DistilBertTokenizer
from torchinfo import summary
from ibm_dataset import IBMDebater
import utils
from train_text import train_loop
from models.text_model import TextModel
from models.audio_model import AudioModel
from models.multimodal_model import MultimodalModel
from transformers import DistilBertTokenizer
import torch
import numpy as np
import torchaudio
from train_multimodal import train_loop
transformers.logging.set_verbosity_error()

In [2]:
data_path = 'data/ibm_debater/full'
text_transform = torchtext.transforms.ToTensor()
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

data = IBMDebater(data_path, 'train', tokenizer=tokenizer, max_audio_len=5, text_transform=text_transform)
train_len = int(len(data)*0.7)
data_train, data_val = random_split(data, [train_len, len(data) - train_len])

small_data_dim = 0.2
rnd_idx = np.random.choice(np.array([i for i in range(1, len(data))]), size=int(len(data)*small_data_dim))
small_data = torch.utils.data.Subset(data, rnd_idx)
train_len = int(len(small_data)*0.7) 
small_data_train, small_data_val = random_split(small_data, [train_len, len(small_data) - train_len])

In [3]:
batch_size = 4
loader_train = DataLoader(small_data_train,
                    batch_size=batch_size,
                    shuffle=True,
                    collate_fn=utils.batch_generator_multimodal,
                    drop_last=True)
loader_val = DataLoader(small_data_val,
                    batch_size=batch_size,
                    shuffle=False,
                    collate_fn=utils.batch_generator_multimodal,
                    drop_last=True)

In [4]:
model = MultimodalModel(chunk_size=5, audio_hidden_state_dim=32, device='cuda')
model.cuda()
summary(model)

Layer (type:depth-idx)                                       Param #
MultimodalModel                                              --
├─TextModel: 1-1                                             --
│    └─DistilBertModel: 2-1                                  --
│    │    └─Embeddings: 3-1                                  (23,835,648)
│    │    └─Transformer: 3-2                                 85,054,464
│    └─Linear: 2-2                                           590,592
│    └─Linear: 2-3                                           769
│    └─ReLU: 2-4                                             --
├─AudioModel: 1-2                                            --
│    └─FeatureExtractor: 2-5                                 --
│    │    └─ModuleList: 3-3                                  4,200,448
│    └─Linear: 2-6                                           8,000
│    └─Linear: 2-7                                           513
│    └─ReLU: 2-8                                             --


In [5]:
train_loop(model, loader_train=loader_train, loader_val=loader_val, epochs=4, device='cuda')

100%|██████████| 74/74 [02:46<00:00,  2.26s/it]


train_loss: 0.7058048191908244 train_accuracy: 0.5135135135135135	val_loss: 0.6931471824645996 val_accuracy: 0.5


100%|██████████| 74/74 [02:47<00:00,  2.27s/it]


train_loss: 0.6931471824645996 train_accuracy: 0.48986486486486486	val_loss: 0.6931471824645996 val_accuracy: 0.5


100%|██████████| 74/74 [02:49<00:00,  2.29s/it]


train_loss: 0.6931471824645996 train_accuracy: 0.4864864864864865	val_loss: 0.6931471824645996 val_accuracy: 0.5


  1%|▏         | 1/74 [00:03<04:22,  3.60s/it]


KeyboardInterrupt: 