In [2]:
%cd /scratch/mt/ashapiro/Hate_Speech/Multitask_trial/

/scratch/mt/ashapiro/Hate_Speech/Multitask_trial


In [3]:
import numpy as np
import torch
import torch.nn as nn
import transformers
import nlp
import logging
from datasets import load_dataset
from model import * 
logging.basicConfig(level=logging.INFO)

# Large

## Preparing Data

In [4]:
task_names = ['offensive', 'hatespeech', 'hatespeech_classes']

In [5]:
dataset_dict = {
    "offensive": load_dataset("csv", data_files={'train': "Data/train/trainA_prepro_large.csv", 'test': "Data/test/testA_prepro.csv" } ),
    "hatespeech": load_dataset("csv", data_files={'train': "Data/train/trainB_prepro_large.csv", 'test': "Data/test/testB_prepro.csv" } ),
    "hatespeech_classes": load_dataset("csv", data_files={'train': "Data/train/trainC_prepro.csv", 'test': "Data/test/testC_prepro.csv" } ),
}

Downloading data files: 100%|██████████| 2/2 [00:00<00:00, 1595.40it/s]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 90.82it/s]

Downloading and preparing dataset csv/default to /home/ashapiro/.cache/huggingface/datasets/csv/default-a5a3e645893b88fd/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...



100%|██████████| 2/2 [00:00<00:00, 345.54it/s]

Dataset csv downloaded and prepared to /home/ashapiro/.cache/huggingface/datasets/csv/default-a5a3e645893b88fd/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.



Downloading data files: 100%|██████████| 2/2 [00:00<00:00, 1402.54it/s]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 102.52it/s]
100%|██████████| 2/2 [00:00<00:00, 485.79it/s]

Downloading and preparing dataset csv/default to /home/ashapiro/.cache/huggingface/datasets/csv/default-35a74ee417a0e4c9/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...
Dataset csv downloaded and prepared to /home/ashapiro/.cache/huggingface/datasets/csv/default-35a74ee417a0e4c9/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.



Downloading data files: 100%|██████████| 2/2 [00:00<00:00, 1580.67it/s]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 71.24it/s]

Downloading and preparing dataset csv/default to /home/ashapiro/.cache/huggingface/datasets/csv/default-efbc25be205baaf6/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...



100%|██████████| 2/2 [00:00<00:00, 536.94it/s]

Dataset csv downloaded and prepared to /home/ashapiro/.cache/huggingface/datasets/csv/default-efbc25be205baaf6/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.





## Setting Model

In [6]:
model_name = "/scratch/mt/ashapiro/Hate_Speech/Models/Marbertv2/"
multitask_model = MultitaskModel.create(
    model_name=model_name,
    model_type_dict={
        "offensive": transformers.AutoModelForSequenceClassification,
        "hatespeech": transformers.AutoModelForSequenceClassification,
        "hatespeech_classes": transformers.AutoModelForSequenceClassification,
    },
    model_config_dict={
        "offensive": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        "hatespeech": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        "hatespeech_classes": transformers.AutoConfig.from_pretrained(model_name, num_labels=7),
    },
)

In [7]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

In [8]:
max_length = 512

def convert_to_features(example_batch):
    inputs = list(example_batch['text'])
    features = tokenizer.batch_encode_plus(
        inputs, max_length=max_length, pad_to_max_length=True
    )
    features["labels"] = example_batch["labels"]
    return features

convert_func_dict = {
    "offensive": convert_to_features,
    "hatespeech": convert_to_features,
    "hatespeech_classes": convert_to_features,
}

In [9]:
columns_dict = {
    "offensive": ['input_ids', 'attention_mask', 'labels'],
    "hatespeech": ['input_ids', 'attention_mask', 'labels'],
    "hatespeech_classes": ['input_ids', 'attention_mask', 'labels'],
}

features_dict = {}
for task_name, dataset in dataset_dict.items():
    features_dict[task_name] = {}
    for phase, phase_dataset in dataset.items():
        features_dict[task_name][phase] = phase_dataset.map(
            convert_func_dict[task_name],
            batched=True,
            load_from_cache_file=False,
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))
        features_dict[task_name][phase].set_format(
            type="torch", 
            columns=columns_dict[task_name],
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))

100%|██████████| 20/20 [00:08<00:00,  2.28ba/s]


offensive train 19906 19906
offensive train 19906 19906


100%|██████████| 2/2 [00:00<00:00,  3.89ba/s]


offensive test 1270 1270
offensive test 1270 1270


100%|██████████| 5/5 [00:02<00:00,  2.16ba/s]


hatespeech train 4800 4800
hatespeech train 4800 4800


100%|██████████| 2/2 [00:00<00:00,  3.91ba/s]


hatespeech test 1270 1270
hatespeech test 1270 1270


100%|██████████| 9/9 [00:03<00:00,  2.48ba/s]


hatespeech_classes train 8887 8887
hatespeech_classes train 8887 8887


100%|██████████| 2/2 [00:00<00:00,  3.88ba/s]

hatespeech_classes test 1270 1270
hatespeech_classes test 1270 1270





In [10]:
eval_dataset = {
    task_name: dataset["test"] 
    for task_name, dataset in features_dict.items()
}

In [11]:
train_dataset = {
    task_name: dataset["train"] 
    for task_name, dataset in features_dict.items()
}
args = transformers.TrainingArguments(
        output_dir="./models/multitask_model/3_main_tasks/large_data/",
        overwrite_output_dir=True,
        learning_rate=2e-5,
        do_train=True,
        num_train_epochs=4,
        # Adjust batch size if this doesn't fit on the Colab GPU
        per_device_train_batch_size=16,  
        save_steps=3000,)

trainer = MultitaskTrainer(
    model=multitask_model,
    args=args,
    data_collator=NLPDataCollator(),
    train_dataset=train_dataset,
)

[34m[1mwandb[0m: Currently logged in as: [33mahmadshapiro[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.11 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [12]:
trainer.train()

Epoch:   0%|          | 0/4 [00:00<?, ?it/s]
Iteration:   0%|          | 0/2101 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/2101 [00:00<20:54,  1.67it/s][A
Iteration:   0%|          | 2/2101 [00:01<19:39,  1.78it/s][A
Iteration:   0%|          | 3/2101 [00:01<18:53,  1.85it/s][A
Iteration:   0%|          | 4/2101 [00:02<18:49,  1.86it/s][A
Iteration:   0%|          | 5/2101 [00:02<18:59,  1.84it/s][A
Iteration:   0%|          | 6/2101 [00:03<19:09,  1.82it/s][A
Iteration:   0%|          | 7/2101 [00:03<19:15,  1.81it/s][A
Iteration:   0%|          | 8/2101 [00:04<19:17,  1.81it/s][A
Iteration:   0%|          | 9/2101 [00:04<18:47,  1.85it/s][A
Iteration:   0%|          | 10/2101 [00:05<18:45,  1.86it/s][A
Iteration:   1%|          | 11/2101 [00:05<18:40,  1.86it/s][A
Iteration:   1%|          | 12/2101 [00:06<18:44,  1.86it/s][A
Iteration:   1%|          | 13/2101 [00:07<18:33,  1.87it/s][A
Iteration:   1%|          | 14/2101 [00:07<18:40,  1.86it/s][A
Iteration:   

{"loss": 0.49617886539548633, "learning_rate": 1.8810090433127083e-05, "epoch": 0.23798191337458352, "step": 500}



Iteration:  24%|██▍       | 501/2101 [04:30<15:10,  1.76it/s][A
Iteration:  24%|██▍       | 502/2101 [04:31<14:47,  1.80it/s][A
Iteration:  24%|██▍       | 503/2101 [04:31<14:22,  1.85it/s][A
Iteration:  24%|██▍       | 504/2101 [04:32<14:18,  1.86it/s][A
Iteration:  24%|██▍       | 505/2101 [04:32<14:20,  1.86it/s][A
Iteration:  24%|██▍       | 506/2101 [04:33<14:23,  1.85it/s][A
Iteration:  24%|██▍       | 507/2101 [04:33<14:23,  1.85it/s][A
Iteration:  24%|██▍       | 508/2101 [04:34<14:20,  1.85it/s][A
Iteration:  24%|██▍       | 509/2101 [04:34<14:19,  1.85it/s][A
Iteration:  24%|██▍       | 510/2101 [04:35<14:17,  1.86it/s][A
Iteration:  24%|██▍       | 511/2101 [04:35<14:16,  1.86it/s][A
Iteration:  24%|██▍       | 512/2101 [04:36<14:16,  1.85it/s][A
Iteration:  24%|██▍       | 513/2101 [04:37<14:16,  1.85it/s][A
Iteration:  24%|██▍       | 514/2101 [04:37<14:18,  1.85it/s][A
Iteration:  25%|██▍       | 515/2101 [04:38<14:20,  1.84it/s][A
Iteration:  25%|██▍     

{"loss": 0.3969753061942756, "learning_rate": 1.7620180866254168e-05, "epoch": 0.47596382674916704, "step": 1000}



Iteration:  48%|████▊     | 1001/2101 [08:59<10:22,  1.77it/s][A
Iteration:  48%|████▊     | 1002/2101 [09:00<10:11,  1.80it/s][A
Iteration:  48%|████▊     | 1003/2101 [09:00<10:05,  1.81it/s][A
Iteration:  48%|████▊     | 1004/2101 [09:01<10:02,  1.82it/s][A
Iteration:  48%|████▊     | 1005/2101 [09:01<09:59,  1.83it/s][A
Iteration:  48%|████▊     | 1006/2101 [09:02<09:56,  1.83it/s][A
Iteration:  48%|████▊     | 1007/2101 [09:02<09:52,  1.85it/s][A
Iteration:  48%|████▊     | 1008/2101 [09:03<09:51,  1.85it/s][A
Iteration:  48%|████▊     | 1009/2101 [09:03<09:51,  1.85it/s][A
Iteration:  48%|████▊     | 1010/2101 [09:04<09:50,  1.85it/s][A
Iteration:  48%|████▊     | 1011/2101 [09:04<09:49,  1.85it/s][A
Iteration:  48%|████▊     | 1012/2101 [09:05<09:48,  1.85it/s][A
Iteration:  48%|████▊     | 1013/2101 [09:06<09:47,  1.85it/s][A
Iteration:  48%|████▊     | 1014/2101 [09:06<09:46,  1.85it/s][A
Iteration:  48%|████▊     | 1015/2101 [09:07<09:45,  1.85it/s][A
Iteration

{"loss": 0.361205443199724, "learning_rate": 1.643027129938125e-05, "epoch": 0.7139457401237506, "step": 1500}



Iteration:  71%|███████▏  | 1501/2101 [13:30<05:43,  1.75it/s][A
Iteration:  71%|███████▏  | 1502/2101 [13:30<05:37,  1.78it/s][A
Iteration:  72%|███████▏  | 1503/2101 [13:31<05:33,  1.79it/s][A
Iteration:  72%|███████▏  | 1504/2101 [13:31<05:29,  1.81it/s][A
Iteration:  72%|███████▏  | 1505/2101 [13:32<05:25,  1.83it/s][A
Iteration:  72%|███████▏  | 1506/2101 [13:32<05:24,  1.84it/s][A
Iteration:  72%|███████▏  | 1507/2101 [13:33<05:23,  1.84it/s][A
Iteration:  72%|███████▏  | 1508/2101 [13:33<05:20,  1.85it/s][A
Iteration:  72%|███████▏  | 1509/2101 [13:34<05:20,  1.85it/s][A
Iteration:  72%|███████▏  | 1510/2101 [13:35<05:20,  1.85it/s][A
Iteration:  72%|███████▏  | 1511/2101 [13:35<05:20,  1.84it/s][A
Iteration:  72%|███████▏  | 1512/2101 [13:36<05:20,  1.84it/s][A
Iteration:  72%|███████▏  | 1513/2101 [13:36<05:18,  1.84it/s][A
Iteration:  72%|███████▏  | 1514/2101 [13:37<05:16,  1.85it/s][A
Iteration:  72%|███████▏  | 1515/2101 [13:37<05:15,  1.86it/s][A
Iteration

{"loss": 0.3321925585283898, "learning_rate": 1.524036173250833e-05, "epoch": 0.9519276534983341, "step": 2000}



Iteration:  95%|█████████▌| 2001/2101 [17:59<00:57,  1.75it/s][A
Iteration:  95%|█████████▌| 2002/2101 [18:00<00:55,  1.78it/s][A
Iteration:  95%|█████████▌| 2003/2101 [18:00<00:54,  1.80it/s][A
Iteration:  95%|█████████▌| 2004/2101 [18:01<00:53,  1.81it/s][A
Iteration:  95%|█████████▌| 2005/2101 [18:01<00:52,  1.82it/s][A
Iteration:  95%|█████████▌| 2006/2101 [18:02<00:52,  1.83it/s][A
Iteration:  96%|█████████▌| 2007/2101 [18:02<00:51,  1.83it/s][A
Iteration:  96%|█████████▌| 2008/2101 [18:03<00:50,  1.84it/s][A
Iteration:  96%|█████████▌| 2009/2101 [18:03<00:49,  1.84it/s][A
Iteration:  96%|█████████▌| 2010/2101 [18:04<00:49,  1.84it/s][A
Iteration:  96%|█████████▌| 2011/2101 [18:04<00:48,  1.85it/s][A
Iteration:  96%|█████████▌| 2012/2101 [18:05<00:47,  1.86it/s][A
Iteration:  96%|█████████▌| 2013/2101 [18:06<00:47,  1.85it/s][A
Iteration:  96%|█████████▌| 2014/2101 [18:06<00:46,  1.85it/s][A
Iteration:  96%|█████████▌| 2015/2101 [18:07<00:46,  1.86it/s][A
Iteration

{"loss": 0.26052306150924415, "learning_rate": 1.4050452165635413e-05, "epoch": 1.1899095668729176, "step": 2500}



Iteration:  19%|█▉        | 400/2101 [03:35<16:16,  1.74it/s][A
Iteration:  19%|█▉        | 401/2101 [03:36<15:59,  1.77it/s][A
Iteration:  19%|█▉        | 402/2101 [03:37<15:49,  1.79it/s][A
Iteration:  19%|█▉        | 403/2101 [03:37<15:39,  1.81it/s][A
Iteration:  19%|█▉        | 404/2101 [03:38<15:32,  1.82it/s][A
Iteration:  19%|█▉        | 405/2101 [03:38<15:22,  1.84it/s][A
Iteration:  19%|█▉        | 406/2101 [03:39<15:19,  1.84it/s][A
Iteration:  19%|█▉        | 407/2101 [03:39<15:17,  1.85it/s][A
Iteration:  19%|█▉        | 408/2101 [03:40<15:18,  1.84it/s][A
Iteration:  19%|█▉        | 409/2101 [03:40<15:17,  1.84it/s][A
Iteration:  20%|█▉        | 410/2101 [03:41<15:18,  1.84it/s][A
Iteration:  20%|█▉        | 411/2101 [03:41<15:03,  1.87it/s][A
Iteration:  20%|█▉        | 412/2101 [03:42<15:06,  1.86it/s][A
Iteration:  20%|█▉        | 413/2101 [03:42<15:06,  1.86it/s][A
Iteration:  20%|█▉        | 414/2101 [03:43<15:07,  1.86it/s][A
Iteration:  20%|█▉      

{"loss": 0.2272716256796848, "learning_rate": 1.2860542598762496e-05, "epoch": 1.4278914802475011, "step": 3000}



Iteration:  43%|████▎     | 899/2101 [08:08<30:03,  1.50s/it][A
Iteration:  43%|████▎     | 900/2101 [08:08<24:17,  1.21s/it][A
Iteration:  43%|████▎     | 901/2101 [08:09<20:13,  1.01s/it][A
Iteration:  43%|████▎     | 902/2101 [08:09<17:18,  1.15it/s][A
Iteration:  43%|████▎     | 903/2101 [08:10<15:18,  1.30it/s][A
Iteration:  43%|████▎     | 904/2101 [08:10<13:56,  1.43it/s][A
Iteration:  43%|████▎     | 905/2101 [08:11<13:00,  1.53it/s][A
Iteration:  43%|████▎     | 906/2101 [08:11<12:17,  1.62it/s][A
Iteration:  43%|████▎     | 907/2101 [08:12<11:47,  1.69it/s][A
Iteration:  43%|████▎     | 908/2101 [08:12<11:28,  1.73it/s][A
Iteration:  43%|████▎     | 909/2101 [08:13<11:15,  1.76it/s][A
Iteration:  43%|████▎     | 910/2101 [08:13<11:04,  1.79it/s][A
Iteration:  43%|████▎     | 911/2101 [08:14<10:59,  1.80it/s][A
Iteration:  43%|████▎     | 912/2101 [08:15<10:54,  1.82it/s][A
Iteration:  43%|████▎     | 913/2101 [08:15<10:49,  1.83it/s][A
Iteration:  44%|████▎   

{"loss": 0.2342777248180937, "learning_rate": 1.1670633031889577e-05, "epoch": 1.6658733936220846, "step": 3500}



Iteration:  67%|██████▋   | 1400/2101 [12:37<06:39,  1.75it/s][A
Iteration:  67%|██████▋   | 1401/2101 [12:38<06:33,  1.78it/s][A
Iteration:  67%|██████▋   | 1402/2101 [12:38<06:27,  1.81it/s][A
Iteration:  67%|██████▋   | 1403/2101 [12:39<06:18,  1.84it/s][A
Iteration:  67%|██████▋   | 1404/2101 [12:39<06:16,  1.85it/s][A
Iteration:  67%|██████▋   | 1405/2101 [12:40<06:16,  1.85it/s][A
Iteration:  67%|██████▋   | 1406/2101 [12:40<06:12,  1.87it/s][A
Iteration:  67%|██████▋   | 1407/2101 [12:41<06:08,  1.88it/s][A
Iteration:  67%|██████▋   | 1408/2101 [12:41<06:09,  1.88it/s][A
Iteration:  67%|██████▋   | 1409/2101 [12:42<06:10,  1.87it/s][A
Iteration:  67%|██████▋   | 1410/2101 [12:43<06:09,  1.87it/s][A
Iteration:  67%|██████▋   | 1411/2101 [12:43<06:10,  1.86it/s][A
Iteration:  67%|██████▋   | 1412/2101 [12:44<06:10,  1.86it/s][A
Iteration:  67%|██████▋   | 1413/2101 [12:44<06:09,  1.86it/s][A
Iteration:  67%|██████▋   | 1414/2101 [12:45<06:09,  1.86it/s][A
Iteration

{"loss": 0.2253049932155991, "learning_rate": 1.048072346501666e-05, "epoch": 1.9038553069966682, "step": 4000}



Iteration:  90%|█████████ | 1900/2101 [17:06<01:55,  1.74it/s][A
Iteration:  90%|█████████ | 1901/2101 [17:07<01:52,  1.77it/s][A
Iteration:  91%|█████████ | 1902/2101 [17:07<01:50,  1.80it/s][A
Iteration:  91%|█████████ | 1903/2101 [17:08<01:49,  1.81it/s][A
Iteration:  91%|█████████ | 1904/2101 [17:08<01:47,  1.82it/s][A
Iteration:  91%|█████████ | 1905/2101 [17:09<01:46,  1.84it/s][A
Iteration:  91%|█████████ | 1906/2101 [17:09<01:45,  1.84it/s][A
Iteration:  91%|█████████ | 1907/2101 [17:10<01:45,  1.84it/s][A
Iteration:  91%|█████████ | 1908/2101 [17:10<01:44,  1.84it/s][A
Iteration:  91%|█████████ | 1909/2101 [17:11<01:44,  1.84it/s][A
Iteration:  91%|█████████ | 1910/2101 [17:12<01:43,  1.85it/s][A
Iteration:  91%|█████████ | 1911/2101 [17:12<01:43,  1.84it/s][A
Iteration:  91%|█████████ | 1912/2101 [17:13<01:42,  1.85it/s][A
Iteration:  91%|█████████ | 1913/2101 [17:13<01:41,  1.85it/s][A
Iteration:  91%|█████████ | 1914/2101 [17:14<01:41,  1.85it/s][A
Iteration

{"loss": 0.18151792278746143, "learning_rate": 9.290813898143742e-06, "epoch": 2.1418372203712517, "step": 4500}



Iteration:  14%|█▍        | 299/2101 [02:41<16:52,  1.78it/s][A
Iteration:  14%|█▍        | 300/2101 [02:41<16:28,  1.82it/s][A
Iteration:  14%|█▍        | 301/2101 [02:42<16:21,  1.83it/s][A
Iteration:  14%|█▍        | 302/2101 [02:42<16:15,  1.84it/s][A
Iteration:  14%|█▍        | 303/2101 [02:43<16:13,  1.85it/s][A
Iteration:  14%|█▍        | 304/2101 [02:43<16:11,  1.85it/s][A
Iteration:  15%|█▍        | 305/2101 [02:44<16:10,  1.85it/s][A
Iteration:  15%|█▍        | 306/2101 [02:44<15:54,  1.88it/s][A
Iteration:  15%|█▍        | 307/2101 [02:45<15:54,  1.88it/s][A
Iteration:  15%|█▍        | 308/2101 [02:45<15:55,  1.88it/s][A
Iteration:  15%|█▍        | 309/2101 [02:46<15:58,  1.87it/s][A
Iteration:  15%|█▍        | 310/2101 [02:46<15:54,  1.88it/s][A
Iteration:  15%|█▍        | 311/2101 [02:47<15:55,  1.87it/s][A
Iteration:  15%|█▍        | 312/2101 [02:48<15:56,  1.87it/s][A
Iteration:  15%|█▍        | 313/2101 [02:48<15:55,  1.87it/s][A
Iteration:  15%|█▍      

{"loss": 0.14009080386220013, "learning_rate": 8.100904331270823e-06, "epoch": 2.379819133745835, "step": 5000}



Iteration:  38%|███▊      | 799/2101 [07:10<12:30,  1.74it/s][A
Iteration:  38%|███▊      | 800/2101 [07:11<12:15,  1.77it/s][A
Iteration:  38%|███▊      | 801/2101 [07:11<12:05,  1.79it/s][A
Iteration:  38%|███▊      | 802/2101 [07:12<11:57,  1.81it/s][A
Iteration:  38%|███▊      | 803/2101 [07:12<11:51,  1.83it/s][A
Iteration:  38%|███▊      | 804/2101 [07:13<11:49,  1.83it/s][A
Iteration:  38%|███▊      | 805/2101 [07:13<11:49,  1.83it/s][A
Iteration:  38%|███▊      | 806/2101 [07:14<11:48,  1.83it/s][A
Iteration:  38%|███▊      | 807/2101 [07:15<11:45,  1.83it/s][A
Iteration:  38%|███▊      | 808/2101 [07:15<11:40,  1.85it/s][A
Iteration:  39%|███▊      | 809/2101 [07:16<11:39,  1.85it/s][A
Iteration:  39%|███▊      | 810/2101 [07:16<11:38,  1.85it/s][A
Iteration:  39%|███▊      | 811/2101 [07:17<11:39,  1.84it/s][A
Iteration:  39%|███▊      | 812/2101 [07:17<11:38,  1.85it/s][A
Iteration:  39%|███▊      | 813/2101 [07:18<11:37,  1.85it/s][A
Iteration:  39%|███▊    

{"loss": 0.13337958228413482, "learning_rate": 6.910994764397906e-06, "epoch": 2.6178010471204187, "step": 5500}



Iteration:  62%|██████▏   | 1299/2101 [11:40<07:43,  1.73it/s][A
Iteration:  62%|██████▏   | 1300/2101 [11:41<07:34,  1.76it/s][A
Iteration:  62%|██████▏   | 1301/2101 [11:41<07:28,  1.78it/s][A
Iteration:  62%|██████▏   | 1302/2101 [11:42<07:24,  1.80it/s][A
Iteration:  62%|██████▏   | 1303/2101 [11:42<07:19,  1.81it/s][A
Iteration:  62%|██████▏   | 1304/2101 [11:43<07:16,  1.82it/s][A
Iteration:  62%|██████▏   | 1305/2101 [11:43<07:16,  1.83it/s][A
Iteration:  62%|██████▏   | 1306/2101 [11:44<07:13,  1.83it/s][A
Iteration:  62%|██████▏   | 1307/2101 [11:45<07:11,  1.84it/s][A
Iteration:  62%|██████▏   | 1308/2101 [11:45<07:11,  1.84it/s][A
Iteration:  62%|██████▏   | 1309/2101 [11:46<07:12,  1.83it/s][A
Iteration:  62%|██████▏   | 1310/2101 [11:46<07:09,  1.84it/s][A
Iteration:  62%|██████▏   | 1311/2101 [11:47<07:08,  1.84it/s][A
Iteration:  62%|██████▏   | 1312/2101 [11:47<07:09,  1.84it/s][A
Iteration:  62%|██████▏   | 1313/2101 [11:48<07:08,  1.84it/s][A
Iteration

{"loss": 0.15297743409778922, "learning_rate": 5.721085197524988e-06, "epoch": 2.8557829604950022, "step": 6000}



Iteration:  86%|████████▌ | 1798/2101 [16:12<07:00,  1.39s/it][A
Iteration:  86%|████████▌ | 1799/2101 [16:12<05:42,  1.13s/it][A
Iteration:  86%|████████▌ | 1800/2101 [16:13<04:47,  1.05it/s][A
Iteration:  86%|████████▌ | 1801/2101 [16:13<04:10,  1.20it/s][A
Iteration:  86%|████████▌ | 1802/2101 [16:14<03:43,  1.34it/s][A
Iteration:  86%|████████▌ | 1803/2101 [16:15<03:24,  1.46it/s][A
Iteration:  86%|████████▌ | 1804/2101 [16:15<03:10,  1.56it/s][A
Iteration:  86%|████████▌ | 1805/2101 [16:16<03:00,  1.64it/s][A
Iteration:  86%|████████▌ | 1806/2101 [16:16<02:53,  1.70it/s][A
Iteration:  86%|████████▌ | 1807/2101 [16:17<02:49,  1.73it/s][A
Iteration:  86%|████████▌ | 1808/2101 [16:17<02:45,  1.77it/s][A
Iteration:  86%|████████▌ | 1809/2101 [16:18<02:43,  1.79it/s][A
Iteration:  86%|████████▌ | 1810/2101 [16:18<02:40,  1.81it/s][A
Iteration:  86%|████████▌ | 1811/2101 [16:19<02:39,  1.82it/s][A
Iteration:  86%|████████▌ | 1812/2101 [16:19<02:38,  1.83it/s][A
Iteration

{"loss": 0.11536938468395966, "learning_rate": 4.531175630652071e-06, "epoch": 3.0937648738695858, "step": 6500}



Iteration:   9%|▉         | 198/2101 [01:46<18:13,  1.74it/s][A
Iteration:   9%|▉         | 199/2101 [01:47<17:53,  1.77it/s][A
Iteration:  10%|▉         | 200/2101 [01:47<17:41,  1.79it/s][A
Iteration:  10%|▉         | 201/2101 [01:48<17:28,  1.81it/s][A
Iteration:  10%|▉         | 202/2101 [01:49<17:20,  1.82it/s][A
Iteration:  10%|▉         | 203/2101 [01:49<17:17,  1.83it/s][A
Iteration:  10%|▉         | 204/2101 [01:50<17:10,  1.84it/s][A
Iteration:  10%|▉         | 205/2101 [01:50<17:08,  1.84it/s][A
Iteration:  10%|▉         | 206/2101 [01:51<17:06,  1.85it/s][A
Iteration:  10%|▉         | 207/2101 [01:51<17:07,  1.84it/s][A
Iteration:  10%|▉         | 208/2101 [01:52<17:03,  1.85it/s][A
Iteration:  10%|▉         | 209/2101 [01:52<17:01,  1.85it/s][A
Iteration:  10%|▉         | 210/2101 [01:53<17:03,  1.85it/s][A
Iteration:  10%|█         | 211/2101 [01:53<17:00,  1.85it/s][A
Iteration:  10%|█         | 212/2101 [01:54<17:00,  1.85it/s][A
Iteration:  10%|█       

{"loss": 0.09109661727852654, "learning_rate": 3.3412660637791533e-06, "epoch": 3.3317467872441693, "step": 7000}



Iteration:  33%|███▎      | 698/2101 [06:16<13:24,  1.74it/s][A
Iteration:  33%|███▎      | 699/2101 [06:16<13:12,  1.77it/s][A
Iteration:  33%|███▎      | 700/2101 [06:17<13:01,  1.79it/s][A
Iteration:  33%|███▎      | 701/2101 [06:17<12:38,  1.84it/s][A
Iteration:  33%|███▎      | 702/2101 [06:18<12:32,  1.86it/s][A
Iteration:  33%|███▎      | 703/2101 [06:18<12:35,  1.85it/s][A
Iteration:  34%|███▎      | 704/2101 [06:19<12:34,  1.85it/s][A
Iteration:  34%|███▎      | 705/2101 [06:20<12:31,  1.86it/s][A
Iteration:  34%|███▎      | 706/2101 [06:20<12:31,  1.86it/s][A
Iteration:  34%|███▎      | 707/2101 [06:21<12:27,  1.86it/s][A
Iteration:  34%|███▎      | 708/2101 [06:21<12:25,  1.87it/s][A
Iteration:  34%|███▎      | 709/2101 [06:22<12:21,  1.88it/s][A
Iteration:  34%|███▍      | 710/2101 [06:22<12:19,  1.88it/s][A
Iteration:  34%|███▍      | 711/2101 [06:23<12:19,  1.88it/s][A
Iteration:  34%|███▍      | 712/2101 [06:23<12:15,  1.89it/s][A
Iteration:  34%|███▍    

{"loss": 0.09081352867558598, "learning_rate": 2.1513564969062355e-06, "epoch": 3.569728700618753, "step": 7500}



Iteration:  57%|█████▋    | 1198/2101 [10:45<08:41,  1.73it/s][A
Iteration:  57%|█████▋    | 1199/2101 [10:45<08:30,  1.77it/s][A
Iteration:  57%|█████▋    | 1200/2101 [10:46<08:22,  1.79it/s][A
Iteration:  57%|█████▋    | 1201/2101 [10:46<08:19,  1.80it/s][A
Iteration:  57%|█████▋    | 1202/2101 [10:47<08:15,  1.81it/s][A
Iteration:  57%|█████▋    | 1203/2101 [10:47<08:12,  1.82it/s][A
Iteration:  57%|█████▋    | 1204/2101 [10:48<08:09,  1.83it/s][A
Iteration:  57%|█████▋    | 1205/2101 [10:49<08:07,  1.84it/s][A
Iteration:  57%|█████▋    | 1206/2101 [10:49<08:05,  1.85it/s][A
Iteration:  57%|█████▋    | 1207/2101 [10:50<08:05,  1.84it/s][A
Iteration:  57%|█████▋    | 1208/2101 [10:50<08:05,  1.84it/s][A
Iteration:  58%|█████▊    | 1209/2101 [10:51<08:03,  1.84it/s][A
Iteration:  58%|█████▊    | 1210/2101 [10:51<08:02,  1.85it/s][A
Iteration:  58%|█████▊    | 1211/2101 [10:52<08:03,  1.84it/s][A
Iteration:  58%|█████▊    | 1212/2101 [10:52<08:02,  1.84it/s][A
Iteration

{"loss": 0.09069429180584848, "learning_rate": 9.614469300333174e-07, "epoch": 3.8077106139933363, "step": 8000}



Iteration:  81%|████████  | 1698/2101 [15:15<03:53,  1.73it/s][A
Iteration:  81%|████████  | 1699/2101 [15:15<03:48,  1.76it/s][A
Iteration:  81%|████████  | 1700/2101 [15:16<03:44,  1.79it/s][A
Iteration:  81%|████████  | 1701/2101 [15:17<03:41,  1.81it/s][A
Iteration:  81%|████████  | 1702/2101 [15:17<03:38,  1.82it/s][A
Iteration:  81%|████████  | 1703/2101 [15:18<03:38,  1.83it/s][A
Iteration:  81%|████████  | 1704/2101 [15:18<03:36,  1.83it/s][A
Iteration:  81%|████████  | 1705/2101 [15:19<03:35,  1.83it/s][A
Iteration:  81%|████████  | 1706/2101 [15:19<03:35,  1.83it/s][A
Iteration:  81%|████████  | 1707/2101 [15:20<03:34,  1.84it/s][A
Iteration:  81%|████████▏ | 1708/2101 [15:20<03:34,  1.83it/s][A
Iteration:  81%|████████▏ | 1709/2101 [15:21<03:33,  1.84it/s][A
Iteration:  81%|████████▏ | 1710/2101 [15:21<03:32,  1.84it/s][A
Iteration:  81%|████████▏ | 1711/2101 [15:22<03:31,  1.85it/s][A
Iteration:  81%|████████▏ | 1712/2101 [15:23<03:30,  1.85it/s][A
Iteration

TrainOutput(global_step=8404, training_loss=0.21418483871537108)

In [13]:
task_names = ['offensive','hatespeech','hatespeech_classes']

In [32]:
def evaluate(trainer, features_dict, epochs):
    from sklearn.metrics import classification_report
    preds_dict = {}
    for task_name in task_names:
        eval_dataloader = DataLoaderWithTaskname(
            task_name,
            trainer.get_eval_dataloader(eval_dataset=features_dict[task_name]["test"])
        )
        print(eval_dataloader.data_loader.collate_fn)
        preds_dict[task_name] = trainer._prediction_loop(
            eval_dataloader, 
            description=f"Test: {task_name}",
        )
        print(f"Classification Report for Task {task_name} trained for {epochs} epochs")
        print(classification_report(y_pred=np.argmax(preds_dict[task_name].predictions, axis=1),  y_true=preds_dict[task_name].label_ids, digits=16))

In [26]:
evaluate(trainer, features_dict, epochs=4)

Test: offensive:   1%|          | 1/159 [00:00<00:16,  9.79it/s]

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff8edc4ee0>>


Test: offensive: 100%|██████████| 159/159 [00:13<00:00, 11.99it/s]
Test: hatespeech:   1%|▏         | 2/159 [00:00<00:13, 11.60it/s]

Classification Report for Task offensive trained for 4 epochs
              precision    recall  f1-score   support

           0       0.92      0.86      0.89       866
           1       0.73      0.83      0.78       404

    accuracy                           0.85      1270
   macro avg       0.82      0.84      0.83      1270
weighted avg       0.86      0.85      0.85      1270

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff8edc4ee0>>


Test: hatespeech: 100%|██████████| 159/159 [00:13<00:00, 11.98it/s]
Test: hatespeech_classes:   1%|▏         | 2/159 [00:00<00:13, 12.02it/s]

Classification Report for Task hatespeech trained for 4 epochs
              precision    recall  f1-score   support

           0       0.98      0.96      0.97      1161
           1       0.65      0.75      0.70       109

    accuracy                           0.94      1270
   macro avg       0.81      0.86      0.83      1270
weighted avg       0.95      0.94      0.95      1270

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff8edc4ee0>>


Test: hatespeech_classes: 100%|██████████| 159/159 [00:13<00:00, 11.93it/s]

Classification Report for Task hatespeech_classes trained for 4 epochs
              precision    recall  f1-score   support

           0       0.97      0.97      0.97      1161
           1       0.58      0.64      0.61        28
           2       0.00      0.00      0.00         4
           3       0.56      0.64      0.60        14
           4       0.00      0.00      0.00         1
           5       0.20      0.10      0.13        10
           6       0.67      0.77      0.71        52

    accuracy                           0.94      1270
   macro avg       0.43      0.45      0.43      1270
weighted avg       0.94      0.94      0.94      1270




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


#### Lets see for one more epoch

In [27]:
trainer.args.num_train_epochs = 1

In [28]:
trainer.train()

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]
Iteration:   0%|          | 0/2101 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/2101 [00:00<18:22,  1.90it/s][A
Iteration:   0%|          | 2/2101 [00:01<17:57,  1.95it/s][A
Iteration:   0%|          | 3/2101 [00:01<18:14,  1.92it/s][A
Iteration:   0%|          | 4/2101 [00:02<18:25,  1.90it/s][A
Iteration:   0%|          | 5/2101 [00:02<18:30,  1.89it/s][A
Iteration:   0%|          | 6/2101 [00:03<18:38,  1.87it/s][A
Iteration:   0%|          | 7/2101 [00:03<18:41,  1.87it/s][A
Iteration:   0%|          | 8/2101 [00:04<18:40,  1.87it/s][A
Iteration:   0%|          | 9/2101 [00:04<18:41,  1.86it/s][A
Iteration:   0%|          | 10/2101 [00:05<18:43,  1.86it/s][A
Iteration:   1%|          | 11/2101 [00:05<18:44,  1.86it/s][A
Iteration:   1%|          | 12/2101 [00:06<18:45,  1.86it/s][A
Iteration:   1%|          | 13/2101 [00:06<18:45,  1.85it/s][A
Iteration:   1%|          | 14/2101 [00:07<18:46,  1.85it/s][A
Iteration:   

{"loss": 0.12110608770386898, "learning_rate": 1.524036173250833e-05, "epoch": 0.23798191337458352, "step": 500}



Iteration:  24%|██▍       | 501/2101 [04:29<14:27,  1.84it/s][A
Iteration:  24%|██▍       | 502/2101 [04:30<14:25,  1.85it/s][A
Iteration:  24%|██▍       | 503/2101 [04:30<14:23,  1.85it/s][A
Iteration:  24%|██▍       | 504/2101 [04:31<14:22,  1.85it/s][A
Iteration:  24%|██▍       | 505/2101 [04:31<14:21,  1.85it/s][A
Iteration:  24%|██▍       | 506/2101 [04:32<14:23,  1.85it/s][A
Iteration:  24%|██▍       | 507/2101 [04:33<14:24,  1.84it/s][A
Iteration:  24%|██▍       | 508/2101 [04:33<14:23,  1.84it/s][A
Iteration:  24%|██▍       | 509/2101 [04:34<14:21,  1.85it/s][A
Iteration:  24%|██▍       | 510/2101 [04:34<14:23,  1.84it/s][A
Iteration:  24%|██▍       | 511/2101 [04:35<14:22,  1.84it/s][A
Iteration:  24%|██▍       | 512/2101 [04:35<14:22,  1.84it/s][A
Iteration:  24%|██▍       | 513/2101 [04:36<14:21,  1.84it/s][A
Iteration:  24%|██▍       | 514/2101 [04:36<14:20,  1.84it/s][A
Iteration:  25%|██▍       | 515/2101 [04:37<14:20,  1.84it/s][A
Iteration:  25%|██▍     

{"loss": 0.12189705061679706, "learning_rate": 1.048072346501666e-05, "epoch": 0.47596382674916704, "step": 1000}



Iteration:  48%|████▊     | 1001/2101 [08:59<09:58,  1.84it/s][A
Iteration:  48%|████▊     | 1002/2101 [09:00<09:57,  1.84it/s][A
Iteration:  48%|████▊     | 1003/2101 [09:00<09:54,  1.85it/s][A
Iteration:  48%|████▊     | 1004/2101 [09:01<09:53,  1.85it/s][A
Iteration:  48%|████▊     | 1005/2101 [09:01<09:52,  1.85it/s][A
Iteration:  48%|████▊     | 1006/2101 [09:02<09:51,  1.85it/s][A
Iteration:  48%|████▊     | 1007/2101 [09:02<09:50,  1.85it/s][A
Iteration:  48%|████▊     | 1008/2101 [09:03<09:51,  1.85it/s][A
Iteration:  48%|████▊     | 1009/2101 [09:04<09:53,  1.84it/s][A
Iteration:  48%|████▊     | 1010/2101 [09:04<09:51,  1.84it/s][A
Iteration:  48%|████▊     | 1011/2101 [09:05<09:50,  1.85it/s][A
Iteration:  48%|████▊     | 1012/2101 [09:05<09:50,  1.85it/s][A
Iteration:  48%|████▊     | 1013/2101 [09:06<09:49,  1.85it/s][A
Iteration:  48%|████▊     | 1014/2101 [09:06<09:50,  1.84it/s][A
Iteration:  48%|████▊     | 1015/2101 [09:07<09:48,  1.84it/s][A
Iteration

{"loss": 0.09711660955034312, "learning_rate": 5.721085197524988e-06, "epoch": 0.7139457401237506, "step": 1500}



Iteration:  71%|███████▏  | 1501/2101 [13:28<05:24,  1.85it/s][A
Iteration:  71%|███████▏  | 1502/2101 [13:28<05:24,  1.85it/s][A
Iteration:  72%|███████▏  | 1503/2101 [13:29<05:23,  1.85it/s][A
Iteration:  72%|███████▏  | 1504/2101 [13:29<05:23,  1.85it/s][A
Iteration:  72%|███████▏  | 1505/2101 [13:30<05:22,  1.85it/s][A
Iteration:  72%|███████▏  | 1506/2101 [13:30<05:21,  1.85it/s][A
Iteration:  72%|███████▏  | 1507/2101 [13:31<05:21,  1.85it/s][A
Iteration:  72%|███████▏  | 1508/2101 [13:31<05:19,  1.86it/s][A
Iteration:  72%|███████▏  | 1509/2101 [13:32<05:19,  1.85it/s][A
Iteration:  72%|███████▏  | 1510/2101 [13:32<05:19,  1.85it/s][A
Iteration:  72%|███████▏  | 1511/2101 [13:33<05:19,  1.85it/s][A
Iteration:  72%|███████▏  | 1512/2101 [13:34<05:18,  1.85it/s][A
Iteration:  72%|███████▏  | 1513/2101 [13:34<05:18,  1.85it/s][A
Iteration:  72%|███████▏  | 1514/2101 [13:35<05:17,  1.85it/s][A
Iteration:  72%|███████▏  | 1515/2101 [13:35<05:16,  1.85it/s][A
Iteration

{"loss": 0.10913289758379688, "learning_rate": 9.614469300333174e-07, "epoch": 0.9519276534983341, "step": 2000}



Iteration:  95%|█████████▌| 2001/2101 [17:56<00:52,  1.91it/s][A
Iteration:  95%|█████████▌| 2002/2101 [17:57<00:51,  1.90it/s][A
Iteration:  95%|█████████▌| 2003/2101 [17:58<00:51,  1.90it/s][A
Iteration:  95%|█████████▌| 2004/2101 [17:58<00:51,  1.89it/s][A
Iteration:  95%|█████████▌| 2005/2101 [17:59<00:50,  1.89it/s][A
Iteration:  95%|█████████▌| 2006/2101 [17:59<00:50,  1.88it/s][A
Iteration:  96%|█████████▌| 2007/2101 [18:00<00:50,  1.87it/s][A
Iteration:  96%|█████████▌| 2008/2101 [18:00<00:49,  1.88it/s][A
Iteration:  96%|█████████▌| 2009/2101 [18:01<00:48,  1.90it/s][A
Iteration:  96%|█████████▌| 2010/2101 [18:01<00:48,  1.89it/s][A
Iteration:  96%|█████████▌| 2011/2101 [18:02<00:47,  1.88it/s][A
Iteration:  96%|█████████▌| 2012/2101 [18:02<00:47,  1.87it/s][A
Iteration:  96%|█████████▌| 2013/2101 [18:03<00:47,  1.87it/s][A
Iteration:  96%|█████████▌| 2014/2101 [18:03<00:46,  1.86it/s][A
Iteration:  96%|█████████▌| 2015/2101 [18:04<00:46,  1.86it/s][A
Iteration

TrainOutput(global_step=2101, training_loss=0.11203017262371649)

In [33]:
evaluate(trainer, features_dict, epochs=5)

Test: offensive:   1%|          | 1/159 [00:00<00:16,  9.84it/s]

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff8edc4ee0>>


Test: offensive: 100%|██████████| 159/159 [00:13<00:00, 11.94it/s]
Test: hatespeech:   1%|▏         | 2/159 [00:00<00:12, 12.20it/s]

Classification Report for Task offensive trained for 5 epochs
                  precision    recall  f1-score   support

               0  0.9061371841155235 0.8695150115473441 0.8874484384207424       866
               1  0.7425968109339408 0.8069306930693070 0.7734282325029657       404

        accuracy                      0.8496062992125984      1270
       macro avg  0.8243669975247321 0.8382228523083255 0.8304383354618541      1270
    weighted avg  0.8541133173711460 0.8496062992125984 0.8511774437823314      1270

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff8edc4ee0>>


Test: hatespeech: 100%|██████████| 159/159 [00:13<00:00, 12.01it/s]
Test: hatespeech_classes:   1%|▏         | 2/159 [00:00<00:13, 12.04it/s]

Classification Report for Task hatespeech trained for 5 epochs
                  precision    recall  f1-score   support

               0  0.9731136166522116 0.9664082687338501 0.9697493517718236      1161
               1  0.6666666666666666 0.7155963302752294 0.6902654867256638       109

        accuracy                      0.9448818897637795      1270
       macro avg  0.8198901416594391 0.8410022995045398 0.8300074192487437      1270
    weighted avg  0.9468122642518775 0.9448818897637795 0.9457621539056573      1270

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff8edc4ee0>>


Test: hatespeech_classes: 100%|██████████| 159/159 [00:13<00:00, 11.96it/s]

Classification Report for Task hatespeech_classes trained for 5 epochs
                  precision    recall  f1-score   support

               0  0.9699054170249355 0.9715762273901809 0.9707401032702238      1161
               1  0.6666666666666666 0.5714285714285714 0.6153846153846153        28
               2  0.0000000000000000 0.0000000000000000 0.0000000000000000         4
               3  0.5333333333333333 0.5714285714285714 0.5517241379310344        14
               4  0.0000000000000000 0.0000000000000000 0.0000000000000000         1
               5  0.1538461538461539 0.2000000000000000 0.1739130434782609        10
               6  0.6909090909090909 0.7307692307692307 0.7102803738317757        52

        accuracy                      0.9385826771653544      1270
       macro avg  0.4306658088257400 0.4350289430023649 0.4317203248422729      1270
    weighted avg  0.9367395722559195 0.9385826771653544 0.9375258873484790      1270




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


### 2 Epochs

In [18]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8418723545933512}

In [19]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8371149406524729}

In [20]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.3926597611174843}

### 3 Epochs

In [21]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8418723545933512}

In [27]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8325985296056062}

In [28]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.4176057674898591}

### 4 Epochs

In [25]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8469622661091527}

In [27]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8325985296056062}

In [28]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.4176057674898591}

### 5 Epochs

In [43]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8387301587301587}

In [44]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8107379458902904}

In [45]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.4355893433350976}

# Combo

## Preparing Data

In [4]:
task_names = ['offensive', 'hatespeech', 'hatespeech_classes']

In [5]:
dataset_dict = {
    "offensive": load_dataset("csv", data_files={'train': "Data/train/trainA_prepro_combo.csv", 'test': "Data/test/testA_prepro.csv" } ),
    "hatespeech": load_dataset("csv", data_files={'train': "Data/train/trainB_prepro_combo.csv", 'test': "Data/test/testB_prepro.csv" } ),
    "hatespeech_classes": load_dataset("csv", data_files={'train': "Data/train/trainC_prepro.csv", 'test': "Data/test/testC_prepro.csv" } ),
}

Downloading and preparing dataset csv/default to /home/ashapiro/.cache/huggingface/datasets/csv/default-1337ce598b95f03a/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files: 100%|██████████| 2/2 [00:00<00:00, 1432.97it/s]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 69.37it/s]


Dataset csv downloaded and prepared to /home/ashapiro/.cache/huggingface/datasets/csv/default-1337ce598b95f03a/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


100%|██████████| 2/2 [00:00<00:00, 571.86it/s]


Downloading and preparing dataset csv/default to /home/ashapiro/.cache/huggingface/datasets/csv/default-ea8e0edd0e0f031f/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files: 100%|██████████| 2/2 [00:00<00:00, 1398.57it/s]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 72.30it/s]


Dataset csv downloaded and prepared to /home/ashapiro/.cache/huggingface/datasets/csv/default-ea8e0edd0e0f031f/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


100%|██████████| 2/2 [00:00<00:00, 578.56it/s]


Downloading and preparing dataset csv/default to /home/ashapiro/.cache/huggingface/datasets/csv/default-0210a04391128139/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files: 100%|██████████| 2/2 [00:00<00:00, 1934.20it/s]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 77.70it/s]


Dataset csv downloaded and prepared to /home/ashapiro/.cache/huggingface/datasets/csv/default-0210a04391128139/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


100%|██████████| 2/2 [00:00<00:00, 331.76it/s]


## Setting Model

In [6]:
model_name = "/scratch/mt/ashapiro/Hate_Speech/Models/Marbertv2/"
multitask_model = MultitaskModel.create(
    model_name=model_name,
    model_type_dict={
        "offensive": transformers.AutoModelForSequenceClassification,
        "hatespeech": transformers.AutoModelForSequenceClassification,
        "hatespeech_classes": transformers.AutoModelForSequenceClassification,
    },
    model_config_dict={
        "offensive": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        "hatespeech": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        "hatespeech_classes": transformers.AutoConfig.from_pretrained(model_name, num_labels=7),
    },
)

In [7]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

In [8]:
max_length = 512

def convert_to_features(example_batch):
    inputs = list(example_batch['text'])
    features = tokenizer.batch_encode_plus(
        inputs, max_length=max_length, pad_to_max_length=True
    )
    features["labels"] = example_batch["labels"]
    return features

convert_func_dict = {
    "offensive": convert_to_features,
    "hatespeech": convert_to_features,
    "hatespeech_classes": convert_to_features,
}

In [9]:
columns_dict = {
    "offensive": ['input_ids', 'attention_mask', 'labels'],
    "hatespeech": ['input_ids', 'attention_mask', 'labels'],
    "hatespeech_classes": ['input_ids', 'attention_mask', 'labels'],
}

features_dict = {}
for task_name, dataset in dataset_dict.items():
    features_dict[task_name] = {}
    for phase, phase_dataset in dataset.items():
        features_dict[task_name][phase] = phase_dataset.map(
            convert_func_dict[task_name],
            batched=True,
            load_from_cache_file=False,
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))
        features_dict[task_name][phase].set_format(
            type="torch", 
            columns=columns_dict[task_name],
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))

100%|██████████| 42/42 [00:19<00:00,  2.21ba/s]


offensive train 41212 41212
offensive train 41212 41212


100%|██████████| 2/2 [00:00<00:00,  3.99ba/s]


offensive test 1270 1270
offensive test 1270 1270


100%|██████████| 33/33 [00:15<00:00,  2.13ba/s]


hatespeech train 32630 32630
hatespeech train 32630 32630


100%|██████████| 2/2 [00:00<00:00,  3.90ba/s]


hatespeech test 1270 1270
hatespeech test 1270 1270


100%|██████████| 9/9 [00:03<00:00,  2.53ba/s]


hatespeech_classes train 8887 8887
hatespeech_classes train 8887 8887


100%|██████████| 2/2 [00:00<00:00,  3.96ba/s]

hatespeech_classes test 1270 1270
hatespeech_classes test 1270 1270





In [10]:
eval_dataset = {
    task_name: dataset["test"] 
    for task_name, dataset in features_dict.items()
}

In [11]:
train_dataset = {
    task_name: dataset["train"] 
    for task_name, dataset in features_dict.items()
}
args = transformers.TrainingArguments(
        output_dir="./models/multitask_model/3_main_tasks/combo_data/4_epochs/",
        overwrite_output_dir=True,
        learning_rate=2e-5,
        do_train=True,
        num_train_epochs=4,
        # Adjust batch size if this doesn't fit on the Colab GPU
        per_device_train_batch_size=16,  
        save_steps=3000,)

trainer = MultitaskTrainer(
    model=multitask_model,
    args=args,
    data_collator=NLPDataCollator(),
    train_dataset=train_dataset,
)

In [13]:
trainer.train()

Epoch:   0%|          | 0/4 [00:00<?, ?it/s]
Iteration:   0%|          | 0/5172 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/5172 [00:00<48:34,  1.77it/s][A
Iteration:   0%|          | 2/5172 [00:01<47:29,  1.81it/s][A
Iteration:   0%|          | 3/5172 [00:01<46:55,  1.84it/s][A
Iteration:   0%|          | 4/5172 [00:02<46:37,  1.85it/s][A
Iteration:   0%|          | 5/5172 [00:02<46:25,  1.85it/s][A
Iteration:   0%|          | 6/5172 [00:03<53:38,  1.61it/s][A
Epoch:   0%|          | 0/4 [00:03<?, ?it/s]


KeyboardInterrupt: 

In [14]:
task_names = ['offensive','hatespeech','hatespeech_classes']

In [15]:
def evaluate(trainer, features_dict, epochs):
    from sklearn.metrics import classification_report
    preds_dict = {}
    for task_name in task_names:
        eval_dataloader = DataLoaderWithTaskname(
            task_name,
            trainer.get_eval_dataloader(eval_dataset=features_dict[task_name]["test"])
        )
        print(eval_dataloader.data_loader.collate_fn)
        preds_dict[task_name] = trainer._prediction_loop(
            eval_dataloader, 
            description=f"Test: {task_name}",
        )
        print(f"Classification Report for Task {task_name} trained for {epochs} epochs")
        print(classification_report(y_pred=np.argmax(preds_dict[task_name].predictions, axis=1),  y_true=preds_dict[task_name].label_ids, digits=16))

In [16]:
evaluate(trainer, features_dict, epochs=4)

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7ffe9019a350>>


Test: offensive: 100%|██████████| 159/159 [00:13<00:00, 11.92it/s]


Classification Report for Task offensive trained for 4 epochs
                  precision    recall  f1-score   support

               0  0.9414062500000000 0.8348729792147807 0.8849449204406366       866
               1  0.7151394422310757 0.8886138613861386 0.7924944812362031       404

        accuracy                      0.8519685039370078      1270
       macro avg  0.8282728461155379 0.8617434203004597 0.8387197008384198      1270
    weighted avg  0.8694284623317752 0.8519685039370078 0.8555354893866278      1270

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7ffe9019a350>>


Test: hatespeech: 100%|██████████| 159/159 [00:13<00:00, 11.87it/s]


Classification Report for Task hatespeech trained for 4 epochs
                  precision    recall  f1-score   support

               0  0.9756944444444444 0.9681309216192937 0.9718979680069174      1161
               1  0.6864406779661016 0.7431192660550459 0.7136563876651981       109

        accuracy                      0.9488188976377953      1270
       macro avg  0.8310675612052730 0.8556250938371698 0.8427771778360578      1270
    weighted avg  0.9508687274789804 0.9488188976377953 0.9497339268594784      1270

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7ffe9019a350>>


Test: hatespeech_classes: 100%|██████████| 159/159 [00:13<00:00, 11.77it/s]

Classification Report for Task hatespeech_classes trained for 4 epochs
                  precision    recall  f1-score   support

               0  0.9699312714776632 0.9724375538329026 0.9711827956989247      1161
               1  0.6818181818181818 0.5357142857142857 0.6000000000000000        28
               2  0.0000000000000000 0.0000000000000000 0.0000000000000000         4
               3  0.4000000000000000 0.5714285714285714 0.4705882352941176        14
               4  0.0000000000000000 0.0000000000000000 0.0000000000000000         1
               5  0.0833333333333333 0.1000000000000000 0.0909090909090909        10
               6  0.6923076923076923 0.6923076923076923 0.6923076923076923        52

        accuracy                      0.9362204724409449      1270
       macro avg  0.4039129255624102 0.4102697290404932 0.4035696877442608      1270
    weighted avg  0.9351294870943380 0.9362204724409449 0.9353075212674489      1270




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


#### Lets see for one more epoch

In [17]:
trainer.args.num_train_epochs = 1

In [None]:
trainer.train()

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]
Iteration:   0%|          | 0/5172 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/5172 [00:00<52:36,  1.64it/s][A
Iteration:   0%|          | 2/5172 [00:01<49:49,  1.73it/s][A
Iteration:   0%|          | 3/5172 [00:01<48:16,  1.78it/s][A
Iteration:   0%|          | 4/5172 [00:02<47:34,  1.81it/s][A
Iteration:   0%|          | 5/5172 [00:02<47:03,  1.83it/s][A
Iteration:   0%|          | 6/5172 [00:03<46:55,  1.84it/s][A
Iteration:   0%|          | 7/5172 [00:03<46:46,  1.84it/s][A
Iteration:   0%|          | 8/5172 [00:04<46:42,  1.84it/s][A
Iteration:   0%|          | 9/5172 [00:04<46:27,  1.85it/s][A
Iteration:   0%|          | 10/5172 [00:05<46:31,  1.85it/s][A
Iteration:   0%|          | 11/5172 [00:06<46:31,  1.85it/s][A
Iteration:   0%|          | 12/5172 [00:06<46:24,  1.85it/s][A
Iteration:   0%|          | 13/5172 [00:07<46:26,  1.85it/s][A
Iteration:   0%|          | 14/5172 [00:07<46:24,  1.85it/s][A
Iteration:   

{"loss": 0.09300880393770057, "learning_rate": 1.806651198762568e-05, "epoch": 0.09667440061871617, "step": 500}



Iteration:  10%|▉         | 501/5172 [04:31<42:05,  1.85it/s][A
Iteration:  10%|▉         | 502/5172 [04:32<42:00,  1.85it/s][A
Iteration:  10%|▉         | 503/5172 [04:32<42:05,  1.85it/s][A
Iteration:  10%|▉         | 504/5172 [04:33<41:56,  1.86it/s][A
Iteration:  10%|▉         | 505/5172 [04:33<41:51,  1.86it/s][A
Iteration:  10%|▉         | 506/5172 [04:34<41:57,  1.85it/s][A
Iteration:  10%|▉         | 507/5172 [04:34<42:02,  1.85it/s][A
Iteration:  10%|▉         | 508/5172 [04:35<42:09,  1.84it/s][A
Iteration:  10%|▉         | 509/5172 [04:36<42:09,  1.84it/s][A
Iteration:  10%|▉         | 510/5172 [04:36<42:12,  1.84it/s][A
Iteration:  10%|▉         | 511/5172 [04:37<42:15,  1.84it/s][A
Iteration:  10%|▉         | 512/5172 [04:37<42:12,  1.84it/s][A
Iteration:  10%|▉         | 513/5172 [04:38<42:12,  1.84it/s][A
Iteration:  10%|▉         | 514/5172 [04:38<42:06,  1.84it/s][A
Iteration:  10%|▉         | 515/5172 [04:39<41:59,  1.85it/s][A
Iteration:  10%|▉       

In [33]:
evaluate(trainer, features_dict, epochs=5)

Test: offensive:   1%|          | 1/159 [00:00<00:16,  9.84it/s]

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff8edc4ee0>>


Test: offensive: 100%|██████████| 159/159 [00:13<00:00, 11.94it/s]
Test: hatespeech:   1%|▏         | 2/159 [00:00<00:12, 12.20it/s]

Classification Report for Task offensive trained for 5 epochs
                  precision    recall  f1-score   support

               0  0.9061371841155235 0.8695150115473441 0.8874484384207424       866
               1  0.7425968109339408 0.8069306930693070 0.7734282325029657       404

        accuracy                      0.8496062992125984      1270
       macro avg  0.8243669975247321 0.8382228523083255 0.8304383354618541      1270
    weighted avg  0.8541133173711460 0.8496062992125984 0.8511774437823314      1270

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff8edc4ee0>>


Test: hatespeech: 100%|██████████| 159/159 [00:13<00:00, 12.01it/s]
Test: hatespeech_classes:   1%|▏         | 2/159 [00:00<00:13, 12.04it/s]

Classification Report for Task hatespeech trained for 5 epochs
                  precision    recall  f1-score   support

               0  0.9731136166522116 0.9664082687338501 0.9697493517718236      1161
               1  0.6666666666666666 0.7155963302752294 0.6902654867256638       109

        accuracy                      0.9448818897637795      1270
       macro avg  0.8198901416594391 0.8410022995045398 0.8300074192487437      1270
    weighted avg  0.9468122642518775 0.9448818897637795 0.9457621539056573      1270

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff8edc4ee0>>


Test: hatespeech_classes: 100%|██████████| 159/159 [00:13<00:00, 11.96it/s]

Classification Report for Task hatespeech_classes trained for 5 epochs
                  precision    recall  f1-score   support

               0  0.9699054170249355 0.9715762273901809 0.9707401032702238      1161
               1  0.6666666666666666 0.5714285714285714 0.6153846153846153        28
               2  0.0000000000000000 0.0000000000000000 0.0000000000000000         4
               3  0.5333333333333333 0.5714285714285714 0.5517241379310344        14
               4  0.0000000000000000 0.0000000000000000 0.0000000000000000         1
               5  0.1538461538461539 0.2000000000000000 0.1739130434782609        10
               6  0.6909090909090909 0.7307692307692307 0.7102803738317757        52

        accuracy                      0.9385826771653544      1270
       macro avg  0.4306658088257400 0.4350289430023649 0.4317203248422729      1270
    weighted avg  0.9367395722559195 0.9385826771653544 0.9375258873484790      1270




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


### 2 Epochs

In [18]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8418723545933512}

In [19]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8371149406524729}

In [20]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.3926597611174843}

### 3 Epochs

In [21]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8418723545933512}

In [27]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8325985296056062}

In [28]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.4176057674898591}

### 4 Epochs

In [25]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8469622661091527}

In [27]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8325985296056062}

In [28]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.4176057674898591}

### 5 Epochs

In [43]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8387301587301587}

In [44]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8107379458902904}

In [45]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.4355893433350976}

## Loading

In [46]:
def load_model(dot_bin_file):
    multitask_model = MultitaskModel.create(
                                            model_name=model_name,
                                            model_type_dict={
                                                "offensive": transformers.AutoModelForSequenceClassification,
                                                "hatespeech": transformers.AutoModelForSequenceClassification,
                                                "hatespeech_classes": transformers.AutoModelForSequenceClassification,
                                            },
                                            model_config_dict={
                                                "offensive": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
                                                "hatespeech": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
                                                "hatespeech_classes": transformers.AutoConfig.from_pretrained(model_name, num_labels=7),
                                            },)
    model = torch.load(dot_bin_file)
    multitask_model.load_state_dict(model)
    return multitask_model

# Without Last Task

In [1]:
%cd /scratch/mt/ashapiro/Hate_Speech/Multitask_trial/

/scratch/mt/ashapiro/Hate_Speech/Multitask_trial


In [2]:
import numpy as np
import torch
import torch.nn as nn
import transformers
import nlp
import logging
from datasets import load_dataset
from model import * 
logging.basicConfig(level=logging.INFO)

## Preparing Data

In [3]:
task_names = ['offensive', 'hatespeech', 'hatespeech_classes']

In [4]:
dataset_dict = {
    "offensive": load_dataset("csv", data_files={'train': "Data/trainA_prepro.csv", 'test': "Data/testA_prepro.csv" } ),
    "hatespeech": load_dataset("csv", data_files={'train': "Data/trainB_prepro.csv", 'test': "Data/testB_prepro.csv" } ),
    "hatespeech_classes": load_dataset("csv", data_files={'train': "Data/trainC_prepro.csv", 'test': "Data/testC_prepro.csv" } ),
}

100%|██████████| 2/2 [00:00<00:00, 193.53it/s]
100%|██████████| 2/2 [00:00<00:00, 208.36it/s]
100%|██████████| 2/2 [00:00<00:00, 206.70it/s]


## Setting Model

In [5]:
model_name = "/scratch/mt/ashapiro/Hate_Speech/Models/Marbertv2/"
multitask_model = MultitaskModel.create(
    model_name=model_name,
    model_type_dict={
        "offensive": transformers.AutoModelForSequenceClassification,
        "hatespeech": transformers.AutoModelForSequenceClassification,
        "hatespeech_classes": transformers.AutoModelForSequenceClassification,
    },
    model_config_dict={
        "offensive": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        "hatespeech": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        "hatespeech_classes": transformers.AutoConfig.from_pretrained(model_name, num_labels=7),
    },
)

In [6]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

In [7]:
max_length = 512

def convert_to_features(example_batch):
    inputs = list(example_batch['text'])
    features = tokenizer.batch_encode_plus(
        inputs, max_length=max_length, pad_to_max_length=True
    )
    features["labels"] = example_batch["labels"]
    return features

convert_func_dict = {
    "offensive": convert_to_features,
    "hatespeech": convert_to_features,
    "hatespeech_classes": convert_to_features,
}

In [8]:
columns_dict = {
    "offensive": ['input_ids', 'attention_mask', 'labels'],
    "hatespeech": ['input_ids', 'attention_mask', 'labels'],
    "hatespeech_classes": ['input_ids', 'attention_mask', 'labels'],
}

features_dict = {}
for task_name, dataset in dataset_dict.items():
    features_dict[task_name] = {}
    for phase, phase_dataset in dataset.items():
        features_dict[task_name][phase] = phase_dataset.map(
            convert_func_dict[task_name],
            batched=True,
            load_from_cache_file=False,
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))
        features_dict[task_name][phase].set_format(
            type="torch", 
            columns=columns_dict[task_name],
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))

100%|██████████| 9/9 [00:03<00:00,  2.52ba/s]


offensive train 8887 8887
offensive train 8887 8887


100%|██████████| 2/2 [00:00<00:00,  3.92ba/s]


offensive test 1270 1270
offensive test 1270 1270


100%|██████████| 9/9 [00:03<00:00,  2.35ba/s]


hatespeech train 8887 8887
hatespeech train 8887 8887


100%|██████████| 2/2 [00:00<00:00,  3.93ba/s]


hatespeech test 1270 1270
hatespeech test 1270 1270


100%|██████████| 9/9 [00:03<00:00,  2.35ba/s]


hatespeech_classes train 8887 8887
hatespeech_classes train 8887 8887


100%|██████████| 2/2 [00:00<00:00,  3.87ba/s]

hatespeech_classes test 1270 1270
hatespeech_classes test 1270 1270





In [9]:
eval_dataset = {
    task_name: dataset["test"] 
    for task_name, dataset in features_dict.items()
}

In [10]:
train_dataset = {
    task_name: dataset["train"] 
    for task_name, dataset in features_dict.items()
}
args = transformers.TrainingArguments(
        output_dir="./models/multitask_model/7",
        overwrite_output_dir=True,
        learning_rate=2e-5,
        do_train=True,
        num_train_epochs=2,
        # Adjust batch size if this doesn't fit on the Colab GPU
        per_device_train_batch_size=16,  
        save_steps=3000,)

trainer = MultitaskTrainer(
    model=multitask_model,
    args=args,
    data_collator=NLPDataCollator(),
    train_dataset=train_dataset,
)

[34m[1mwandb[0m: Currently logged in as: [33mahmadshapiro[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.11 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [11]:
trainer.train()

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]
Iteration:   0%|          | 0/1668 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/1668 [00:00<17:53,  1.55it/s][A
Iteration:   0%|          | 2/1668 [00:01<16:14,  1.71it/s][A
Iteration:   0%|          | 3/1668 [00:01<15:44,  1.76it/s][A
Iteration:   0%|          | 4/1668 [00:02<15:26,  1.80it/s][A
Iteration:   0%|          | 5/1668 [00:02<15:19,  1.81it/s][A
Iteration:   0%|          | 6/1668 [00:03<15:17,  1.81it/s][A
Iteration:   0%|          | 7/1668 [00:03<15:07,  1.83it/s][A
Iteration:   0%|          | 8/1668 [00:04<15:13,  1.82it/s][A
Iteration:   1%|          | 9/1668 [00:05<15:13,  1.82it/s][A
Iteration:   1%|          | 10/1668 [00:05<15:09,  1.82it/s][A
Iteration:   1%|          | 11/1668 [00:06<15:04,  1.83it/s][A
Iteration:   1%|          | 12/1668 [00:06<15:06,  1.83it/s][A
Iteration:   1%|          | 13/1668 [00:07<15:04,  1.83it/s][A
Iteration:   1%|          | 14/1668 [00:07<15:02,  1.83it/s][A
Iteration:   

{"loss": 0.41617555819451807, "learning_rate": 1.7002398081534774e-05, "epoch": 0.2997601918465228, "step": 500}



Iteration:  30%|███       | 501/1668 [04:32<11:34,  1.68it/s][A
Iteration:  30%|███       | 502/1668 [04:32<11:18,  1.72it/s][A
Iteration:  30%|███       | 503/1668 [04:33<11:07,  1.74it/s][A
Iteration:  30%|███       | 504/1668 [04:33<10:59,  1.77it/s][A
Iteration:  30%|███       | 505/1668 [04:34<10:52,  1.78it/s][A
Iteration:  30%|███       | 506/1668 [04:34<10:47,  1.80it/s][A
Iteration:  30%|███       | 507/1668 [04:35<10:45,  1.80it/s][A
Iteration:  30%|███       | 508/1668 [04:36<10:42,  1.80it/s][A
Iteration:  31%|███       | 509/1668 [04:36<10:39,  1.81it/s][A
Iteration:  31%|███       | 510/1668 [04:37<10:37,  1.82it/s][A
Iteration:  31%|███       | 511/1668 [04:37<10:35,  1.82it/s][A
Iteration:  31%|███       | 512/1668 [04:38<10:33,  1.82it/s][A
Iteration:  31%|███       | 513/1668 [04:38<10:33,  1.82it/s][A
Iteration:  31%|███       | 514/1668 [04:39<10:33,  1.82it/s][A
Iteration:  31%|███       | 515/1668 [04:39<10:32,  1.82it/s][A
Iteration:  31%|███     

{"loss": 0.2891820684103295, "learning_rate": 1.4004796163069546e-05, "epoch": 0.5995203836930456, "step": 1000}



Iteration:  60%|██████    | 1001/1668 [09:03<06:25,  1.73it/s][A
Iteration:  60%|██████    | 1002/1668 [09:04<06:18,  1.76it/s][A
Iteration:  60%|██████    | 1003/1668 [09:04<06:13,  1.78it/s][A
Iteration:  60%|██████    | 1004/1668 [09:05<06:10,  1.79it/s][A
Iteration:  60%|██████    | 1005/1668 [09:05<06:07,  1.80it/s][A
Iteration:  60%|██████    | 1006/1668 [09:06<05:59,  1.84it/s][A
Iteration:  60%|██████    | 1007/1668 [09:06<05:56,  1.85it/s][A
Iteration:  60%|██████    | 1008/1668 [09:07<05:56,  1.85it/s][A
Iteration:  60%|██████    | 1009/1668 [09:07<05:58,  1.84it/s][A
Iteration:  61%|██████    | 1010/1668 [09:08<05:58,  1.84it/s][A
Iteration:  61%|██████    | 1011/1668 [09:08<05:58,  1.83it/s][A
Iteration:  61%|██████    | 1012/1668 [09:09<05:57,  1.83it/s][A
Iteration:  61%|██████    | 1013/1668 [09:09<05:56,  1.84it/s][A
Iteration:  61%|██████    | 1014/1668 [09:10<05:56,  1.84it/s][A
Iteration:  61%|██████    | 1015/1668 [09:11<05:55,  1.84it/s][A
Iteration

{"loss": 0.22568737766169944, "learning_rate": 1.1007194244604318e-05, "epoch": 0.8992805755395683, "step": 1500}



Iteration:  90%|████████▉ | 1501/1668 [13:34<01:36,  1.73it/s][A
Iteration:  90%|█████████ | 1502/1668 [13:35<01:34,  1.76it/s][A
Iteration:  90%|█████████ | 1503/1668 [13:35<01:32,  1.78it/s][A
Iteration:  90%|█████████ | 1504/1668 [13:36<01:31,  1.79it/s][A
Iteration:  90%|█████████ | 1505/1668 [13:36<01:30,  1.80it/s][A
Iteration:  90%|█████████ | 1506/1668 [13:37<01:29,  1.81it/s][A
Iteration:  90%|█████████ | 1507/1668 [13:37<01:28,  1.81it/s][A
Iteration:  90%|█████████ | 1508/1668 [13:38<01:28,  1.80it/s][A
Iteration:  90%|█████████ | 1509/1668 [13:38<01:27,  1.83it/s][A
Iteration:  91%|█████████ | 1510/1668 [13:39<01:25,  1.84it/s][A
Iteration:  91%|█████████ | 1511/1668 [13:39<01:25,  1.83it/s][A
Iteration:  91%|█████████ | 1512/1668 [13:40<01:25,  1.83it/s][A
Iteration:  91%|█████████ | 1513/1668 [13:41<01:23,  1.87it/s][A
Iteration:  91%|█████████ | 1514/1668 [13:41<01:22,  1.88it/s][A
Iteration:  91%|█████████ | 1515/1668 [13:42<01:21,  1.87it/s][A
Iteration

{"loss": 0.16795708821783772, "learning_rate": 8.00959232613909e-06, "epoch": 1.1990407673860912, "step": 2000}



Iteration:  20%|█▉        | 333/1668 [03:01<13:12,  1.68it/s][A
Iteration:  20%|██        | 334/1668 [03:01<12:51,  1.73it/s][A
Iteration:  20%|██        | 335/1668 [03:02<12:37,  1.76it/s][A
Iteration:  20%|██        | 336/1668 [03:02<12:27,  1.78it/s][A
Iteration:  20%|██        | 337/1668 [03:03<12:19,  1.80it/s][A
Iteration:  20%|██        | 338/1668 [03:04<12:10,  1.82it/s][A
Iteration:  20%|██        | 339/1668 [03:04<12:04,  1.83it/s][A
Iteration:  20%|██        | 340/1668 [03:05<12:05,  1.83it/s][A
Iteration:  20%|██        | 341/1668 [03:05<12:06,  1.83it/s][A
Iteration:  21%|██        | 342/1668 [03:06<12:06,  1.83it/s][A
Iteration:  21%|██        | 343/1668 [03:06<12:05,  1.83it/s][A
Iteration:  21%|██        | 344/1668 [03:07<12:03,  1.83it/s][A
Iteration:  21%|██        | 345/1668 [03:07<12:04,  1.83it/s][A
Iteration:  21%|██        | 346/1668 [03:08<12:04,  1.83it/s][A
Iteration:  21%|██        | 347/1668 [03:08<12:02,  1.83it/s][A
Iteration:  21%|██      

{"loss": 0.13271094839542638, "learning_rate": 5.011990407673861e-06, "epoch": 1.498800959232614, "step": 2500}



Iteration:  50%|████▉     | 833/1668 [07:34<07:51,  1.77it/s][A
Iteration:  50%|█████     | 834/1668 [07:34<07:45,  1.79it/s][A
Iteration:  50%|█████     | 835/1668 [07:35<07:41,  1.81it/s][A
Iteration:  50%|█████     | 836/1668 [07:36<07:37,  1.82it/s][A
Iteration:  50%|█████     | 837/1668 [07:36<07:35,  1.83it/s][A
Iteration:  50%|█████     | 838/1668 [07:37<07:33,  1.83it/s][A
Iteration:  50%|█████     | 839/1668 [07:37<07:32,  1.83it/s][A
Iteration:  50%|█████     | 840/1668 [07:38<07:32,  1.83it/s][A
Iteration:  50%|█████     | 841/1668 [07:38<07:30,  1.84it/s][A
Iteration:  50%|█████     | 842/1668 [07:39<07:30,  1.83it/s][A
Iteration:  51%|█████     | 843/1668 [07:39<07:29,  1.84it/s][A
Iteration:  51%|█████     | 844/1668 [07:40<07:27,  1.84it/s][A
Iteration:  51%|█████     | 845/1668 [07:40<07:27,  1.84it/s][A
Iteration:  51%|█████     | 846/1668 [07:41<07:27,  1.84it/s][A
Iteration:  51%|█████     | 847/1668 [07:42<07:27,  1.84it/s][A
Iteration:  51%|█████   

{"loss": 0.11166536170669134, "learning_rate": 2.0143884892086333e-06, "epoch": 1.7985611510791366, "step": 3000}



Iteration:  80%|███████▉  | 1332/1668 [12:08<08:28,  1.51s/it][A
Iteration:  80%|███████▉  | 1333/1668 [12:09<06:50,  1.22s/it][A
Iteration:  80%|███████▉  | 1334/1668 [12:09<05:40,  1.02s/it][A
Iteration:  80%|████████  | 1335/1668 [12:10<04:51,  1.14it/s][A
Iteration:  80%|████████  | 1336/1668 [12:10<04:17,  1.29it/s][A
Iteration:  80%|████████  | 1337/1668 [12:11<03:53,  1.42it/s][A
Iteration:  80%|████████  | 1338/1668 [12:11<03:36,  1.52it/s][A
Iteration:  80%|████████  | 1339/1668 [12:12<03:25,  1.60it/s][A
Iteration:  80%|████████  | 1340/1668 [12:12<03:16,  1.67it/s][A
Iteration:  80%|████████  | 1341/1668 [12:13<03:10,  1.71it/s][A
Iteration:  80%|████████  | 1342/1668 [12:13<03:06,  1.75it/s][A
Iteration:  81%|████████  | 1343/1668 [12:14<03:03,  1.77it/s][A
Iteration:  81%|████████  | 1344/1668 [12:14<03:00,  1.79it/s][A
Iteration:  81%|████████  | 1345/1668 [12:15<02:58,  1.81it/s][A
Iteration:  81%|████████  | 1346/1668 [12:16<02:57,  1.82it/s][A
Iteration

TrainOutput(global_step=3336, training_loss=0.21281198457132977)

In [12]:
task_names = ['offensive','hatespeech','hatespeech_classes']

In [13]:
import datasets

In [14]:
f1 = datasets.load_metric('f1')

In [15]:
recall = datasets.load_metric('recall')

In [16]:
precision = datasets.load_metric('precision')

In [17]:
preds_dict = {}
for task_name in task_names:
    eval_dataloader = DataLoaderWithTaskname(
        task_name,
        trainer.get_eval_dataloader(eval_dataset=features_dict[task_name]["test"])
    )
    print(eval_dataloader.data_loader.collate_fn)
    preds_dict[task_name] = trainer._prediction_loop(
        eval_dataloader, 
        description=f"Test: {task_name}",
    )

Test: offensive:   1%|▏         | 2/159 [00:00<00:15, 10.40it/s]

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7ffef9fbb070>>


Test: offensive: 100%|██████████| 159/159 [00:13<00:00, 11.91it/s]
Test: hatespeech:   1%|▏         | 2/159 [00:00<00:13, 11.97it/s]

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7ffef9fbb070>>


Test: hatespeech: 100%|██████████| 159/159 [00:13<00:00, 11.88it/s]
Test: hatespeech_classes:   1%|▏         | 2/159 [00:00<00:13, 11.92it/s]

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7ffef9fbb070>>


Test: hatespeech_classes: 100%|██████████| 159/159 [00:13<00:00, 11.83it/s]


### 2 Epochs

In [18]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8418723545933512}

In [19]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8371149406524729}

In [20]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.3926597611174843}

### 3 Epochs

In [21]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8418723545933512}

In [27]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8325985296056062}

In [28]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.4176057674898591}

### 4 Epochs

In [25]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8469622661091527}

In [27]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8325985296056062}

In [28]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.4176057674898591}

### 5 Epochs

In [43]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8387301587301587}

In [44]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8107379458902904}

In [45]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.4355893433350976}

## Loading

In [46]:
def load_model(dot_bin_file):
    multitask_model = MultitaskModel.create(
                                            model_name=model_name,
                                            model_type_dict={
                                                "offensive": transformers.AutoModelForSequenceClassification,
                                                "hatespeech": transformers.AutoModelForSequenceClassification,
                                                "hatespeech_classes": transformers.AutoModelForSequenceClassification,
                                            },
                                            model_config_dict={
                                                "offensive": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
                                                "hatespeech": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
                                                "hatespeech_classes": transformers.AutoConfig.from_pretrained(model_name, num_labels=7),
                                            },)
    model = torch.load(dot_bin_file)
    multitask_model.load_state_dict(model)
    return multitask_model