In [1]:
%load_ext autoreload
%autoreload 2
import torchtext
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import transformers
from transformers import DistilBertTokenizer
from torchinfo import summary
from ibm_dataset import IBMDebater
import utils
from train_text import train_loop
from models.text_model import TextModel
from models.audio_model import AudioModel
from models.multimodal_model import MultimodalModel
from transformers import DistilBertTokenizer
import torch
torch.cuda.empty_cache()
import numpy as np
import torchaudio
from train_multimodal import train_loop
transformers.logging.set_verbosity_error()

In [2]:
data_path = 'data/ibm_debater/full'
text_transform = torchtext.transforms.ToTensor()
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

data = IBMDebater(data_path, 'train', tokenizer=tokenizer, max_audio_len=5, text_transform=text_transform)
train_len = int(len(data)*0.7)
data_train, data_val = random_split(data, [train_len, len(data) - train_len])

small_data_dim = 0.2
rnd_idx = np.random.choice(np.array([i for i in range(1, len(data))]), size=int(len(data)*small_data_dim))
small_data = torch.utils.data.Subset(data, rnd_idx)
train_len = int(len(small_data)*0.7) 
small_data_train, small_data_val = random_split(small_data, [train_len, len(small_data) - train_len])

In [3]:
batch_size = 8
loader_train = DataLoader(data_train,
                    batch_size=batch_size,
                    shuffle=True,
                    collate_fn=utils.batch_generator_multimodal,
                    drop_last=True)
loader_val = DataLoader(data_val,
                    batch_size=batch_size,
                    shuffle=False,
                    collate_fn=utils.batch_generator_multimodal,
                    drop_last=True)

In [4]:
model = MultimodalModel(chunk_size=5, audio_hidden_state_dim=32, device='cuda')
summary(model)

Layer (type:depth-idx)                                       Param #
MultimodalModel                                              --
├─TextModel: 1-1                                             --
│    └─DistilBertModel: 2-1                                  --
│    │    └─Embeddings: 3-1                                  (23,835,648)
│    │    └─Transformer: 3-2                                 42,527,232
│    └─Linear: 2-2                                           590,592
│    └─ReLU: 2-3                                             --
├─AudioModel: 1-2                                            --
│    └─FeatureExtractor: 2-4                                 --
│    │    └─ModuleList: 3-3                                  (4,200,448)
│    └─Linear: 2-5                                           8,000
│    └─LSTM: 2-6                                             1,576,960
│    └─Linear: 2-7                                           262,656
│    └─ReLU: 2-8                                    

In [5]:
train_loop(model, loader_train=loader_train, loader_val=loader_val, epochs=10, device='cuda')

100%|██████████| 186/186 [11:47<00:00,  3.81s/it]


train_loss: 0.6357030578518427 train_accuracy: 0.594758064516129	val_loss: 0.23303105030208826 val_accuracy: 0.9375


100%|██████████| 186/186 [11:53<00:00,  3.84s/it]


train_loss: 0.1518140535761592 train_accuracy: 0.948252688172043	val_loss: 0.1301706244936213 val_accuracy: 0.9609375


100%|██████████| 186/186 [11:22<00:00,  3.67s/it]


train_loss: 0.10413675051262622 train_accuracy: 0.9663978494623656	val_loss: 0.11475377284223214 val_accuracy: 0.965625


100%|██████████| 186/186 [11:00<00:00,  3.55s/it]


train_loss: 0.07978954175937801 train_accuracy: 0.9778225806451613	val_loss: 0.11580827509169467 val_accuracy: 0.9625


100%|██████████| 186/186 [11:17<00:00,  3.64s/it]


train_loss: 0.07805390325024404 train_accuracy: 0.9717741935483871	val_loss: 0.1139128506416455 val_accuracy: 0.96875


100%|██████████| 186/186 [11:01<00:00,  3.56s/it]


train_loss: 0.05043791734328073 train_accuracy: 0.9845430107526881	val_loss: 0.12072023550281301 val_accuracy: 0.96875


100%|██████████| 186/186 [10:57<00:00,  3.54s/it]


train_loss: 0.03689688009958494 train_accuracy: 0.9872311827956989	val_loss: 0.12722821427159942 val_accuracy: 0.9625


100%|██████████| 186/186 [11:13<00:00,  3.62s/it]


train_loss: 0.0378370397950783 train_accuracy: 0.9899193548387096	val_loss: 0.12102368611376732 val_accuracy: 0.9609375


100%|██████████| 186/186 [10:56<00:00,  3.53s/it]


train_loss: 0.02817545052409993 train_accuracy: 0.991263440860215	val_loss: 0.10647446470829891 val_accuracy: 0.9703125


100%|██████████| 186/186 [10:54<00:00,  3.52s/it]


train_loss: 0.015971869431493643 train_accuracy: 0.9952956989247311	val_loss: 0.13557864781287207 val_accuracy: 0.9734375
