In [5]:
! rm -rf nano-BERT
! rm -rf space-model
! git clone https://github.com/StepanTita/nano-BERT.git
! git clone https://github.com/StepanTita/space-model.git

Cloning into 'nano-BERT'...
remote: Enumerating objects: 32, done.[K
remote: Counting objects: 100% (32/32), done.[K
remote: Compressing objects: 100% (24/24), done.[K
remote: Total 32 (delta 13), reused 22 (delta 6), pack-reused 0[K
Receiving objects: 100% (32/32), 38.28 MiB | 27.78 MiB/s, done.
Resolving deltas: 100% (13/13), done.
Cloning into 'space-model'...
remote: Enumerating objects: 19, done.[K
remote: Counting objects: 100% (19/19), done.[K
remote: Compressing objects: 100% (13/13), done.[K
remote: Total 19 (delta 4), reused 18 (delta 3), pack-reused 0[K
Receiving objects: 100% (19/19), 2.08 MiB | 13.81 MiB/s, done.
Resolving deltas: 100% (4/4), done.


In [6]:
import sys

sys.path.append('/content/nano-BERT')
sys.path.append('/content/space-model')

In [23]:
import math
import json
from collections import Counter

import torch
import torch.nn.functional as F

from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split

from tqdm import tqdm

import matplotlib.pyplot as plt
import plotly.graph_objects as go

from model import NanoBertForClassification
from tokenizer import WordTokenizer

from space_model.model import SpaceModelForClassification
from space_model.loss import *

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [12]:
data = None
with open('nano-BERT/data/imdb_train.json') as f:
    data = [json.loads(l) for l in f.readlines()]

In [13]:
vocab = set()
for d in data:
    vocab |= set([w.lower() for w in d['text']])

In [14]:
test_data = None
with open('nano-BERT/data/imdb_test.json') as f:
   test_data = [json.loads(l) for l in f.readlines()]

In [15]:
def encode_label(label):
    if label == 'pos':
        return 1
    elif label == 'neg':
        return 0
    raise Exception(f'Unknown Label: {label}!')


class IMDBDataloader:
    def __init__(self, data, test_data, tokenizer, label_encoder, batch_size, val_frac=0.2):
        train_data, val_data = train_test_split(data, shuffle=True, random_state=42, test_size=val_frac)

        self.splits = {
            'train': [d['text'] for d in train_data],
            'test': [d['text'] for d in test_data],
            'val': [d['text'] for d in val_data]
        }

        self.labels = {
            'train': [d['label'] for d in train_data],
            'test': [d['label'] for d in test_data],
            'val': [d['label'] for d in val_data]
        }

        self.tokenized = {
            'train': [tokenizer(record).unsqueeze(0) for record in
                      tqdm(self.splits['train'], desc='Train Tokenization')],
            'test': [tokenizer(record).unsqueeze(0) for record in tqdm(self.splits['test'], desc='Test Tokenization')],
            'val': [tokenizer(record).unsqueeze(0) for record in tqdm(self.splits['val'], desc='Val Tokenization')],
        }

        self.encoded_labels = {
            'train': [label_encoder(label) for label in tqdm(self.labels['train'], desc='Train Label Encoding')],
            'test': [label_encoder(label) for label in tqdm(self.labels['test'], desc='Test Label Encoding')],
            'val': [label_encoder(label) for label in tqdm(self.labels['val'], desc='Val Label Encoding')],
        }

        self.curr_batch = 0
        self.batch_size = batch_size
        self.iterate_split = None

    def peek(self, split):
        return {
            'input_ids': self.splits[split][self.batch_size * self.curr_batch:self.batch_size * (self.curr_batch + 1)],
            'label_ids': self.labels[split][self.batch_size * self.curr_batch:self.batch_size * (self.curr_batch + 1)],
        }

    def take(self, split):
        batch = self.splits[split][self.batch_size * self.curr_batch:self.batch_size * (self.curr_batch + 1)]
        labels = self.labels[split][self.batch_size * self.curr_batch:self.batch_size * (self.curr_batch + 1)]
        self.curr_batch += 1
        return {
            'input_ids': batch,
            'label_ids': labels,
        }

    def peek_tokenized(self, split):
        return {
            'input_ids': torch.cat(
                self.tokenized[split][self.batch_size * self.curr_batch:self.batch_size * (self.curr_batch + 1)],
                dim=0),
            'label_ids': torch.tensor(
                self.encoded_labels[split][self.batch_size * self.curr_batch:self.batch_size * (self.curr_batch + 1)],
                dtype=torch.long),
        }

    def peek_index_tokenized(self, index, split):
        return {
            'input_ids': torch.cat(
                [self.tokenized[split][index]],
                dim=0),
            'label_ids': torch.tensor(
                [self.encoded_labels[split][index]],
                dtype=torch.long),
        }

    def peek_index(self, index, split):
        return {
            'input_ids': [self.splits[split][index]],
            'label_ids': [self.labels[split][index]],
        }

    def take_tokenized(self, split):
        batch = self.tokenized[split][self.batch_size * self.curr_batch:self.batch_size * (self.curr_batch + 1)]
        labels = self.encoded_labels[split][self.batch_size * self.curr_batch:self.batch_size * (self.curr_batch + 1)]
        self.curr_batch += 1
        return {
            'input_ids': torch.cat(batch, dim=0),
            'label_ids': torch.tensor(labels, dtype=torch.long),
        }

    def get_split(self, split):
        self.iterate_split = split
        return self

    def steps(self, split):
        return len(self.tokenized[split]) // self.batch_size

    def __iter__(self):
        self.reset()
        return self

    def __next__(self):
        if self.batch_size * self.curr_batch < len(self.splits[self.iterate_split]):
            return self.take_tokenized(self.iterate_split)
        else:
            raise StopIteration

    def reset(self):
        self.curr_batch = 0

In [16]:
NUM_EPOCHS = 50
BATCH_SIZE = 32
MAX_SEQ_LEN = 128
LEARNING_RATE = 1e-2

In [17]:
tokenizer = WordTokenizer(vocab=vocab, max_seq_len=MAX_SEQ_LEN)
tokenizer

Tokenizer[vocab=101522,self.special_tokens=['[PAD]', '[CLS]', '[SEP]', '[UNK]'],self.sep=' ',self.max_seq_len=128]

In [18]:
dataloader = IMDBDataloader(data, test_data, tokenizer, encode_label, batch_size=BATCH_SIZE)

Train Tokenization: 100%|██████████| 20000/20000 [00:02<00:00, 8334.15it/s]
Test Tokenization: 100%|██████████| 25000/25000 [00:01<00:00, 16258.75it/s]
Val Tokenization: 100%|██████████| 5000/5000 [00:00<00:00, 15976.46it/s]
Train Label Encoding: 100%|██████████| 20000/20000 [00:00<00:00, 1264010.85it/s]
Test Label Encoding: 100%|██████████| 25000/25000 [00:00<00:00, 1543567.10it/s]
Val Label Encoding: 100%|██████████| 5000/5000 [00:00<00:00, 1350387.64it/s]


In [19]:
bert = NanoBertForClassification(
    vocab_size=len(tokenizer.vocab),
    n_layers=1,
    n_heads=1,
    max_seq_len=MAX_SEQ_LEN,
    n_classes=2
).to(device)
bert

NanoBertForClassification(
  (nano_bert): NanoBERT(
    (embedding): BertEmbeddings(
      (word_embeddings): Embedding(101522, 3)
      (pos_embeddings): Embedding(128, 3)
      (layer_norm): LayerNorm((3,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layers): ModuleList(
        (0): BertLayer(
          (layer_norm1): LayerNorm((3,), eps=1e-05, elementwise_affine=True)
          (self_attention): BertSelfAttention(
            (heads): ModuleList(
              (0): BertAttentionHead(
                (query): Linear(in_features=3, out_features=3, bias=True)
                (key): Linear(in_features=3, out_features=3, bias=True)
                (values): Linear(in_features=3, out_features=3, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (proj): Linear(in_features=3, out_features=3, bias=True)
            (dropout): Dropout(p=0.1, inplace=Fals

In [20]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [21]:
count_parameters(bert)

305123

In [22]:
optimizer = torch.optim.Adam(bert.parameters(), lr=LEARNING_RATE)

for i in range(NUM_EPOCHS):
    print(f'Epoch: {i + 1}')
    train_loss = 0.0
    train_preds = []
    train_labels = []

    bert.train()
    for step, batch in enumerate(tqdm(dataloader.get_split('train'), total=dataloader.steps('train'))):
        logits = bert(batch['input_ids'].to(device)) # (B, Seq_Len, 2)

        probs = F.softmax(logits[:, 0, :], dim=-1).cpu()
        pred = torch.argmax(probs, dim=-1) # (B)
        train_preds += pred.detach().tolist()
        train_labels += [l.item() for l in batch['label_ids']]

        loss = F.cross_entropy(logits[:, 0, :].cpu(), batch['label_ids'])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    val_loss = 0.0
    val_preds = []
    val_labels = []

    bert.eval()
    for step, batch in enumerate(tqdm(dataloader.get_split('val'), total=dataloader.steps('val'))):
        logits = bert(batch['input_ids'].to(device))

        probs = F.softmax(logits[:, 0, :], dim=-1).cpu()
        pred = torch.argmax(probs, dim=-1) # (B)
        val_preds += pred.detach().tolist()
        val_labels += [l.item() for l in batch['label_ids']]

        loss = F.cross_entropy(logits[:, 0, :].cpu(), batch['label_ids'])

        val_loss += loss.item()

    print()
    print(f'Train loss: {train_loss / dataloader.steps("train")} | Val loss: {val_loss / dataloader.steps("val")}')
    print(f'Train acc: {accuracy_score(train_labels, train_preds)} | Val acc: {accuracy_score(val_labels, val_preds)}')
    print(f'Train f1: {f1_score(train_labels, train_preds)} | Val f1: {f1_score(val_labels, val_preds)}')

Epoch: 1


100%|██████████| 625/625 [00:05<00:00, 110.48it/s]
157it [00:00, 174.67it/s]                         



Train loss: 0.69424509973526 | Val loss: 0.6982253128901507
Train acc: 0.5035 | Val acc: 0.5034
Train f1: 0.44251066696609026 | Val f1: 0.5033006601320265
Epoch: 2


100%|██████████| 625/625 [00:03<00:00, 201.68it/s]
157it [00:00, 242.28it/s]                         



Train loss: 0.6143448389053344 | Val loss: 0.5143855911416885
Train acc: 0.6317 | Val acc: 0.7508
Train f1: 0.6353465346534652 | Val f1: 0.7579642579642579
Epoch: 3


100%|██████████| 625/625 [00:03<00:00, 203.81it/s]
157it [00:01, 141.68it/s]                         



Train loss: 0.42988802280426025 | Val loss: 0.35617833622755146
Train acc: 0.80655 | Val acc: 0.8488
Train f1: 0.8079996029973698 | Val f1: 0.850887573964497
Epoch: 4


100%|██████████| 625/625 [00:03<00:00, 166.83it/s]
157it [00:01, 154.40it/s]                         



Train loss: 0.3420516100406647 | Val loss: 0.24630708877856916
Train acc: 0.85795 | Val acc: 0.9056
Train f1: 0.8583819350979511 | Val f1: 0.9066824831949387
Epoch: 5


100%|██████████| 625/625 [00:04<00:00, 137.67it/s]
157it [00:00, 189.07it/s]                         



Train loss: 0.2893947966814041 | Val loss: 0.17396821052982256
Train acc: 0.8851 | Val acc: 0.9384
Train f1: 0.8854208216992421 | Val f1: 0.9388403494837172
Epoch: 6


100%|██████████| 625/625 [00:02<00:00, 226.95it/s]
157it [00:00, 245.23it/s]                         



Train loss: 0.24356074215769768 | Val loss: 0.12202644179491565
Train acc: 0.9037 | Val acc: 0.962
Train f1: 0.9034296028880866 | Val f1: 0.9622266401590457
Epoch: 7


100%|██████████| 625/625 [00:03<00:00, 197.43it/s]
157it [00:00, 186.45it/s]                         



Train loss: 0.21535038361400366 | Val loss: 0.0923539778077378
Train acc: 0.9183 | Val acc: 0.9698
Train f1: 0.918185459643501 | Val f1: 0.9700218383958706
Epoch: 8


100%|██████████| 625/625 [00:02<00:00, 228.86it/s]
157it [00:00, 242.09it/s]                         



Train loss: 0.19683323801085353 | Val loss: 0.06851655008563867
Train acc: 0.92585 | Val acc: 0.9812
Train f1: 0.9256455251942842 | Val f1: 0.9813195548489666
Epoch: 9


100%|██████████| 625/625 [00:02<00:00, 233.00it/s]
157it [00:00, 240.45it/s]                         



Train loss: 0.18306916678845883 | Val loss: 0.06520352853933731
Train acc: 0.93115 | Val acc: 0.9806
Train f1: 0.9311190035515983 | Val f1: 0.9806811392152958
Epoch: 10


100%|██████████| 625/625 [00:02<00:00, 233.31it/s]
157it [00:00, 189.45it/s]                         



Train loss: 0.16266239131987095 | Val loss: 0.05700734513587891
Train acc: 0.93935 | Val acc: 0.9808
Train f1: 0.9389808340459782 | Val f1: 0.9809296781883194
Epoch: 11


100%|██████████| 625/625 [00:04<00:00, 148.20it/s]
157it [00:00, 212.40it/s]                         



Train loss: 0.15268866265565156 | Val loss: 0.04742882347319466
Train acc: 0.94265 | Val acc: 0.986
Train f1: 0.9425954656924078 | Val f1: 0.9860945570123162
Epoch: 12


100%|██████████| 625/625 [00:03<00:00, 170.94it/s]
157it [00:01, 134.55it/s]                         



Train loss: 0.14203243795484305 | Val loss: 0.04384835634361176
Train acc: 0.949 | Val acc: 0.9866
Train f1: 0.9488567990373046 | Val f1: 0.986687860123187
Epoch: 13


100%|██████████| 625/625 [00:06<00:00, 89.98it/s] 
157it [00:01, 81.97it/s]                         



Train loss: 0.13767368436306715 | Val loss: 0.04215710751939183
Train acc: 0.9487 | Val acc: 0.988
Train f1: 0.9485611150105284 | Val f1: 0.9880999603332011
Epoch: 14


100%|██████████| 625/625 [00:04<00:00, 140.13it/s]
157it [00:00, 166.54it/s]                         



Train loss: 0.12467939331457019 | Val loss: 0.03255721302188109
Train acc: 0.95365 | Val acc: 0.9886
Train f1: 0.9536662168241117 | Val f1: 0.9886747466719651
Epoch: 15


100%|██████████| 625/625 [00:04<00:00, 135.80it/s]
157it [00:00, 221.61it/s]                         



Train loss: 0.12210243988651782 | Val loss: 0.0301575575482648
Train acc: 0.9555 | Val acc: 0.9908
Train f1: 0.9555088982203559 | Val f1: 0.9908730158730158
Epoch: 16


100%|██████████| 625/625 [00:04<00:00, 132.21it/s]
157it [00:01, 142.94it/s]                         



Train loss: 0.11276096128225327 | Val loss: 0.027088777085177405
Train acc: 0.95825 | Val acc: 0.9934
Train f1: 0.9582186639979985 | Val f1: 0.9934458788480635
Epoch: 17


100%|██████████| 625/625 [00:04<00:00, 142.57it/s]
157it [00:00, 206.83it/s]                         



Train loss: 0.11333809667378664 | Val loss: 0.023306277559477635
Train acc: 0.9598 | Val acc: 0.9938
Train f1: 0.9596871239470517 | Val f1: 0.993840651698788
Epoch: 18


100%|██████████| 625/625 [00:04<00:00, 140.05it/s]
157it [00:00, 173.93it/s]                         



Train loss: 0.11091501301638781 | Val loss: 0.027645489815068477
Train acc: 0.96 | Val acc: 0.9912
Train f1: 0.9600199900049975 | Val f1: 0.9912594358363132
Epoch: 19


100%|██████████| 625/625 [00:04<00:00, 140.15it/s]
157it [00:01, 98.73it/s]                          



Train loss: 0.10558255863711238 | Val loss: 0.023741665846804
Train acc: 0.96285 | Val acc: 0.9924
Train f1: 0.9627848735286751 | Val f1: 0.9924513309495432
Epoch: 20


100%|██████████| 625/625 [00:04<00:00, 140.60it/s]
157it [00:01, 146.02it/s]                         



Train loss: 0.10275407341057435 | Val loss: 0.024094740706529522
Train acc: 0.9632 | Val acc: 0.993
Train f1: 0.9631705364291434 | Val f1: 0.9930514194957315
Epoch: 21


100%|██████████| 625/625 [00:03<00:00, 168.00it/s]
157it [00:00, 186.62it/s]                         



Train loss: 0.10227889029048383 | Val loss: 0.01842693757563934
Train acc: 0.964 | Val acc: 0.9944
Train f1: 0.9640215870477713 | Val f1: 0.9944378228049265
Epoch: 22


100%|██████████| 625/625 [00:04<00:00, 147.90it/s]
157it [00:00, 200.84it/s]                         



Train loss: 0.09559211812727153 | Val loss: 0.012344700883589282
Train acc: 0.9661 | Val acc: 0.9968
Train f1: 0.9661440127833815 | Val f1: 0.9968190854870775
Epoch: 23


100%|██████████| 625/625 [00:04<00:00, 130.25it/s]
157it [00:01, 137.02it/s]                         



Train loss: 0.10053175151385367 | Val loss: 0.019436386892732863
Train acc: 0.96505 | Val acc: 0.9954
Train f1: 0.9649922371913657 | Val f1: 0.9954301609378103
Epoch: 24


100%|██████████| 625/625 [00:06<00:00, 93.05it/s] 
157it [00:00, 175.18it/s]                         



Train loss: 0.09550555308014154 | Val loss: 0.015518063343197589
Train acc: 0.9658 | Val acc: 0.9952
Train f1: 0.9658102569229232 | Val f1: 0.9952286282306163
Epoch: 25


100%|██████████| 625/625 [00:05<00:00, 123.43it/s]
157it [00:00, 198.29it/s]                         



Train loss: 0.08639993268474937 | Val loss: 0.013882212770159897
Train acc: 0.969 | Val acc: 0.9966
Train f1: 0.9689875950380153 | Val f1: 0.9966209501093222
Epoch: 26


100%|██████████| 625/625 [00:04<00:00, 142.71it/s]
157it [00:00, 219.79it/s]                         



Train loss: 0.08614280918100849 | Val loss: 0.01865940991042291
Train acc: 0.9696 | Val acc: 0.9946
Train f1: 0.9695543314972459 | Val f1: 0.9946354063182993
Epoch: 27


100%|██████████| 625/625 [00:05<00:00, 122.62it/s]
157it [00:01, 142.12it/s]                         



Train loss: 0.08764239957537502 | Val loss: 0.015412338793211324
Train acc: 0.96915 | Val acc: 0.9968
Train f1: 0.9692499377024669 | Val f1: 0.9968228752978554
Epoch: 28


100%|██████████| 625/625 [00:04<00:00, 135.97it/s]
157it [00:01, 112.05it/s]                         



Train loss: 0.08465413036113605 | Val loss: 0.01635553674644213
Train acc: 0.9696 | Val acc: 0.9966
Train f1: 0.9695969596959696 | Val f1: 0.9966222928670774
Epoch: 29


100%|██████████| 625/625 [00:02<00:00, 218.76it/s]
157it [00:00, 226.46it/s]                         



Train loss: 0.0843944850133732 | Val loss: 0.016169613807039403
Train acc: 0.9692 | Val acc: 0.996
Train f1: 0.9692184689186488 | Val f1: 0.9960238568588469
Epoch: 30


100%|██████████| 625/625 [00:03<00:00, 157.72it/s]
157it [00:00, 236.37it/s]                         



Train loss: 0.08383518650257028 | Val loss: 0.017412482612342454
Train acc: 0.97115 | Val acc: 0.9956
Train f1: 0.9711946482951425 | Val f1: 0.9956245027844074
Epoch: 31


100%|██████████| 625/625 [00:03<00:00, 191.32it/s]
157it [00:00, 210.59it/s]                         



Train loss: 0.07672271969066932 | Val loss: 0.016117104479293723
Train acc: 0.9714 | Val acc: 0.9948
Train f1: 0.9714627818798643 | Val f1: 0.994831013916501
Epoch: 32


100%|██████████| 625/625 [00:02<00:00, 226.95it/s]
157it [00:00, 190.89it/s]                         



Train loss: 0.07949097560010851 | Val loss: 0.014100697010853298
Train acc: 0.97115 | Val acc: 0.9976
Train f1: 0.9712147667747569 | Val f1: 0.9976143141153081
Epoch: 33


100%|██████████| 625/625 [00:02<00:00, 221.39it/s]
157it [00:00, 238.46it/s]                         



Train loss: 0.07336161033250392 | Val loss: 0.009892026024877291
Train acc: 0.9725 | Val acc: 0.998
Train f1: 0.9725 | Val f1: 0.9980142970611596
Epoch: 34


100%|██████████| 625/625 [00:02<00:00, 210.28it/s]
157it [00:00, 172.00it/s]                         



Train loss: 0.08022094918088987 | Val loss: 0.010980437859521212
Train acc: 0.97115 | Val acc: 0.998
Train f1: 0.9711312353029469 | Val f1: 0.9980119284294234
Epoch: 35


100%|██████████| 625/625 [00:02<00:00, 222.26it/s]
157it [00:00, 237.72it/s]                         



Train loss: 0.07320630355572794 | Val loss: 0.009288841460100882
Train acc: 0.9743 | Val acc: 0.9984
Train f1: 0.974310275889644 | Val f1: 0.9984095427435388
Epoch: 36


100%|██████████| 625/625 [00:02<00:00, 231.99it/s]
157it [00:00, 238.79it/s]                         



Train loss: 0.07408792081680149 | Val loss: 0.009541675976460958
Train acc: 0.97345 | Val acc: 0.9984
Train f1: 0.9734380471212044 | Val f1: 0.9984089101034209
Epoch: 37


100%|██████████| 625/625 [00:03<00:00, 195.35it/s]
157it [00:00, 189.47it/s]



Train loss: 0.07609309555776417 | Val loss: 0.011766991543295146
Train acc: 0.97455 | Val acc: 0.9974
Train f1: 0.9745944596955328 | Val f1: 0.9974139645912075
Epoch: 38


100%|██████████| 625/625 [00:03<00:00, 198.01it/s]
157it [00:00, 240.24it/s]                         



Train loss: 0.06929890594524331 | Val loss: 0.009933770963680918
Train acc: 0.9768 | Val acc: 0.9978
Train f1: 0.9767767767767768 | Val f1: 0.9978135559530908
Epoch: 39


100%|██████████| 625/625 [00:02<00:00, 232.29it/s]
157it [00:00, 234.60it/s]                         



Train loss: 0.06973353418051265 | Val loss: 0.010148341035323355
Train acc: 0.9747 | Val acc: 0.9986
Train f1: 0.9746848108865319 | Val f1: 0.9986086265156032
Epoch: 40


100%|██████████| 625/625 [00:02<00:00, 227.78it/s]
157it [00:00, 243.09it/s]                         



Train loss: 0.0671683953662403 | Val loss: 0.007835414672090571
Train acc: 0.9767 | Val acc: 0.9986
Train f1: 0.9766953390678136 | Val f1: 0.9986091794158554
Epoch: 41


100%|██████████| 625/625 [00:03<00:00, 186.43it/s]
157it [00:00, 236.47it/s]                         



Train loss: 0.07134580684474204 | Val loss: 0.012312432253822454
Train acc: 0.9739 | Val acc: 0.9976
Train f1: 0.9739052189562087 | Val f1: 0.9976143141153081
Epoch: 42


100%|██████████| 625/625 [00:02<00:00, 230.28it/s]
157it [00:00, 242.18it/s]                         



Train loss: 0.06902823394914158 | Val loss: 0.010113990743134193
Train acc: 0.97565 | Val acc: 0.9982
Train f1: 0.9756731105449823 | Val f1: 0.9982125124131083
Epoch: 43


100%|██████████| 625/625 [00:02<00:00, 230.42it/s]
157it [00:00, 237.03it/s]                         



Train loss: 0.0707148328538984 | Val loss: 0.010445693702562354
Train acc: 0.97605 | Val acc: 0.9982
Train f1: 0.9761132997556474 | Val f1: 0.9982125124131083
Epoch: 44


100%|██████████| 625/625 [00:03<00:00, 206.10it/s]
157it [00:00, 176.34it/s]                         



Train loss: 0.06817574074543081 | Val loss: 0.009008239233830513
Train acc: 0.9753 | Val acc: 0.9982
Train f1: 0.9753 | Val f1: 0.9982118021060997
Epoch: 45


100%|██████████| 625/625 [00:02<00:00, 226.32it/s]
157it [00:00, 236.87it/s]                         



Train loss: 0.06417671635304578 | Val loss: 0.01092439209549831
Train acc: 0.9773 | Val acc: 0.9972
Train f1: 0.977293187956387 | Val f1: 0.9972189114024632
Epoch: 46


100%|██████████| 625/625 [00:03<00:00, 174.22it/s]
157it [00:01, 118.30it/s]                         



Train loss: 0.06539606132116169 | Val loss: 0.009963327975548544
Train acc: 0.97585 | Val acc: 0.9976
Train f1: 0.9758632751986408 | Val f1: 0.9976152623211446
Epoch: 47


100%|██████████| 625/625 [00:06<00:00, 95.52it/s] 
157it [00:01, 126.46it/s]                         



Train loss: 0.06452466400489211 | Val loss: 0.006125287502725722
Train acc: 0.97665 | Val acc: 0.9992
Train f1: 0.9766791510611735 | Val f1: 0.9992047713717693
Epoch: 48


100%|██████████| 625/625 [00:04<00:00, 125.88it/s]
157it [00:00, 182.63it/s]



Train loss: 0.06323236619406379 | Val loss: 0.009533259160656225
Train acc: 0.9772 | Val acc: 0.9982
Train f1: 0.9772227772227773 | Val f1: 0.9982118021060997
Epoch: 49


100%|██████████| 625/625 [00:03<00:00, 185.82it/s]
157it [00:00, 235.86it/s]                         



Train loss: 0.06699056839924306 | Val loss: 0.009622795951005313
Train acc: 0.9772 | Val acc: 0.9986
Train f1: 0.9772227772227773 | Val f1: 0.9986080731755816
Epoch: 50


100%|██████████| 625/625 [00:02<00:00, 227.72it/s]
157it [00:00, 240.14it/s]                         


Train loss: 0.06226340431179851 | Val loss: 0.007051446901804854
Train acc: 0.97735 | Val acc: 0.9994
Train f1: 0.9774008480917935 | Val f1: 0.999403459932392





In [24]:
test_loss = 0.0
test_preds = []
test_labels = []

bert.eval()
for step, batch in enumerate(tqdm(dataloader.get_split('test'), total=dataloader.steps('test'))):
    logits = bert(batch['input_ids'].to(device))

    probs = F.softmax(logits[:, 0, :], dim=-1).cpu()
    pred = torch.argmax(probs, dim=-1) # (B)
    test_preds += pred.detach().tolist()
    test_labels += [l.item() for l in batch['label_ids']]

    loss = F.cross_entropy(logits[:, 0, :].cpu(), batch['label_ids'])

    test_loss += loss.item()

print()
print(f'Test loss: {test_loss / dataloader.steps("test")}')
print(f'Test acc: {accuracy_score(test_labels, test_preds)}')
print(f'Test f1: {f1_score(test_labels, test_preds)}')

782it [00:03, 225.12it/s]                         



Test loss: 0.0819332813621452
Test acc: 0.98692
Test f1: 0.987061290705496


# Add Space Model

In [96]:
class SpaceBertForClassification(torch.nn.Module):
    def __init__(self, base_model, n_embed=3, n_latent=3, n_concept_spaces=2, fine_tune=True):
        super().__init__()

        if fine_tune:
            for p in base_model.parameters():
                p.requires_grad_(False)
        self.bert = base_model
        self.space_model = SpaceModelForClassification(n_embed, n_latent, n_concept_spaces)

    def forward(self, x):
        embed = self.bert(x)

        out = self.space_model(embed)

        return out

In [97]:
space_bert = SpaceBertForClassification(bert.nano_bert).to(device)
space_bert

SpaceBertForClassification(
  (bert): NanoBERT(
    (embedding): BertEmbeddings(
      (word_embeddings): Embedding(101522, 3)
      (pos_embeddings): Embedding(128, 3)
      (layer_norm): LayerNorm((3,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layers): ModuleList(
        (0): BertLayer(
          (layer_norm1): LayerNorm((3,), eps=1e-05, elementwise_affine=True)
          (self_attention): BertSelfAttention(
            (heads): ModuleList(
              (0): BertAttentionHead(
                (query): Linear(in_features=3, out_features=3, bias=True)
                (key): Linear(in_features=3, out_features=3, bias=True)
                (values): Linear(in_features=3, out_features=3, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (proj): Linear(in_features=3, out_features=3, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
 

In [98]:
count_parameters(space_bert)

32

In [100]:
NUM_OPTIM_EPOCHS = 5
LEARNING_RATE = 1e-2

l1 = 1e-6
l2 = 1e-6

In [101]:
optimizer = torch.optim.Adam(space_bert.parameters(), lr=LEARNING_RATE)

for i in range(NUM_OPTIM_EPOCHS):
    print(f'Epoch: {i + 1}')
    train_loss = 0.0
    train_preds = []
    train_labels = []

    bert.train()
    for step, batch in enumerate(tqdm(dataloader.get_split('train'), total=dataloader.steps('train'))):
        out = space_bert(batch['input_ids'].to(device)) # (B, 2)

        logits = out.logits.cpu()
        concept_spaces = [c.cpu() for c in out.concept_spaces]

        probs = F.softmax(logits, dim=-1).cpu()
        pred = torch.argmax(probs, dim=-1) # (B)
        train_preds += pred.detach().tolist()
        train_labels += [l.item() for l in batch['label_ids']]

        # CE + l1 * inter_loss + l2 * intra_loss
        loss = F.cross_entropy(logits, batch['label_ids']) + l1 * inter_space_loss(concept_spaces) + l2 * intra_space_loss(concept_spaces)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    val_loss = 0.0
    val_preds = []
    val_labels = []

    bert.eval()
    for step, batch in enumerate(tqdm(dataloader.get_split('val'), total=dataloader.steps('val'))):
        out = space_bert(batch['input_ids'].to(device)) # (B, 2)

        logits = out.logits.cpu()
        concept_spaces = [c.cpu() for c in out.concept_spaces]

        probs = F.softmax(logits, dim=-1).cpu()
        pred = torch.argmax(probs, dim=-1) # (B)
        val_preds += pred.detach().tolist()
        val_labels += [l.item() for l in batch['label_ids']]

        # CE + l1 * inter_loss + l2 * intra_loss
        loss = F.cross_entropy(logits, batch['label_ids']) + l1 * inter_space_loss(concept_spaces) + l2 * intra_space_loss(concept_spaces)

        val_loss += loss.item()

    print()
    print(f'Train loss: {train_loss / dataloader.steps("train")} | Val loss: {val_loss / dataloader.steps("val")}')
    print(f'Train acc: {accuracy_score(train_labels, train_preds)} | Val acc: {accuracy_score(val_labels, val_preds)}')
    print(f'Train f1: {f1_score(train_labels, train_preds)} | Val f1: {f1_score(val_labels, val_preds)}')

Epoch: 1


100%|██████████| 625/625 [00:05<00:00, 116.00it/s]
157it [00:00, 231.95it/s]                         



Train loss: 0.6956174991607667 | Val loss: 5567.979883438502
Train acc: 0.5101 | Val acc: 0.5054
Train f1: 0.5113216957605985 | Val f1: 0.6703985072637612
Epoch: 2


100%|██████████| 625/625 [00:03<00:00, 198.64it/s]
157it [00:00, 334.76it/s]                         



Train loss: 0.6931235491752624 | Val loss: 5530.198789938902
Train acc: 0.5141 | Val acc: 0.5046
Train f1: 0.5145854145854145 | Val f1: 0.6700412947915279
Epoch: 3


100%|██████████| 625/625 [00:02<00:00, 214.48it/s]
157it [00:00, 416.13it/s]                         



Train loss: 0.6923209614753724 | Val loss: 4875.984616010617
Train acc: 0.51565 | Val acc: 0.5046
Train f1: 0.5162546816479401 | Val f1: 0.6700412947915279
Epoch: 4


100%|██████████| 625/625 [00:02<00:00, 244.16it/s]
157it [00:00, 440.08it/s]                         



Train loss: 0.6919279090881347 | Val loss: 5133.811683067908
Train acc: 0.5174 | Val acc: 0.504
Train f1: 0.5151697809925658 | Val f1: 0.669773635153129
Epoch: 5


100%|██████████| 625/625 [00:02<00:00, 244.11it/s]
157it [00:00, 441.64it/s]                         


Train loss: 0.6911894980430603 | Val loss: 5574.917411608573
Train acc: 0.5207 | Val acc: 0.504
Train f1: 0.5222288676236044 | Val f1: 0.669773635153129





# Interpreting and visualizing the results

In [103]:
test_dataloader = IMDBDataloader(data, test_data, tokenizer, encode_label, batch_size=1)

Train Tokenization: 100%|██████████| 20000/20000 [00:01<00:00, 15903.95it/s]
Test Tokenization: 100%|██████████| 25000/25000 [00:02<00:00, 12417.93it/s]
Val Tokenization: 100%|██████████| 5000/5000 [00:00<00:00, 15466.56it/s]
Train Label Encoding: 100%|██████████| 20000/20000 [00:00<00:00, 1375452.22it/s]
Test Label Encoding: 100%|██████████| 25000/25000 [00:00<00:00, 1526022.73it/s]
Val Label Encoding: 100%|██████████| 5000/5000 [00:00<00:00, 749947.07it/s]


In [104]:
# examples with less than 16 words are easier to visualize, so focus on them
examples_ids = []
for i, v in enumerate(test_dataloader.splits['test']):
    if len(v) <= 16:
        examples_ids.append(i)
print(examples_ids)

[1959, 2939, 6394, 15789, 16349, 21487, 22019, 24588]


## Bert Embeddings

In [105]:
scatters = []
for sample_index in examples_ids:
    # extract example, decode to tokens and get the sequence length (ingoring padding)
    test_tokenized_batch = test_dataloader.peek_index_tokenized(index=sample_index, split='test')
    tokens = tokenizer.decode([t.item() for t in test_tokenized_batch['input_ids'][0] if t != 0], ignore_special=False).split(' ')[:MAX_SEQ_LEN]
    seq_len = len(tokens)

    embed = bert.nano_bert.embedding(test_tokenized_batch['input_ids'].to(device))

    x, y, z = embed[0, :seq_len, 0].detach().cpu().numpy(), embed[0, :seq_len, 1].detach().cpu().numpy(), embed[0, :seq_len, 2].detach().cpu().numpy()

    scatters.append(go.Scatter3d(
        x=x, y=y, z=z, mode='markers+text', name=f'Example: {sample_index}',
        text=tokens,
    ))

In [106]:
fig = go.Figure(
    data=scatters,
    layout=go.Layout(
        title=go.layout.Title(text='Embeddings')
    ))
fig.show()

In [107]:
scatters = []
for sample_index in examples_ids:
    # extract example, decode to tokens and get the sequence length (ingoring padding)
    test_tokenized_batch = test_dataloader.peek_index_tokenized(index=sample_index, split='test')
    tokens = tokenizer.decode([t.item() for t in test_tokenized_batch['input_ids'][0] if t != 0], ignore_special=False).split(' ')[:MAX_SEQ_LEN]
    seq_len = len(tokens)

    embed = bert.nano_bert(test_tokenized_batch['input_ids'].to(device))

    x, y, z = embed[0, :seq_len, 0].detach().cpu().numpy(), embed[0, :seq_len, 1].detach().cpu().numpy(), embed[0, :seq_len, 2].detach().cpu().numpy()

    scatters.append(go.Scatter3d(
        x=x, y=y, z=z, mode='markers+text', name=f'Example: {sample_index}',
        text=tokens,
    ))

In [108]:
fig = go.Figure(
    data=scatters,
    layout=go.Layout(
        title=go.layout.Title(text='Raw Embeddings')
    ))
fig.show()

## Space Embeddings

In [122]:
scatters = []
colors = ['blue', 'red']
for sample_index in examples_ids:
    # extract example, decode to tokens and get the sequence length (ingoring padding)
    test_tokenized_batch = test_dataloader.peek_index_tokenized(index=sample_index, split='test')
    tokens = tokenizer.decode([t.item() for t in test_tokenized_batch['input_ids'][0] if t != 0], ignore_special=False).split(' ')[:MAX_SEQ_LEN]
    seq_len = len(tokens)

    bert_embed = space_bert.bert(test_tokenized_batch['input_ids'].to(device))

    concept_spaces = space_bert.space_model.space_model(bert_embed).concept_spaces

    for c, embed in enumerate(concept_spaces):
        x, y, z = embed[0, :seq_len, 0].detach().cpu().numpy(), embed[0, :seq_len, 1].detach().cpu().numpy(), embed[0, :seq_len, 2].detach().cpu().numpy()

        scatters.append(go.Scatter3d(
            x=x, y=y, z=z, mode='markers+text',
            name=f'Example: {sample_index} ({c})',
            text=tokens,
            marker=dict(color=colors[c]),
            # hovertext=[]
        ))

In [123]:
fig = go.Figure(
    data=scatters,
    layout=go.Layout(
        title=go.layout.Title(text='Space Embeddings')
    ))
fig.show()