### Complex Model

1. used the pretrianed Bert-base model to get the initial embeddings
2. use two linear layers to deal with the multilable task
    - one linear layer cope with the 8-class classification task that predict the current speaker given a line.
    - the other linear layer cope with the 9-class classification task that predict the next speaker given a line.
    - the total loss of the model is the sum of the cross-entropy loss of these two tasks.
3. Tried both balanced-training and unbalanced-training
4. the results are shown in the last two cells

In [None]:
import pandas as pd
import torch
import numpy as np
from transformers import BertTokenizer
from torch import nn
from transformers import BertModel
from torch.optim import Adam
from tqdm import tqdm
from sklearn.utils import shuffle
from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, ExponentialLR, ReduceLROnPlateau
from sklearn.metrics import accuracy_score, recall_score, classification_report

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
labels = {
    "Sheldon" : 0,
    "Penny" : 1,
    "Leonard" : 2,
    "Raj" : 3,
    "Howard" : 4,
    "Amy" : 5,
    "Bernadette" : 6,
    "Secondary" : 7,
    "End" : 8
}

In [198]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.cur_labels = [labels[label] for label in df['cur_speaker_label']]
        self.next_labels = [labels[label] for label in df['next_speaker_label']]
        self.lines = [tokenizer(line, padding = 'max_length', max_length = 512, truncation= True, return_tensors = "pt") for line in df["raw_line"]]
    
    
    def __len__(self):
        return len(self.cur_labels)

    def get_batch_labels(self, idx):
        return {"cur": np.array(self.cur_labels[idx]), "next" : np.array(self.next_labels[idx])}
    
    def get_batch_lines(self, idx):
        return self.lines[idx]
    
    def __getitem__(self, idx):
        batch_x = self.get_batch_lines(idx)
        batch_y = self.get_batch_labels(idx)
        return batch_x, batch_y


In [199]:
class BertClassifier(nn.Module):
    def __init__(self, dropout = 0.5):
        super(BertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.cur = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(768, 8),
            nn.ReLU()
        )
        self.next = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(768, 9),
            nn.ReLU()
        )

    def forward(self, input_ids, mask):
        _, cls_embedding = self.bert(input_ids = input_ids, attention_mask = mask, return_dict = False)
        return {
            'cur' : self.cur(cls_embedding),
            'next' : self.next(cls_embedding)
        }

In [200]:
def get_loss_acc(criterion_cur, criterion_next, output, label, device):
    label['cur'] = label['cur'].type(torch.LongTensor)
    label['next'] = label['next'].type(torch.LongTensor)
    label['cur'] = label['cur'].to(device)
    label['next'] = label['next'].to(device)
    cur_loss = criterion_cur(output['cur'], label['cur'])
    next_loss = criterion_next(output['next'], label['next'])
    loss = cur_loss + next_loss
    cur_acc = (output['cur'].argmax(dim=1) == label['cur']).sum().item()
    next_acc = (output['next'].argmax(dim=1) == label['next']).sum().item()
    return loss, {'cur' : cur_loss, 'next' : next_loss}, {'cur' : cur_acc, 'next': next_acc}

In [201]:
def run_model(model, train_data, val_data, learning_rate, epochs, check_point, batch_size):
    train, val = Dataset(train_data), Dataset(val_data)
    train_dataloader = torch.utils.data.DataLoader(train, batch_size = batch_size, shuffle = True)
    val_dataloader = torch.utils.data.DataLoader(val, batch_size = batch_size)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
  
    #The weight here is set by the value we get from the "Fasttext_unsupervised_training.ipynb"
    criterion_cur = nn.CrossEntropyLoss(weight = torch.tensor([0.48333895, 0.90484957, 0.71429165, 1.27128289, 1.06559713, 1.87805677, 2.52688014, 1.29423714]))
    criterion_next = nn.CrossEntropyLoss(weight = torch.tensor([0.5671942,  0.75431904, 0.59435462, 1.2725995,  1.03293404, 1.61986817, 2.12974311, 1.1452633, 1.94252484]))
    optimizer = Adam(model.parameters(), lr=learning_rate)
    scheduler = ExponentialLR(optimizer, gamma=0.2)

    if use_cuda:
        model = model.cuda()
        criterion_cur = criterion_cur.cuda()
        criterion_next = criterion_next.cuda()
    
    i = 0
    train_iteration = []
    val_iteration = []

    total_train_loss = []
    cur_train_loss = []
    next_train_loss = []
    cur_train_acc = []
    next_train_acc = []

    total_val_loss = []
    cur_val_loss = []
    next_val_loss = []
    cur_val_acc = []
    next_val_acc = []

    for epoch_num in range(epochs):
        total_loss_train = 0
        cur_loss_train = 0
        next_loss_train = 0
        cur_acc_train = 0
        next_acc_train = 0
        
        for train_input, train_label in tqdm(train_dataloader):
            mask = train_input['attention_mask'].to(device)
            input_id = train_input['input_ids'].squeeze(1).to(device)
            
            output = model(input_id, mask)
                    
            batch_loss, categorical_loss, categorical_acc = get_loss_acc(criterion_cur, criterion_next, output, train_label, device)
                
            total_loss_train += batch_loss.item()
            cur_loss_train += categorical_loss['cur'].item()
            next_loss_train += categorical_loss['next'].item()
            cur_acc_train += categorical_acc['cur']
            next_acc_train += categorical_acc['next']
            
            model.zero_grad()
            batch_loss.backward()
            optimizer.step()

            train_iteration.append(i)
            total_train_loss.append(batch_loss.item()/batch_size)
            cur_train_loss.append(categorical_loss['cur'].item()/batch_size)
            next_train_loss.append(categorical_loss['next'].item()/batch_size)
            cur_train_acc.append(categorical_acc['cur']/batch_size)
            next_train_acc.append(categorical_acc['next']/batch_size)

            i+=1

            if i%check_point==0:
                total_loss_val = 0
                cur_loss_val = 0
                next_loss_val = 0
                cur_acc_val = 0
                next_acc_val = 0
            
                with torch.no_grad():
                        
                    for val_input, val_label in val_dataloader:
                        mask = val_input['attention_mask'].to(device)
                        input_id = val_input['input_ids'].squeeze(1).to(device)
        
                        output = model(input_id, mask)

                        batch_loss, categorical_loss, categorical_acc = get_loss_acc(criterion_cur, criterion_next, output, val_label, device)
                        total_loss_val += batch_loss.item()
                        cur_loss_val += categorical_loss['cur'].item()
                        next_loss_val += categorical_loss['next'].item()
                        cur_acc_val += categorical_acc['cur']
                        next_acc_val += categorical_acc['next']

                print(
                    f'''Epochs: {epoch_num + 1} 
                    | Iterations: {i}
                    | Total Train Loss: {total_loss_train / (check_point*batch_size): .3f} 
                    | Cur Train Loss : {cur_loss_train / (check_point*batch_size): .3f} 
                    | Next Train Loss : {next_loss_train / (check_point*batch_size): .3f}
                    | Cur Train Accuracy : {cur_acc_train / (check_point*batch_size): .3f}
                    | Next Train Accuracy: {next_acc_train / (check_point*batch_size): .3f} 
                    | Total Val Loss: {total_loss_val / len(val_data): .3f} 
                    | Cur Val Loss : {cur_loss_val / len(val_data): .3f} 
                    | Next Val Loss : {next_loss_val / len(val_data): .3f}
                    | Cur Val Accuracy : {cur_acc_val / len(val_data): .3f}
                    | Next Val Accuracy: {next_acc_val / len(val_data): .3f}''')  
                    
                val_iteration.append(i)
                total_val_loss.append(total_loss_val / len(val_data))
                cur_val_loss.append(cur_loss_val / len(val_data))
                next_val_loss.append(next_loss_val / len(val_data))
                cur_val_acc.append(cur_acc_val / len(val_data))
                next_val_acc.append(next_acc_val / len(val_data))

                total_loss_train = 0
                cur_loss_train = 0
                next_loss_train = 0
                cur_acc_train = 0
                next_acc_train = 0
                

        scheduler.step()

    
    total_loss_val = 0
    cur_loss_val = 0
    next_loss_val = 0
    cur_acc_val = 0
    next_acc_val = 0
        
    with torch.no_grad():
                    
        for val_input, val_label in val_dataloader:
            mask = val_input['attention_mask'].to(device)
            input_id = val_input['input_ids'].squeeze(1).to(device)
    
            output = model(input_id, mask)

            batch_loss, categorical_loss, categorical_acc = get_loss_acc(criterion_cur, criterion_next, output, val_label, device)
            total_loss_val += batch_loss.item()
            cur_loss_val += categorical_loss['cur'].item()
            next_loss_val += categorical_loss['next'].item()
            cur_acc_val += categorical_acc['cur']
            next_acc_val += categorical_acc['next']
    print(
        f'''Total Val Loss: {total_loss_val / len(val_data): .3f} 
            | Cur Val Loss : {cur_loss_val / len(val_data): .3f} 
            | Next Val Loss : {next_loss_val / len(val_data): .3f}
            | Cur Val Accuracy : {cur_acc_val / len(val_data): .3f}
            | Next Val Accuracy: {next_acc_val / len(val_data): .3f}''')  

In [202]:
df_val = pd.read_csv("data/test_data.csv")
df_train = pd.read_csv("data/train_data.csv")

In [203]:
df = pd.read_csv("data/processed_lines.csv")

EPOCHS = 2
model = BertClassifier()
LR = 1e-5
run_model(model, df_train, df_val, LR, EPOCHS, 500, 5)
torch.save(model.state_dict(), "Models/Balanced_MultiLabel_Bert.pt")

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  7%|▋         | 500/7203 [02:26<17:44:05,  9.52s/it]

Epochs: 1 
                    | Iterations: 500
                    | Total Train Loss:  0.860 
                    | Cur Train Loss :  0.419 
                    | Next Train Loss :  0.442
                    | Cur Train Accuracy :  0.151
                    | Next Train Accuracy:  0.129 
                    | Total Val Loss:  0.858 
                    | Cur Val Loss :  0.418 
                    | Next Val Loss :  0.441
                    | Cur Val Accuracy :  0.204
                    | Next Val Accuracy:  0.164


 14%|█▍        | 1000/7203 [05:04<17:07:06,  9.93s/it]

Epochs: 1 
                    | Iterations: 1000
                    | Total Train Loss:  0.853 
                    | Cur Train Loss :  0.414 
                    | Next Train Loss :  0.440
                    | Cur Train Accuracy :  0.242
                    | Next Train Accuracy:  0.176 
                    | Total Val Loss:  0.852 
                    | Cur Val Loss :  0.411 
                    | Next Val Loss :  0.441
                    | Cur Val Accuracy :  0.251
                    | Next Val Accuracy:  0.156


 21%|██        | 1500/7203 [07:44<16:05:55, 10.16s/it]

Epochs: 1 
                    | Iterations: 1500
                    | Total Train Loss:  0.849 
                    | Cur Train Loss :  0.409 
                    | Next Train Loss :  0.440
                    | Cur Train Accuracy :  0.255
                    | Next Train Accuracy:  0.188 
                    | Total Val Loss:  0.851 
                    | Cur Val Loss :  0.411 
                    | Next Val Loss :  0.439
                    | Cur Val Accuracy :  0.253
                    | Next Val Accuracy:  0.202


 28%|██▊       | 2000/7203 [10:24<14:30:00, 10.03s/it]

Epochs: 1 
                    | Iterations: 2000
                    | Total Train Loss:  0.846 
                    | Cur Train Loss :  0.408 
                    | Next Train Loss :  0.437
                    | Cur Train Accuracy :  0.259
                    | Next Train Accuracy:  0.212 
                    | Total Val Loss:  0.847 
                    | Cur Val Loss :  0.408 
                    | Next Val Loss :  0.439
                    | Cur Val Accuracy :  0.259
                    | Next Val Accuracy:  0.201


 35%|███▍      | 2500/7203 [13:05<13:18:47, 10.19s/it]

Epochs: 1 
                    | Iterations: 2500
                    | Total Train Loss:  0.844 
                    | Cur Train Loss :  0.406 
                    | Next Train Loss :  0.437
                    | Cur Train Accuracy :  0.259
                    | Next Train Accuracy:  0.199 
                    | Total Val Loss:  0.843 
                    | Cur Val Loss :  0.405 
                    | Next Val Loss :  0.438
                    | Cur Val Accuracy :  0.279
                    | Next Val Accuracy:  0.199


 42%|████▏     | 3000/7203 [15:44<11:49:22, 10.13s/it]

Epochs: 1 
                    | Iterations: 3000
                    | Total Train Loss:  0.839 
                    | Cur Train Loss :  0.403 
                    | Next Train Loss :  0.436
                    | Cur Train Accuracy :  0.282
                    | Next Train Accuracy:  0.208 
                    | Total Val Loss:  0.846 
                    | Cur Val Loss :  0.407 
                    | Next Val Loss :  0.439
                    | Cur Val Accuracy :  0.274
                    | Next Val Accuracy:  0.182


 49%|████▊     | 3500/7203 [18:25<10:24:33, 10.12s/it]

Epochs: 1 
                    | Iterations: 3500
                    | Total Train Loss:  0.837 
                    | Cur Train Loss :  0.402 
                    | Next Train Loss :  0.435
                    | Cur Train Accuracy :  0.304
                    | Next Train Accuracy:  0.202 
                    | Total Val Loss:  0.840 
                    | Cur Val Loss :  0.403 
                    | Next Val Loss :  0.437
                    | Cur Val Accuracy :  0.296
                    | Next Val Accuracy:  0.195


 56%|█████▌    | 4000/7203 [21:05<8:56:55, 10.06s/it] 

Epochs: 1 
                    | Iterations: 4000
                    | Total Train Loss:  0.836 
                    | Cur Train Loss :  0.403 
                    | Next Train Loss :  0.434
                    | Cur Train Accuracy :  0.303
                    | Next Train Accuracy:  0.226 
                    | Total Val Loss:  0.840 
                    | Cur Val Loss :  0.403 
                    | Next Val Loss :  0.437
                    | Cur Val Accuracy :  0.298
                    | Next Val Accuracy:  0.191


 62%|██████▏   | 4500/7203 [23:46<7:43:29, 10.29s/it]

Epochs: 1 
                    | Iterations: 4500
                    | Total Train Loss:  0.834 
                    | Cur Train Loss :  0.401 
                    | Next Train Loss :  0.433
                    | Cur Train Accuracy :  0.285
                    | Next Train Accuracy:  0.198 
                    | Total Val Loss:  0.840 
                    | Cur Val Loss :  0.401 
                    | Next Val Loss :  0.438
                    | Cur Val Accuracy :  0.278
                    | Next Val Accuracy:  0.160


 69%|██████▉   | 5000/7203 [26:27<6:14:00, 10.19s/it]

Epochs: 1 
                    | Iterations: 5000
                    | Total Train Loss:  0.832 
                    | Cur Train Loss :  0.399 
                    | Next Train Loss :  0.433
                    | Cur Train Accuracy :  0.301
                    | Next Train Accuracy:  0.197 
                    | Total Val Loss:  0.838 
                    | Cur Val Loss :  0.401 
                    | Next Val Loss :  0.437
                    | Cur Val Accuracy :  0.310
                    | Next Val Accuracy:  0.194


 76%|███████▋  | 5500/7203 [29:08<4:49:54, 10.21s/it]

Epochs: 1 
                    | Iterations: 5500
                    | Total Train Loss:  0.827 
                    | Cur Train Loss :  0.397 
                    | Next Train Loss :  0.430
                    | Cur Train Accuracy :  0.295
                    | Next Train Accuracy:  0.220 
                    | Total Val Loss:  0.837 
                    | Cur Val Loss :  0.401 
                    | Next Val Loss :  0.436
                    | Cur Val Accuracy :  0.309
                    | Next Val Accuracy:  0.185


 83%|████████▎ | 6000/7203 [31:50<3:30:09, 10.48s/it]

Epochs: 1 
                    | Iterations: 6000
                    | Total Train Loss:  0.823 
                    | Cur Train Loss :  0.391 
                    | Next Train Loss :  0.431
                    | Cur Train Accuracy :  0.315
                    | Next Train Accuracy:  0.204 
                    | Total Val Loss:  0.835 
                    | Cur Val Loss :  0.399 
                    | Next Val Loss :  0.436
                    | Cur Val Accuracy :  0.296
                    | Next Val Accuracy:  0.179


 90%|█████████ | 6500/7203 [34:33<1:58:52, 10.15s/it]

Epochs: 1 
                    | Iterations: 6500
                    | Total Train Loss:  0.822 
                    | Cur Train Loss :  0.392 
                    | Next Train Loss :  0.430
                    | Cur Train Accuracy :  0.312
                    | Next Train Accuracy:  0.209 
                    | Total Val Loss:  0.829 
                    | Cur Val Loss :  0.395 
                    | Next Val Loss :  0.434
                    | Cur Val Accuracy :  0.301
                    | Next Val Accuracy:  0.179


 97%|█████████▋| 7000/7203 [37:15<35:04, 10.37s/it]  

Epochs: 1 
                    | Iterations: 7000
                    | Total Train Loss:  0.819 
                    | Cur Train Loss :  0.389 
                    | Next Train Loss :  0.430
                    | Cur Train Accuracy :  0.328
                    | Next Train Accuracy:  0.203 
                    | Total Val Loss:  0.826 
                    | Cur Val Loss :  0.393 
                    | Next Val Loss :  0.433
                    | Cur Val Accuracy :  0.310
                    | Next Val Accuracy:  0.190


100%|██████████| 7203/7203 [38:08<00:00,  3.15it/s]
  4%|▍         | 297/7203 [01:49<19:35:30, 10.21s/it]

Epochs: 2 
                    | Iterations: 7500
                    | Total Train Loss:  0.475 
                    | Cur Train Loss :  0.225 
                    | Next Train Loss :  0.250
                    | Cur Train Accuracy :  0.207
                    | Next Train Accuracy:  0.148 
                    | Total Val Loss:  0.820 
                    | Cur Val Loss :  0.388 
                    | Next Val Loss :  0.432
                    | Cur Val Accuracy :  0.306
                    | Next Val Accuracy:  0.207


 11%|█         | 797/7203 [04:31<18:12:01, 10.23s/it]

Epochs: 2 
                    | Iterations: 8000
                    | Total Train Loss:  0.795 
                    | Cur Train Loss :  0.370 
                    | Next Train Loss :  0.426
                    | Cur Train Accuracy :  0.368
                    | Next Train Accuracy:  0.230 
                    | Total Val Loss:  0.820 
                    | Cur Val Loss :  0.388 
                    | Next Val Loss :  0.432
                    | Cur Val Accuracy :  0.317
                    | Next Val Accuracy:  0.206


 18%|█▊        | 1297/7203 [07:11<16:31:43, 10.08s/it]

Epochs: 2 
                    | Iterations: 8500
                    | Total Train Loss:  0.792 
                    | Cur Train Loss :  0.369 
                    | Next Train Loss :  0.423
                    | Cur Train Accuracy :  0.358
                    | Next Train Accuracy:  0.244 
                    | Total Val Loss:  0.821 
                    | Cur Val Loss :  0.388 
                    | Next Val Loss :  0.432
                    | Cur Val Accuracy :  0.341
                    | Next Val Accuracy:  0.205


 25%|██▍       | 1797/7203 [09:51<15:10:31, 10.11s/it]

Epochs: 2 
                    | Iterations: 9000
                    | Total Train Loss:  0.788 
                    | Cur Train Loss :  0.367 
                    | Next Train Loss :  0.421
                    | Cur Train Accuracy :  0.366
                    | Next Train Accuracy:  0.244 
                    | Total Val Loss:  0.818 
                    | Cur Val Loss :  0.387 
                    | Next Val Loss :  0.431
                    | Cur Val Accuracy :  0.314
                    | Next Val Accuracy:  0.196


 32%|███▏      | 2297/7203 [12:31<13:55:42, 10.22s/it]

Epochs: 2 
                    | Iterations: 9500
                    | Total Train Loss:  0.784 
                    | Cur Train Loss :  0.363 
                    | Next Train Loss :  0.421
                    | Cur Train Accuracy :  0.374
                    | Next Train Accuracy:  0.248 
                    | Total Val Loss:  0.819 
                    | Cur Val Loss :  0.388 
                    | Next Val Loss :  0.432
                    | Cur Val Accuracy :  0.332
                    | Next Val Accuracy:  0.218


 39%|███▉      | 2797/7203 [15:13<12:35:10, 10.28s/it]

Epochs: 2 
                    | Iterations: 10000
                    | Total Train Loss:  0.789 
                    | Cur Train Loss :  0.366 
                    | Next Train Loss :  0.423
                    | Cur Train Accuracy :  0.376
                    | Next Train Accuracy:  0.232 
                    | Total Val Loss:  0.816 
                    | Cur Val Loss :  0.385 
                    | Next Val Loss :  0.432
                    | Cur Val Accuracy :  0.311
                    | Next Val Accuracy:  0.206


 46%|████▌     | 3297/7203 [17:55<11:13:43, 10.35s/it]

Epochs: 2 
                    | Iterations: 10500
                    | Total Train Loss:  0.792 
                    | Cur Train Loss :  0.366 
                    | Next Train Loss :  0.426
                    | Cur Train Accuracy :  0.360
                    | Next Train Accuracy:  0.227 
                    | Total Val Loss:  0.814 
                    | Cur Val Loss :  0.383 
                    | Next Val Loss :  0.431
                    | Cur Val Accuracy :  0.318
                    | Next Val Accuracy:  0.202


 53%|█████▎    | 3797/7203 [20:37<9:48:55, 10.37s/it] 

Epochs: 2 
                    | Iterations: 11000
                    | Total Train Loss:  0.786 
                    | Cur Train Loss :  0.365 
                    | Next Train Loss :  0.421
                    | Cur Train Accuracy :  0.372
                    | Next Train Accuracy:  0.232 
                    | Total Val Loss:  0.814 
                    | Cur Val Loss :  0.384 
                    | Next Val Loss :  0.429
                    | Cur Val Accuracy :  0.321
                    | Next Val Accuracy:  0.204


 60%|█████▉    | 4297/7203 [23:23<8:35:38, 10.65s/it]

Epochs: 2 
                    | Iterations: 11500
                    | Total Train Loss:  0.779 
                    | Cur Train Loss :  0.357 
                    | Next Train Loss :  0.422
                    | Cur Train Accuracy :  0.395
                    | Next Train Accuracy:  0.230 
                    | Total Val Loss:  0.812 
                    | Cur Val Loss :  0.382 
                    | Next Val Loss :  0.430
                    | Cur Val Accuracy :  0.329
                    | Next Val Accuracy:  0.206


 67%|██████▋   | 4797/7203 [26:06<6:50:58, 10.25s/it]

Epochs: 2 
                    | Iterations: 12000
                    | Total Train Loss:  0.790 
                    | Cur Train Loss :  0.365 
                    | Next Train Loss :  0.424
                    | Cur Train Accuracy :  0.382
                    | Next Train Accuracy:  0.213 
                    | Total Val Loss:  0.811 
                    | Cur Val Loss :  0.383 
                    | Next Val Loss :  0.428
                    | Cur Val Accuracy :  0.307
                    | Next Val Accuracy:  0.226


 74%|███████▎  | 5297/7203 [28:46<5:26:42, 10.28s/it]

Epochs: 2 
                    | Iterations: 12500
                    | Total Train Loss:  0.781 
                    | Cur Train Loss :  0.361 
                    | Next Train Loss :  0.420
                    | Cur Train Accuracy :  0.382
                    | Next Train Accuracy:  0.240 
                    | Total Val Loss:  0.811 
                    | Cur Val Loss :  0.381 
                    | Next Val Loss :  0.429
                    | Cur Val Accuracy :  0.338
                    | Next Val Accuracy:  0.208


 80%|████████  | 5797/7203 [31:28<4:00:29, 10.26s/it]

Epochs: 2 
                    | Iterations: 13000
                    | Total Train Loss:  0.773 
                    | Cur Train Loss :  0.355 
                    | Next Train Loss :  0.418
                    | Cur Train Accuracy :  0.394
                    | Next Train Accuracy:  0.245 
                    | Total Val Loss:  0.811 
                    | Cur Val Loss :  0.382 
                    | Next Val Loss :  0.428
                    | Cur Val Accuracy :  0.324
                    | Next Val Accuracy:  0.225


 87%|████████▋ | 6297/7203 [34:09<2:33:10, 10.14s/it]

Epochs: 2 
                    | Iterations: 13500
                    | Total Train Loss:  0.769 
                    | Cur Train Loss :  0.354 
                    | Next Train Loss :  0.416
                    | Cur Train Accuracy :  0.391
                    | Next Train Accuracy:  0.254 
                    | Total Val Loss:  0.813 
                    | Cur Val Loss :  0.382 
                    | Next Val Loss :  0.431
                    | Cur Val Accuracy :  0.339
                    | Next Val Accuracy:  0.205


 94%|█████████▍| 6797/7203 [36:49<1:08:37, 10.14s/it]

Epochs: 2 
                    | Iterations: 14000
                    | Total Train Loss:  0.772 
                    | Cur Train Loss :  0.356 
                    | Next Train Loss :  0.415
                    | Cur Train Accuracy :  0.392
                    | Next Train Accuracy:  0.257 
                    | Total Val Loss:  0.811 
                    | Cur Val Loss :  0.382 
                    | Next Val Loss :  0.429
                    | Cur Val Accuracy :  0.331
                    | Next Val Accuracy:  0.214


100%|██████████| 7203/7203 [38:33<00:00,  3.11it/s]  


Total Val Loss:  0.811 
            | Cur Val Loss :  0.382 
            | Next Val Loss :  0.429
            | Cur Val Accuracy :  0.330
            | Next Val Accuracy:  0.209


In [204]:
def test_model(model, test_data):
    model.eval()
    test = Dataset(test_data)
    test_dataloader = torch.utils.data.DataLoader(test, batch_size = 5)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
  
    #The weight here is set by the value we get from the "Fasttext_unsupervised_training.ipynb"
    criterion_cur = nn.CrossEntropyLoss()
    criterion_next = nn.CrossEntropyLoss()

    if use_cuda:
        model = model.cuda()
        criterion_cur = criterion_cur.cuda()
        criterion_next = criterion_next.cuda()
    
    total_loss_val = 0
    cur_loss_val = 0
    next_loss_val = 0
    cur_acc_val = 0
    next_acc_val = 0

    cur_speaker_pred = []
    cur_speaker_y = []
    next_speaker_pred = []
    next_speaker_y = []
        
    with torch.no_grad():
                    
        for val_input, val_label in test_dataloader:
            mask = val_input['attention_mask'].to(device)
            input_id = val_input['input_ids'].squeeze(1).to(device)
    
            output = model(input_id, mask)

            
            cur_speaker_pred += torch.argmax(output['cur'], dim = 1).tolist()
            next_speaker_pred += torch.argmax(output['next'], dim = 1).tolist()
            cur_speaker_y += val_label['cur'].tolist()
            next_speaker_y += val_label['next'].tolist()


            batch_loss, categorical_loss, categorical_acc = get_loss_acc(criterion_cur, criterion_next, output, val_label, device)
            total_loss_val += batch_loss.item()
            cur_loss_val += categorical_loss['cur'].item()
            next_loss_val += categorical_loss['next'].item()
            cur_acc_val += categorical_acc['cur']
            next_acc_val += categorical_acc['next']
    print(
        f'''Total Test Loss: {total_loss_val / len(test_data): .3f} 
        | Cur_speaker Test Loss : {cur_loss_val / len(test_data): .3f} 
        | Next_speaker Test Loss : {next_loss_val / len(test_data): .3f}
        | Cur_speaker Test Accuracy : {cur_acc_val / len(test_data): .3f}
        | Next_speaker Test Accuracy: {next_acc_val / len(test_data): .3f}''') 

    print("====================================================================")
    
    print("The classification report for current speaker prediction:")
    cur_speaker_m = classification_report(cur_speaker_y, cur_speaker_pred, target_names = ["Sheldon", "Penny", "Leonard", "Raj","Howard","Amy","Bernadette","Secondary"])
    print(cur_speaker_m)

    print("====================================================================")
    
    print("The classification report for next speaker prediction:")
    next_speaker_m = classification_report(next_speaker_y, next_speaker_pred, target_names = ["Sheldon", "Penny", "Leonard", "Raj","Howard","Amy","Bernadette","Secondary", "Ending"])
    print(next_speaker_m)


In [205]:
df_test = pd.read_csv("data/test_data.csv")

In [206]:
model_test = BertClassifier()
model_test.load_state_dict(torch.load("Models/MultiLabel_Bert.pt"))
test_model(model_test, df_test)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Total Test Loss:  0.758 
        | Cur_speaker Test Loss :  0.355 
        | Next_speaker Test Loss :  0.403
        | Cur_speaker Test Accuracy :  0.365
        | Next_speaker Test Accuracy:  0.269
The classification report for current speaker prediction:
              precision    recall  f1-score   support

     Sheldon       0.50      0.72      0.59       475
       Penny       0.34      0.42      0.38       251
     Leonard       0.25      0.50      0.33       298
         Raj       0.32      0.12      0.17       195
      Howard       0.28      0.18      0.22       250
         Amy       0.50      0.01      0.01       135
  Bernadette       0.50      0.02      0.04        94
   Secondary       0.33      0.12      0.18       198

    accuracy                           0.36      1896
   macro avg       0.38      0.26      0.24      1896
weighted avg       0.37      0.36      0.32      1896

The classification report for next speaker prediction:
              precision    recall  f1

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [207]:
balanced_model_test = BertClassifier()
balanced_model_test.load_state_dict(torch.load("Models/Balanced_MultiLabel_Bert.pt"))
test_model(balanced_model_test, df_test)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Total Test Loss:  0.780 
        | Cur_speaker Test Loss :  0.359 
        | Next_speaker Test Loss :  0.421
        | Cur_speaker Test Accuracy :  0.342
        | Next_speaker Test Accuracy:  0.228
The classification report for current speaker prediction:
              precision    recall  f1-score   support

     Sheldon       0.60      0.57      0.59       475
       Penny       0.31      0.53      0.39       251
     Leonard       0.29      0.27      0.28       298
         Raj       0.26      0.22      0.24       195
      Howard       0.24      0.18      0.21       250
         Amy       0.19      0.21      0.20       135
  Bernadette       0.09      0.09      0.09        94
   Secondary       0.26      0.20      0.22       198

    accuracy                           0.34      1896
   macro avg       0.28      0.28      0.28      1896
weighted avg       0.34      0.34      0.34      1896

The classification report for next speaker prediction:
              precision    recall  f1