In [1]:
%load_ext autoreload
%autoreload 2
import torchtext
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import transformers
from transformers import DistilBertTokenizer
from torchinfo import summary
from ibm_dataset import IBMDebater
import utils
from train_text import train_loop
from models.text_model import TextModel
transformers.logging.set_verbosity_error()

In [3]:
text_transform = torchtext.transforms.ToTensor()
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

data_path = 'data/ibm_debater/full'

data = IBMDebater(data_path, 'train', tokenizer=tokenizer, text_transform=text_transform, load_audio=False)
train_len = int(len(data)*0.7)
data_train, data_val = random_split(data, [train_len, len(data) - train_len])

In [8]:
batch_size = 16
loader_train = DataLoader(data_train,
                    batch_size=batch_size,
                    shuffle=True,
                    collate_fn=utils.batch_generator_text,
                    drop_last=True)
loader_val = DataLoader(data_val,
                    batch_size=batch_size,
                    shuffle=False,
                    collate_fn=utils.batch_generator_text,
                    drop_last=True)

In [9]:
model = TextModel()
summary(model)

Layer (type:depth-idx)                                  Param #
TextModel                                               --
├─DistilBertModel: 1-1                                  --
│    └─Embeddings: 2-1                                  --
│    │    └─Embedding: 3-1                              (23,440,896)
│    │    └─Embedding: 3-2                              (393,216)
│    │    └─LayerNorm: 3-3                              (1,536)
│    │    └─Dropout: 3-4                                --
│    └─Transformer: 2-2                                 --
│    │    └─ModuleList: 3-5                             85,054,464
├─Linear: 1-2                                           590,592
├─Linear: 1-3                                           769
├─ReLU: 1-4                                             --
Total params: 109,481,473
Trainable params: 14,767,105
Non-trainable params: 94,714,368

In [24]:
train_loop(model, loader_train, loader_val, 8, 'cuda')

train_loss: 0.6811088868366775 train_accuracy: 0.5376344086021505	val_loss: 0.6221732035279274 val_accuracy: 0.671875
train_loss: 0.3237089231530184 train_accuracy: 0.8729838709677419	val_loss: 0.1835200794506818 val_accuracy: 0.9359375
train_loss: 0.12031921399857408 train_accuracy: 0.959005376344086	val_loss: 0.14195448660757393 val_accuracy: 0.9609375
train_loss: 0.08611857353319083 train_accuracy: 0.9717741935483871	val_loss: 0.14197878866689279 val_accuracy: 0.9546875
train_loss: 0.06781581145340718 train_accuracy: 0.9778225806451613	val_loss: 0.11057145184604451 val_accuracy: 0.96875
train_loss: 0.05014113288232556 train_accuracy: 0.9865591397849462	val_loss: 0.12479066994856111 val_accuracy: 0.9671875
train_loss: 0.03370617529607668 train_accuracy: 0.989247311827957	val_loss: 0.16990083383279853 val_accuracy: 0.95
train_loss: 0.02571467857467391 train_accuracy: 0.9932795698924731	val_loss: 0.13143969392112922 val_accuracy: 0.965625
