In [2]:
import json 
import torch
from datasets import Dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


### preprocess data such that a datapoint has the following
* input is all the utterances from the text till that moment concatenated
* label is the emotion of the last concatenated utterance 

In [5]:
json_path = "full_dataset.json"

with open(json_path, 'r') as file:
    dataset = json.load(file)

class_mapping = {
            "anger":0,
            "disgust":1,
            "fear":2,
            "joy":3,
            "sadness":4,
            "surprise":5,
            "neutral":6
}

all_data = {
    "text" : [],
    "label" : []
}

for val in dataset:
    concat_utt = ""
    for utt in val["conversation"]:
        concat_utt = concat_utt + utt["text"]
        all_data["text"].append(concat_utt)
        all_data["label"].append(class_mapping[utt["emotion"]])

created_dataset = Dataset.from_dict(all_data)
created_dataset

Dataset({
    features: ['text', 'label'],
    num_rows: 13619
})

#### Train args

In [6]:
model_name = "bert-base-uncased"
num_labels = 7
batch_size = 8
results = "results/{}".format(model_name)
lr = 2e-5
num_epochs = 5 # [8,5,2]

#### Tokenize data

In [8]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)


def prepare_for_train(tp_dat, split=True):
    def tokenize(batch):
        return tokenizer(
            batch["text"],
            padding=True,
            truncation=True
            )
    tp_dat = tp_dat.map(
        tokenize,
        batched=True,
        batch_size=None)
    
    tp_dat.set_format(
        "torch",
        columns=["input_ids", "attention_mask", "label"])

    if split:
        tp_dat = tp_dat.train_test_split(
            test_size=0.2,
            shuffle = False,
            seed = 1337
            )
        
    return tp_dat

train_test_data = prepare_for_train(created_dataset, True)

train_test_data

Map: 100%|██████████| 13619/13619 [00:03<00:00, 4311.93 examples/s]


DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 10895
    })
    test: Dataset({
        features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2724
    })
})

#### Train setup

In [9]:
from sklearn.metrics import accuracy_score, f1_score

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    f1 = f1_score(labels, preds, average="weighted")
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1}


from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=num_labels
    ).to(device)


logging_steps = len(train_test_data["train"]) // batch_size

training_args = TrainingArguments(output_dir=results,
                                  num_train_epochs=num_epochs,
                                  learning_rate=lr,
                                  per_device_train_batch_size=batch_size,
                                  per_device_eval_batch_size=batch_size,
                                  load_best_model_at_end=True,
                                  metric_for_best_model="f1",
                                  weight_decay=0.01,
                                  evaluation_strategy="epoch",
                                  save_strategy="epoch",
                                  disable_tqdm=False)

trainer = Trainer(model=model,
                  args=training_args,
                  compute_metrics=compute_metrics,
                  train_dataset=train_test_data["train"],
                  eval_dataset=train_test_data["test"])


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
trainer.train()

  7%|▋         | 500/6810 [03:46<48:03,  2.19it/s]

{'loss': 1.3799, 'learning_rate': 1.853157121879589e-05, 'epoch': 0.37}


 15%|█▍        | 1000/6810 [07:35<44:16,  2.19it/s]

{'loss': 1.2035, 'learning_rate': 1.7063142437591777e-05, 'epoch': 0.73}


 20%|██        | 1362/6810 [10:20<39:53,  2.28it/s]
 20%|██        | 1362/6810 [11:12<39:53,  2.28it/s]Checkpoint destination directory results/bert-base-uncased/checkpoint-1362 already exists and is non-empty.Saving will proceed but saved results may be invalid.


{'eval_loss': 1.2236084938049316, 'eval_accuracy': 0.5939794419970631, 'eval_f1': 0.5676638144701666, 'eval_runtime': 51.8583, 'eval_samples_per_second': 52.528, 'eval_steps_per_second': 6.576, 'epoch': 1.0}


 22%|██▏       | 1500/6810 [12:31<40:15,  2.20it/s]   

{'loss': 1.1059, 'learning_rate': 1.5594713656387664e-05, 'epoch': 1.1}


 29%|██▉       | 2000/6810 [16:19<36:34,  2.19it/s]

{'loss': 0.9244, 'learning_rate': 1.4126284875183555e-05, 'epoch': 1.47}


 37%|███▋      | 2500/6810 [20:08<32:52,  2.19it/s]

{'loss': 0.9333, 'learning_rate': 1.2657856093979443e-05, 'epoch': 1.84}


 40%|████      | 2724/6810 [21:50<29:58,  2.27it/s]
 40%|████      | 2724/6810 [22:42<29:58,  2.27it/s]Checkpoint destination directory results/bert-base-uncased/checkpoint-2724 already exists and is non-empty.Saving will proceed but saved results may be invalid.


{'eval_loss': 1.1991859674453735, 'eval_accuracy': 0.6020558002936858, 'eval_f1': 0.5789671747048933, 'eval_runtime': 51.9059, 'eval_samples_per_second': 52.48, 'eval_steps_per_second': 6.57, 'epoch': 2.0}


 44%|████▍     | 3000/6810 [25:04<29:02,  2.19it/s]   

{'loss': 0.7901, 'learning_rate': 1.1189427312775332e-05, 'epoch': 2.2}


 51%|█████▏    | 3500/6810 [28:52<25:07,  2.20it/s]

{'loss': 0.6501, 'learning_rate': 9.72099853157122e-06, 'epoch': 2.57}


 59%|█████▊    | 4000/6810 [32:40<21:24,  2.19it/s]

{'loss': 0.645, 'learning_rate': 8.252569750367108e-06, 'epoch': 2.94}


 60%|██████    | 4086/6810 [33:19<19:55,  2.28it/s]
 60%|██████    | 4086/6810 [34:11<19:55,  2.28it/s]Checkpoint destination directory results/bert-base-uncased/checkpoint-4086 already exists and is non-empty.Saving will proceed but saved results may be invalid.


{'eval_loss': 1.3886299133300781, 'eval_accuracy': 0.5796622613803231, 'eval_f1': 0.577820646482216, 'eval_runtime': 51.9165, 'eval_samples_per_second': 52.469, 'eval_steps_per_second': 6.568, 'epoch': 3.0}


 66%|██████▌   | 4500/6810 [37:36<17:37,  2.19it/s]   

{'loss': 0.4661, 'learning_rate': 6.784140969162997e-06, 'epoch': 3.3}


 73%|███████▎  | 5000/6810 [41:25<13:47,  2.19it/s]

{'loss': 0.4329, 'learning_rate': 5.3157121879588845e-06, 'epoch': 3.67}


 80%|████████  | 5448/6810 [44:49<09:58,  2.28it/s]
 80%|████████  | 5448/6810 [45:41<09:58,  2.28it/s]

{'eval_loss': 1.6512377262115479, 'eval_accuracy': 0.5884728340675477, 'eval_f1': 0.574765816492121, 'eval_runtime': 51.8194, 'eval_samples_per_second': 52.567, 'eval_steps_per_second': 6.581, 'epoch': 4.0}


 81%|████████  | 5500/6810 [46:17<09:54,  2.20it/s]  

{'loss': 0.4302, 'learning_rate': 3.847283406754773e-06, 'epoch': 4.04}


 88%|████████▊ | 6000/6810 [50:05<06:10,  2.18it/s]

{'loss': 0.2868, 'learning_rate': 2.378854625550661e-06, 'epoch': 4.41}


 95%|█████████▌| 6500/6810 [53:54<02:21,  2.18it/s]

{'loss': 0.3149, 'learning_rate': 9.104258443465493e-07, 'epoch': 4.77}


100%|██████████| 6810/6810 [56:16<00:00,  2.27it/s]
100%|██████████| 6810/6810 [57:08<00:00,  2.27it/s]

{'eval_loss': 1.8411591053009033, 'eval_accuracy': 0.5833333333333334, 'eval_f1': 0.575769359999327, 'eval_runtime': 51.8801, 'eval_samples_per_second': 52.506, 'eval_steps_per_second': 6.573, 'epoch': 5.0}


100%|██████████| 6810/6810 [57:24<00:00,  1.98it/s]

{'train_runtime': 3444.5086, 'train_samples_per_second': 15.815, 'train_steps_per_second': 1.977, 'train_loss': 0.7167956953006694, 'epoch': 5.0}





TrainOutput(global_step=6810, training_loss=0.7167956953006694, metrics={'train_runtime': 3444.5086, 'train_samples_per_second': 15.815, 'train_steps_per_second': 1.977, 'train_loss': 0.7167956953006694, 'epoch': 5.0})

### Evaluate model 

In [16]:
from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification
from torch.utils.data import DataLoader, TensorDataset
import torch 

model_name = "results/bert-base-uncased_com_inp/checkpoint-2724"
num_labels = 7

model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=num_labels
    ).to(device)


eval_loader = DataLoader(train_test_data["test"], batch_size=16)

model.eval()

true_labels = []
pred_labels = []

with torch.no_grad():
    for batch in eval_loader:
        input_ids, attention_mask, labels = batch["input_ids"], batch["attention_mask"], batch["label"]
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        # labels = labels.to(device)
        
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits

        # Store true and predicted labels
        true_labels.extend(labels.cpu().numpy())
        pred_labels.extend(torch.argmax(logits, axis=1).cpu().numpy())


from sklearn.metrics import accuracy_score, classification_report
accuracy = accuracy_score(true_labels, pred_labels)
report = classification_report(true_labels, pred_labels)

print(f"Accuracy: {accuracy}")
print(report)


Accuracy: 0.6020558002936858
              precision    recall  f1-score   support

           0       0.51      0.45      0.48       346
           1       0.27      0.04      0.07        81
           2       0.09      0.03      0.05        61
           3       0.61      0.48      0.54       456
           4       0.43      0.32      0.36       263
           5       0.59      0.60      0.59       329
           6       0.65      0.82      0.73      1188

    accuracy                           0.60      2724
   macro avg       0.45      0.39      0.40      2724
weighted avg       0.57      0.60      0.58      2724

