## Data pre-processing

In [None]:
import tokenizers
import torch
from tqdm.auto import tqdm
from datasets import load_dataset
from torch.utils.data import DataLoader

In [2]:
dataset = load_dataset("dair-ai/emotion")

In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 16000
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 2000
    })
})

In [4]:
dataset['train'][:5]

{'text': ['i didnt feel humiliated',
  'i can go from feeling so hopeless to so damned hopeful just from being around someone who cares and is awake',
  'im grabbing a minute to post i feel greedy wrong',
  'i am ever feeling nostalgic about the fireplace i will know that it is still on the property',
  'i am feeling grouchy'],
 'label': [0, 0, 3, 2, 3]}

In [5]:
import tiktoken

In [6]:
token_encoder = tiktoken.get_encoding("gpt2")

In [7]:
token_encoder.n_vocab

50257

In [8]:
def encode_text(x):
    text = str(x['text'])
    num_classes = 5
    max_seq_len = 36
    
    output = token_encoder.encode(text)
    output = output if len(output) <= max_seq_len else output[:max_seq_len]
    
    padding_length = max_seq_len - len(output)
    if padding_length > 0:
        output += [0] * padding_length
    
    label = [0 for _ in range(num_classes)]
    label[x['label']-1] = 1
    
    result = {
        'text': text,
        'encoded_text': output,
        'label': label
    }
    return result


In [9]:
tokenized_dataset_train = dataset['train'].map(encode_text)
tokenized_dataset_test = dataset['test'].map(encode_text)
tokenized_dataset_validation = dataset['validation'].map(encode_text)

In [10]:
len(tokenized_dataset_train[8]['encoded_text'])

36

In [11]:
train_dataloader = DataLoader(tokenized_dataset_train, batch_size=128, shuffle=True)
test_dataloader = DataLoader(tokenized_dataset_test, shuffle=True)
val_dataloader = DataLoader(tokenized_dataset_validation, batch_size=128, shuffle=True)

## Model Building

In [15]:
import sys
sys.path.append('..')

In [16]:
from model.transformers import EncoderClassifier
import torch

In [17]:
num_classes = 5
max_seq_len = 36

In [18]:
config = {
    "num_layers": 4,
    "vocab_size": token_encoder.n_vocab,
    "embed_dims": 768,
    "max_seq_len": max_seq_len,
    "n_segments": 5,
    "heads": 8,
    "dropout": 0.3,
    "device": "cpu",
    "ff_layer_sizes": [768, 256, 768],
    "batch_size": 128,
    "num_classes": 5
}

In [19]:
model = EncoderClassifier(config)

In [20]:
from tqdm.autonotebook import tqdm
from torch.optim import Adam

optim = Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

In [21]:
epochs = 100
step = 0
device = config['device']

for epoch in range(epochs):
    loop = tqdm(train_dataloader, leave=True)
    
    model.train()
    
    total_loss = 0.0
    correct_predictions = 0
    for batch in loop:
        optim.zero_grad()

        inputs = torch.stack(batch['encoded_text']).int()
        labels = torch.stack(batch['label']).float()
        labels = torch.Tensor(labels).to(device)
        labels = labels.transpose(0, 1)
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optim.step()

        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())
        step += 1
        
        total_loss += loss.item()
        _, labels = torch.max(labels, 1)
        _, predicted = torch.max(outputs, 1)
        correct_predictions += (predicted == labels).sum().item()

    average_loss = total_loss / len(train_dataloader)
    accuracy = correct_predictions / 16000
    
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {average_loss:.4f}, Accuracy: {accuracy:.4f}")
    torch.save(model.state_dict(), f'models/model_epoch_{epoch + 1}.pth')

  0%|          | 0/125 [00:00<?, ?it/s]

Epoch 0: 100%|██████████| 125/125 [08:34<00:00,  4.12s/it, loss=1.52]


Epoch 1/100, Loss: 1.5163, Accuracy: 0.3371


Epoch 1: 100%|██████████| 125/125 [08:21<00:00,  4.01s/it, loss=1.51]


Epoch 2/100, Loss: 1.4928, Accuracy: 0.3862


Epoch 2: 100%|██████████| 125/125 [08:26<00:00,  4.05s/it, loss=1.43]


Epoch 3/100, Loss: 1.4269, Accuracy: 0.4736


Epoch 3: 100%|██████████| 125/125 [08:27<00:00,  4.06s/it, loss=1.4] 


Epoch 4/100, Loss: 1.3915, Accuracy: 0.5124


Epoch 4: 100%|██████████| 125/125 [08:45<00:00,  4.20s/it, loss=1.34]


Epoch 5/100, Loss: 1.3646, Accuracy: 0.5415


Epoch 5: 100%|██████████| 125/125 [08:20<00:00,  4.00s/it, loss=1.26]


Epoch 6/100, Loss: 1.3257, Accuracy: 0.5844


Epoch 6: 100%|██████████| 125/125 [08:12<00:00,  3.94s/it, loss=1.28]


Epoch 7/100, Loss: 1.2872, Accuracy: 0.6244


Epoch 7: 100%|██████████| 125/125 [08:12<00:00,  3.94s/it, loss=1.23]


Epoch 8/100, Loss: 1.2636, Accuracy: 0.6416


Epoch 8: 100%|██████████| 125/125 [08:33<00:00,  4.11s/it, loss=1.26]


Epoch 9/100, Loss: 1.2538, Accuracy: 0.6476


Epoch 9: 100%|██████████| 125/125 [08:19<00:00,  4.00s/it, loss=1.25]


Epoch 10/100, Loss: 1.2480, Accuracy: 0.6521


Epoch 10: 100%|██████████| 125/125 [08:22<00:00,  4.02s/it, loss=1.19]


Epoch 11/100, Loss: 1.2456, Accuracy: 0.6541


Epoch 11: 100%|██████████| 125/125 [08:03<00:00,  3.87s/it, loss=1.26]


Epoch 12/100, Loss: 1.2390, Accuracy: 0.6639


Epoch 12: 100%|██████████| 125/125 [08:08<00:00,  3.91s/it, loss=1.22]


Epoch 13/100, Loss: 1.2168, Accuracy: 0.7049


Epoch 13: 100%|██████████| 125/125 [08:02<00:00,  3.86s/it, loss=1.15]


Epoch 14/100, Loss: 1.1971, Accuracy: 0.7272


Epoch 14: 100%|██████████| 125/125 [08:18<00:00,  3.99s/it, loss=1.19]


Epoch 15/100, Loss: 1.1817, Accuracy: 0.7361


Epoch 15: 100%|██████████| 125/125 [08:09<00:00,  3.91s/it, loss=1.2] 


Epoch 16/100, Loss: 1.1715, Accuracy: 0.7419


Epoch 16: 100%|██████████| 125/125 [08:13<00:00,  3.95s/it, loss=1.16]


Epoch 17/100, Loss: 1.1676, Accuracy: 0.7423


Epoch 17: 100%|██████████| 125/125 [08:07<00:00,  3.90s/it, loss=1.13]


Epoch 18/100, Loss: 1.1588, Accuracy: 0.7478


Epoch 18: 100%|██████████| 125/125 [08:16<00:00,  3.97s/it, loss=1.09]


Epoch 19/100, Loss: 1.1578, Accuracy: 0.7466


Epoch 19: 100%|██████████| 125/125 [08:10<00:00,  3.93s/it, loss=1.23]


Epoch 20/100, Loss: 1.1513, Accuracy: 0.7521


Epoch 20: 100%|██████████| 125/125 [08:12<00:00,  3.94s/it, loss=1.08]


Epoch 21/100, Loss: 1.1454, Accuracy: 0.7570


Epoch 21: 100%|██████████| 125/125 [08:08<00:00,  3.91s/it, loss=1.12]


Epoch 22/100, Loss: 1.1382, Accuracy: 0.7681


Epoch 22: 100%|██████████| 125/125 [08:07<00:00,  3.90s/it, loss=1.13]


Epoch 23/100, Loss: 1.1200, Accuracy: 0.7985


Epoch 23: 100%|██████████| 125/125 [08:06<00:00,  3.89s/it, loss=1.1] 


Epoch 24/100, Loss: 1.1019, Accuracy: 0.8171


Epoch 24: 100%|██████████| 125/125 [08:06<00:00,  3.89s/it, loss=1.08]


Epoch 25/100, Loss: 1.0767, Accuracy: 0.8405


Epoch 25: 100%|██████████| 125/125 [08:05<00:00,  3.89s/it, loss=1.06]


Epoch 26/100, Loss: 1.0579, Accuracy: 0.8549


Epoch 26: 100%|██████████| 125/125 [07:57<00:00,  3.82s/it, loss=1.03] 


Epoch 27/100, Loss: 1.0471, Accuracy: 0.8634


Epoch 27: 100%|██████████| 125/125 [07:45<00:00,  3.72s/it, loss=1.03] 


Epoch 28/100, Loss: 1.0402, Accuracy: 0.8690


Epoch 28: 100%|██████████| 125/125 [07:48<00:00,  3.75s/it, loss=1.03] 


Epoch 29/100, Loss: 1.0347, Accuracy: 0.8734


Epoch 29: 100%|██████████| 125/125 [07:48<00:00,  3.75s/it, loss=1.03] 


Epoch 30/100, Loss: 1.0305, Accuracy: 0.8769


Epoch 30: 100%|██████████| 125/125 [07:43<00:00,  3.71s/it, loss=1.01] 


Epoch 31/100, Loss: 1.0277, Accuracy: 0.8791


Epoch 31: 100%|██████████| 125/125 [07:57<00:00,  3.82s/it, loss=0.989]


Epoch 32/100, Loss: 1.0254, Accuracy: 0.8814


Epoch 32: 100%|██████████| 125/125 [08:00<00:00,  3.84s/it, loss=0.977]


Epoch 33/100, Loss: 1.0239, Accuracy: 0.8831


Epoch 33: 100%|██████████| 125/125 [07:58<00:00,  3.83s/it, loss=1.04] 


Epoch 34/100, Loss: 1.0209, Accuracy: 0.8852


Epoch 34: 100%|██████████| 125/125 [08:12<00:00,  3.94s/it, loss=1.02] 


Epoch 35/100, Loss: 1.0181, Accuracy: 0.8878


Epoch 35: 100%|██████████| 125/125 [08:12<00:00,  3.94s/it, loss=1.04] 


Epoch 36/100, Loss: 1.0161, Accuracy: 0.8902


Epoch 36: 100%|██████████| 125/125 [08:13<00:00,  3.95s/it, loss=1.02] 


Epoch 37/100, Loss: 1.0152, Accuracy: 0.8904


Epoch 37: 100%|██████████| 125/125 [08:13<00:00,  3.95s/it, loss=1.01] 


Epoch 38/100, Loss: 1.0135, Accuracy: 0.8919


Epoch 38: 100%|██████████| 125/125 [08:12<00:00,  3.94s/it, loss=1.05] 


Epoch 39/100, Loss: 1.0030, Accuracy: 0.9098


Epoch 39: 100%|██████████| 125/125 [08:10<00:00,  3.93s/it, loss=0.972]


Epoch 40/100, Loss: 0.9974, Accuracy: 0.9188


Epoch 40: 100%|██████████| 125/125 [08:15<00:00,  3.97s/it, loss=0.987]


Epoch 41/100, Loss: 0.9876, Accuracy: 0.9270


Epoch 41: 100%|██████████| 125/125 [08:09<00:00,  3.92s/it, loss=0.968]


Epoch 42/100, Loss: 0.9780, Accuracy: 0.9368


Epoch 42: 100%|██████████| 125/125 [08:08<00:00,  3.91s/it, loss=0.952]


Epoch 43/100, Loss: 0.9710, Accuracy: 0.9416


Epoch 43: 100%|██████████| 125/125 [07:58<00:00,  3.83s/it, loss=0.986]


Epoch 44/100, Loss: 0.9685, Accuracy: 0.9441


Epoch 44: 100%|██████████| 125/125 [07:59<00:00,  3.84s/it, loss=0.953]


Epoch 45/100, Loss: 0.9640, Accuracy: 0.9475


Epoch 45: 100%|██████████| 125/125 [08:09<00:00,  3.91s/it, loss=0.957]


Epoch 46/100, Loss: 0.9596, Accuracy: 0.9514


Epoch 46: 100%|██████████| 125/125 [08:06<00:00,  3.89s/it, loss=0.952]


Epoch 47/100, Loss: 0.9546, Accuracy: 0.9556


Epoch 47: 100%|██████████| 125/125 [08:05<00:00,  3.88s/it, loss=0.959]


Epoch 48/100, Loss: 0.9531, Accuracy: 0.9573


Epoch 48: 100%|██████████| 125/125 [08:00<00:00,  3.84s/it, loss=0.978]


Epoch 49/100, Loss: 0.9504, Accuracy: 0.9591


Epoch 49: 100%|██████████| 125/125 [07:58<00:00,  3.83s/it, loss=0.948]


Epoch 50/100, Loss: 0.9488, Accuracy: 0.9606


Epoch 50: 100%|██████████| 125/125 [07:52<00:00,  3.78s/it, loss=0.932]


Epoch 51/100, Loss: 0.9495, Accuracy: 0.9597


Epoch 51: 100%|██████████| 125/125 [08:08<00:00,  3.91s/it, loss=0.978]


Epoch 52/100, Loss: 0.9479, Accuracy: 0.9616


Epoch 52: 100%|██████████| 125/125 [08:12<00:00,  3.94s/it, loss=0.952]


Epoch 53/100, Loss: 0.9460, Accuracy: 0.9624


Epoch 53: 100%|██████████| 125/125 [08:11<00:00,  3.93s/it, loss=0.936]


Epoch 54/100, Loss: 0.9438, Accuracy: 0.9643


Epoch 54: 100%|██████████| 125/125 [07:55<00:00,  3.80s/it, loss=0.924]


Epoch 55/100, Loss: 0.9440, Accuracy: 0.9644


Epoch 55: 100%|██████████| 125/125 [07:59<00:00,  3.84s/it, loss=0.949]


Epoch 56/100, Loss: 0.9421, Accuracy: 0.9655


Epoch 56: 100%|██████████| 125/125 [07:58<00:00,  3.83s/it, loss=0.954]


Epoch 57/100, Loss: 0.9424, Accuracy: 0.9657


Epoch 57: 100%|██████████| 125/125 [08:09<00:00,  3.92s/it, loss=0.949]


Epoch 58/100, Loss: 0.9424, Accuracy: 0.9651


Epoch 58: 100%|██████████| 125/125 [08:21<00:00,  4.01s/it, loss=0.932]


Epoch 59/100, Loss: 0.9431, Accuracy: 0.9646


Epoch 59: 100%|██████████| 125/125 [08:09<00:00,  3.92s/it, loss=0.947]


Epoch 60/100, Loss: 0.9440, Accuracy: 0.9637


Epoch 60: 100%|██████████| 125/125 [07:58<00:00,  3.83s/it, loss=0.939]


Epoch 61/100, Loss: 0.9413, Accuracy: 0.9666


Epoch 61: 100%|██████████| 125/125 [08:04<00:00,  3.87s/it, loss=0.942]


Epoch 62/100, Loss: 0.9403, Accuracy: 0.9669


Epoch 62:  51%|█████     | 64/125 [04:11<03:59,  3.93s/it, loss=0.941]


KeyboardInterrupt: 