In [1]:
import datasets
import torch

from transformers import BertTokenizer, BertModel

In [2]:
DataFilePath = "../personality_dataset"
tokenizer = BertTokenizer.from_pretrained('../bert-base-uncase')

In [3]:
labels = ["ENTJ", "ENTP", "ENFJ", "ENFP", "ESFJ", "ESFP", "ESTJ", "ESTP", "INTP", "INTJ", "INFP", "INFJ", "ISFP", "ISFJ", "ISTP", "ISTJ"]

In [4]:
import torch.utils
import numpy as np


class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, number):
        self.trian = datasets.load_from_disk(DataFilePath+"/"+data)
        self.labels = [labels.index(label) for label in self.trian["personality"][0 : number if number < len(self.trian["personality"]) else -1]]
        self.texts = [tokenizer(person.replace("|||", "[SEP]"),
                                padding="max_length",
                                max_length=512,
                                truncation=True,
                                return_tensors="pt")
                      for person in self.trian["content"][0 : number if number < len(self.trian["content"]) else -1]]
    
    def classes(self):
        return self.labels
    
    def __len__(self):
        return len(self.labels)
    
    def get_batch_labels(self, idx):
        return np.array(self.labels[idx])
    
    def get_batch_texts(self, idx):
        return self.texts[idx]
    
    def __getitem__(self, idx):
        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)
        return batch_texts, batch_y

In [5]:
from torch import nn

class BertClassifier(nn.Module):
    def __init__(self, dropout=0.5):
        super(BertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('../bert-base-uncase')
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 16)
        self.relu = nn.ReLU()

    def forward(self, input_id, mask):
        _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        final_layer = self.relu(linear_output)
        return final_layer

In [6]:
from torch.optim import Adam
from tqdm import tqdm

def train(model, train_data, val_data, learning_rate, epochs, number):
    
    train, val = Dataset(train_data, number), Dataset(val_data, -1)
    
    train_dataloader = torch.utils.data.DataLoader(train, batch_size=2, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val, batch_size=2)
    
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=learning_rate)

    if use_cuda:
            model = model.cuda()
            criterion = criterion.cuda()
    
    for epoch_num in range(epochs):
        
        total_acc_train = 0
        total_loss_train = 0
        
        for train_input, train_label in tqdm(train_dataloader):

            train_label = train_label.to(device)
            mask = train_input['attention_mask'].to(device)
            input_id = train_input['input_ids'].squeeze(1).to(device)
            
            output = model(input_id, mask)
            
            batch_loss = criterion(output, train_label.to(torch.int64))
            total_loss_train += batch_loss.item()
            
            acc = (output.argmax(dim=1) == train_label).sum().item()
            total_acc_train += acc
            
            model.zero_grad()
            batch_loss.backward()
            optimizer.step()
            
            
            total_acc_val = 0
            total_loss_val = 0
            
            with torch.no_grad():
                
                for val_input, val_label in val_dataloader:

                    val_label = val_label.to(device)
                    mask = val_input['attention_mask'].to(device)
                    input_id = val_input['input_ids'].squeeze(1).to(device)
                    output = model(input_id, mask)

                    batch_loss = criterion(output, val_label.to(torch.int64))
                    total_loss_val += batch_loss.item()
                    
                    acc = (output.argmax(dim=1) == val_label).sum().item()
                    total_acc_val += acc
            
            print(
                f'''Epochs: {epoch_num + 1} 
                | Train Loss: {total_loss_train / len(train_data): .3f} 
                | Train Accuracy: {total_acc_train / len(train_data): .3f} 
                | Val Loss: {total_loss_val / len(val_data): .3f} 
                | Val Accuracy: {total_acc_val / len(val_data): .3f}''')   

In [7]:
EPOCHS = 1
model = BertClassifier()
LR = 1e-6
number = 100

train(model, "train", "valid",  LR, EPOCHS, number)

  return self.fget.__get__(instance, owner)()
  2%|▏         | 1/50 [03:17<2:41:12, 197.40s/it]

Epochs: 1 
                | Train Loss:  0.620 
                | Train Accuracy:  0.000 
                | Val Loss:  871.500 
                | Val Accuracy:  33.400


  4%|▍         | 2/50 [06:26<2:34:08, 192.69s/it]

Epochs: 1 
                | Train Loss:  1.207 
                | Train Accuracy:  0.000 
                | Val Loss:  868.304 
                | Val Accuracy:  41.000


  6%|▌         | 3/50 [09:35<2:29:34, 190.94s/it]

Epochs: 1 
                | Train Loss:  1.777 
                | Train Accuracy:  0.000 
                | Val Loss:  868.880 
                | Val Accuracy:  35.200


  8%|▊         | 4/50 [12:44<2:25:48, 190.19s/it]

Epochs: 1 
                | Train Loss:  2.327 
                | Train Accuracy:  0.000 
                | Val Loss:  867.814 
                | Val Accuracy:  39.800


 10%|█         | 5/50 [15:53<2:22:14, 189.65s/it]

Epochs: 1 
                | Train Loss:  2.886 
                | Train Accuracy:  0.000 
                | Val Loss:  864.964 
                | Val Accuracy:  38.200


 12%|█▏        | 6/50 [19:02<2:18:49, 189.30s/it]

Epochs: 1 
                | Train Loss:  3.437 
                | Train Accuracy:  0.000 
                | Val Loss:  863.095 
                | Val Accuracy:  39.400


 14%|█▍        | 7/50 [22:10<2:15:29, 189.05s/it]

Epochs: 1 
                | Train Loss:  4.030 
                | Train Accuracy:  0.000 
                | Val Loss:  865.272 
                | Val Accuracy:  39.200


 16%|█▌        | 8/50 [25:18<2:12:11, 188.83s/it]

Epochs: 1 
                | Train Loss:  4.614 
                | Train Accuracy:  0.000 
                | Val Loss:  864.490 
                | Val Accuracy:  43.600


 18%|█▊        | 9/50 [28:27<2:08:57, 188.73s/it]

Epochs: 1 
                | Train Loss:  5.164 
                | Train Accuracy:  0.000 
                | Val Loss:  860.877 
                | Val Accuracy:  46.000


 20%|██        | 10/50 [31:35<2:05:45, 188.63s/it]

Epochs: 1 
                | Train Loss:  5.727 
                | Train Accuracy:  0.000 
                | Val Loss:  860.607 
                | Val Accuracy:  50.000


 22%|██▏       | 11/50 [34:44<2:02:35, 188.59s/it]

Epochs: 1 
                | Train Loss:  6.328 
                | Train Accuracy:  0.000 
                | Val Loss:  862.641 
                | Val Accuracy:  42.800


 24%|██▍       | 12/50 [38:01<2:01:00, 191.07s/it]

Epochs: 1 
                | Train Loss:  6.926 
                | Train Accuracy:  0.000 
                | Val Loss:  859.109 
                | Val Accuracy:  47.200


 26%|██▌       | 13/50 [41:16<1:58:41, 192.47s/it]

Epochs: 1 
                | Train Loss:  7.533 
                | Train Accuracy:  0.000 
                | Val Loss:  860.828 
                | Val Accuracy:  45.800


 28%|██▊       | 14/50 [44:32<1:56:01, 193.38s/it]

Epochs: 1 
                | Train Loss:  8.087 
                | Train Accuracy:  0.000 
                | Val Loss:  858.923 
                | Val Accuracy:  47.600


 30%|███       | 15/50 [47:47<1:53:11, 194.04s/it]

Epochs: 1 
                | Train Loss:  8.646 
                | Train Accuracy:  0.000 
                | Val Loss:  856.978 
                | Val Accuracy:  48.600


 32%|███▏      | 16/50 [51:03<1:50:12, 194.49s/it]

Epochs: 1 
                | Train Loss:  9.186 
                | Train Accuracy:  0.000 
                | Val Loss:  857.506 
                | Val Accuracy:  51.600


 34%|███▍      | 17/50 [54:19<1:47:09, 194.85s/it]

Epochs: 1 
                | Train Loss:  9.716 
                | Train Accuracy:  0.000 
                | Val Loss:  856.374 
                | Val Accuracy:  51.000


 36%|███▌      | 18/50 [57:34<1:44:03, 195.10s/it]

Epochs: 1 
                | Train Loss:  10.236 
                | Train Accuracy:  0.200 
                | Val Loss:  853.883 
                | Val Accuracy:  53.600


 38%|███▊      | 19/50 [1:00:50<1:40:53, 195.27s/it]

Epochs: 1 
                | Train Loss:  10.775 
                | Train Accuracy:  0.200 
                | Val Loss:  854.618 
                | Val Accuracy:  54.600


 40%|████      | 20/50 [1:04:06<1:37:48, 195.60s/it]

Epochs: 1 
                | Train Loss:  11.343 
                | Train Accuracy:  0.200 
                | Val Loss:  854.102 
                | Val Accuracy:  55.600


 42%|████▏     | 21/50 [1:07:23<1:34:41, 195.93s/it]

Epochs: 1 
                | Train Loss:  11.842 
                | Train Accuracy:  0.400 
                | Val Loss:  853.416 
                | Val Accuracy:  56.600


 44%|████▍     | 22/50 [1:10:39<1:31:31, 196.11s/it]

Epochs: 1 
                | Train Loss:  12.409 
                | Train Accuracy:  0.400 
                | Val Loss:  854.032 
                | Val Accuracy:  53.600


 46%|████▌     | 23/50 [1:13:56<1:28:16, 196.17s/it]

Epochs: 1 
                | Train Loss:  12.959 
                | Train Accuracy:  0.400 
                | Val Loss:  851.736 
                | Val Accuracy:  59.600


 48%|████▊     | 24/50 [1:17:16<1:25:29, 197.27s/it]

Epochs: 1 
                | Train Loss:  13.484 
                | Train Accuracy:  0.600 
                | Val Loss:  853.444 
                | Val Accuracy:  53.400


 50%|█████     | 25/50 [1:20:27<1:21:25, 195.43s/it]

Epochs: 1 
                | Train Loss:  14.018 
                | Train Accuracy:  0.600 
                | Val Loss:  853.364 
                | Val Accuracy:  54.800


 52%|█████▏    | 26/50 [1:23:41<1:18:03, 195.13s/it]

Epochs: 1 
                | Train Loss:  14.573 
                | Train Accuracy:  0.800 
                | Val Loss:  851.556 
                | Val Accuracy:  61.200


 54%|█████▍    | 27/50 [1:26:57<1:14:52, 195.31s/it]

Epochs: 1 
                | Train Loss:  15.143 
                | Train Accuracy:  0.800 
                | Val Loss:  851.399 
                | Val Accuracy:  59.400


 56%|█████▌    | 28/50 [1:30:14<1:11:46, 195.74s/it]

Epochs: 1 
                | Train Loss:  15.697 
                | Train Accuracy:  0.800 
                | Val Loss:  852.106 
                | Val Accuracy:  56.200


 58%|█████▊    | 29/50 [1:33:30<1:08:32, 195.85s/it]

Epochs: 1 
                | Train Loss:  16.300 
                | Train Accuracy:  0.800 
                | Val Loss:  853.504 
                | Val Accuracy:  55.600


 60%|██████    | 30/50 [1:36:47<1:05:23, 196.19s/it]

Epochs: 1 
                | Train Loss:  16.856 
                | Train Accuracy:  0.800 
                | Val Loss:  851.396 
                | Val Accuracy:  63.400


 62%|██████▏   | 31/50 [1:40:03<1:02:10, 196.35s/it]

Epochs: 1 
                | Train Loss:  17.318 
                | Train Accuracy:  1.000 
                | Val Loss:  849.939 
                | Val Accuracy:  60.000


 64%|██████▍   | 32/50 [1:43:20<58:57, 196.50s/it]  

Epochs: 1 
                | Train Loss:  17.826 
                | Train Accuracy:  1.000 
                | Val Loss:  851.786 
                | Val Accuracy:  56.000


 66%|██████▌   | 33/50 [1:46:38<55:45, 196.76s/it]

Epochs: 1 
                | Train Loss:  18.407 
                | Train Accuracy:  1.000 
                | Val Loss:  851.515 
                | Val Accuracy:  59.400


 68%|██████▊   | 34/50 [1:49:54<52:28, 196.76s/it]

Epochs: 1 
                | Train Loss:  18.972 
                | Train Accuracy:  1.000 
                | Val Loss:  850.828 
                | Val Accuracy:  63.400


 70%|███████   | 35/50 [1:53:11<49:09, 196.63s/it]

Epochs: 1 
                | Train Loss:  19.517 
                | Train Accuracy:  1.000 
                | Val Loss:  850.000 
                | Val Accuracy:  60.200


 72%|███████▏  | 36/50 [1:56:27<45:50, 196.49s/it]

Epochs: 1 
                | Train Loss:  20.058 
                | Train Accuracy:  1.000 
                | Val Loss:  848.502 
                | Val Accuracy:  69.000


 74%|███████▍  | 37/50 [1:59:43<42:32, 196.37s/it]

Epochs: 1 
                | Train Loss:  20.599 
                | Train Accuracy:  1.000 
                | Val Loss:  851.466 
                | Val Accuracy:  59.400


 76%|███████▌  | 38/50 [2:02:57<39:09, 195.79s/it]

Epochs: 1 
                | Train Loss:  21.126 
                | Train Accuracy:  1.000 
                | Val Loss:  849.603 
                | Val Accuracy:  64.400


 78%|███████▊  | 39/50 [2:06:15<35:58, 196.19s/it]

Epochs: 1 
                | Train Loss:  21.649 
                | Train Accuracy:  1.200 
                | Val Loss:  850.787 
                | Val Accuracy:  58.200


 80%|████████  | 40/50 [2:09:27<32:30, 195.03s/it]

Epochs: 1 
                | Train Loss:  22.245 
                | Train Accuracy:  1.200 
                | Val Loss:  850.530 
                | Val Accuracy:  54.400


 82%|████████▏ | 41/50 [2:12:47<29:28, 196.52s/it]

Epochs: 1 
                | Train Loss:  22.774 
                | Train Accuracy:  1.200 
                | Val Loss:  849.043 
                | Val Accuracy:  63.800


 84%|████████▍ | 42/50 [2:16:06<26:18, 197.37s/it]

Epochs: 1 
                | Train Loss:  23.290 
                | Train Accuracy:  1.400 
                | Val Loss:  848.938 
                | Val Accuracy:  67.000


 86%|████████▌ | 43/50 [2:19:28<23:11, 198.82s/it]

Epochs: 1 
                | Train Loss:  23.786 
                | Train Accuracy:  1.400 
                | Val Loss:  848.880 
                | Val Accuracy:  63.000


 88%|████████▊ | 44/50 [2:22:47<19:53, 198.84s/it]

Epochs: 1 
                | Train Loss:  24.377 
                | Train Accuracy:  1.400 
                | Val Loss:  847.114 
                | Val Accuracy:  66.400


 90%|█████████ | 45/50 [2:26:07<16:35, 199.04s/it]

Epochs: 1 
                | Train Loss:  24.893 
                | Train Accuracy:  1.400 
                | Val Loss:  848.171 
                | Val Accuracy:  67.400


 92%|█████████▏| 46/50 [2:29:23<13:12, 198.20s/it]

Epochs: 1 
                | Train Loss:  25.460 
                | Train Accuracy:  1.400 
                | Val Loss:  850.390 
                | Val Accuracy:  61.600


 94%|█████████▍| 47/50 [2:32:38<09:51, 197.31s/it]

Epochs: 1 
                | Train Loss:  25.987 
                | Train Accuracy:  1.400 
                | Val Loss:  849.718 
                | Val Accuracy:  62.800


 96%|█████████▌| 48/50 [2:35:56<06:34, 197.39s/it]

Epochs: 1 
                | Train Loss:  26.565 
                | Train Accuracy:  1.400 
                | Val Loss:  847.783 
                | Val Accuracy:  66.200


 98%|█████████▊| 49/50 [2:39:16<03:18, 198.05s/it]

Epochs: 1 
                | Train Loss:  27.139 
                | Train Accuracy:  1.400 
                | Val Loss:  848.910 
                | Val Accuracy:  69.200


100%|██████████| 50/50 [2:42:35<00:00, 195.12s/it]

Epochs: 1 
                | Train Loss:  27.724 
                | Train Accuracy:  1.400 
                | Val Loss:  849.141 
                | Val Accuracy:  64.600



