In [1]:
%load_ext autoreload
%autoreload 2
import torchtext
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import transformers
from transformers import DistilBertTokenizer
from torchinfo import summary
from ibm_dataset import IBMDebater
import utils
from train_text import train_loop
from models.text_model import TextModel
transformers.logging.set_verbosity_error()

In [2]:
text_transform = torchtext.transforms.ToTensor()
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

data_path = 'data/ibm_debater/full'

data = IBMDebater(data_path, 'train', tokenizer=tokenizer, text_transform=text_transform, load_audio=False)
train_len = int(len(data)*0.7)
data_train, data_val = random_split(data, [train_len, len(data) - train_len])

In [3]:
batch_size = 16
loader_train = DataLoader(data_train,
                    batch_size=batch_size,
                    shuffle=True,
                    collate_fn=utils.batch_generator_text,
                    drop_last=True)
loader_val = DataLoader(data_val,
                    batch_size=batch_size,
                    shuffle=False,
                    collate_fn=utils.batch_generator_text,
                    drop_last=True)

In [4]:
model = TextModel(classify=True)
summary(model)

Layer (type:depth-idx)                                  Param #
TextModel                                               --
├─DistilBertModel: 1-1                                  --
│    └─Embeddings: 2-1                                  --
│    │    └─Embedding: 3-1                              (23,440,896)
│    │    └─Embedding: 3-2                              (393,216)
│    │    └─LayerNorm: 3-3                              (1,536)
│    │    └─Dropout: 3-4                                --
│    └─Transformer: 2-2                                 --
│    │    └─ModuleList: 3-5                             42,527,232
├─Linear: 1-2                                           590,592
├─ReLU: 1-3                                             --
├─Linear: 1-4                                           769
Total params: 66,954,241
Trainable params: 14,767,105
Non-trainable params: 52,187,136

In [None]:
train_loop(model, loader_train, loader_val, 4, 'cuda')