In [1]:
%load_ext autoreload
%autoreload 2
import torchtext
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import transformers
from transformers import DistilBertTokenizer
from torchinfo import summary
from ibm_dataset import IBMDebater
import utils
from train_text import train_loop
from models.text_model import TextModel
transformers.logging.set_verbosity_error()

In [2]:
text_transform = torchtext.transforms.ToTensor()
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

data_path = 'data/ibm_debater/full'

data = IBMDebater(data_path, 'train', tokenizer=tokenizer, text_transform=text_transform, load_audio=False)
train_len = int(len(data)*0.7)
data_train, data_val = random_split(data, [train_len, len(data) - train_len])

In [3]:
batch_size = 16
loader_train = DataLoader(data_train,
                    batch_size=batch_size,
                    shuffle=True,
                    collate_fn=utils.batch_generator_text,
                    drop_last=True)
loader_val = DataLoader(data_val,
                    batch_size=batch_size,
                    shuffle=False,
                    collate_fn=utils.batch_generator_text,
                    drop_last=True)

In [4]:
model = TextModel(classify=True)
summary(model)

Layer (type:depth-idx)                                  Param #
TextModel                                               --
├─DistilBertModel: 1-1                                  --
│    └─Embeddings: 2-1                                  --
│    │    └─Embedding: 3-1                              (23,440,896)
│    │    └─Embedding: 3-2                              (393,216)
│    │    └─LayerNorm: 3-3                              (1,536)
│    │    └─Dropout: 3-4                                --
│    └─Transformer: 2-2                                 --
│    │    └─ModuleList: 3-5                             42,527,232
├─Dropout: 1-2                                          --
├─Linear: 1-3                                           590,592
├─Dropout: 1-4                                          --
├─ReLU: 1-5                                             --
├─Linear: 1-6                                           769
Total params: 66,954,241
Trainable params: 14,767,105
Non-trainable params

In [5]:
train_loop(model, loader_train, loader_val, 20, 'cuda')

100%|██████████| 93/93 [00:37<00:00,  2.48it/s]


train_loss: 0.6899694460694508 train_accuracy: 0.532258064516129	val_loss: 0.6751947402954102 val_accuracy: 0.5546875


100%|██████████| 93/93 [00:37<00:00,  2.50it/s]


train_loss: 0.5008030835018363 train_accuracy: 0.7446236559139785	val_loss: 0.2543546285480261 val_accuracy: 0.909375


100%|██████████| 93/93 [00:37<00:00,  2.51it/s]


train_loss: 0.15934282338987754 train_accuracy: 0.9495967741935484	val_loss: 0.1251414939528331 val_accuracy: 0.9578125


100%|██████████| 93/93 [00:37<00:00,  2.48it/s]


train_loss: 0.10878299258809576 train_accuracy: 0.9657258064516129	val_loss: 0.13296029062476009 val_accuracy: 0.9515625


100%|██████████| 93/93 [00:37<00:00,  2.45it/s]


train_loss: 0.07643299726569043 train_accuracy: 0.9758064516129032	val_loss: 0.12625482071889566 val_accuracy: 0.959375


100%|██████████| 93/93 [00:37<00:00,  2.47it/s]


train_loss: 0.06248075105450166 train_accuracy: 0.9798387096774194	val_loss: 0.12567857663379983 val_accuracy: 0.959375


100%|██████████| 93/93 [00:37<00:00,  2.48it/s]


train_loss: 0.054576165002760704 train_accuracy: 0.9811827956989247	val_loss: 0.1273996320553124 val_accuracy: 0.9546875


100%|██████████| 93/93 [00:37<00:00,  2.45it/s]


train_loss: 0.04944681028725319 train_accuracy: 0.9852150537634409	val_loss: 0.12579675366869197 val_accuracy: 0.9625


100%|██████████| 93/93 [00:37<00:00,  2.45it/s]


train_loss: 0.04830237257204229 train_accuracy: 0.9838709677419355	val_loss: 0.121961780579295 val_accuracy: 0.9609375


100%|██████████| 93/93 [00:37<00:00,  2.48it/s]


train_loss: 0.04241827991040003 train_accuracy: 0.9865591397849462	val_loss: 0.13045318468939512 val_accuracy: 0.9578125


100%|██████████| 93/93 [00:37<00:00,  2.50it/s]


train_loss: 0.042329616419049684 train_accuracy: 0.9879032258064516	val_loss: 0.12866212776862085 val_accuracy: 0.9578125


100%|██████████| 93/93 [00:37<00:00,  2.49it/s]


train_loss: 0.04456302620250211 train_accuracy: 0.9879032258064516	val_loss: 0.1275550359627232 val_accuracy: 0.959375


In [6]:
from train_text import validate
from torch import nn
validate(model, nn.BCEWithLogitsLoss(), loader_val, 'cuda')

val_loss: 0.12579675366869197 val_accuracy: 0.9625
{'val_loss': 0.12579675366869197, 'val_accuracy': 0.9625}
