In [4]:
%cd /scratch/mt/ashapiro/Hate_Speech/Multitask_trial/

/scratch/mt/ashapiro/Hate_Speech/Multitask_trial


In [8]:
%ls Data/multitask_data/

train_48.csv  train_lhsab.csv  train_thsab.csv


In [5]:
import numpy as np
import torch
import torch.nn as nn
import transformers
import nlp
import logging
from datasets import load_dataset
from model import * 
logging.basicConfig(level=logging.INFO)

# All

## Preparing Data

In [19]:
task_names = ['offensive', 'hatespeech', 'forty_eight','lhsab', 'thsab']

In [9]:
dataset_dict = {
    "offensive": load_dataset("csv", data_files={'train': "Data/train/trainA_prepro_large.csv", 'test': "Data/test/testA_prepro.csv" } ),
    "hatespeech": load_dataset("csv", data_files={'train': "Data/train/trainB_prepro_large.csv", 'test': "Data/test/testB_prepro.csv" } ),
    "forty_eight": load_dataset("csv", data_files={'train': "Data/multitask_data/train_48.csv" } ),
    "lhsab": load_dataset("csv", data_files={'train': "Data/multitask_data/train_lhsab.csv"} ),
    "thsab": load_dataset("csv", data_files={'train': "Data/multitask_data/train_thsab.csv" } )    
}

Downloading and preparing dataset csv/default to /home/ashapiro/.cache/huggingface/datasets/csv/default-2328e34458fcd88e/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files: 100%|██████████| 2/2 [00:00<00:00, 1411.75it/s]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 74.86it/s]


Dataset csv downloaded and prepared to /home/ashapiro/.cache/huggingface/datasets/csv/default-2328e34458fcd88e/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


100%|██████████| 2/2 [00:00<00:00, 278.70it/s]


Downloading and preparing dataset csv/default to /home/ashapiro/.cache/huggingface/datasets/csv/default-6be373aac13c1a94/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files: 100%|██████████| 2/2 [00:00<00:00, 1344.33it/s]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 84.28it/s]


Dataset csv downloaded and prepared to /home/ashapiro/.cache/huggingface/datasets/csv/default-6be373aac13c1a94/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


100%|██████████| 2/2 [00:00<00:00, 394.20it/s]


Downloading and preparing dataset csv/default to /home/ashapiro/.cache/huggingface/datasets/csv/default-8ae56efe40ca880e/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 1008.97it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 83.84it/s]


Dataset csv downloaded and prepared to /home/ashapiro/.cache/huggingface/datasets/csv/default-8ae56efe40ca880e/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 278.64it/s]


Downloading and preparing dataset csv/default to /home/ashapiro/.cache/huggingface/datasets/csv/default-71d561b6517dac88/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 1059.97it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 83.48it/s]


Dataset csv downloaded and prepared to /home/ashapiro/.cache/huggingface/datasets/csv/default-71d561b6517dac88/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 408.28it/s]


Downloading and preparing dataset csv/default to /home/ashapiro/.cache/huggingface/datasets/csv/default-b9870a1789795d7b/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 1054.91it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 73.91it/s]


Dataset csv downloaded and prepared to /home/ashapiro/.cache/huggingface/datasets/csv/default-b9870a1789795d7b/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 318.62it/s]


## Setting Model

In [11]:
model_name = "/scratch/mt/ashapiro/Hate_Speech/Models/Marbertv2/"
multitask_model = MultitaskModel.create(
    model_name=model_name,
    model_type_dict={
        "offensive": transformers.AutoModelForSequenceClassification,
        "hatespeech": transformers.AutoModelForSequenceClassification,
        "forty_eight": transformers.AutoModelForSequenceClassification,
        "lhsab": transformers.AutoModelForSequenceClassification,
        "thsab": transformers.AutoModelForSequenceClassification
        
    },
    model_config_dict={
        "offensive": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        "hatespeech": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        "forty_eight": transformers.AutoConfig.from_pretrained(model_name, num_labels=3),
        "lhsab": transformers.AutoConfig.from_pretrained(model_name, num_labels=3),
        "thsab": transformers.AutoConfig.from_pretrained(model_name, num_labels=3)
    },
)

In [12]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

In [13]:
max_length = 512

def convert_to_features(example_batch):
    inputs = list(example_batch['text'])
    features = tokenizer.batch_encode_plus(
        inputs, max_length=max_length, pad_to_max_length=True
    )
    features["labels"] = example_batch["labels"]
    return features

convert_func_dict = {
    "offensive": convert_to_features,
    "hatespeech": convert_to_features,
    "forty_eight": convert_to_features,
    "lhsab": convert_to_features,
    "thsab": convert_to_features,
}

In [14]:
columns_dict = {
    "offensive": ['input_ids', 'attention_mask', 'labels'],
    "hatespeech": ['input_ids', 'attention_mask', 'labels'],
    "forty_eight": ['input_ids', 'attention_mask', 'labels'],
    "lhsab": ['input_ids', 'attention_mask', 'labels'],
    "thsab": ['input_ids', 'attention_mask', 'labels'],
}

features_dict = {}
for task_name, dataset in dataset_dict.items():
    features_dict[task_name] = {}
    for phase, phase_dataset in dataset.items():
        features_dict[task_name][phase] = phase_dataset.map(
            convert_func_dict[task_name],
            batched=True,
            load_from_cache_file=False,
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))
        features_dict[task_name][phase].set_format(
            type="torch", 
            columns=columns_dict[task_name],
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))

100%|██████████| 20/20 [00:08<00:00,  2.47ba/s]


offensive train 19906 19906
offensive train 19906 19906


100%|██████████| 2/2 [00:00<00:00,  4.01ba/s]


offensive test 1270 1270
offensive test 1270 1270


100%|██████████| 5/5 [00:02<00:00,  2.31ba/s]


hatespeech train 4800 4800
hatespeech train 4800 4800


100%|██████████| 2/2 [00:00<00:00,  4.22ba/s]


hatespeech test 1270 1270
hatespeech test 1270 1270


100%|██████████| 2/2 [00:00<00:00,  3.74ba/s]


forty_eight train 1359 1359
forty_eight train 1359 1359


100%|██████████| 11/11 [00:04<00:00,  2.71ba/s]


lhsab train 10950 10950
lhsab train 10950 10950


100%|██████████| 12/12 [00:04<00:00,  2.83ba/s]


thsab train 11460 11460
thsab train 11460 11460


In [15]:
train_dataset = {
    task_name: dataset["train"] 
    for task_name, dataset in features_dict.items()
}
args = transformers.TrainingArguments(
        output_dir="./models/multitask_model/5_main_tasks/4_epochs/",
        overwrite_output_dir=True,
        learning_rate=2e-5,
        do_train=True,
        num_train_epochs=4,
        # Adjust batch size if this doesn't fit on the Colab GPU
        per_device_train_batch_size=16,  
        save_steps=3000,)

trainer = MultitaskTrainer(
    model=multitask_model,
    args=args,
    data_collator=NLPDataCollator(),
    train_dataset=train_dataset,
)

[34m[1mwandb[0m: Currently logged in as: [33mahmadshapiro[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.11 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [16]:
trainer.train()

Epoch:   0%|          | 0/4 [00:00<?, ?it/s]
Iteration:   0%|          | 0/3032 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/3032 [00:00<31:54,  1.58it/s][A
Iteration:   0%|          | 2/3032 [00:01<29:32,  1.71it/s][A
Iteration:   0%|          | 3/3032 [00:01<28:25,  1.78it/s][A
Iteration:   0%|          | 4/3032 [00:02<27:58,  1.80it/s][A
Iteration:   0%|          | 5/3032 [00:02<27:41,  1.82it/s][A
Iteration:   0%|          | 6/3032 [00:03<27:19,  1.85it/s][A
Iteration:   0%|          | 7/3032 [00:03<27:14,  1.85it/s][A
Iteration:   0%|          | 8/3032 [00:04<27:10,  1.85it/s][A
Iteration:   0%|          | 9/3032 [00:04<27:08,  1.86it/s][A
Iteration:   0%|          | 10/3032 [00:05<27:10,  1.85it/s][A
Iteration:   0%|          | 11/3032 [00:06<27:10,  1.85it/s][A
Iteration:   0%|          | 12/3032 [00:06<27:14,  1.85it/s][A
Iteration:   0%|          | 13/3032 [00:07<27:14,  1.85it/s][A
Iteration:   0%|          | 14/3032 [00:07<27:16,  1.84it/s][A
Iteration:   

{"loss": 0.6605726473927498, "learning_rate": 1.9175461741424804e-05, "epoch": 0.16490765171503957, "step": 500}



Iteration:  17%|█▋        | 501/3032 [04:31<24:45,  1.70it/s][A
Iteration:  17%|█▋        | 502/3032 [04:31<24:09,  1.75it/s][A
Iteration:  17%|█▋        | 503/3032 [04:32<23:47,  1.77it/s][A
Iteration:  17%|█▋        | 504/3032 [04:32<23:29,  1.79it/s][A
Iteration:  17%|█▋        | 505/3032 [04:33<23:16,  1.81it/s][A
Iteration:  17%|█▋        | 506/3032 [04:34<23:09,  1.82it/s][A
Iteration:  17%|█▋        | 507/3032 [04:34<23:05,  1.82it/s][A
Iteration:  17%|█▋        | 508/3032 [04:35<23:00,  1.83it/s][A
Iteration:  17%|█▋        | 509/3032 [04:35<22:55,  1.83it/s][A
Iteration:  17%|█▋        | 510/3032 [04:36<22:52,  1.84it/s][A
Iteration:  17%|█▋        | 511/3032 [04:36<22:53,  1.84it/s][A
Iteration:  17%|█▋        | 512/3032 [04:37<22:53,  1.83it/s][A
Iteration:  17%|█▋        | 513/3032 [04:37<22:51,  1.84it/s][A
Iteration:  17%|█▋        | 514/3032 [04:38<22:46,  1.84it/s][A
Iteration:  17%|█▋        | 515/3032 [04:38<22:44,  1.84it/s][A
Iteration:  17%|█▋      

{"loss": 0.5118207466900349, "learning_rate": 1.8350923482849604e-05, "epoch": 0.32981530343007914, "step": 1000}



Iteration:  33%|███▎      | 1001/3032 [09:02<19:24,  1.74it/s][A
Iteration:  33%|███▎      | 1002/3032 [09:02<19:07,  1.77it/s][A
Iteration:  33%|███▎      | 1003/3032 [09:03<18:50,  1.80it/s][A
Iteration:  33%|███▎      | 1004/3032 [09:03<18:39,  1.81it/s][A
Iteration:  33%|███▎      | 1005/3032 [09:04<18:32,  1.82it/s][A
Iteration:  33%|███▎      | 1006/3032 [09:04<18:27,  1.83it/s][A
Iteration:  33%|███▎      | 1007/3032 [09:05<18:24,  1.83it/s][A
Iteration:  33%|███▎      | 1008/3032 [09:05<18:21,  1.84it/s][A
Iteration:  33%|███▎      | 1009/3032 [09:06<18:19,  1.84it/s][A
Iteration:  33%|███▎      | 1010/3032 [09:07<18:18,  1.84it/s][A
Iteration:  33%|███▎      | 1011/3032 [09:07<18:17,  1.84it/s][A
Iteration:  33%|███▎      | 1012/3032 [09:08<18:16,  1.84it/s][A
Iteration:  33%|███▎      | 1013/3032 [09:08<18:14,  1.85it/s][A
Iteration:  33%|███▎      | 1014/3032 [09:09<18:13,  1.85it/s][A
Iteration:  33%|███▎      | 1015/3032 [09:09<18:14,  1.84it/s][A
Iteration

{"loss": 0.43515530635416505, "learning_rate": 1.7526385224274407e-05, "epoch": 0.4947229551451187, "step": 1500}



Iteration:  50%|████▉     | 1501/3032 [13:33<14:35,  1.75it/s][A
Iteration:  50%|████▉     | 1502/3032 [13:33<14:19,  1.78it/s][A
Iteration:  50%|████▉     | 1503/3032 [13:34<14:09,  1.80it/s][A
Iteration:  50%|████▉     | 1504/3032 [13:34<14:03,  1.81it/s][A
Iteration:  50%|████▉     | 1505/3032 [13:35<13:58,  1.82it/s][A
Iteration:  50%|████▉     | 1506/3032 [13:35<13:54,  1.83it/s][A
Iteration:  50%|████▉     | 1507/3032 [13:36<13:51,  1.83it/s][A
Iteration:  50%|████▉     | 1508/3032 [13:36<13:49,  1.84it/s][A
Iteration:  50%|████▉     | 1509/3032 [13:37<13:48,  1.84it/s][A
Iteration:  50%|████▉     | 1510/3032 [13:38<13:47,  1.84it/s][A
Iteration:  50%|████▉     | 1511/3032 [13:38<13:46,  1.84it/s][A
Iteration:  50%|████▉     | 1512/3032 [13:39<13:44,  1.84it/s][A
Iteration:  50%|████▉     | 1513/3032 [13:39<13:43,  1.85it/s][A
Iteration:  50%|████▉     | 1514/3032 [13:40<13:43,  1.84it/s][A
Iteration:  50%|████▉     | 1515/3032 [13:40<13:42,  1.84it/s][A
Iteration

{"loss": 0.398882093295455, "learning_rate": 1.670184696569921e-05, "epoch": 0.6596306068601583, "step": 2000}



Iteration:  66%|██████▌   | 2001/3032 [18:04<09:57,  1.73it/s][A
Iteration:  66%|██████▌   | 2002/3032 [18:04<09:44,  1.76it/s][A
Iteration:  66%|██████▌   | 2003/3032 [18:05<09:36,  1.79it/s][A
Iteration:  66%|██████▌   | 2004/3032 [18:05<09:29,  1.81it/s][A
Iteration:  66%|██████▌   | 2005/3032 [18:06<09:24,  1.82it/s][A
Iteration:  66%|██████▌   | 2006/3032 [18:06<09:21,  1.83it/s][A
Iteration:  66%|██████▌   | 2007/3032 [18:07<09:17,  1.84it/s][A
Iteration:  66%|██████▌   | 2008/3032 [18:07<09:15,  1.84it/s][A
Iteration:  66%|██████▋   | 2009/3032 [18:08<09:14,  1.85it/s][A
Iteration:  66%|██████▋   | 2010/3032 [18:08<09:13,  1.85it/s][A
Iteration:  66%|██████▋   | 2011/3032 [18:09<09:13,  1.84it/s][A
Iteration:  66%|██████▋   | 2012/3032 [18:10<09:12,  1.85it/s][A
Iteration:  66%|██████▋   | 2013/3032 [18:10<09:11,  1.85it/s][A
Iteration:  66%|██████▋   | 2014/3032 [18:11<09:11,  1.85it/s][A
Iteration:  66%|██████▋   | 2015/3032 [18:11<09:11,  1.85it/s][A
Iteration

{"loss": 0.36344905727356674, "learning_rate": 1.5877308707124012e-05, "epoch": 0.8245382585751979, "step": 2500}



Iteration:  82%|████████▏ | 2501/3032 [22:34<05:02,  1.75it/s][A
Iteration:  83%|████████▎ | 2502/3032 [22:35<04:57,  1.78it/s][A
Iteration:  83%|████████▎ | 2503/3032 [22:36<04:53,  1.80it/s][A
Iteration:  83%|████████▎ | 2504/3032 [22:36<04:50,  1.81it/s][A
Iteration:  83%|████████▎ | 2505/3032 [22:37<04:49,  1.82it/s][A
Iteration:  83%|████████▎ | 2506/3032 [22:37<04:47,  1.83it/s][A
Iteration:  83%|████████▎ | 2507/3032 [22:38<04:46,  1.83it/s][A
Iteration:  83%|████████▎ | 2508/3032 [22:38<04:45,  1.84it/s][A
Iteration:  83%|████████▎ | 2509/3032 [22:39<04:44,  1.84it/s][A
Iteration:  83%|████████▎ | 2510/3032 [22:39<04:43,  1.84it/s][A
Iteration:  83%|████████▎ | 2511/3032 [22:40<04:42,  1.84it/s][A
Iteration:  83%|████████▎ | 2512/3032 [22:40<04:41,  1.84it/s][A
Iteration:  83%|████████▎ | 2513/3032 [22:41<04:41,  1.85it/s][A
Iteration:  83%|████████▎ | 2514/3032 [22:42<04:40,  1.85it/s][A
Iteration:  83%|████████▎ | 2515/3032 [22:42<04:40,  1.85it/s][A
Iteration

{"loss": 0.33389562773145737, "learning_rate": 1.5052770448548815e-05, "epoch": 0.9894459102902374, "step": 3000}



Iteration:  99%|█████████▉| 3000/3032 [27:08<00:45,  1.41s/it][A
Iteration:  99%|█████████▉| 3001/3032 [27:08<00:35,  1.15s/it][A
Iteration:  99%|█████████▉| 3002/3032 [27:09<00:29,  1.03it/s][A
Iteration:  99%|█████████▉| 3003/3032 [27:09<00:24,  1.19it/s][A
Iteration:  99%|█████████▉| 3004/3032 [27:10<00:21,  1.33it/s][A
Iteration:  99%|█████████▉| 3005/3032 [27:10<00:18,  1.45it/s][A
Iteration:  99%|█████████▉| 3006/3032 [27:11<00:16,  1.55it/s][A
Iteration:  99%|█████████▉| 3007/3032 [27:11<00:15,  1.62it/s][A
Iteration:  99%|█████████▉| 3008/3032 [27:12<00:14,  1.68it/s][A
Iteration:  99%|█████████▉| 3009/3032 [27:12<00:13,  1.73it/s][A
Iteration:  99%|█████████▉| 3010/3032 [27:13<00:12,  1.76it/s][A
Iteration:  99%|█████████▉| 3011/3032 [27:14<00:11,  1.79it/s][A
Iteration:  99%|█████████▉| 3012/3032 [27:14<00:11,  1.81it/s][A
Iteration:  99%|█████████▉| 3013/3032 [27:15<00:10,  1.82it/s][A
Iteration:  99%|█████████▉| 3014/3032 [27:15<00:09,  1.83it/s][A
Iteration

{"loss": 0.2364832679014653, "learning_rate": 1.4228232189973616e-05, "epoch": 1.154353562005277, "step": 3500}



Iteration:  15%|█▌        | 469/3032 [04:14<24:41,  1.73it/s][A
Iteration:  16%|█▌        | 470/3032 [04:14<24:14,  1.76it/s][A
Iteration:  16%|█▌        | 471/3032 [04:15<23:53,  1.79it/s][A
Iteration:  16%|█▌        | 472/3032 [04:15<23:36,  1.81it/s][A
Iteration:  16%|█▌        | 473/3032 [04:16<23:25,  1.82it/s][A
Iteration:  16%|█▌        | 474/3032 [04:16<23:17,  1.83it/s][A
Iteration:  16%|█▌        | 475/3032 [04:17<23:14,  1.83it/s][A
Iteration:  16%|█▌        | 476/3032 [04:17<23:09,  1.84it/s][A
Iteration:  16%|█▌        | 477/3032 [04:18<23:07,  1.84it/s][A
Iteration:  16%|█▌        | 478/3032 [04:18<23:05,  1.84it/s][A
Iteration:  16%|█▌        | 479/3032 [04:19<23:06,  1.84it/s][A
Iteration:  16%|█▌        | 480/3032 [04:19<23:04,  1.84it/s][A
Iteration:  16%|█▌        | 481/3032 [04:20<23:02,  1.85it/s][A
Iteration:  16%|█▌        | 482/3032 [04:21<23:02,  1.84it/s][A
Iteration:  16%|█▌        | 483/3032 [04:21<22:59,  1.85it/s][A
Iteration:  16%|█▌      

{"loss": 0.23473819272220134, "learning_rate": 1.3403693931398417e-05, "epoch": 1.3192612137203166, "step": 4000}



Iteration:  32%|███▏      | 969/3032 [08:44<19:50,  1.73it/s][A
Iteration:  32%|███▏      | 970/3032 [08:45<19:27,  1.77it/s][A
Iteration:  32%|███▏      | 971/3032 [08:45<19:11,  1.79it/s][A
Iteration:  32%|███▏      | 972/3032 [08:46<18:59,  1.81it/s][A
Iteration:  32%|███▏      | 973/3032 [08:46<18:50,  1.82it/s][A
Iteration:  32%|███▏      | 974/3032 [08:47<18:40,  1.84it/s][A
Iteration:  32%|███▏      | 975/3032 [08:47<18:36,  1.84it/s][A
Iteration:  32%|███▏      | 976/3032 [08:48<18:38,  1.84it/s][A
Iteration:  32%|███▏      | 977/3032 [08:48<18:35,  1.84it/s][A
Iteration:  32%|███▏      | 978/3032 [08:49<18:33,  1.84it/s][A
Iteration:  32%|███▏      | 979/3032 [08:50<18:34,  1.84it/s][A
Iteration:  32%|███▏      | 980/3032 [08:50<18:33,  1.84it/s][A
Iteration:  32%|███▏      | 981/3032 [08:51<18:30,  1.85it/s][A
Iteration:  32%|███▏      | 982/3032 [08:51<18:25,  1.85it/s][A
Iteration:  32%|███▏      | 983/3032 [08:52<18:27,  1.85it/s][A
Iteration:  32%|███▏    

{"loss": 0.2170686432113871, "learning_rate": 1.257915567282322e-05, "epoch": 1.4841688654353562, "step": 4500}



Iteration:  48%|████▊     | 1469/3032 [13:15<14:51,  1.75it/s][A
Iteration:  48%|████▊     | 1470/3032 [13:16<14:38,  1.78it/s][A
Iteration:  49%|████▊     | 1471/3032 [13:16<14:25,  1.80it/s][A
Iteration:  49%|████▊     | 1472/3032 [13:17<14:17,  1.82it/s][A
Iteration:  49%|████▊     | 1473/3032 [13:17<14:12,  1.83it/s][A
Iteration:  49%|████▊     | 1474/3032 [13:18<14:08,  1.84it/s][A
Iteration:  49%|████▊     | 1475/3032 [13:18<14:06,  1.84it/s][A
Iteration:  49%|████▊     | 1476/3032 [13:19<14:05,  1.84it/s][A
Iteration:  49%|████▊     | 1477/3032 [13:19<14:05,  1.84it/s][A
Iteration:  49%|████▊     | 1478/3032 [13:20<14:03,  1.84it/s][A
Iteration:  49%|████▉     | 1479/3032 [13:20<14:01,  1.85it/s][A
Iteration:  49%|████▉     | 1480/3032 [13:21<13:59,  1.85it/s][A
Iteration:  49%|████▉     | 1481/3032 [13:21<14:00,  1.85it/s][A
Iteration:  49%|████▉     | 1482/3032 [13:22<14:00,  1.84it/s][A
Iteration:  49%|████▉     | 1483/3032 [13:23<13:59,  1.85it/s][A
Iteration

{"loss": 0.21282043099123985, "learning_rate": 1.1754617414248021e-05, "epoch": 1.6490765171503958, "step": 5000}



Iteration:  65%|██████▍   | 1969/3032 [17:46<10:11,  1.74it/s][A
Iteration:  65%|██████▍   | 1970/3032 [17:47<09:58,  1.77it/s][A
Iteration:  65%|██████▌   | 1971/3032 [17:47<09:51,  1.79it/s][A
Iteration:  65%|██████▌   | 1972/3032 [17:48<09:46,  1.81it/s][A
Iteration:  65%|██████▌   | 1973/3032 [17:49<09:41,  1.82it/s][A
Iteration:  65%|██████▌   | 1974/3032 [17:49<09:38,  1.83it/s][A
Iteration:  65%|██████▌   | 1975/3032 [17:50<09:36,  1.83it/s][A
Iteration:  65%|██████▌   | 1976/3032 [17:50<09:34,  1.84it/s][A
Iteration:  65%|██████▌   | 1977/3032 [17:51<09:32,  1.84it/s][A
Iteration:  65%|██████▌   | 1978/3032 [17:51<09:30,  1.85it/s][A
Iteration:  65%|██████▌   | 1979/3032 [17:52<09:28,  1.85it/s][A
Iteration:  65%|██████▌   | 1980/3032 [17:52<09:28,  1.85it/s][A
Iteration:  65%|██████▌   | 1981/3032 [17:53<09:28,  1.85it/s][A
Iteration:  65%|██████▌   | 1982/3032 [17:53<09:27,  1.85it/s][A
Iteration:  65%|██████▌   | 1983/3032 [17:54<09:27,  1.85it/s][A
Iteration

{"loss": 0.2163077299145516, "learning_rate": 1.0930079155672824e-05, "epoch": 1.8139841688654355, "step": 5500}



Iteration:  81%|████████▏ | 2469/3032 [22:17<05:21,  1.75it/s][A
Iteration:  81%|████████▏ | 2470/3032 [22:18<05:14,  1.79it/s][A
Iteration:  81%|████████▏ | 2471/3032 [22:18<05:11,  1.80it/s][A
Iteration:  82%|████████▏ | 2472/3032 [22:19<05:08,  1.82it/s][A
Iteration:  82%|████████▏ | 2473/3032 [22:19<05:06,  1.83it/s][A
Iteration:  82%|████████▏ | 2474/3032 [22:20<05:03,  1.84it/s][A
Iteration:  82%|████████▏ | 2475/3032 [22:20<05:03,  1.84it/s][A
Iteration:  82%|████████▏ | 2476/3032 [22:21<05:01,  1.84it/s][A
Iteration:  82%|████████▏ | 2477/3032 [22:21<05:01,  1.84it/s][A
Iteration:  82%|████████▏ | 2478/3032 [22:22<05:00,  1.85it/s][A
Iteration:  82%|████████▏ | 2479/3032 [22:22<04:59,  1.85it/s][A
Iteration:  82%|████████▏ | 2480/3032 [22:23<04:58,  1.85it/s][A
Iteration:  82%|████████▏ | 2481/3032 [22:24<04:58,  1.84it/s][A
Iteration:  82%|████████▏ | 2482/3032 [22:24<04:57,  1.85it/s][A
Iteration:  82%|████████▏ | 2483/3032 [22:25<04:57,  1.84it/s][A
Iteration

{"loss": 0.20584245995245873, "learning_rate": 1.0105540897097625e-05, "epoch": 1.978891820580475, "step": 6000}



Iteration:  98%|█████████▊| 2968/3032 [26:50<01:35,  1.50s/it][A
Iteration:  98%|█████████▊| 2969/3032 [26:51<01:16,  1.21s/it][A
Iteration:  98%|█████████▊| 2970/3032 [26:51<01:03,  1.02s/it][A
Iteration:  98%|█████████▊| 2971/3032 [26:52<00:53,  1.14it/s][A
Iteration:  98%|█████████▊| 2972/3032 [26:52<00:46,  1.29it/s][A
Iteration:  98%|█████████▊| 2973/3032 [26:53<00:41,  1.42it/s][A
Iteration:  98%|█████████▊| 2974/3032 [26:54<00:37,  1.53it/s][A
Iteration:  98%|█████████▊| 2975/3032 [26:54<00:35,  1.61it/s][A
Iteration:  98%|█████████▊| 2976/3032 [26:55<00:33,  1.68it/s][A
Iteration:  98%|█████████▊| 2977/3032 [26:55<00:31,  1.72it/s][A
Iteration:  98%|█████████▊| 2978/3032 [26:56<00:30,  1.76it/s][A
Iteration:  98%|█████████▊| 2979/3032 [26:56<00:29,  1.79it/s][A
Iteration:  98%|█████████▊| 2980/3032 [26:57<00:28,  1.80it/s][A
Iteration:  98%|█████████▊| 2981/3032 [26:57<00:28,  1.82it/s][A
Iteration:  98%|█████████▊| 2982/3032 [26:58<00:27,  1.83it/s][A
Iteration

{"loss": 0.12576990044629202, "learning_rate": 9.281002638522428e-06, "epoch": 2.1437994722955147, "step": 6500}



Iteration:  14%|█▍        | 437/3032 [03:55<24:34,  1.76it/s][A
Iteration:  14%|█▍        | 438/3032 [03:55<24:07,  1.79it/s][A
Iteration:  14%|█▍        | 439/3032 [03:56<23:50,  1.81it/s][A
Iteration:  15%|█▍        | 440/3032 [03:57<23:37,  1.83it/s][A
Iteration:  15%|█▍        | 441/3032 [03:57<23:34,  1.83it/s][A
Iteration:  15%|█▍        | 442/3032 [03:58<23:33,  1.83it/s][A
Iteration:  15%|█▍        | 443/3032 [03:58<23:29,  1.84it/s][A
Iteration:  15%|█▍        | 444/3032 [03:59<23:28,  1.84it/s][A
Iteration:  15%|█▍        | 445/3032 [03:59<23:22,  1.84it/s][A
Iteration:  15%|█▍        | 446/3032 [04:00<23:13,  1.86it/s][A
Iteration:  15%|█▍        | 447/3032 [04:00<23:05,  1.87it/s][A
Iteration:  15%|█▍        | 448/3032 [04:01<23:08,  1.86it/s][A
Iteration:  15%|█▍        | 449/3032 [04:01<23:12,  1.85it/s][A
Iteration:  15%|█▍        | 450/3032 [04:02<23:12,  1.85it/s][A
Iteration:  15%|█▍        | 451/3032 [04:02<23:13,  1.85it/s][A
Iteration:  15%|█▍      

{"loss": 0.13335266750631855, "learning_rate": 8.456464379947231e-06, "epoch": 2.308707124010554, "step": 7000}



Iteration:  31%|███       | 937/3032 [08:24<19:46,  1.77it/s][A
Iteration:  31%|███       | 938/3032 [08:24<19:23,  1.80it/s][A
Iteration:  31%|███       | 939/3032 [08:25<19:10,  1.82it/s][A
Iteration:  31%|███       | 940/3032 [08:25<19:04,  1.83it/s][A
Iteration:  31%|███       | 941/3032 [08:26<19:01,  1.83it/s][A
Iteration:  31%|███       | 942/3032 [08:26<18:55,  1.84it/s][A
Iteration:  31%|███       | 943/3032 [08:27<18:52,  1.85it/s][A
Iteration:  31%|███       | 944/3032 [08:27<18:54,  1.84it/s][A
Iteration:  31%|███       | 945/3032 [08:28<18:59,  1.83it/s][A
Iteration:  31%|███       | 946/3032 [08:28<18:45,  1.85it/s][A
Iteration:  31%|███       | 947/3032 [08:29<18:43,  1.86it/s][A
Iteration:  31%|███▏      | 948/3032 [08:30<18:43,  1.85it/s][A
Iteration:  31%|███▏      | 949/3032 [08:30<18:45,  1.85it/s][A
Iteration:  31%|███▏      | 950/3032 [08:31<18:43,  1.85it/s][A
Iteration:  31%|███▏      | 951/3032 [08:31<18:30,  1.87it/s][A
Iteration:  31%|███▏    

{"loss": 0.11941558906412683, "learning_rate": 7.631926121372032e-06, "epoch": 2.4736147757255935, "step": 7500}



Iteration:  47%|████▋     | 1437/3032 [12:52<15:17,  1.74it/s][A
Iteration:  47%|████▋     | 1438/3032 [12:53<15:01,  1.77it/s][A
Iteration:  47%|████▋     | 1439/3032 [12:53<14:48,  1.79it/s][A
Iteration:  47%|████▋     | 1440/3032 [12:54<14:32,  1.82it/s][A
Iteration:  48%|████▊     | 1441/3032 [12:55<14:25,  1.84it/s][A
Iteration:  48%|████▊     | 1442/3032 [12:55<14:22,  1.84it/s][A
Iteration:  48%|████▊     | 1443/3032 [12:56<14:22,  1.84it/s][A
Iteration:  48%|████▊     | 1444/3032 [12:56<14:21,  1.84it/s][A
Iteration:  48%|████▊     | 1445/3032 [12:57<14:17,  1.85it/s][A
Iteration:  48%|████▊     | 1446/3032 [12:57<14:18,  1.85it/s][A
Iteration:  48%|████▊     | 1447/3032 [12:58<14:16,  1.85it/s][A
Iteration:  48%|████▊     | 1448/3032 [12:58<14:18,  1.85it/s][A
Iteration:  48%|████▊     | 1449/3032 [12:59<14:18,  1.84it/s][A
Iteration:  48%|████▊     | 1450/3032 [12:59<14:12,  1.86it/s][A
Iteration:  48%|████▊     | 1451/3032 [13:00<14:12,  1.86it/s][A
Iteration

{"loss": 0.12226586146268528, "learning_rate": 6.807387862796835e-06, "epoch": 2.638522427440633, "step": 8000}



Iteration:  64%|██████▍   | 1937/3032 [17:22<10:17,  1.77it/s][A
Iteration:  64%|██████▍   | 1938/3032 [17:22<10:03,  1.81it/s][A
Iteration:  64%|██████▍   | 1939/3032 [17:23<09:58,  1.83it/s][A
Iteration:  64%|██████▍   | 1940/3032 [17:23<09:55,  1.83it/s][A
Iteration:  64%|██████▍   | 1941/3032 [17:24<09:51,  1.84it/s][A
Iteration:  64%|██████▍   | 1942/3032 [17:24<09:54,  1.83it/s][A
Iteration:  64%|██████▍   | 1943/3032 [17:25<09:55,  1.83it/s][A
Iteration:  64%|██████▍   | 1944/3032 [17:25<09:52,  1.84it/s][A
Iteration:  64%|██████▍   | 1945/3032 [17:26<09:49,  1.84it/s][A
Iteration:  64%|██████▍   | 1946/3032 [17:26<09:46,  1.85it/s][A
Iteration:  64%|██████▍   | 1947/3032 [17:27<09:34,  1.89it/s][A
Iteration:  64%|██████▍   | 1948/3032 [17:27<09:34,  1.89it/s][A
Iteration:  64%|██████▍   | 1949/3032 [17:28<09:36,  1.88it/s][A
Iteration:  64%|██████▍   | 1950/3032 [17:29<09:39,  1.87it/s][A
Iteration:  64%|██████▍   | 1951/3032 [17:29<09:39,  1.87it/s][A
Iteration

{"loss": 0.12086807915638201, "learning_rate": 5.982849604221637e-06, "epoch": 2.8034300791556728, "step": 8500}



Iteration:  80%|████████  | 2437/3032 [21:50<05:34,  1.78it/s][A
Iteration:  80%|████████  | 2438/3032 [21:51<05:24,  1.83it/s][A
Iteration:  80%|████████  | 2439/3032 [21:51<05:19,  1.86it/s][A
Iteration:  80%|████████  | 2440/3032 [21:52<05:18,  1.86it/s][A
Iteration:  81%|████████  | 2441/3032 [21:52<05:18,  1.85it/s][A
Iteration:  81%|████████  | 2442/3032 [21:53<05:18,  1.85it/s][A
Iteration:  81%|████████  | 2443/3032 [21:53<05:16,  1.86it/s][A
Iteration:  81%|████████  | 2444/3032 [21:54<05:15,  1.86it/s][A
Iteration:  81%|████████  | 2445/3032 [21:54<05:15,  1.86it/s][A
Iteration:  81%|████████  | 2446/3032 [21:55<05:14,  1.86it/s][A
Iteration:  81%|████████  | 2447/3032 [21:56<05:14,  1.86it/s][A
Iteration:  81%|████████  | 2448/3032 [21:56<05:15,  1.85it/s][A
Iteration:  81%|████████  | 2449/3032 [21:57<05:14,  1.85it/s][A
Iteration:  81%|████████  | 2450/3032 [21:57<05:14,  1.85it/s][A
Iteration:  81%|████████  | 2451/3032 [21:58<05:12,  1.86it/s][A
Iteration

{"loss": 0.13150540708273184, "learning_rate": 5.158311345646439e-06, "epoch": 2.9683377308707124, "step": 9000}



Iteration:  97%|█████████▋| 2936/3032 [26:21<02:20,  1.46s/it][A
Iteration:  97%|█████████▋| 2937/3032 [26:22<01:51,  1.18s/it][A
Iteration:  97%|█████████▋| 2938/3032 [26:22<01:31,  1.03it/s][A
Iteration:  97%|█████████▋| 2939/3032 [26:23<01:18,  1.19it/s][A
Iteration:  97%|█████████▋| 2940/3032 [26:24<01:09,  1.33it/s][A
Iteration:  97%|█████████▋| 2941/3032 [26:24<01:02,  1.46it/s][A
Iteration:  97%|█████████▋| 2942/3032 [26:25<00:57,  1.56it/s][A
Iteration:  97%|█████████▋| 2943/3032 [26:25<00:54,  1.64it/s][A
Iteration:  97%|█████████▋| 2944/3032 [26:26<00:51,  1.70it/s][A
Iteration:  97%|█████████▋| 2945/3032 [26:26<00:49,  1.75it/s][A
Iteration:  97%|█████████▋| 2946/3032 [26:27<00:48,  1.79it/s][A
Iteration:  97%|█████████▋| 2947/3032 [26:27<00:46,  1.81it/s][A
Iteration:  97%|█████████▋| 2948/3032 [26:28<00:46,  1.82it/s][A
Iteration:  97%|█████████▋| 2949/3032 [26:28<00:45,  1.83it/s][A
Iteration:  97%|█████████▋| 2950/3032 [26:29<00:43,  1.87it/s][A
Iteration

{"loss": 0.08640403769724071, "learning_rate": 4.33377308707124e-06, "epoch": 3.133245382585752, "step": 9500}



Iteration:  13%|█▎        | 405/3032 [03:37<24:46,  1.77it/s][A
Iteration:  13%|█▎        | 406/3032 [03:37<24:16,  1.80it/s][A
Iteration:  13%|█▎        | 407/3032 [03:38<24:05,  1.82it/s][A
Iteration:  13%|█▎        | 408/3032 [03:39<23:56,  1.83it/s][A
Iteration:  13%|█▎        | 409/3032 [03:39<23:50,  1.83it/s][A
Iteration:  14%|█▎        | 410/3032 [03:40<23:42,  1.84it/s][A
Iteration:  14%|█▎        | 411/3032 [03:40<23:39,  1.85it/s][A
Iteration:  14%|█▎        | 412/3032 [03:41<23:38,  1.85it/s][A
Iteration:  14%|█▎        | 413/3032 [03:41<23:37,  1.85it/s][A
Iteration:  14%|█▎        | 414/3032 [03:42<23:35,  1.85it/s][A
Iteration:  14%|█▎        | 415/3032 [03:42<23:28,  1.86it/s][A
Iteration:  14%|█▎        | 416/3032 [03:43<23:27,  1.86it/s][A
Iteration:  14%|█▍        | 417/3032 [03:43<23:26,  1.86it/s][A
Iteration:  14%|█▍        | 418/3032 [03:44<23:24,  1.86it/s][A
Iteration:  14%|█▍        | 419/3032 [03:44<23:29,  1.85it/s][A
Iteration:  14%|█▍      

{"loss": 0.07247035594657064, "learning_rate": 3.5092348284960427e-06, "epoch": 3.2981530343007917, "step": 10000}



Iteration:  30%|██▉       | 905/3032 [08:05<21:32,  1.65it/s][A
Iteration:  30%|██▉       | 906/3032 [08:06<20:45,  1.71it/s][A
Iteration:  30%|██▉       | 907/3032 [08:06<20:16,  1.75it/s][A
Iteration:  30%|██▉       | 908/3032 [08:07<19:55,  1.78it/s][A
Iteration:  30%|██▉       | 909/3032 [08:07<19:38,  1.80it/s][A
Iteration:  30%|███       | 910/3032 [08:08<19:26,  1.82it/s][A
Iteration:  30%|███       | 911/3032 [08:09<19:20,  1.83it/s][A
Iteration:  30%|███       | 912/3032 [08:09<19:14,  1.84it/s][A
Iteration:  30%|███       | 913/3032 [08:10<19:10,  1.84it/s][A
Iteration:  30%|███       | 914/3032 [08:10<19:08,  1.84it/s][A
Iteration:  30%|███       | 915/3032 [08:11<19:06,  1.85it/s][A
Iteration:  30%|███       | 916/3032 [08:11<18:58,  1.86it/s][A
Iteration:  30%|███       | 917/3032 [08:12<18:56,  1.86it/s][A
Iteration:  30%|███       | 918/3032 [08:12<19:01,  1.85it/s][A
Iteration:  30%|███       | 919/3032 [08:13<19:01,  1.85it/s][A
Iteration:  30%|███     

{"loss": 0.0628563115184661, "learning_rate": 2.6846965699208443e-06, "epoch": 3.4630606860158313, "step": 10500}



Iteration:  46%|████▋     | 1405/3032 [12:34<15:43,  1.73it/s][A
Iteration:  46%|████▋     | 1406/3032 [12:34<15:20,  1.77it/s][A
Iteration:  46%|████▋     | 1407/3032 [12:35<15:06,  1.79it/s][A
Iteration:  46%|████▋     | 1408/3032 [12:35<14:39,  1.85it/s][A
Iteration:  46%|████▋     | 1409/3032 [12:36<14:34,  1.86it/s][A
Iteration:  47%|████▋     | 1410/3032 [12:36<14:32,  1.86it/s][A
Iteration:  47%|████▋     | 1411/3032 [12:37<14:40,  1.84it/s][A
Iteration:  47%|████▋     | 1412/3032 [12:38<14:40,  1.84it/s][A
Iteration:  47%|████▋     | 1413/3032 [12:38<14:21,  1.88it/s][A
Iteration:  47%|████▋     | 1414/3032 [12:39<14:22,  1.88it/s][A
Iteration:  47%|████▋     | 1415/3032 [12:39<14:22,  1.87it/s][A
Iteration:  47%|████▋     | 1416/3032 [12:40<14:23,  1.87it/s][A
Iteration:  47%|████▋     | 1417/3032 [12:40<14:18,  1.88it/s][A
Iteration:  47%|████▋     | 1418/3032 [12:41<14:13,  1.89it/s][A
Iteration:  47%|████▋     | 1419/3032 [12:41<14:14,  1.89it/s][A
Iteration

{"loss": 0.07594620484649203, "learning_rate": 1.8601583113456467e-06, "epoch": 3.627968337730871, "step": 11000}



Iteration:  63%|██████▎   | 1905/3032 [17:02<10:27,  1.80it/s][A
Iteration:  63%|██████▎   | 1906/3032 [17:02<10:20,  1.81it/s][A
Iteration:  63%|██████▎   | 1907/3032 [17:03<10:16,  1.82it/s][A
Iteration:  63%|██████▎   | 1908/3032 [17:03<10:14,  1.83it/s][A
Iteration:  63%|██████▎   | 1909/3032 [17:04<10:12,  1.83it/s][A
Iteration:  63%|██████▎   | 1910/3032 [17:05<10:08,  1.84it/s][A
Iteration:  63%|██████▎   | 1911/3032 [17:05<10:08,  1.84it/s][A
Iteration:  63%|██████▎   | 1912/3032 [17:06<10:05,  1.85it/s][A
Iteration:  63%|██████▎   | 1913/3032 [17:06<10:02,  1.86it/s][A
Iteration:  63%|██████▎   | 1914/3032 [17:07<10:03,  1.85it/s][A
Iteration:  63%|██████▎   | 1915/3032 [17:07<10:01,  1.86it/s][A
Iteration:  63%|██████▎   | 1916/3032 [17:08<10:01,  1.85it/s][A
Iteration:  63%|██████▎   | 1917/3032 [17:08<10:01,  1.85it/s][A
Iteration:  63%|██████▎   | 1918/3032 [17:09<09:59,  1.86it/s][A
Iteration:  63%|██████▎   | 1919/3032 [17:09<09:58,  1.86it/s][A
Iteration

{"loss": 0.06983707585709635, "learning_rate": 1.0356200527704487e-06, "epoch": 3.7928759894459105, "step": 11500}



Iteration:  79%|███████▉  | 2405/3032 [21:30<05:34,  1.87it/s][A
Iteration:  79%|███████▉  | 2406/3032 [21:30<05:26,  1.92it/s][A
Iteration:  79%|███████▉  | 2407/3032 [21:31<05:24,  1.92it/s][A
Iteration:  79%|███████▉  | 2408/3032 [21:31<05:25,  1.92it/s][A
Iteration:  79%|███████▉  | 2409/3032 [21:32<05:27,  1.90it/s][A
Iteration:  79%|███████▉  | 2410/3032 [21:32<05:22,  1.93it/s][A
Iteration:  80%|███████▉  | 2411/3032 [21:33<05:17,  1.96it/s][A
Iteration:  80%|███████▉  | 2412/3032 [21:33<05:18,  1.94it/s][A
Iteration:  80%|███████▉  | 2413/3032 [21:34<05:26,  1.90it/s][A
Iteration:  80%|███████▉  | 2414/3032 [21:34<05:27,  1.89it/s][A
Iteration:  80%|███████▉  | 2415/3032 [21:35<05:21,  1.92it/s][A
Iteration:  80%|███████▉  | 2416/3032 [21:35<05:20,  1.92it/s][A
Iteration:  80%|███████▉  | 2417/3032 [21:36<05:21,  1.91it/s][A
Iteration:  80%|███████▉  | 2418/3032 [21:37<05:24,  1.89it/s][A
Iteration:  80%|███████▉  | 2419/3032 [21:37<05:25,  1.88it/s][A
Iteration

{"loss": 0.07198260226415004, "learning_rate": 2.1108179419525068e-07, "epoch": 3.9577836411609497, "step": 12000}



Iteration:  96%|█████████▌| 2904/3032 [25:52<03:17,  1.54s/it][A
Iteration:  96%|█████████▌| 2905/3032 [25:52<02:38,  1.25s/it][A
Iteration:  96%|█████████▌| 2906/3032 [25:53<02:10,  1.04s/it][A
Iteration:  96%|█████████▌| 2907/3032 [25:53<01:50,  1.13it/s][A
Iteration:  96%|█████████▌| 2908/3032 [25:54<01:36,  1.29it/s][A
Iteration:  96%|█████████▌| 2909/3032 [25:54<01:26,  1.42it/s][A
Iteration:  96%|█████████▌| 2910/3032 [25:55<01:20,  1.52it/s][A
Iteration:  96%|█████████▌| 2911/3032 [25:55<01:15,  1.60it/s][A
Iteration:  96%|█████████▌| 2912/3032 [25:56<01:11,  1.69it/s][A
Iteration:  96%|█████████▌| 2913/3032 [25:56<01:08,  1.75it/s][A
Iteration:  96%|█████████▌| 2914/3032 [25:57<01:06,  1.78it/s][A
Iteration:  96%|█████████▌| 2915/3032 [25:57<01:04,  1.81it/s][A
Iteration:  96%|█████████▌| 2916/3032 [25:58<01:03,  1.82it/s][A
Iteration:  96%|█████████▌| 2917/3032 [25:58<01:02,  1.83it/s][A
Iteration:  96%|█████████▌| 2918/3032 [25:59<01:02,  1.83it/s][A
Iteration

TrainOutput(global_step=12128, training_loss=0.21598201864807112)

In [17]:
def evaluate(trainer, features_dict, epochs):
    from sklearn.metrics import classification_report
    preds_dict = {}
    for task_name in task_names:
        eval_dataloader = DataLoaderWithTaskname(
            task_name,
            trainer.get_eval_dataloader(eval_dataset=features_dict[task_name]["test"])
        )
        print(eval_dataloader.data_loader.collate_fn)
        preds_dict[task_name] = trainer._prediction_loop(
            eval_dataloader, 
            description=f"Test: {task_name}",
        )
        print(f"Classification Report for Task {task_name} trained for {epochs} epochs")
        print(classification_report(y_pred=np.argmax(preds_dict[task_name].predictions, axis=1),  y_true=preds_dict[task_name].label_ids, digits=16))

In [20]:
evaluate(trainer, features_dict, epochs=4)

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff56df1990>>


Test: offensive: 100%|██████████| 159/159 [00:13<00:00, 12.06it/s]


Classification Report for Task offensive trained for 4 epochs
                  precision    recall  f1-score   support

               0  0.9110576923076923 0.8752886836027713 0.8928150765606596       866
               1  0.7534246575342466 0.8168316831683168 0.7838479809976246       404

        accuracy                      0.8566929133858268      1270
       macro avg  0.8322411749209695 0.8460601833855441 0.8383315287791422      1270
    weighted avg  0.8609130103797615 0.8566929133858268 0.8581515280508437      1270

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff56df1990>>


Test: hatespeech: 100%|██████████| 159/159 [00:13<00:00, 12.18it/s]

Classification Report for Task hatespeech trained for 4 epochs
                  precision    recall  f1-score   support

               0  0.9884615384615385 0.8854435831180018 0.9341208541572013      1161
               1  0.4217391304347826 0.8899082568807339 0.5722713864306785       109

        accuracy                      0.8858267716535433      1270
       macro avg  0.7051003344481606 0.8876759199993678 0.7531961202939399      1270
    weighted avg  0.9398215837568801 0.8858267716535433 0.9030644825176808      1270






KeyError: 'test'

#### Lets see for two more epoch

In [22]:
trainer.args.num_train_epochs = 2

In [None]:
trainer.train()

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]
Iteration:   0%|          | 0/3032 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/3032 [00:00<28:43,  1.76it/s][A
Iteration:   0%|          | 2/3032 [00:01<27:49,  1.81it/s][A
Iteration:   0%|          | 3/3032 [00:01<27:30,  1.83it/s][A
Iteration:   0%|          | 4/3032 [00:02<27:15,  1.85it/s][A
Iteration:   0%|          | 5/3032 [00:02<26:59,  1.87it/s][A
Iteration:   0%|          | 6/3032 [00:03<27:08,  1.86it/s][A
Iteration:   0%|          | 7/3032 [00:03<27:24,  1.84it/s][A
Iteration:   0%|          | 8/3032 [00:04<27:20,  1.84it/s][A
Iteration:   0%|          | 9/3032 [00:04<27:14,  1.85it/s][A
Iteration:   0%|          | 10/3032 [00:05<27:22,  1.84it/s][A
Iteration:   0%|          | 11/3032 [00:05<27:15,  1.85it/s][A
Iteration:   0%|          | 12/3032 [00:06<27:01,  1.86it/s][A
Iteration:   0%|          | 13/3032 [00:07<26:59,  1.86it/s][A
Iteration:   0%|          | 14/3032 [00:07<27:00,  1.86it/s][A
Iteration:   

{"loss": 0.09856156264507444, "learning_rate": 1.8350923482849604e-05, "epoch": 0.16490765171503957, "step": 500}



Iteration:  17%|█▋        | 501/3032 [04:27<22:36,  1.87it/s][A
Iteration:  17%|█▋        | 502/3032 [04:28<22:40,  1.86it/s][A
Iteration:  17%|█▋        | 503/3032 [04:28<22:43,  1.85it/s][A
Iteration:  17%|█▋        | 504/3032 [04:29<22:48,  1.85it/s][A
Iteration:  17%|█▋        | 505/3032 [04:29<22:46,  1.85it/s][A
Iteration:  17%|█▋        | 506/3032 [04:30<22:50,  1.84it/s][A
Iteration:  17%|█▋        | 507/3032 [04:30<22:51,  1.84it/s][A
Iteration:  17%|█▋        | 508/3032 [04:31<22:49,  1.84it/s][A
Iteration:  17%|█▋        | 509/3032 [04:31<22:43,  1.85it/s][A
Iteration:  17%|█▋        | 510/3032 [04:32<22:42,  1.85it/s][A
Iteration:  17%|█▋        | 511/3032 [04:32<22:42,  1.85it/s][A
Iteration:  17%|█▋        | 512/3032 [04:33<22:43,  1.85it/s][A
Iteration:  17%|█▋        | 513/3032 [04:34<22:43,  1.85it/s][A
Iteration:  17%|█▋        | 514/3032 [04:34<22:41,  1.85it/s][A
Iteration:  17%|█▋        | 515/3032 [04:35<22:44,  1.84it/s][A
Iteration:  17%|█▋      

{"loss": 0.11112148779537528, "learning_rate": 1.670184696569921e-05, "epoch": 0.32981530343007914, "step": 1000}



Iteration:  33%|███▎      | 1001/3032 [08:55<17:54,  1.89it/s][A
Iteration:  33%|███▎      | 1002/3032 [08:55<17:45,  1.90it/s][A
Iteration:  33%|███▎      | 1003/3032 [08:56<17:38,  1.92it/s][A
Iteration:  33%|███▎      | 1004/3032 [08:56<17:41,  1.91it/s][A
Iteration:  33%|███▎      | 1005/3032 [08:57<17:50,  1.89it/s][A
Iteration:  33%|███▎      | 1006/3032 [08:57<17:52,  1.89it/s][A
Iteration:  33%|███▎      | 1007/3032 [08:58<17:38,  1.91it/s][A
Iteration:  33%|███▎      | 1008/3032 [08:58<17:41,  1.91it/s][A
Iteration:  33%|███▎      | 1009/3032 [08:59<17:51,  1.89it/s][A
Iteration:  33%|███▎      | 1010/3032 [08:59<17:58,  1.87it/s][A
Iteration:  33%|███▎      | 1011/3032 [09:00<18:05,  1.86it/s][A
Iteration:  33%|███▎      | 1012/3032 [09:00<17:49,  1.89it/s][A
Iteration:  33%|███▎      | 1013/3032 [09:01<17:42,  1.90it/s][A
Iteration:  33%|███▎      | 1014/3032 [09:01<17:42,  1.90it/s][A
Iteration:  33%|███▎      | 1015/3032 [09:02<17:51,  1.88it/s][A
Iteration

{"loss": 0.11249433731511817, "learning_rate": 1.5052770448548815e-05, "epoch": 0.4947229551451187, "step": 1500}



Iteration:  50%|████▉     | 1501/3032 [13:22<13:42,  1.86it/s][A
Iteration:  50%|████▉     | 1502/3032 [13:23<13:41,  1.86it/s][A
Iteration:  50%|████▉     | 1503/3032 [13:23<13:40,  1.86it/s][A
Iteration:  50%|████▉     | 1504/3032 [13:24<13:41,  1.86it/s][A
Iteration:  50%|████▉     | 1505/3032 [13:24<13:37,  1.87it/s][A
Iteration:  50%|████▉     | 1506/3032 [13:25<13:36,  1.87it/s][A
Iteration:  50%|████▉     | 1507/3032 [13:25<13:36,  1.87it/s][A
Iteration:  50%|████▉     | 1508/3032 [13:26<13:37,  1.86it/s][A
Iteration:  50%|████▉     | 1509/3032 [13:26<13:37,  1.86it/s][A
Iteration:  50%|████▉     | 1510/3032 [13:27<13:38,  1.86it/s][A
Iteration:  50%|████▉     | 1511/3032 [13:27<13:41,  1.85it/s][A
Iteration:  50%|████▉     | 1512/3032 [13:28<13:41,  1.85it/s][A
Iteration:  50%|████▉     | 1513/3032 [13:29<13:38,  1.86it/s][A
Iteration:  50%|████▉     | 1514/3032 [13:29<13:34,  1.86it/s][A
Iteration:  50%|████▉     | 1515/3032 [13:30<13:35,  1.86it/s][A
Iteration

{"loss": 0.10871931239272817, "learning_rate": 1.3403693931398417e-05, "epoch": 0.6596306068601583, "step": 2000}



Iteration:  66%|██████▌   | 2001/3032 [17:50<09:14,  1.86it/s][A
Iteration:  66%|██████▌   | 2002/3032 [17:51<09:05,  1.89it/s][A
Iteration:  66%|██████▌   | 2003/3032 [17:51<09:04,  1.89it/s][A
Iteration:  66%|██████▌   | 2004/3032 [17:52<09:05,  1.88it/s][A
Iteration:  66%|██████▌   | 2005/3032 [17:52<09:08,  1.87it/s][A
Iteration:  66%|██████▌   | 2006/3032 [17:53<09:13,  1.85it/s][A
Iteration:  66%|██████▌   | 2007/3032 [17:53<09:10,  1.86it/s][A
Iteration:  66%|██████▌   | 2008/3032 [17:54<09:06,  1.87it/s][A
Iteration:  66%|██████▋   | 2009/3032 [17:55<09:03,  1.88it/s][A
Iteration:  66%|██████▋   | 2010/3032 [17:55<09:06,  1.87it/s][A
Iteration:  66%|██████▋   | 2011/3032 [17:56<08:59,  1.89it/s][A
Iteration:  66%|██████▋   | 2012/3032 [17:56<09:00,  1.89it/s][A
Iteration:  66%|██████▋   | 2013/3032 [17:57<09:02,  1.88it/s][A
Iteration:  66%|██████▋   | 2014/3032 [17:57<09:06,  1.86it/s][A
Iteration:  66%|██████▋   | 2015/3032 [17:58<09:07,  1.86it/s][A
Iteration

{"loss": 0.11563910545190446, "learning_rate": 1.1754617414248021e-05, "epoch": 0.8245382585751979, "step": 2500}



Iteration:  82%|████████▏ | 2501/3032 [22:17<04:46,  1.85it/s][A
Iteration:  83%|████████▎ | 2502/3032 [22:17<04:38,  1.90it/s][A
Iteration:  83%|████████▎ | 2503/3032 [22:18<04:41,  1.88it/s][A
Iteration:  83%|████████▎ | 2504/3032 [22:18<04:41,  1.87it/s][A
Iteration:  83%|████████▎ | 2505/3032 [22:19<04:42,  1.87it/s][A
Iteration:  83%|████████▎ | 2506/3032 [22:19<04:43,  1.86it/s][A
Iteration:  83%|████████▎ | 2507/3032 [22:20<04:38,  1.88it/s][A
Iteration:  83%|████████▎ | 2508/3032 [22:20<04:32,  1.92it/s][A
Iteration:  83%|████████▎ | 2509/3032 [22:21<04:30,  1.93it/s][A
Iteration:  83%|████████▎ | 2510/3032 [22:21<04:30,  1.93it/s][A
Iteration:  83%|████████▎ | 2511/3032 [22:22<04:31,  1.92it/s][A
Iteration:  83%|████████▎ | 2512/3032 [22:22<04:34,  1.89it/s][A
Iteration:  83%|████████▎ | 2513/3032 [22:23<04:35,  1.88it/s][A
Iteration:  83%|████████▎ | 2514/3032 [22:23<04:36,  1.88it/s][A
Iteration:  83%|████████▎ | 2515/3032 [22:24<04:36,  1.87it/s][A
Iteration

{"loss": 0.10495833750563907, "learning_rate": 1.0105540897097625e-05, "epoch": 0.9894459102902374, "step": 3000}



Iteration:  99%|█████████▉| 3000/3032 [26:45<00:47,  1.48s/it][A
Iteration:  99%|█████████▉| 3001/3032 [26:46<00:37,  1.20s/it][A
Iteration:  99%|█████████▉| 3002/3032 [26:46<00:30,  1.00s/it][A
Iteration:  99%|█████████▉| 3003/3032 [26:47<00:25,  1.16it/s][A
Iteration:  99%|█████████▉| 3004/3032 [26:48<00:21,  1.30it/s][A
Iteration:  99%|█████████▉| 3005/3032 [26:48<00:18,  1.43it/s][A
Iteration:  99%|█████████▉| 3006/3032 [26:49<00:16,  1.54it/s][A
Iteration:  99%|█████████▉| 3007/3032 [26:49<00:15,  1.65it/s][A
Iteration:  99%|█████████▉| 3008/3032 [26:50<00:14,  1.70it/s][A
Iteration:  99%|█████████▉| 3009/3032 [26:50<00:13,  1.75it/s][A
Iteration:  99%|█████████▉| 3010/3032 [26:51<00:12,  1.78it/s][A
Iteration:  99%|█████████▉| 3011/3032 [26:51<00:11,  1.80it/s][A
Iteration:  99%|█████████▉| 3012/3032 [26:52<00:10,  1.84it/s][A
Iteration:  99%|█████████▉| 3013/3032 [26:52<00:10,  1.85it/s][A
Iteration:  99%|█████████▉| 3014/3032 [26:53<00:09,  1.85it/s][A
Iteration

{"loss": 0.05659188774196082, "learning_rate": 8.456464379947231e-06, "epoch": 1.154353562005277, "step": 3500}



Iteration:  15%|█▌        | 469/3032 [04:09<22:39,  1.89it/s][A
Iteration:  16%|█▌        | 470/3032 [04:10<22:37,  1.89it/s][A
Iteration:  16%|█▌        | 471/3032 [04:10<22:41,  1.88it/s][A
Iteration:  16%|█▌        | 472/3032 [04:11<22:47,  1.87it/s][A
Iteration:  16%|█▌        | 473/3032 [04:11<22:55,  1.86it/s][A
Iteration:  16%|█▌        | 474/3032 [04:12<22:41,  1.88it/s][A
Iteration:  16%|█▌        | 475/3032 [04:12<22:48,  1.87it/s][A
Iteration:  16%|█▌        | 476/3032 [04:13<22:53,  1.86it/s][A
Iteration:  16%|█▌        | 477/3032 [04:13<22:56,  1.86it/s][A
Iteration:  16%|█▌        | 478/3032 [04:14<22:56,  1.86it/s][A
Iteration:  16%|█▌        | 479/3032 [04:14<22:33,  1.89it/s][A
Iteration:  16%|█▌        | 480/3032 [04:15<22:23,  1.90it/s][A
Iteration:  16%|█▌        | 481/3032 [04:16<22:30,  1.89it/s][A
Iteration:  16%|█▌        | 482/3032 [04:16<22:36,  1.88it/s][A
Iteration:  16%|█▌        | 483/3032 [04:17<22:42,  1.87it/s][A
Iteration:  16%|█▌      

{"loss": 0.05963085182575742, "learning_rate": 6.807387862796835e-06, "epoch": 1.3192612137203166, "step": 4000}



Iteration:  32%|███▏      | 969/3032 [08:30<17:17,  1.99it/s][A
Iteration:  32%|███▏      | 970/3032 [08:30<17:19,  1.98it/s][A
Iteration:  32%|███▏      | 971/3032 [08:31<17:14,  1.99it/s][A
Iteration:  32%|███▏      | 972/3032 [08:31<17:16,  1.99it/s][A
Iteration:  32%|███▏      | 973/3032 [08:32<17:30,  1.96it/s][A
Iteration:  32%|███▏      | 974/3032 [08:32<17:26,  1.97it/s][A
Iteration:  32%|███▏      | 975/3032 [08:33<17:19,  1.98it/s][A
Iteration:  32%|███▏      | 976/3032 [08:33<17:19,  1.98it/s][A
Iteration:  32%|███▏      | 977/3032 [08:34<17:17,  1.98it/s][A
Iteration:  32%|███▏      | 978/3032 [08:34<17:11,  1.99it/s][A
Iteration:  32%|███▏      | 979/3032 [08:35<17:05,  2.00it/s][A
Iteration:  32%|███▏      | 980/3032 [08:35<17:02,  2.01it/s][A
Iteration:  32%|███▏      | 981/3032 [08:36<16:59,  2.01it/s][A
Iteration:  32%|███▏      | 982/3032 [08:36<16:57,  2.01it/s][A
Iteration:  32%|███▏      | 983/3032 [08:37<16:56,  2.02it/s][A
Iteration:  32%|███▏    

In [33]:
evaluate(trainer, features_dict, epochs=5)

Test: offensive:   1%|          | 1/159 [00:00<00:16,  9.84it/s]

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff8edc4ee0>>


Test: offensive: 100%|██████████| 159/159 [00:13<00:00, 11.94it/s]
Test: hatespeech:   1%|▏         | 2/159 [00:00<00:12, 12.20it/s]

Classification Report for Task offensive trained for 5 epochs
                  precision    recall  f1-score   support

               0  0.9061371841155235 0.8695150115473441 0.8874484384207424       866
               1  0.7425968109339408 0.8069306930693070 0.7734282325029657       404

        accuracy                      0.8496062992125984      1270
       macro avg  0.8243669975247321 0.8382228523083255 0.8304383354618541      1270
    weighted avg  0.8541133173711460 0.8496062992125984 0.8511774437823314      1270

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff8edc4ee0>>


Test: hatespeech: 100%|██████████| 159/159 [00:13<00:00, 12.01it/s]
Test: hatespeech_classes:   1%|▏         | 2/159 [00:00<00:13, 12.04it/s]

Classification Report for Task hatespeech trained for 5 epochs
                  precision    recall  f1-score   support

               0  0.9731136166522116 0.9664082687338501 0.9697493517718236      1161
               1  0.6666666666666666 0.7155963302752294 0.6902654867256638       109

        accuracy                      0.9448818897637795      1270
       macro avg  0.8198901416594391 0.8410022995045398 0.8300074192487437      1270
    weighted avg  0.9468122642518775 0.9448818897637795 0.9457621539056573      1270

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff8edc4ee0>>


Test: hatespeech_classes: 100%|██████████| 159/159 [00:13<00:00, 11.96it/s]

Classification Report for Task hatespeech_classes trained for 5 epochs
                  precision    recall  f1-score   support

               0  0.9699054170249355 0.9715762273901809 0.9707401032702238      1161
               1  0.6666666666666666 0.5714285714285714 0.6153846153846153        28
               2  0.0000000000000000 0.0000000000000000 0.0000000000000000         4
               3  0.5333333333333333 0.5714285714285714 0.5517241379310344        14
               4  0.0000000000000000 0.0000000000000000 0.0000000000000000         1
               5  0.1538461538461539 0.2000000000000000 0.1739130434782609        10
               6  0.6909090909090909 0.7307692307692307 0.7102803738317757        52

        accuracy                      0.9385826771653544      1270
       macro avg  0.4306658088257400 0.4350289430023649 0.4317203248422729      1270
    weighted avg  0.9367395722559195 0.9385826771653544 0.9375258873484790      1270




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


### 2 Epochs

In [18]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8418723545933512}

In [19]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8371149406524729}

In [20]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.3926597611174843}

### 3 Epochs

In [21]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8418723545933512}

In [27]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8325985296056062}

In [28]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.4176057674898591}

### 4 Epochs

In [25]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8469622661091527}

In [27]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8325985296056062}

In [28]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.4176057674898591}

### 5 Epochs

In [43]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8387301587301587}

In [44]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8107379458902904}

In [45]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.4355893433350976}

# Combo

## Preparing Data

In [4]:
task_names = ['offensive', 'hatespeech', 'hatespeech_classes']

In [5]:
dataset_dict = {
    "offensive": load_dataset("csv", data_files={'train': "Data/train/trainA_prepro_combo.csv", 'test': "Data/test/testA_prepro.csv" } ),
    "hatespeech": load_dataset("csv", data_files={'train': "Data/train/trainB_prepro_combo.csv", 'test': "Data/test/testB_prepro.csv" } ),
    "hatespeech_classes": load_dataset("csv", data_files={'train': "Data/train/trainC_prepro.csv", 'test': "Data/test/testC_prepro.csv" } ),
}

Downloading and preparing dataset csv/default to /home/ashapiro/.cache/huggingface/datasets/csv/default-1337ce598b95f03a/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files: 100%|██████████| 2/2 [00:00<00:00, 1432.97it/s]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 69.37it/s]


Dataset csv downloaded and prepared to /home/ashapiro/.cache/huggingface/datasets/csv/default-1337ce598b95f03a/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


100%|██████████| 2/2 [00:00<00:00, 571.86it/s]


Downloading and preparing dataset csv/default to /home/ashapiro/.cache/huggingface/datasets/csv/default-ea8e0edd0e0f031f/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files: 100%|██████████| 2/2 [00:00<00:00, 1398.57it/s]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 72.30it/s]


Dataset csv downloaded and prepared to /home/ashapiro/.cache/huggingface/datasets/csv/default-ea8e0edd0e0f031f/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


100%|██████████| 2/2 [00:00<00:00, 578.56it/s]


Downloading and preparing dataset csv/default to /home/ashapiro/.cache/huggingface/datasets/csv/default-0210a04391128139/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files: 100%|██████████| 2/2 [00:00<00:00, 1934.20it/s]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 77.70it/s]


Dataset csv downloaded and prepared to /home/ashapiro/.cache/huggingface/datasets/csv/default-0210a04391128139/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


100%|██████████| 2/2 [00:00<00:00, 331.76it/s]


## Setting Model

In [6]:
model_name = "/scratch/mt/ashapiro/Hate_Speech/Models/Marbertv2/"
multitask_model = MultitaskModel.create(
    model_name=model_name,
    model_type_dict={
        "offensive": transformers.AutoModelForSequenceClassification,
        "hatespeech": transformers.AutoModelForSequenceClassification,
        "hatespeech_classes": transformers.AutoModelForSequenceClassification,
    },
    model_config_dict={
        "offensive": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        "hatespeech": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        "hatespeech_classes": transformers.AutoConfig.from_pretrained(model_name, num_labels=7),
    },
)

In [7]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

In [8]:
max_length = 512

def convert_to_features(example_batch):
    inputs = list(example_batch['text'])
    features = tokenizer.batch_encode_plus(
        inputs, max_length=max_length, pad_to_max_length=True
    )
    features["labels"] = example_batch["labels"]
    return features

convert_func_dict = {
    "offensive": convert_to_features,
    "hatespeech": convert_to_features,
    "hatespeech_classes": convert_to_features,
}

In [9]:
columns_dict = {
    "offensive": ['input_ids', 'attention_mask', 'labels'],
    "hatespeech": ['input_ids', 'attention_mask', 'labels'],
    "hatespeech_classes": ['input_ids', 'attention_mask', 'labels'],
}

features_dict = {}
for task_name, dataset in dataset_dict.items():
    features_dict[task_name] = {}
    for phase, phase_dataset in dataset.items():
        features_dict[task_name][phase] = phase_dataset.map(
            convert_func_dict[task_name],
            batched=True,
            load_from_cache_file=False,
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))
        features_dict[task_name][phase].set_format(
            type="torch", 
            columns=columns_dict[task_name],
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))

100%|██████████| 42/42 [00:19<00:00,  2.21ba/s]


offensive train 41212 41212
offensive train 41212 41212


100%|██████████| 2/2 [00:00<00:00,  3.99ba/s]


offensive test 1270 1270
offensive test 1270 1270


100%|██████████| 33/33 [00:15<00:00,  2.13ba/s]


hatespeech train 32630 32630
hatespeech train 32630 32630


100%|██████████| 2/2 [00:00<00:00,  3.90ba/s]


hatespeech test 1270 1270
hatespeech test 1270 1270


100%|██████████| 9/9 [00:03<00:00,  2.53ba/s]


hatespeech_classes train 8887 8887
hatespeech_classes train 8887 8887


100%|██████████| 2/2 [00:00<00:00,  3.96ba/s]

hatespeech_classes test 1270 1270
hatespeech_classes test 1270 1270





In [10]:
eval_dataset = {
    task_name: dataset["test"] 
    for task_name, dataset in features_dict.items()
}

In [11]:
train_dataset = {
    task_name: dataset["train"] 
    for task_name, dataset in features_dict.items()
}
args = transformers.TrainingArguments(
        output_dir="./models/multitask_model/3_main_tasks/combo_data/4_epochs/",
        overwrite_output_dir=True,
        learning_rate=2e-5,
        do_train=True,
        num_train_epochs=4,
        # Adjust batch size if this doesn't fit on the Colab GPU
        per_device_train_batch_size=16,  
        save_steps=3000,)

trainer = MultitaskTrainer(
    model=multitask_model,
    args=args,
    data_collator=NLPDataCollator(),
    train_dataset=train_dataset,
)

In [None]:
trainer.train()

Epoch:   0%|          | 0/4 [00:00<?, ?it/s]
Iteration:   0%|          | 0/5172 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/5172 [00:00<1:23:36,  1.03it/s][A
Iteration:   0%|          | 2/5172 [00:01<1:01:13,  1.41it/s][A
Iteration:   0%|          | 3/5172 [00:02<54:20,  1.59it/s]  [A
Iteration:   0%|          | 4/5172 [00:02<51:16,  1.68it/s][A
Iteration:   0%|          | 5/5172 [00:03<49:22,  1.74it/s][A
Iteration:   0%|          | 6/5172 [00:03<48:16,  1.78it/s][A
Iteration:   0%|          | 7/5172 [00:04<47:36,  1.81it/s][A
Iteration:   0%|          | 8/5172 [00:04<47:05,  1.83it/s][A
Iteration:   0%|          | 9/5172 [00:05<46:40,  1.84it/s][A
Iteration:   0%|          | 10/5172 [00:05<46:18,  1.86it/s][A
Iteration:   0%|          | 11/5172 [00:06<46:23,  1.85it/s][A
Iteration:   0%|          | 12/5172 [00:06<46:34,  1.85it/s][A
Iteration:   0%|          | 13/5172 [00:07<46:26,  1.85it/s][A
Iteration:   0%|          | 14/5172 [00:07<46:19,  1.86it/s][A
Iterati

In [13]:
task_names = ['offensive','hatespeech','hatespeech_classes']

In [32]:
def evaluate(trainer, features_dict, epochs):
    from sklearn.metrics import classification_report
    preds_dict = {}
    for task_name in task_names:
        eval_dataloader = DataLoaderWithTaskname(
            task_name,
            trainer.get_eval_dataloader(eval_dataset=features_dict[task_name]["test"])
        )
        print(eval_dataloader.data_loader.collate_fn)
        preds_dict[task_name] = trainer._prediction_loop(
            eval_dataloader, 
            description=f"Test: {task_name}",
        )
        print(f"Classification Report for Task {task_name} trained for {epochs} epochs")
        print(classification_report(y_pred=np.argmax(preds_dict[task_name].predictions, axis=1),  y_true=preds_dict[task_name].label_ids, digits=16))

In [26]:
evaluate(trainer, features_dict, epochs=4)

Test: offensive:   1%|          | 1/159 [00:00<00:16,  9.79it/s]

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff8edc4ee0>>


Test: offensive: 100%|██████████| 159/159 [00:13<00:00, 11.99it/s]
Test: hatespeech:   1%|▏         | 2/159 [00:00<00:13, 11.60it/s]

Classification Report for Task offensive trained for 4 epochs
              precision    recall  f1-score   support

           0       0.92      0.86      0.89       866
           1       0.73      0.83      0.78       404

    accuracy                           0.85      1270
   macro avg       0.82      0.84      0.83      1270
weighted avg       0.86      0.85      0.85      1270

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff8edc4ee0>>


Test: hatespeech: 100%|██████████| 159/159 [00:13<00:00, 11.98it/s]
Test: hatespeech_classes:   1%|▏         | 2/159 [00:00<00:13, 12.02it/s]

Classification Report for Task hatespeech trained for 4 epochs
              precision    recall  f1-score   support

           0       0.98      0.96      0.97      1161
           1       0.65      0.75      0.70       109

    accuracy                           0.94      1270
   macro avg       0.81      0.86      0.83      1270
weighted avg       0.95      0.94      0.95      1270

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff8edc4ee0>>


Test: hatespeech_classes: 100%|██████████| 159/159 [00:13<00:00, 11.93it/s]

Classification Report for Task hatespeech_classes trained for 4 epochs
              precision    recall  f1-score   support

           0       0.97      0.97      0.97      1161
           1       0.58      0.64      0.61        28
           2       0.00      0.00      0.00         4
           3       0.56      0.64      0.60        14
           4       0.00      0.00      0.00         1
           5       0.20      0.10      0.13        10
           6       0.67      0.77      0.71        52

    accuracy                           0.94      1270
   macro avg       0.43      0.45      0.43      1270
weighted avg       0.94      0.94      0.94      1270




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


#### Lets see for one more epoch

In [27]:
trainer.args.num_train_epochs = 1

In [28]:
trainer.train()

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]
Iteration:   0%|          | 0/2101 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/2101 [00:00<18:22,  1.90it/s][A
Iteration:   0%|          | 2/2101 [00:01<17:57,  1.95it/s][A
Iteration:   0%|          | 3/2101 [00:01<18:14,  1.92it/s][A
Iteration:   0%|          | 4/2101 [00:02<18:25,  1.90it/s][A
Iteration:   0%|          | 5/2101 [00:02<18:30,  1.89it/s][A
Iteration:   0%|          | 6/2101 [00:03<18:38,  1.87it/s][A
Iteration:   0%|          | 7/2101 [00:03<18:41,  1.87it/s][A
Iteration:   0%|          | 8/2101 [00:04<18:40,  1.87it/s][A
Iteration:   0%|          | 9/2101 [00:04<18:41,  1.86it/s][A
Iteration:   0%|          | 10/2101 [00:05<18:43,  1.86it/s][A
Iteration:   1%|          | 11/2101 [00:05<18:44,  1.86it/s][A
Iteration:   1%|          | 12/2101 [00:06<18:45,  1.86it/s][A
Iteration:   1%|          | 13/2101 [00:06<18:45,  1.85it/s][A
Iteration:   1%|          | 14/2101 [00:07<18:46,  1.85it/s][A
Iteration:   

{"loss": 0.12110608770386898, "learning_rate": 1.524036173250833e-05, "epoch": 0.23798191337458352, "step": 500}



Iteration:  24%|██▍       | 501/2101 [04:29<14:27,  1.84it/s][A
Iteration:  24%|██▍       | 502/2101 [04:30<14:25,  1.85it/s][A
Iteration:  24%|██▍       | 503/2101 [04:30<14:23,  1.85it/s][A
Iteration:  24%|██▍       | 504/2101 [04:31<14:22,  1.85it/s][A
Iteration:  24%|██▍       | 505/2101 [04:31<14:21,  1.85it/s][A
Iteration:  24%|██▍       | 506/2101 [04:32<14:23,  1.85it/s][A
Iteration:  24%|██▍       | 507/2101 [04:33<14:24,  1.84it/s][A
Iteration:  24%|██▍       | 508/2101 [04:33<14:23,  1.84it/s][A
Iteration:  24%|██▍       | 509/2101 [04:34<14:21,  1.85it/s][A
Iteration:  24%|██▍       | 510/2101 [04:34<14:23,  1.84it/s][A
Iteration:  24%|██▍       | 511/2101 [04:35<14:22,  1.84it/s][A
Iteration:  24%|██▍       | 512/2101 [04:35<14:22,  1.84it/s][A
Iteration:  24%|██▍       | 513/2101 [04:36<14:21,  1.84it/s][A
Iteration:  24%|██▍       | 514/2101 [04:36<14:20,  1.84it/s][A
Iteration:  25%|██▍       | 515/2101 [04:37<14:20,  1.84it/s][A
Iteration:  25%|██▍     

{"loss": 0.12189705061679706, "learning_rate": 1.048072346501666e-05, "epoch": 0.47596382674916704, "step": 1000}



Iteration:  48%|████▊     | 1001/2101 [08:59<09:58,  1.84it/s][A
Iteration:  48%|████▊     | 1002/2101 [09:00<09:57,  1.84it/s][A
Iteration:  48%|████▊     | 1003/2101 [09:00<09:54,  1.85it/s][A
Iteration:  48%|████▊     | 1004/2101 [09:01<09:53,  1.85it/s][A
Iteration:  48%|████▊     | 1005/2101 [09:01<09:52,  1.85it/s][A
Iteration:  48%|████▊     | 1006/2101 [09:02<09:51,  1.85it/s][A
Iteration:  48%|████▊     | 1007/2101 [09:02<09:50,  1.85it/s][A
Iteration:  48%|████▊     | 1008/2101 [09:03<09:51,  1.85it/s][A
Iteration:  48%|████▊     | 1009/2101 [09:04<09:53,  1.84it/s][A
Iteration:  48%|████▊     | 1010/2101 [09:04<09:51,  1.84it/s][A
Iteration:  48%|████▊     | 1011/2101 [09:05<09:50,  1.85it/s][A
Iteration:  48%|████▊     | 1012/2101 [09:05<09:50,  1.85it/s][A
Iteration:  48%|████▊     | 1013/2101 [09:06<09:49,  1.85it/s][A
Iteration:  48%|████▊     | 1014/2101 [09:06<09:50,  1.84it/s][A
Iteration:  48%|████▊     | 1015/2101 [09:07<09:48,  1.84it/s][A
Iteration

{"loss": 0.09711660955034312, "learning_rate": 5.721085197524988e-06, "epoch": 0.7139457401237506, "step": 1500}



Iteration:  71%|███████▏  | 1501/2101 [13:28<05:24,  1.85it/s][A
Iteration:  71%|███████▏  | 1502/2101 [13:28<05:24,  1.85it/s][A
Iteration:  72%|███████▏  | 1503/2101 [13:29<05:23,  1.85it/s][A
Iteration:  72%|███████▏  | 1504/2101 [13:29<05:23,  1.85it/s][A
Iteration:  72%|███████▏  | 1505/2101 [13:30<05:22,  1.85it/s][A
Iteration:  72%|███████▏  | 1506/2101 [13:30<05:21,  1.85it/s][A
Iteration:  72%|███████▏  | 1507/2101 [13:31<05:21,  1.85it/s][A
Iteration:  72%|███████▏  | 1508/2101 [13:31<05:19,  1.86it/s][A
Iteration:  72%|███████▏  | 1509/2101 [13:32<05:19,  1.85it/s][A
Iteration:  72%|███████▏  | 1510/2101 [13:32<05:19,  1.85it/s][A
Iteration:  72%|███████▏  | 1511/2101 [13:33<05:19,  1.85it/s][A
Iteration:  72%|███████▏  | 1512/2101 [13:34<05:18,  1.85it/s][A
Iteration:  72%|███████▏  | 1513/2101 [13:34<05:18,  1.85it/s][A
Iteration:  72%|███████▏  | 1514/2101 [13:35<05:17,  1.85it/s][A
Iteration:  72%|███████▏  | 1515/2101 [13:35<05:16,  1.85it/s][A
Iteration

{"loss": 0.10913289758379688, "learning_rate": 9.614469300333174e-07, "epoch": 0.9519276534983341, "step": 2000}



Iteration:  95%|█████████▌| 2001/2101 [17:56<00:52,  1.91it/s][A
Iteration:  95%|█████████▌| 2002/2101 [17:57<00:51,  1.90it/s][A
Iteration:  95%|█████████▌| 2003/2101 [17:58<00:51,  1.90it/s][A
Iteration:  95%|█████████▌| 2004/2101 [17:58<00:51,  1.89it/s][A
Iteration:  95%|█████████▌| 2005/2101 [17:59<00:50,  1.89it/s][A
Iteration:  95%|█████████▌| 2006/2101 [17:59<00:50,  1.88it/s][A
Iteration:  96%|█████████▌| 2007/2101 [18:00<00:50,  1.87it/s][A
Iteration:  96%|█████████▌| 2008/2101 [18:00<00:49,  1.88it/s][A
Iteration:  96%|█████████▌| 2009/2101 [18:01<00:48,  1.90it/s][A
Iteration:  96%|█████████▌| 2010/2101 [18:01<00:48,  1.89it/s][A
Iteration:  96%|█████████▌| 2011/2101 [18:02<00:47,  1.88it/s][A
Iteration:  96%|█████████▌| 2012/2101 [18:02<00:47,  1.87it/s][A
Iteration:  96%|█████████▌| 2013/2101 [18:03<00:47,  1.87it/s][A
Iteration:  96%|█████████▌| 2014/2101 [18:03<00:46,  1.86it/s][A
Iteration:  96%|█████████▌| 2015/2101 [18:04<00:46,  1.86it/s][A
Iteration

TrainOutput(global_step=2101, training_loss=0.11203017262371649)

In [33]:
evaluate(trainer, features_dict, epochs=5)

Test: offensive:   1%|          | 1/159 [00:00<00:16,  9.84it/s]

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff8edc4ee0>>


Test: offensive: 100%|██████████| 159/159 [00:13<00:00, 11.94it/s]
Test: hatespeech:   1%|▏         | 2/159 [00:00<00:12, 12.20it/s]

Classification Report for Task offensive trained for 5 epochs
                  precision    recall  f1-score   support

               0  0.9061371841155235 0.8695150115473441 0.8874484384207424       866
               1  0.7425968109339408 0.8069306930693070 0.7734282325029657       404

        accuracy                      0.8496062992125984      1270
       macro avg  0.8243669975247321 0.8382228523083255 0.8304383354618541      1270
    weighted avg  0.8541133173711460 0.8496062992125984 0.8511774437823314      1270

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff8edc4ee0>>


Test: hatespeech: 100%|██████████| 159/159 [00:13<00:00, 12.01it/s]
Test: hatespeech_classes:   1%|▏         | 2/159 [00:00<00:13, 12.04it/s]

Classification Report for Task hatespeech trained for 5 epochs
                  precision    recall  f1-score   support

               0  0.9731136166522116 0.9664082687338501 0.9697493517718236      1161
               1  0.6666666666666666 0.7155963302752294 0.6902654867256638       109

        accuracy                      0.9448818897637795      1270
       macro avg  0.8198901416594391 0.8410022995045398 0.8300074192487437      1270
    weighted avg  0.9468122642518775 0.9448818897637795 0.9457621539056573      1270

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7fff8edc4ee0>>


Test: hatespeech_classes: 100%|██████████| 159/159 [00:13<00:00, 11.96it/s]

Classification Report for Task hatespeech_classes trained for 5 epochs
                  precision    recall  f1-score   support

               0  0.9699054170249355 0.9715762273901809 0.9707401032702238      1161
               1  0.6666666666666666 0.5714285714285714 0.6153846153846153        28
               2  0.0000000000000000 0.0000000000000000 0.0000000000000000         4
               3  0.5333333333333333 0.5714285714285714 0.5517241379310344        14
               4  0.0000000000000000 0.0000000000000000 0.0000000000000000         1
               5  0.1538461538461539 0.2000000000000000 0.1739130434782609        10
               6  0.6909090909090909 0.7307692307692307 0.7102803738317757        52

        accuracy                      0.9385826771653544      1270
       macro avg  0.4306658088257400 0.4350289430023649 0.4317203248422729      1270
    weighted avg  0.9367395722559195 0.9385826771653544 0.9375258873484790      1270




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


### 2 Epochs

In [18]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8418723545933512}

In [19]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8371149406524729}

In [20]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.3926597611174843}

### 3 Epochs

In [21]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8418723545933512}

In [27]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8325985296056062}

In [28]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.4176057674898591}

### 4 Epochs

In [25]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8469622661091527}

In [27]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8325985296056062}

In [28]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.4176057674898591}

### 5 Epochs

In [43]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8387301587301587}

In [44]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8107379458902904}

In [45]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.4355893433350976}

## Loading

In [46]:
def load_model(dot_bin_file):
    multitask_model = MultitaskModel.create(
                                            model_name=model_name,
                                            model_type_dict={
                                                "offensive": transformers.AutoModelForSequenceClassification,
                                                "hatespeech": transformers.AutoModelForSequenceClassification,
                                                "hatespeech_classes": transformers.AutoModelForSequenceClassification,
                                            },
                                            model_config_dict={
                                                "offensive": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
                                                "hatespeech": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
                                                "hatespeech_classes": transformers.AutoConfig.from_pretrained(model_name, num_labels=7),
                                            },)
    model = torch.load(dot_bin_file)
    multitask_model.load_state_dict(model)
    return multitask_model

# Without Last Task

In [1]:
%cd /scratch/mt/ashapiro/Hate_Speech/Multitask_trial/

/scratch/mt/ashapiro/Hate_Speech/Multitask_trial


In [2]:
import numpy as np
import torch
import torch.nn as nn
import transformers
import nlp
import logging
from datasets import load_dataset
from model import * 
logging.basicConfig(level=logging.INFO)

## Preparing Data

In [3]:
task_names = ['offensive', 'hatespeech', 'hatespeech_classes']

In [4]:
dataset_dict = {
    "offensive": load_dataset("csv", data_files={'train': "Data/trainA_prepro.csv", 'test': "Data/testA_prepro.csv" } ),
    "hatespeech": load_dataset("csv", data_files={'train': "Data/trainB_prepro.csv", 'test': "Data/testB_prepro.csv" } ),
    "hatespeech_classes": load_dataset("csv", data_files={'train': "Data/trainC_prepro.csv", 'test': "Data/testC_prepro.csv" } ),
}

100%|██████████| 2/2 [00:00<00:00, 193.53it/s]
100%|██████████| 2/2 [00:00<00:00, 208.36it/s]
100%|██████████| 2/2 [00:00<00:00, 206.70it/s]


## Setting Model

In [5]:
model_name = "/scratch/mt/ashapiro/Hate_Speech/Models/Marbertv2/"
multitask_model = MultitaskModel.create(
    model_name=model_name,
    model_type_dict={
        "offensive": transformers.AutoModelForSequenceClassification,
        "hatespeech": transformers.AutoModelForSequenceClassification,
        "hatespeech_classes": transformers.AutoModelForSequenceClassification,
    },
    model_config_dict={
        "offensive": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        "hatespeech": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        "hatespeech_classes": transformers.AutoConfig.from_pretrained(model_name, num_labels=7),
    },
)

In [6]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

In [7]:
max_length = 512

def convert_to_features(example_batch):
    inputs = list(example_batch['text'])
    features = tokenizer.batch_encode_plus(
        inputs, max_length=max_length, pad_to_max_length=True
    )
    features["labels"] = example_batch["labels"]
    return features

convert_func_dict = {
    "offensive": convert_to_features,
    "hatespeech": convert_to_features,
    "hatespeech_classes": convert_to_features,
}

In [8]:
columns_dict = {
    "offensive": ['input_ids', 'attention_mask', 'labels'],
    "hatespeech": ['input_ids', 'attention_mask', 'labels'],
    "hatespeech_classes": ['input_ids', 'attention_mask', 'labels'],
}

features_dict = {}
for task_name, dataset in dataset_dict.items():
    features_dict[task_name] = {}
    for phase, phase_dataset in dataset.items():
        features_dict[task_name][phase] = phase_dataset.map(
            convert_func_dict[task_name],
            batched=True,
            load_from_cache_file=False,
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))
        features_dict[task_name][phase].set_format(
            type="torch", 
            columns=columns_dict[task_name],
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))

100%|██████████| 9/9 [00:03<00:00,  2.52ba/s]


offensive train 8887 8887
offensive train 8887 8887


100%|██████████| 2/2 [00:00<00:00,  3.92ba/s]


offensive test 1270 1270
offensive test 1270 1270


100%|██████████| 9/9 [00:03<00:00,  2.35ba/s]


hatespeech train 8887 8887
hatespeech train 8887 8887


100%|██████████| 2/2 [00:00<00:00,  3.93ba/s]


hatespeech test 1270 1270
hatespeech test 1270 1270


100%|██████████| 9/9 [00:03<00:00,  2.35ba/s]


hatespeech_classes train 8887 8887
hatespeech_classes train 8887 8887


100%|██████████| 2/2 [00:00<00:00,  3.87ba/s]

hatespeech_classes test 1270 1270
hatespeech_classes test 1270 1270





In [9]:
eval_dataset = {
    task_name: dataset["test"] 
    for task_name, dataset in features_dict.items()
}

In [10]:
train_dataset = {
    task_name: dataset["train"] 
    for task_name, dataset in features_dict.items()
}
args = transformers.TrainingArguments(
        output_dir="./models/multitask_model/3_epochs",
        overwrite_output_dir=True,
        learning_rate=2e-5,
        do_train=True,
        num_train_epochs=2,
        # Adjust batch size if this doesn't fit on the Colab GPU
        per_device_train_batch_size=16,  
        save_steps=3000,)

trainer = MultitaskTrainer(
    model=multitask_model,
    args=args,
    data_collator=NLPDataCollator(),
    train_dataset=train_dataset,
)

[34m[1mwandb[0m: Currently logged in as: [33mahmadshapiro[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.11 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [11]:
trainer.train()

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]
Iteration:   0%|          | 0/1668 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/1668 [00:00<17:53,  1.55it/s][A
Iteration:   0%|          | 2/1668 [00:01<16:14,  1.71it/s][A
Iteration:   0%|          | 3/1668 [00:01<15:44,  1.76it/s][A
Iteration:   0%|          | 4/1668 [00:02<15:26,  1.80it/s][A
Iteration:   0%|          | 5/1668 [00:02<15:19,  1.81it/s][A
Iteration:   0%|          | 6/1668 [00:03<15:17,  1.81it/s][A
Iteration:   0%|          | 7/1668 [00:03<15:07,  1.83it/s][A
Iteration:   0%|          | 8/1668 [00:04<15:13,  1.82it/s][A
Iteration:   1%|          | 9/1668 [00:05<15:13,  1.82it/s][A
Iteration:   1%|          | 10/1668 [00:05<15:09,  1.82it/s][A
Iteration:   1%|          | 11/1668 [00:06<15:04,  1.83it/s][A
Iteration:   1%|          | 12/1668 [00:06<15:06,  1.83it/s][A
Iteration:   1%|          | 13/1668 [00:07<15:04,  1.83it/s][A
Iteration:   1%|          | 14/1668 [00:07<15:02,  1.83it/s][A
Iteration:   

{"loss": 0.41617555819451807, "learning_rate": 1.7002398081534774e-05, "epoch": 0.2997601918465228, "step": 500}



Iteration:  30%|███       | 501/1668 [04:32<11:34,  1.68it/s][A
Iteration:  30%|███       | 502/1668 [04:32<11:18,  1.72it/s][A
Iteration:  30%|███       | 503/1668 [04:33<11:07,  1.74it/s][A
Iteration:  30%|███       | 504/1668 [04:33<10:59,  1.77it/s][A
Iteration:  30%|███       | 505/1668 [04:34<10:52,  1.78it/s][A
Iteration:  30%|███       | 506/1668 [04:34<10:47,  1.80it/s][A
Iteration:  30%|███       | 507/1668 [04:35<10:45,  1.80it/s][A
Iteration:  30%|███       | 508/1668 [04:36<10:42,  1.80it/s][A
Iteration:  31%|███       | 509/1668 [04:36<10:39,  1.81it/s][A
Iteration:  31%|███       | 510/1668 [04:37<10:37,  1.82it/s][A
Iteration:  31%|███       | 511/1668 [04:37<10:35,  1.82it/s][A
Iteration:  31%|███       | 512/1668 [04:38<10:33,  1.82it/s][A
Iteration:  31%|███       | 513/1668 [04:38<10:33,  1.82it/s][A
Iteration:  31%|███       | 514/1668 [04:39<10:33,  1.82it/s][A
Iteration:  31%|███       | 515/1668 [04:39<10:32,  1.82it/s][A
Iteration:  31%|███     

{"loss": 0.2891820684103295, "learning_rate": 1.4004796163069546e-05, "epoch": 0.5995203836930456, "step": 1000}



Iteration:  60%|██████    | 1001/1668 [09:03<06:25,  1.73it/s][A
Iteration:  60%|██████    | 1002/1668 [09:04<06:18,  1.76it/s][A
Iteration:  60%|██████    | 1003/1668 [09:04<06:13,  1.78it/s][A
Iteration:  60%|██████    | 1004/1668 [09:05<06:10,  1.79it/s][A
Iteration:  60%|██████    | 1005/1668 [09:05<06:07,  1.80it/s][A
Iteration:  60%|██████    | 1006/1668 [09:06<05:59,  1.84it/s][A
Iteration:  60%|██████    | 1007/1668 [09:06<05:56,  1.85it/s][A
Iteration:  60%|██████    | 1008/1668 [09:07<05:56,  1.85it/s][A
Iteration:  60%|██████    | 1009/1668 [09:07<05:58,  1.84it/s][A
Iteration:  61%|██████    | 1010/1668 [09:08<05:58,  1.84it/s][A
Iteration:  61%|██████    | 1011/1668 [09:08<05:58,  1.83it/s][A
Iteration:  61%|██████    | 1012/1668 [09:09<05:57,  1.83it/s][A
Iteration:  61%|██████    | 1013/1668 [09:09<05:56,  1.84it/s][A
Iteration:  61%|██████    | 1014/1668 [09:10<05:56,  1.84it/s][A
Iteration:  61%|██████    | 1015/1668 [09:11<05:55,  1.84it/s][A
Iteration

{"loss": 0.22568737766169944, "learning_rate": 1.1007194244604318e-05, "epoch": 0.8992805755395683, "step": 1500}



Iteration:  90%|████████▉ | 1501/1668 [13:34<01:36,  1.73it/s][A
Iteration:  90%|█████████ | 1502/1668 [13:35<01:34,  1.76it/s][A
Iteration:  90%|█████████ | 1503/1668 [13:35<01:32,  1.78it/s][A
Iteration:  90%|█████████ | 1504/1668 [13:36<01:31,  1.79it/s][A
Iteration:  90%|█████████ | 1505/1668 [13:36<01:30,  1.80it/s][A
Iteration:  90%|█████████ | 1506/1668 [13:37<01:29,  1.81it/s][A
Iteration:  90%|█████████ | 1507/1668 [13:37<01:28,  1.81it/s][A
Iteration:  90%|█████████ | 1508/1668 [13:38<01:28,  1.80it/s][A
Iteration:  90%|█████████ | 1509/1668 [13:38<01:27,  1.83it/s][A
Iteration:  91%|█████████ | 1510/1668 [13:39<01:25,  1.84it/s][A
Iteration:  91%|█████████ | 1511/1668 [13:39<01:25,  1.83it/s][A
Iteration:  91%|█████████ | 1512/1668 [13:40<01:25,  1.83it/s][A
Iteration:  91%|█████████ | 1513/1668 [13:41<01:23,  1.87it/s][A
Iteration:  91%|█████████ | 1514/1668 [13:41<01:22,  1.88it/s][A
Iteration:  91%|█████████ | 1515/1668 [13:42<01:21,  1.87it/s][A
Iteration

{"loss": 0.16795708821783772, "learning_rate": 8.00959232613909e-06, "epoch": 1.1990407673860912, "step": 2000}



Iteration:  20%|█▉        | 333/1668 [03:01<13:12,  1.68it/s][A
Iteration:  20%|██        | 334/1668 [03:01<12:51,  1.73it/s][A
Iteration:  20%|██        | 335/1668 [03:02<12:37,  1.76it/s][A
Iteration:  20%|██        | 336/1668 [03:02<12:27,  1.78it/s][A
Iteration:  20%|██        | 337/1668 [03:03<12:19,  1.80it/s][A
Iteration:  20%|██        | 338/1668 [03:04<12:10,  1.82it/s][A
Iteration:  20%|██        | 339/1668 [03:04<12:04,  1.83it/s][A
Iteration:  20%|██        | 340/1668 [03:05<12:05,  1.83it/s][A
Iteration:  20%|██        | 341/1668 [03:05<12:06,  1.83it/s][A
Iteration:  21%|██        | 342/1668 [03:06<12:06,  1.83it/s][A
Iteration:  21%|██        | 343/1668 [03:06<12:05,  1.83it/s][A
Iteration:  21%|██        | 344/1668 [03:07<12:03,  1.83it/s][A
Iteration:  21%|██        | 345/1668 [03:07<12:04,  1.83it/s][A
Iteration:  21%|██        | 346/1668 [03:08<12:04,  1.83it/s][A
Iteration:  21%|██        | 347/1668 [03:08<12:02,  1.83it/s][A
Iteration:  21%|██      

{"loss": 0.13271094839542638, "learning_rate": 5.011990407673861e-06, "epoch": 1.498800959232614, "step": 2500}



Iteration:  50%|████▉     | 833/1668 [07:34<07:51,  1.77it/s][A
Iteration:  50%|█████     | 834/1668 [07:34<07:45,  1.79it/s][A
Iteration:  50%|█████     | 835/1668 [07:35<07:41,  1.81it/s][A
Iteration:  50%|█████     | 836/1668 [07:36<07:37,  1.82it/s][A
Iteration:  50%|█████     | 837/1668 [07:36<07:35,  1.83it/s][A
Iteration:  50%|█████     | 838/1668 [07:37<07:33,  1.83it/s][A
Iteration:  50%|█████     | 839/1668 [07:37<07:32,  1.83it/s][A
Iteration:  50%|█████     | 840/1668 [07:38<07:32,  1.83it/s][A
Iteration:  50%|█████     | 841/1668 [07:38<07:30,  1.84it/s][A
Iteration:  50%|█████     | 842/1668 [07:39<07:30,  1.83it/s][A
Iteration:  51%|█████     | 843/1668 [07:39<07:29,  1.84it/s][A
Iteration:  51%|█████     | 844/1668 [07:40<07:27,  1.84it/s][A
Iteration:  51%|█████     | 845/1668 [07:40<07:27,  1.84it/s][A
Iteration:  51%|█████     | 846/1668 [07:41<07:27,  1.84it/s][A
Iteration:  51%|█████     | 847/1668 [07:42<07:27,  1.84it/s][A
Iteration:  51%|█████   

{"loss": 0.11166536170669134, "learning_rate": 2.0143884892086333e-06, "epoch": 1.7985611510791366, "step": 3000}



Iteration:  80%|███████▉  | 1332/1668 [12:08<08:28,  1.51s/it][A
Iteration:  80%|███████▉  | 1333/1668 [12:09<06:50,  1.22s/it][A
Iteration:  80%|███████▉  | 1334/1668 [12:09<05:40,  1.02s/it][A
Iteration:  80%|████████  | 1335/1668 [12:10<04:51,  1.14it/s][A
Iteration:  80%|████████  | 1336/1668 [12:10<04:17,  1.29it/s][A
Iteration:  80%|████████  | 1337/1668 [12:11<03:53,  1.42it/s][A
Iteration:  80%|████████  | 1338/1668 [12:11<03:36,  1.52it/s][A
Iteration:  80%|████████  | 1339/1668 [12:12<03:25,  1.60it/s][A
Iteration:  80%|████████  | 1340/1668 [12:12<03:16,  1.67it/s][A
Iteration:  80%|████████  | 1341/1668 [12:13<03:10,  1.71it/s][A
Iteration:  80%|████████  | 1342/1668 [12:13<03:06,  1.75it/s][A
Iteration:  81%|████████  | 1343/1668 [12:14<03:03,  1.77it/s][A
Iteration:  81%|████████  | 1344/1668 [12:14<03:00,  1.79it/s][A
Iteration:  81%|████████  | 1345/1668 [12:15<02:58,  1.81it/s][A
Iteration:  81%|████████  | 1346/1668 [12:16<02:57,  1.82it/s][A
Iteration

TrainOutput(global_step=3336, training_loss=0.21281198457132977)

In [12]:
task_names = ['offensive','hatespeech','hatespeech_classes']

In [13]:
import datasets

In [14]:
f1 = datasets.load_metric('f1')

In [15]:
recall = datasets.load_metric('recall')

In [16]:
precision = datasets.load_metric('precision')

In [17]:
preds_dict = {}
for task_name in task_names:
    eval_dataloader = DataLoaderWithTaskname(
        task_name,
        trainer.get_eval_dataloader(eval_dataset=features_dict[task_name]["test"])
    )
    print(eval_dataloader.data_loader.collate_fn)
    preds_dict[task_name] = trainer._prediction_loop(
        eval_dataloader, 
        description=f"Test: {task_name}",
    )

Test: offensive:   1%|▏         | 2/159 [00:00<00:15, 10.40it/s]

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7ffef9fbb070>>


Test: offensive: 100%|██████████| 159/159 [00:13<00:00, 11.91it/s]
Test: hatespeech:   1%|▏         | 2/159 [00:00<00:13, 11.97it/s]

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7ffef9fbb070>>


Test: hatespeech: 100%|██████████| 159/159 [00:13<00:00, 11.88it/s]
Test: hatespeech_classes:   1%|▏         | 2/159 [00:00<00:13, 11.92it/s]

<bound method NLPDataCollator.collate_batch of <model.NLPDataCollator object at 0x7ffef9fbb070>>


Test: hatespeech_classes: 100%|██████████| 159/159 [00:13<00:00, 11.83it/s]


### 2 Epochs

In [18]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8418723545933512}

In [19]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8371149406524729}

In [20]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.3926597611174843}

### 3 Epochs

In [21]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8418723545933512}

In [27]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8325985296056062}

In [28]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.4176057674898591}

### 4 Epochs

In [25]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8469622661091527}

In [27]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8325985296056062}

In [28]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.4176057674898591}

### 5 Epochs

In [43]:
f1.compute(predictions=np.argmax(preds_dict["offensive"].predictions, axis=1),  references=preds_dict["offensive"].label_ids, average='macro' )

{'f1': 0.8387301587301587}

In [44]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech"].predictions, axis=1),  references=preds_dict["hatespeech"].label_ids, average='macro' )

{'f1': 0.8107379458902904}

In [45]:
f1.compute(predictions=np.argmax(preds_dict["hatespeech_classes"].predictions, axis=1),  references=preds_dict["hatespeech_classes"].label_ids, average='macro' )

{'f1': 0.4355893433350976}

## Loading

In [46]:
def load_model(dot_bin_file):
    multitask_model = MultitaskModel.create(
                                            model_name=model_name,
                                            model_type_dict={
                                                "offensive": transformers.AutoModelForSequenceClassification,
                                                "hatespeech": transformers.AutoModelForSequenceClassification,
                                                "hatespeech_classes": transformers.AutoModelForSequenceClassification,
                                            },
                                            model_config_dict={
                                                "offensive": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
                                                "hatespeech": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
                                                "hatespeech_classes": transformers.AutoConfig.from_pretrained(model_name, num_labels=7),
                                            },)
    model = torch.load(dot_bin_file)
    multitask_model.load_state_dict(model)
    return multitask_model