**NOTE:** 

Some additional modifications to the Multimodal-Toolkit source code ([Repository-Link](https://github.com/georgian-io/Multimodal-Toolkit)) with respect to loading the data have been performed (in particular the `load_data_from_folder` method). These changes were not tracked by version control!

In [1]:
! pip install multimodal-transformers



In [2]:
import json
import re
import sys
import pandas as pd
import numpy as np

## Preprocessing

In [3]:
import dataloader

In [4]:
train_df = dataloader.load_data(data_file = '../data/train.data.jsonl', label_file = '../data/train.label.json', perform_stemming = False)
dev_df = dataloader.load_data(data_file = '../data/dev.data.jsonl', label_file = '../data/dev.label.json', perform_stemming = False)
test_df = dataloader.load_data(data_file = '../data/test.data.jsonl', label_file = None, perform_stemming = False)

In [5]:
test_df['text'][0]

'people have been able to get out of sydney cafe during hostage situation fucking terrorists   people have been able to get out of sydney cafe during hostage situation her fingers look broken otages libres   people have been able to get out of sydney cafe definitely adding that tweet to my people have been able to get out of sydney cafe during hostage situation     cuba usa stop showing their faces  they are going to have enough trouble without forever being linked with this event i sure hope the hostage taker gets out  feet first and room temp people have been able to get out of sydney cafe during hostage situation dear mr  swat team guy  you may want to choose a long rifle for a more precise shot  this is not csi or miami vice radical islam again spreading terror  muslims are fanatical zealots p  isis syndeysiege anyone see the connection yet  how long r free and democratic nations going to allow terrorists and state sponsors of terror dictate'

## Multimodal BERT

In [6]:
from dataclasses import dataclass, field
import json
import logging
import os
from typing import Optional

from transformers import (
    AutoTokenizer,
    AutoConfig,
    Trainer,
    EvalPrediction,
    set_seed
)
from transformers.training_args import TrainingArguments

from multimodal_transformers.data import load_data_from_folder
from multimodal_transformers.model import TabularConfig
from multimodal_transformers.model import AutoModelWithTabular

logging.basicConfig(level=logging.INFO)
os.environ['COMET_MODE'] = 'DISABLED'

In [7]:
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
    model_name_or_path: str = field(
      metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
      default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
      default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
      default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )


@dataclass
class MultimodalDataTrainingArguments:
    """
    Arguments pertaining to how we combine tabular features
    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.
    """

    data_path: str = field(metadata={
                            'help': 'the path to the csv file containing the dataset'
                        })
    column_info_path: str = field(
      default=None,
      metadata={
          'help': 'the path to the json file detailing which columns are text, categorical, numerical, and the label'
    })

    column_info: dict = field(
      default=None,
      metadata={
          'help': 'a dict referencing the text, categorical, numerical, and label columns'
                  'its keys are text_cols, num_cols, cat_cols, and label_col'
    })

    categorical_encode_type: str = field(default='ohe',
                                        metadata={
                                            'help': 'sklearn encoder to use for categorical data',
                                            'choices': ['ohe', 'binary', 'label', 'none']
                                        })
    numerical_transformer_method: str = field(default='yeo_johnson',
                                            metadata={
                                                'help': 'sklearn numerical transformer to preprocess numerical data',
                                                'choices': ['yeo_johnson', 'box_cox', 'quantile_normal', 'none']
                                            })
    task: str = field(default="classification",
                    metadata={
                        "help": "The downstream training task",
                        "choices": ["classification", "regression"]
                    })

    mlp_division: int = field(default=4,
                            metadata={
                                'help': 'the ratio of the number of '
                                        'hidden dims in a current layer to the next MLP layer'
                            })
    combine_feat_method: str = field(default='individual_mlps_on_cat_and_numerical_feats_then_concat',
                                    metadata={
                                        'help': 'method to combine categorical and numerical features, '
                                                'see README for all the method'
                                    })
    mlp_dropout: float = field(default=0.1,
                              metadata={
                                'help': 'dropout ratio used for MLP layers'
                              })
    numerical_bn: bool = field(default=True,
                              metadata={
                                  'help': 'whether to use batchnorm on numerical features'
                              })
    use_simple_classifier: str = field(default=True,
                                      metadata={
                                          'help': 'whether to use single layer or MLP as final classifier'
                                      })
    mlp_act: str = field(default='relu',
                        metadata={
                            'help': 'the activation function to use for finetuning layers',
                            'choices': ['relu', 'prelu', 'sigmoid', 'tanh', 'linear']
                        })
    gating_beta: float = field(default=0.2,
                              metadata={
                                  'help': "the beta hyperparameters used for gating tabular data "
                                          "see https://www.aclweb.org/anthology/2020.acl-main.214.pdf"
                              })

    def __post_init__(self):
        assert self.column_info != self.column_info_path
        if self.column_info is None and self.column_info_path:
            with open(self.column_info_path, 'r') as f:
                self.column_info = json.load(f)

In [8]:
text_cols = ['text']
cat_cols = ['question_mark', 'contains_url', 'contains_media', 'contains_profile_background_image', 'verified', 'geo_enabled', 'has_description']
numerical_cols = ['retweet_count', 'favorite_count', 'number_urls','statuses_count', 'listed_count', 'reputation_score_1', 'reputation_score_2', 'favourites_count','length_description','follow_tweets']

column_info_dict = {
    'text_cols': text_cols,
    'num_cols': numerical_cols,
    'cat_cols': cat_cols,
    'label_col': 'label',
    'label_list': [0, 1]
}


model_args = ModelArguments(
    model_name_or_path='bert-base-uncased'
)

data_args = MultimodalDataTrainingArguments(
    data_path='.',
    combine_feat_method='gating_on_cat_and_num_feats_then_sum',
    column_info=column_info_dict,
    task='classification'
)

training_args = TrainingArguments(
    output_dir="./logs/model_name",
    logging_dir="./logs/runs",
    overwrite_output_dir=True,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=4,
    num_train_epochs=7,
    evaluate_during_training=True,
    logging_steps=25,
    eval_steps=250,
    dataloader_drop_last=True
)

set_seed(training_args.seed)

In [9]:
tokenizer_path_or_name = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
print('Specified tokenizer: ', tokenizer_path_or_name)

# Tokens automatically converted to lower_case
tokenizer = AutoTokenizer.from_pretrained(
    tokenizer_path_or_name,
    cache_dir=model_args.cache_dir
)

Specified tokenizer:  bert-base-uncased


In [10]:
# Get Datasets
train_dataset, dev_dataset, test_dataset = load_data_from_folder(train_df, dev_df, test_df,
    data_args.column_info['text_cols'],
    tokenizer,
    label_col=data_args.column_info['label_col'],
    label_list=data_args.column_info['label_list'],
    categorical_cols=data_args.column_info['cat_cols'],
    numerical_cols=data_args.column_info['num_cols'],
    sep_text_token_str=tokenizer.sep_token
)

INFO:multimodal_transformers.data.data_utils:9 numerical columns
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self[k1] = value[k2]
INFO:multimodal_transformers.data.data_utils:20 categorical columns
INFO:multimodal_transformers.data.data_utils:9 numerical columns
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self[k1] = value[k2]
INFO:multimodal_transformers.data.load_data:Text columns: ['text']
INFO:multimodal_transformers.data.load_data:Raw text example: how to respond to the murderous attack on charlie hebdo  every newspaper in the free world should pri

In [11]:
num_labels = len(np.unique(train_dataset.labels))
num_labels

2

In [12]:
config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
tabular_config = TabularConfig(num_labels=num_labels,
                               cat_feat_dim=train_dataset.cat_feats.shape[1],
                               numerical_feat_dim=train_dataset.numerical_feats.shape[1],
                               **vars(data_args))
config.tabular_config = tabular_config

In [13]:
model = AutoModelWithTabular.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        config=config,
        cache_dir=model_args.cache_dir
    )

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertWithTabular: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertWithTabular from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertWithTabular from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertWithTabular were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifi

In [14]:
import numpy as np
from scipy.special import softmax
from sklearn.metrics import (
    auc,
    precision_recall_curve,
    roc_auc_score,
    f1_score,
    confusion_matrix,
    matthews_corrcoef,
)

def calc_classification_metrics(p: EvalPrediction):
    pred_labels = np.argmax(p.predictions, axis=1)
    pred_scores = softmax(p.predictions, axis=1)[:, 1]
    labels = p.label_ids
    if len(np.unique(labels)) == 2:  # binary classification
        roc_auc_pred_score = roc_auc_score(labels, pred_scores)
        precisions, recalls, thresholds = precision_recall_curve(labels,
                                                                pred_scores)
        fscore = (2 * precisions * recalls) / (precisions + recalls)
        fscore[np.isnan(fscore)] = 0
        ix = np.argmax(fscore)
        threshold = thresholds[ix].item()
        pr_auc = auc(recalls, precisions)
        tn, fp, fn, tp = confusion_matrix(labels, pred_labels, labels=[0, 1]).ravel()
        result = {'roc_auc': roc_auc_pred_score,
                'threshold': threshold,
                'pr_auc': pr_auc,
                'recall': recalls[ix].item(),
                'precision': precisions[ix].item(), 'f1': fscore[ix].item(),
                'tn': tn.item(), 'fp': fp.item(), 'fn': fn.item(), 'tp': tp.item()
                }
    else:
        acc = (pred_labels == labels).mean()
        f1 = f1_score(y_true=labels, y_pred=pred_labels)
        result = {
          "acc": acc,
          "f1": f1,
          "acc_and_f1": (acc + f1) / 2,
          "mcc": matthews_corrcoef(labels, pred_labels)
        }

    return result

In [15]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,
    compute_metrics=calc_classification_metrics
)

In [16]:
%%time
trainer.train()

Epoch:   0%|          | 0/7 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1160 [00:00<?, ?it/s]

{'loss': 0.7431385040283203, 'learning_rate': 4.9846059113300494e-05, 'epoch': 0.021551724137931036, 'step': 25}
{'loss': 0.598366470336914, 'learning_rate': 4.9692118226600986e-05, 'epoch': 0.04310344827586207, 'step': 50}
{'loss': 0.6772987365722656, 'learning_rate': 4.9538177339901484e-05, 'epoch': 0.06465517241379311, 'step': 75}
{'loss': 0.6353677368164062, 'learning_rate': 4.938423645320197e-05, 'epoch': 0.08620689655172414, 'step': 100}
{'loss': 0.6358514404296876, 'learning_rate': 4.923029556650247e-05, 'epoch': 0.10775862068965517, 'step': 125}
{'loss': 0.6349850463867187, 'learning_rate': 4.907635467980296e-05, 'epoch': 0.12931034482758622, 'step': 150}
{'loss': 0.6305630493164063, 'learning_rate': 4.892241379310345e-05, 'epoch': 0.15086206896551724, 'step': 175}
{'loss': 0.5048910522460938, 'learning_rate': 4.876847290640394e-05, 'epoch': 0.1724137931034483, 'step': 200}
{'loss': 0.5755682373046875, 'learning_rate': 4.8614532019704434e-05, 'epoch': 0.1939655172413793, 'step'

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]



{'eval_loss': 0.5801599841150973, 'eval_roc_auc': 0.6970361180038599, 'eval_threshold': 0.2638416886329651, 'eval_pr_auc': 0.546207613836113, 'eval_recall': 0.7096774193548387, 'eval_precision': 0.4697508896797153, 'eval_f1': 0.5653104925053533, 'eval_tn': 384, 'eval_fp': 6, 'eval_fn': 166, 'eval_tp': 20, 'epoch': 0.21551724137931033, 'step': 250}
{'loss': 0.6587677001953125, 'learning_rate': 4.8306650246305424e-05, 'epoch': 0.23706896551724138, 'step': 275}
{'loss': 0.5853216552734375, 'learning_rate': 4.8152709359605915e-05, 'epoch': 0.25862068965517243, 'step': 300}
{'loss': 0.69006103515625, 'learning_rate': 4.799876847290641e-05, 'epoch': 0.2801724137931034, 'step': 325}
{'loss': 0.629112548828125, 'learning_rate': 4.78448275862069e-05, 'epoch': 0.3017241379310345, 'step': 350}
{'loss': 0.6356671142578125, 'learning_rate': 4.769088669950739e-05, 'epoch': 0.3232758620689655, 'step': 375}
{'loss': 0.5566705322265625, 'learning_rate': 4.753694581280788e-05, 'epoch': 0.344827586206896

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]



{'eval_loss': 0.7201832710868783, 'eval_roc_auc': 0.7805348773090709, 'eval_threshold': 0.04888703301548958, 'eval_pr_auc': 0.643067574554814, 'eval_recall': 0.7258064516129032, 'eval_precision': 0.5625, 'eval_f1': 0.6338028169014084, 'eval_tn': 380, 'eval_fp': 10, 'eval_fn': 131, 'eval_tp': 55, 'epoch': 0.43103448275862066, 'step': 500}
{'loss': 0.59272216796875, 'learning_rate': 4.6767241379310346e-05, 'epoch': 0.4525862068965517, 'step': 525}
{'loss': 0.469764404296875, 'learning_rate': 4.661330049261084e-05, 'epoch': 0.47413793103448276, 'step': 550}
{'loss': 0.611531982421875, 'learning_rate': 4.6459359605911336e-05, 'epoch': 0.4956896551724138, 'step': 575}
{'loss': 0.566463623046875, 'learning_rate': 4.630541871921182e-05, 'epoch': 0.5172413793103449, 'step': 600}
{'loss': 0.718939208984375, 'learning_rate': 4.615147783251232e-05, 'epoch': 0.5387931034482759, 'step': 625}
{'loss': 0.6609716796875, 'learning_rate': 4.599753694581281e-05, 'epoch': 0.5603448275862069, 'step': 650}


Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.4863107382423348, 'eval_roc_auc': 0.8271436448855803, 'eval_threshold': 0.4326234757900238, 'eval_pr_auc': 0.7039372004297799, 'eval_recall': 0.6774193548387096, 'eval_precision': 0.711864406779661, 'eval_f1': 0.6942148760330579, 'eval_tn': 351, 'eval_fp': 39, 'eval_fn': 74, 'eval_tp': 112, 'epoch': 0.646551724137931, 'step': 750}
{'loss': 0.528765869140625, 'learning_rate': 4.522783251231527e-05, 'epoch': 0.6681034482758621, 'step': 775}
{'loss': 0.491522216796875, 'learning_rate': 4.507389162561577e-05, 'epoch': 0.6896551724137931, 'step': 800}
{'loss': 0.50287353515625, 'learning_rate': 4.491995073891626e-05, 'epoch': 0.7112068965517241, 'step': 825}
{'loss': 0.661307373046875, 'learning_rate': 4.476600985221675e-05, 'epoch': 0.7327586206896551, 'step': 850}
{'loss': 0.786220703125, 'learning_rate': 4.461206896551724e-05, 'epoch': 0.7543103448275862, 'step': 875}
{'loss': 0.54729248046875, 'learning_rate': 4.4458128078817734e-05, 'epoch': 0.7758620689655172, 'step': 

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.522536869885193, 'eval_roc_auc': 0.7777915632754342, 'eval_threshold': 0.2811976969242096, 'eval_pr_auc': 0.6582510957162593, 'eval_recall': 0.7634408602150538, 'eval_precision': 0.5182481751824818, 'eval_f1': 0.6173913043478261, 'eval_tn': 369, 'eval_fp': 21, 'eval_fn': 110, 'eval_tp': 76, 'epoch': 0.8620689655172413, 'step': 1000}
{'loss': 0.58204345703125, 'learning_rate': 4.36884236453202e-05, 'epoch': 0.8836206896551724, 'step': 1025}
{'loss': 0.523984375, 'learning_rate': 4.353448275862069e-05, 'epoch': 0.9051724137931034, 'step': 1050}
{'loss': 0.60415771484375, 'learning_rate': 4.338054187192118e-05, 'epoch': 0.9267241379310345, 'step': 1075}
{'loss': 0.577509765625, 'learning_rate': 4.3226600985221674e-05, 'epoch': 0.9482758620689655, 'step': 1100}
{'loss': 0.53253173828125, 'learning_rate': 4.307266009852217e-05, 'epoch': 0.9698275862068966, 'step': 1125}
{'loss': 0.43788818359375, 'learning_rate': 4.2918719211822664e-05, 'epoch': 0.9913793103448276, 'step': 1

Iteration:   0%|          | 0/1160 [00:00<?, ?it/s]

{'loss': 0.5106982421875, 'learning_rate': 4.2764778325123155e-05, 'epoch': 1.0129310344827587, 'step': 1175}
{'loss': 0.4902001953125, 'learning_rate': 4.261083743842365e-05, 'epoch': 1.0344827586206897, 'step': 1200}
{'loss': 0.3700439453125, 'learning_rate': 4.245689655172414e-05, 'epoch': 1.0560344827586208, 'step': 1225}
{'loss': 0.67509765625, 'learning_rate': 4.230295566502464e-05, 'epoch': 1.0775862068965518, 'step': 1250}


Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]



{'eval_loss': 0.5373643950248758, 'eval_roc_auc': 0.8205128205128206, 'eval_threshold': 0.14392291009426117, 'eval_pr_auc': 0.692382827805128, 'eval_recall': 0.6989247311827957, 'eval_precision': 0.6842105263157895, 'eval_f1': 0.6914893617021276, 'eval_tn': 368, 'eval_fp': 22, 'eval_fn': 98, 'eval_tp': 88, 'epoch': 1.0775862068965518, 'step': 1250}
{'loss': 0.59179443359375, 'learning_rate': 4.214901477832512e-05, 'epoch': 1.0991379310344827, 'step': 1275}
{'loss': 0.594619140625, 'learning_rate': 4.199507389162562e-05, 'epoch': 1.1206896551724137, 'step': 1300}
{'loss': 0.53516357421875, 'learning_rate': 4.184113300492611e-05, 'epoch': 1.1422413793103448, 'step': 1325}
{'loss': 0.5799658203125, 'learning_rate': 4.16871921182266e-05, 'epoch': 1.1637931034482758, 'step': 1350}
{'loss': 0.51719482421875, 'learning_rate': 4.1533251231527095e-05, 'epoch': 1.1853448275862069, 'step': 1375}
{'loss': 0.52018798828125, 'learning_rate': 4.1379310344827587e-05, 'epoch': 1.206896551724138, 'step'

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.5615455140877101, 'eval_roc_auc': 0.775475599669148, 'eval_threshold': 0.19506855309009552, 'eval_pr_auc': 0.6368305313124294, 'eval_recall': 0.7258064516129032, 'eval_precision': 0.5510204081632653, 'eval_f1': 0.6264501160092807, 'eval_tn': 383, 'eval_fp': 7, 'eval_fn': 157, 'eval_tp': 29, 'epoch': 1.293103448275862, 'step': 1500}
{'loss': 0.67717041015625, 'learning_rate': 4.060960591133005e-05, 'epoch': 1.3146551724137931, 'step': 1525}
{'loss': 0.5598193359375, 'learning_rate': 4.045566502463054e-05, 'epoch': 1.3362068965517242, 'step': 1550}
{'loss': 0.49600830078125, 'learning_rate': 4.0301724137931035e-05, 'epoch': 1.3577586206896552, 'step': 1575}
{'loss': 0.48645751953125, 'learning_rate': 4.014778325123153e-05, 'epoch': 1.3793103448275863, 'step': 1600}
{'loss': 0.61676513671875, 'learning_rate': 3.999384236453202e-05, 'epoch': 1.4008620689655173, 'step': 1625}
{'loss': 0.58815185546875, 'learning_rate': 3.9839901477832516e-05, 'epoch': 1.4224137931034484, 'st

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]



{'eval_loss': 0.46576257836487556, 'eval_roc_auc': 0.8349186655638268, 'eval_threshold': 0.19106513261795044, 'eval_pr_auc': 0.7211410822808698, 'eval_recall': 0.7688172043010753, 'eval_precision': 0.6682242990654206, 'eval_f1': 0.715, 'eval_tn': 367, 'eval_fp': 23, 'eval_fn': 88, 'eval_tp': 98, 'epoch': 1.5086206896551724, 'step': 1750}
{'loss': 0.6202685546875, 'learning_rate': 3.9070197044334974e-05, 'epoch': 1.5301724137931034, 'step': 1775}
{'loss': 0.582861328125, 'learning_rate': 3.891625615763547e-05, 'epoch': 1.5517241379310345, 'step': 1800}
{'loss': 0.623876953125, 'learning_rate': 3.8762315270935964e-05, 'epoch': 1.5732758620689655, 'step': 1825}
{'loss': 0.5491748046875, 'learning_rate': 3.8608374384236456e-05, 'epoch': 1.5948275862068966, 'step': 1850}
{'loss': 0.53080078125, 'learning_rate': 3.845443349753695e-05, 'epoch': 1.6163793103448276, 'step': 1875}
{'loss': 0.570478515625, 'learning_rate': 3.830049261083744e-05, 'epoch': 1.6379310344827587, 'step': 1900}
{'loss':

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.419994552516275, 'eval_roc_auc': 0.8641439205955334, 'eval_threshold': 0.3068794906139374, 'eval_pr_auc': 0.7197261611052437, 'eval_recall': 0.7688172043010753, 'eval_precision': 0.7333333333333333, 'eval_f1': 0.7506561679790027, 'eval_tn': 343, 'eval_fp': 47, 'eval_fn': 53, 'eval_tp': 133, 'epoch': 1.7241379310344827, 'step': 2000}
{'loss': 0.5343896484375, 'learning_rate': 3.7530788177339904e-05, 'epoch': 1.7456896551724137, 'step': 2025}
{'loss': 0.4818359375, 'learning_rate': 3.7376847290640395e-05, 'epoch': 1.7672413793103448, 'step': 2050}
{'loss': 0.4035546875, 'learning_rate': 3.722290640394089e-05, 'epoch': 1.7887931034482758, 'step': 2075}
{'loss': 0.5372705078125, 'learning_rate': 3.7068965517241385e-05, 'epoch': 1.8103448275862069, 'step': 2100}
{'loss': 0.5917431640625, 'learning_rate': 3.691502463054187e-05, 'epoch': 1.831896551724138, 'step': 2125}
{'loss': 0.48666015625, 'learning_rate': 3.676108374384237e-05, 'epoch': 1.853448275862069, 'step': 2150}
{'

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]



{'eval_loss': 0.464904288864798, 'eval_roc_auc': 0.8254618141714916, 'eval_threshold': 0.3514133393764496, 'eval_pr_auc': 0.6337953323465277, 'eval_recall': 0.8440860215053764, 'eval_precision': 0.5880149812734082, 'eval_f1': 0.6931567328918322, 'eval_tn': 330, 'eval_fp': 60, 'eval_fn': 91, 'eval_tp': 95, 'epoch': 1.9396551724137931, 'step': 2250}
{'loss': 0.3899951171875, 'learning_rate': 3.5991379310344833e-05, 'epoch': 1.9612068965517242, 'step': 2275}
{'loss': 0.3845556640625, 'learning_rate': 3.583743842364532e-05, 'epoch': 1.9827586206896552, 'step': 2300}


Iteration:   0%|          | 0/1160 [00:00<?, ?it/s]

{'loss': 0.4716748046875, 'learning_rate': 3.568349753694582e-05, 'epoch': 2.0043103448275863, 'step': 2325}
{'loss': 0.5469873046875, 'learning_rate': 3.552955665024631e-05, 'epoch': 2.0258620689655173, 'step': 2350}
{'loss': 0.6546142578125, 'learning_rate': 3.53756157635468e-05, 'epoch': 2.0474137931034484, 'step': 2375}
{'loss': 0.5283984375, 'learning_rate': 3.522167487684729e-05, 'epoch': 2.0689655172413794, 'step': 2400}
{'loss': 0.6766015625, 'learning_rate': 3.506773399014778e-05, 'epoch': 2.0905172413793105, 'step': 2425}
{'loss': 0.4710400390625, 'learning_rate': 3.4913793103448275e-05, 'epoch': 2.1120689655172415, 'step': 2450}
{'loss': 0.511533203125, 'learning_rate': 3.475985221674877e-05, 'epoch': 2.1336206896551726, 'step': 2475}
{'loss': 0.779091796875, 'learning_rate': 3.4605911330049265e-05, 'epoch': 2.1551724137931036, 'step': 2500}


Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.5669361661291785, 'eval_roc_auc': 0.8469120485249517, 'eval_threshold': 0.0724063590168953, 'eval_pr_auc': 0.7148532817456328, 'eval_recall': 0.7258064516129032, 'eval_precision': 0.7219251336898396, 'eval_f1': 0.7238605898123325, 'eval_tn': 350, 'eval_fp': 40, 'eval_fn': 67, 'eval_tp': 119, 'epoch': 2.1551724137931036, 'step': 2500}
{'loss': 0.4835888671875, 'learning_rate': 3.4451970443349756e-05, 'epoch': 2.1767241379310347, 'step': 2525}
{'loss': 0.693818359375, 'learning_rate': 3.429802955665025e-05, 'epoch': 2.1982758620689653, 'step': 2550}
{'loss': 0.5179345703125, 'learning_rate': 3.414408866995074e-05, 'epoch': 2.2198275862068964, 'step': 2575}
{'loss': 0.4632421875, 'learning_rate': 3.399014778325123e-05, 'epoch': 2.2413793103448274, 'step': 2600}
{'loss': 0.367177734375, 'learning_rate': 3.383620689655172e-05, 'epoch': 2.2629310344827585, 'step': 2625}
{'loss': 0.4860595703125, 'learning_rate': 3.368226600985222e-05, 'epoch': 2.2844827586206895, 'step': 2650

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]



{'eval_loss': 0.4660772791960173, 'eval_roc_auc': 0.8566308243727598, 'eval_threshold': 0.5068467855453491, 'eval_pr_auc': 0.6979330215585184, 'eval_recall': 0.7795698924731183, 'eval_precision': 0.6651376146788991, 'eval_f1': 0.7178217821782178, 'eval_tn': 312, 'eval_fp': 78, 'eval_fn': 40, 'eval_tp': 146, 'epoch': 2.3706896551724137, 'step': 2750}
{'loss': 0.46541015625, 'learning_rate': 3.2912561576354686e-05, 'epoch': 2.3922413793103448, 'step': 2775}
{'loss': 0.5839404296875, 'learning_rate': 3.275862068965517e-05, 'epoch': 2.413793103448276, 'step': 2800}
{'loss': 0.5476025390625, 'learning_rate': 3.260467980295567e-05, 'epoch': 2.435344827586207, 'step': 2825}
{'loss': 0.4786376953125, 'learning_rate': 3.2450738916256154e-05, 'epoch': 2.456896551724138, 'step': 2850}
{'loss': 0.4753369140625, 'learning_rate': 3.229679802955665e-05, 'epoch': 2.478448275862069, 'step': 2875}
{'loss': 0.495283203125, 'learning_rate': 3.2142857142857144e-05, 'epoch': 2.5, 'step': 2900}
{'loss': 0.50

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.4517242595449918, 'eval_roc_auc': 0.8659360352908739, 'eval_threshold': 0.44114598631858826, 'eval_pr_auc': 0.7035882310318083, 'eval_recall': 0.8494623655913979, 'eval_precision': 0.6556016597510373, 'eval_f1': 0.7400468384074941, 'eval_tn': 323, 'eval_fp': 67, 'eval_fn': 40, 'eval_tp': 146, 'epoch': 2.586206896551724, 'step': 3000}
{'loss': 0.4416455078125, 'learning_rate': 3.137315270935961e-05, 'epoch': 2.6077586206896552, 'step': 3025}
{'loss': 0.4362109375, 'learning_rate': 3.12192118226601e-05, 'epoch': 2.6293103448275863, 'step': 3050}
{'loss': 0.4464697265625, 'learning_rate': 3.106527093596059e-05, 'epoch': 2.6508620689655173, 'step': 3075}
{'loss': 0.5187158203125, 'learning_rate': 3.0911330049261084e-05, 'epoch': 2.6724137931034484, 'step': 3100}
{'loss': 0.5334619140625, 'learning_rate': 3.0757389162561575e-05, 'epoch': 2.6939655172413794, 'step': 3125}
{'loss': 0.4197998046875, 'learning_rate': 3.060344827586207e-05, 'epoch': 2.7155172413793105, 'step': 31

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.42338745466743904, 'eval_roc_auc': 0.871119382409705, 'eval_threshold': 0.49385929107666016, 'eval_pr_auc': 0.7083166971723365, 'eval_recall': 0.8548387096774194, 'eval_precision': 0.6411290322580645, 'eval_f1': 0.7327188940092165, 'eval_tn': 303, 'eval_fp': 87, 'eval_fn': 30, 'eval_tp': 156, 'epoch': 2.8017241379310347, 'step': 3250}
{'loss': 0.405322265625, 'learning_rate': 2.983374384236453e-05, 'epoch': 2.8232758620689653, 'step': 3275}
{'loss': 0.5442236328125, 'learning_rate': 2.9679802955665027e-05, 'epoch': 2.844827586206897, 'step': 3300}
{'loss': 0.4300927734375, 'learning_rate': 2.952586206896552e-05, 'epoch': 2.8663793103448274, 'step': 3325}
{'loss': 0.506669921875, 'learning_rate': 2.937192118226601e-05, 'epoch': 2.887931034482759, 'step': 3350}
{'loss': 0.5143017578125, 'learning_rate': 2.9217980295566505e-05, 'epoch': 2.9094827586206895, 'step': 3375}
{'loss': 0.425888671875, 'learning_rate': 2.9064039408866993e-05, 'epoch': 2.9310344827586206, 'step': 3

Iteration:   0%|          | 0/1160 [00:00<?, ?it/s]

{'loss': 0.4899462890625, 'learning_rate': 2.844827586206897e-05, 'epoch': 3.0172413793103448, 'step': 3500}


Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.5649446768220514, 'eval_roc_auc': 0.8665150261924455, 'eval_threshold': 0.03289210423827171, 'eval_pr_auc': 0.7207479559072935, 'eval_recall': 0.7849462365591398, 'eval_precision': 0.6790697674418604, 'eval_f1': 0.7281795511221945, 'eval_tn': 335, 'eval_fp': 55, 'eval_fn': 54, 'eval_tp': 132, 'epoch': 3.0172413793103448, 'step': 3500}
{'loss': 0.5022705078125, 'learning_rate': 2.8294334975369458e-05, 'epoch': 3.038793103448276, 'step': 3525}
{'loss': 0.4496142578125, 'learning_rate': 2.8140394088669953e-05, 'epoch': 3.060344827586207, 'step': 3550}
{'loss': 0.5105322265625, 'learning_rate': 2.7986453201970448e-05, 'epoch': 3.081896551724138, 'step': 3575}
{'loss': 0.54244140625, 'learning_rate': 2.7832512315270936e-05, 'epoch': 3.103448275862069, 'step': 3600}
{'loss': 0.431201171875, 'learning_rate': 2.767857142857143e-05, 'epoch': 3.125, 'step': 3625}
{'loss': 0.4260009765625, 'learning_rate': 2.752463054187192e-05, 'epoch': 3.146551724137931, 'step': 3650}
{'loss': 0

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.48641926733156043, 'eval_roc_auc': 0.8710780259167356, 'eval_threshold': 0.037127867341041565, 'eval_pr_auc': 0.7089396998945479, 'eval_recall': 0.8655913978494624, 'eval_precision': 0.6363636363636364, 'eval_f1': 0.7334851936218678, 'eval_tn': 308, 'eval_fp': 82, 'eval_fn': 34, 'eval_tp': 152, 'epoch': 3.2327586206896552, 'step': 3750}
{'loss': 0.3850390625, 'learning_rate': 2.6754926108374384e-05, 'epoch': 3.2543103448275863, 'step': 3775}
{'loss': 0.5754736328125, 'learning_rate': 2.660098522167488e-05, 'epoch': 3.2758620689655173, 'step': 3800}
{'loss': 0.529716796875, 'learning_rate': 2.6447044334975367e-05, 'epoch': 3.2974137931034484, 'step': 3825}
{'loss': 0.53396484375, 'learning_rate': 2.6293103448275862e-05, 'epoch': 3.3189655172413794, 'step': 3850}
{'loss': 0.882734375, 'learning_rate': 2.6139162561576357e-05, 'epoch': 3.3405172413793105, 'step': 3875}
{'loss': 0.597919921875, 'learning_rate': 2.598522167487685e-05, 'epoch': 3.3620689655172415, 'step': 3900

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.6189430632059358, 'eval_roc_auc': 0.8461814171491591, 'eval_threshold': 0.04708851873874664, 'eval_pr_auc': 0.7167240680016373, 'eval_recall': 0.7365591397849462, 'eval_precision': 0.6919191919191919, 'eval_f1': 0.7135416666666667, 'eval_tn': 352, 'eval_fp': 38, 'eval_fn': 79, 'eval_tp': 107, 'epoch': 3.4482758620689653, 'step': 4000}
{'loss': 0.646884765625, 'learning_rate': 2.521551724137931e-05, 'epoch': 3.469827586206897, 'step': 4025}
{'loss': 0.56064453125, 'learning_rate': 2.5061576354679805e-05, 'epoch': 3.4913793103448274, 'step': 4050}
{'loss': 0.292568359375, 'learning_rate': 2.4907635467980297e-05, 'epoch': 3.512931034482759, 'step': 4075}
{'loss': 0.558564453125, 'learning_rate': 2.475369458128079e-05, 'epoch': 3.5344827586206895, 'step': 4100}
{'loss': 0.474375, 'learning_rate': 2.4599753694581283e-05, 'epoch': 3.5560344827586206, 'step': 4125}
{'loss': 0.710673828125, 'learning_rate': 2.4445812807881775e-05, 'epoch': 3.5775862068965516, 'step': 4150}
{'lo

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.5208930997177958, 'eval_roc_auc': 0.850799558864075, 'eval_threshold': 0.11902419477701187, 'eval_pr_auc': 0.7214700884890975, 'eval_recall': 0.7311827956989247, 'eval_precision': 0.6974358974358974, 'eval_f1': 0.7139107611548556, 'eval_tn': 352, 'eval_fp': 38, 'eval_fn': 75, 'eval_tp': 111, 'epoch': 3.663793103448276, 'step': 4250}
{'loss': 0.53111328125, 'learning_rate': 2.3676108374384236e-05, 'epoch': 3.685344827586207, 'step': 4275}
{'loss': 0.624677734375, 'learning_rate': 2.3522167487684728e-05, 'epoch': 3.706896551724138, 'step': 4300}
{'loss': 0.5360546875, 'learning_rate': 2.3368226600985223e-05, 'epoch': 3.728448275862069, 'step': 4325}
{'loss': 0.35998046875, 'learning_rate': 2.3214285714285715e-05, 'epoch': 3.75, 'step': 4350}
{'loss': 0.714765625, 'learning_rate': 2.306034482758621e-05, 'epoch': 3.771551724137931, 'step': 4375}
{'loss': 0.51626953125, 'learning_rate': 2.29064039408867e-05, 'epoch': 3.793103448275862, 'step': 4400}
{'loss': 0.51109375, 'lea

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.4975456171151664, 'eval_roc_auc': 0.8523435346015992, 'eval_threshold': 0.5979310274124146, 'eval_pr_auc': 0.6779104040282973, 'eval_recall': 0.8440860215053764, 'eval_precision': 0.6108949416342413, 'eval_f1': 0.7088036117381489, 'eval_tn': 265, 'eval_fp': 125, 'eval_fn': 21, 'eval_tp': 165, 'epoch': 3.8793103448275863, 'step': 4500}
{'loss': 0.47353515625, 'learning_rate': 2.2136699507389163e-05, 'epoch': 3.9008620689655173, 'step': 4525}
{'loss': 0.525361328125, 'learning_rate': 2.1982758620689654e-05, 'epoch': 3.9224137931034484, 'step': 4550}
{'loss': 0.528974609375, 'learning_rate': 2.182881773399015e-05, 'epoch': 3.9439655172413794, 'step': 4575}
{'loss': 0.52814453125, 'learning_rate': 2.1674876847290644e-05, 'epoch': 3.9655172413793105, 'step': 4600}
{'loss': 0.65814453125, 'learning_rate': 2.1520935960591136e-05, 'epoch': 3.987068965517241, 'step': 4625}


Iteration:   0%|          | 0/1160 [00:00<?, ?it/s]

{'loss': 0.495263671875, 'learning_rate': 2.1366995073891627e-05, 'epoch': 4.008620689655173, 'step': 4650}
{'loss': 0.44763671875, 'learning_rate': 2.121305418719212e-05, 'epoch': 4.030172413793103, 'step': 4675}
{'loss': 0.43103515625, 'learning_rate': 2.105911330049261e-05, 'epoch': 4.051724137931035, 'step': 4700}
{'loss': 0.463740234375, 'learning_rate': 2.0905172413793102e-05, 'epoch': 4.073275862068965, 'step': 4725}
{'loss': 0.61275390625, 'learning_rate': 2.0751231527093597e-05, 'epoch': 4.094827586206897, 'step': 4750}


Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.4556386619288888, 'eval_roc_auc': 0.8748276812792942, 'eval_threshold': 0.5757675170898438, 'eval_pr_auc': 0.7409062544913981, 'eval_recall': 0.7741935483870968, 'eval_precision': 0.6923076923076923, 'eval_f1': 0.7309644670050761, 'eval_tn': 305, 'eval_fp': 85, 'eval_fn': 38, 'eval_tp': 148, 'epoch': 4.094827586206897, 'step': 4750}
{'loss': 0.337265625, 'learning_rate': 2.059729064039409e-05, 'epoch': 4.116379310344827, 'step': 4775}
{'loss': 0.53203125, 'learning_rate': 2.0443349753694584e-05, 'epoch': 4.137931034482759, 'step': 4800}
{'loss': 0.36568359375, 'learning_rate': 2.0289408866995076e-05, 'epoch': 4.1594827586206895, 'step': 4825}
{'loss': 0.569736328125, 'learning_rate': 2.0135467980295567e-05, 'epoch': 4.181034482758621, 'step': 4850}
{'loss': 0.333271484375, 'learning_rate': 1.9981527093596062e-05, 'epoch': 4.202586206896552, 'step': 4875}
{'loss': 0.469482421875, 'learning_rate': 1.9827586206896554e-05, 'epoch': 4.224137931034483, 'step': 4900}
{'loss': 

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.5169872652946247, 'eval_roc_auc': 0.8646126275158534, 'eval_threshold': 0.028015566989779472, 'eval_pr_auc': 0.704651533690468, 'eval_recall': 0.8440860215053764, 'eval_precision': 0.6624472573839663, 'eval_f1': 0.7423167848699764, 'eval_tn': 331, 'eval_fp': 59, 'eval_fn': 47, 'eval_tp': 139, 'epoch': 4.310344827586207, 'step': 5000}
{'loss': 0.639482421875, 'learning_rate': 1.9057881773399015e-05, 'epoch': 4.331896551724138, 'step': 5025}
{'loss': 0.634462890625, 'learning_rate': 1.890394088669951e-05, 'epoch': 4.353448275862069, 'step': 5050}
{'loss': 0.490771484375, 'learning_rate': 1.8750000000000002e-05, 'epoch': 4.375, 'step': 5075}
{'loss': 0.51037109375, 'learning_rate': 1.8596059113300493e-05, 'epoch': 4.396551724137931, 'step': 5100}
{'loss': 0.563798828125, 'learning_rate': 1.8442118226600985e-05, 'epoch': 4.418103448275862, 'step': 5125}
{'loss': 0.506962890625, 'learning_rate': 1.828817733990148e-05, 'epoch': 4.439655172413793, 'step': 5150}
{'loss': 0.5066

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.5685216512324082, 'eval_roc_auc': 0.8497105045492142, 'eval_threshold': 0.06672349572181702, 'eval_pr_auc': 0.7382196415970713, 'eval_recall': 0.7365591397849462, 'eval_precision': 0.7098445595854922, 'eval_f1': 0.7229551451187335, 'eval_tn': 377, 'eval_fp': 13, 'eval_fn': 99, 'eval_tp': 87, 'epoch': 4.525862068965517, 'step': 5250}
{'loss': 0.5515234375, 'learning_rate': 1.751847290640394e-05, 'epoch': 4.547413793103448, 'step': 5275}
{'loss': 0.584912109375, 'learning_rate': 1.7364532019704436e-05, 'epoch': 4.568965517241379, 'step': 5300}
{'loss': 0.515625, 'learning_rate': 1.7210591133004928e-05, 'epoch': 4.5905172413793105, 'step': 5325}
{'loss': 0.365, 'learning_rate': 1.705665024630542e-05, 'epoch': 4.612068965517241, 'step': 5350}
{'loss': 0.609609375, 'learning_rate': 1.690270935960591e-05, 'epoch': 4.633620689655173, 'step': 5375}
{'loss': 0.637578125, 'learning_rate': 1.6748768472906403e-05, 'epoch': 4.655172413793103, 'step': 5400}
{'loss': 0.530361328125, '

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.5758973348937515, 'eval_roc_auc': 0.8772539288668322, 'eval_threshold': 0.02787007763981819, 'eval_pr_auc': 0.7465627151681155, 'eval_recall': 0.8064516129032258, 'eval_precision': 0.6976744186046512, 'eval_f1': 0.7481296758104737, 'eval_tn': 363, 'eval_fp': 27, 'eval_fn': 68, 'eval_tp': 118, 'epoch': 4.741379310344827, 'step': 5500}
{'loss': 0.892880859375, 'learning_rate': 1.5979064039408868e-05, 'epoch': 4.762931034482759, 'step': 5525}
{'loss': 0.459482421875, 'learning_rate': 1.582512315270936e-05, 'epoch': 4.7844827586206895, 'step': 5550}
{'loss': 0.5320703125, 'learning_rate': 1.5671182266009854e-05, 'epoch': 4.806034482758621, 'step': 5575}
{'loss': 0.43990234375, 'learning_rate': 1.5517241379310346e-05, 'epoch': 4.827586206896552, 'step': 5600}
{'loss': 0.344609375, 'learning_rate': 1.5363300492610837e-05, 'epoch': 4.849137931034483, 'step': 5625}
{'loss': 0.653134765625, 'learning_rate': 1.520935960591133e-05, 'epoch': 4.870689655172414, 'step': 5650}
{'loss'

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.6749545709721537, 'eval_roc_auc': 0.856368899917287, 'eval_threshold': 0.032892435789108276, 'eval_pr_auc': 0.7490765312433156, 'eval_recall': 0.7043010752688172, 'eval_precision': 0.7485714285714286, 'eval_f1': 0.7257617728531857, 'eval_tn': 368, 'eval_fp': 22, 'eval_fn': 82, 'eval_tp': 104, 'epoch': 4.956896551724138, 'step': 5750}
{'loss': 0.625234375, 'learning_rate': 1.4439655172413794e-05, 'epoch': 4.978448275862069, 'step': 5775}
{'loss': 0.799111328125, 'learning_rate': 1.4285714285714285e-05, 'epoch': 5.0, 'step': 5800}


Iteration:   0%|          | 0/1160 [00:00<?, ?it/s]

{'loss': 0.725986328125, 'learning_rate': 1.4131773399014777e-05, 'epoch': 5.021551724137931, 'step': 5825}
{'loss': 0.55177734375, 'learning_rate': 1.3977832512315272e-05, 'epoch': 5.043103448275862, 'step': 5850}
{'loss': 0.492275390625, 'learning_rate': 1.3823891625615765e-05, 'epoch': 5.064655172413793, 'step': 5875}
{'loss': 0.45455078125, 'learning_rate': 1.3669950738916257e-05, 'epoch': 5.086206896551724, 'step': 5900}
{'loss': 0.34447265625, 'learning_rate': 1.3516009852216749e-05, 'epoch': 5.107758620689655, 'step': 5925}
{'loss': 0.47474609375, 'learning_rate': 1.336206896551724e-05, 'epoch': 5.129310344827586, 'step': 5950}
{'loss': 0.497333984375, 'learning_rate': 1.3208128078817735e-05, 'epoch': 5.150862068965517, 'step': 5975}
{'loss': 0.454658203125, 'learning_rate': 1.3054187192118228e-05, 'epoch': 5.172413793103448, 'step': 6000}


Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.5532323764378412, 'eval_roc_auc': 0.8865315687896334, 'eval_threshold': 0.06656220555305481, 'eval_pr_auc': 0.753677047537662, 'eval_recall': 0.7903225806451613, 'eval_precision': 0.7424242424242424, 'eval_f1': 0.7656250000000001, 'eval_tn': 347, 'eval_fp': 43, 'eval_fn': 46, 'eval_tp': 140, 'epoch': 5.172413793103448, 'step': 6000}
{'loss': 0.567880859375, 'learning_rate': 1.290024630541872e-05, 'epoch': 5.193965517241379, 'step': 6025}
{'loss': 0.332763671875, 'learning_rate': 1.2746305418719212e-05, 'epoch': 5.2155172413793105, 'step': 6050}
{'loss': 0.45935546875, 'learning_rate': 1.2592364532019705e-05, 'epoch': 5.237068965517241, 'step': 6075}
{'loss': 0.58078125, 'learning_rate': 1.2438423645320198e-05, 'epoch': 5.258620689655173, 'step': 6100}
{'loss': 0.7459375, 'learning_rate': 1.228448275862069e-05, 'epoch': 5.280172413793103, 'step': 6125}
{'loss': 0.346416015625, 'learning_rate': 1.2130541871921183e-05, 'epoch': 5.301724137931035, 'step': 6150}
{'loss': 0.5

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.55595020933025, 'eval_roc_auc': 0.8748276812792941, 'eval_threshold': 0.20050129294395447, 'eval_pr_auc': 0.7497009374517275, 'eval_recall': 0.7634408602150538, 'eval_precision': 0.7675675675675676, 'eval_f1': 0.7654986522911051, 'eval_tn': 347, 'eval_fp': 43, 'eval_fn': 45, 'eval_tp': 141, 'epoch': 5.387931034482759, 'step': 6250}
{'loss': 0.671044921875, 'learning_rate': 1.1360837438423645e-05, 'epoch': 5.4094827586206895, 'step': 6275}
{'loss': 0.3534375, 'learning_rate': 1.1206896551724138e-05, 'epoch': 5.431034482758621, 'step': 6300}
{'loss': 0.47353515625, 'learning_rate': 1.1052955665024631e-05, 'epoch': 5.452586206896552, 'step': 6325}
{'loss': 0.620458984375, 'learning_rate': 1.0899014778325124e-05, 'epoch': 5.474137931034483, 'step': 6350}
{'loss': 0.71244140625, 'learning_rate': 1.0745073891625616e-05, 'epoch': 5.495689655172414, 'step': 6375}
{'loss': 0.378056640625, 'learning_rate': 1.0591133004926108e-05, 'epoch': 5.517241379310345, 'step': 6400}
{'loss':

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.48086353130121195, 'eval_roc_auc': 0.8793631100082713, 'eval_threshold': 0.06676262617111206, 'eval_pr_auc': 0.7789063346314855, 'eval_recall': 0.7365591397849462, 'eval_precision': 0.8058823529411765, 'eval_f1': 0.7696629213483146, 'eval_tn': 365, 'eval_fp': 25, 'eval_fn': 58, 'eval_tp': 128, 'epoch': 5.603448275862069, 'step': 6500}
{'loss': 0.445263671875, 'learning_rate': 9.821428571428573e-06, 'epoch': 5.625, 'step': 6525}
{'loss': 0.53009765625, 'learning_rate': 9.667487684729066e-06, 'epoch': 5.646551724137931, 'step': 6550}
{'loss': 0.302900390625, 'learning_rate': 9.513546798029557e-06, 'epoch': 5.668103448275862, 'step': 6575}
{'loss': 0.4135546875, 'learning_rate': 9.359605911330049e-06, 'epoch': 5.689655172413794, 'step': 6600}
{'loss': 0.637666015625, 'learning_rate': 9.205665024630542e-06, 'epoch': 5.711206896551724, 'step': 6625}
{'loss': 0.49921875, 'learning_rate': 9.051724137931036e-06, 'epoch': 5.732758620689655, 'step': 6650}
{'loss': 0.647216796875,

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 1.379079156451755, 'eval_roc_auc': 0.7807416597739179, 'eval_threshold': 0.9042877554893494, 'eval_pr_auc': 0.71571191769907, 'eval_recall': 0.7150537634408602, 'eval_precision': 0.7823529411764706, 'eval_f1': 0.747191011235955, 'eval_tn': 78, 'eval_fp': 312, 'eval_fn': 34, 'eval_tp': 152, 'epoch': 5.818965517241379, 'step': 6750}
{'loss': 0.389287109375, 'learning_rate': 8.282019704433499e-06, 'epoch': 5.8405172413793105, 'step': 6775}
{'loss': 0.671455078125, 'learning_rate': 8.12807881773399e-06, 'epoch': 5.862068965517241, 'step': 6800}
{'loss': 0.6459375, 'learning_rate': 7.974137931034484e-06, 'epoch': 5.883620689655173, 'step': 6825}
{'loss': 0.56134765625, 'learning_rate': 7.820197044334975e-06, 'epoch': 5.905172413793103, 'step': 6850}
{'loss': 0.492099609375, 'learning_rate': 7.666256157635469e-06, 'epoch': 5.926724137931035, 'step': 6875}
{'loss': 0.569150390625, 'learning_rate': 7.512315270935962e-06, 'epoch': 5.948275862068965, 'step': 6900}
{'loss': 0.562714

Iteration:   0%|          | 0/1160 [00:00<?, ?it/s]

{'loss': 0.335634765625, 'learning_rate': 7.050492610837439e-06, 'epoch': 6.012931034482759, 'step': 6975}
{'loss': 0.507880859375, 'learning_rate': 6.896551724137932e-06, 'epoch': 6.0344827586206895, 'step': 7000}


Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.6492021766801676, 'eval_roc_auc': 0.8220567962503447, 'eval_threshold': 0.9369513988494873, 'eval_pr_auc': 0.7159982631344622, 'eval_recall': 0.7526881720430108, 'eval_precision': 0.7329842931937173, 'eval_f1': 0.7427055702917773, 'eval_tn': 311, 'eval_fp': 79, 'eval_fn': 31, 'eval_tp': 155, 'epoch': 6.0344827586206895, 'step': 7000}
{'loss': 0.50826171875, 'learning_rate': 6.742610837438423e-06, 'epoch': 6.056034482758621, 'step': 7025}
{'loss': 0.29529296875, 'learning_rate': 6.5886699507389166e-06, 'epoch': 6.077586206896552, 'step': 7050}
{'loss': 0.39466796875, 'learning_rate': 6.434729064039409e-06, 'epoch': 6.099137931034483, 'step': 7075}
{'loss': 0.497900390625, 'learning_rate': 6.280788177339902e-06, 'epoch': 6.120689655172414, 'step': 7100}
{'loss': 0.328203125, 'learning_rate': 6.126847290640395e-06, 'epoch': 6.142241379310345, 'step': 7125}
{'loss': 0.51177734375, 'learning_rate': 5.972906403940887e-06, 'epoch': 6.163793103448276, 'step': 7150}
{'loss': 0.5

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.4470876540678243, 'eval_roc_auc': 0.8831541218637993, 'eval_threshold': 0.06987116485834122, 'eval_pr_auc': 0.6874782232058522, 'eval_recall': 0.8978494623655914, 'eval_precision': 0.6929460580912863, 'eval_f1': 0.7822014051522249, 'eval_tn': 323, 'eval_fp': 67, 'eval_fn': 31, 'eval_tp': 155, 'epoch': 6.25, 'step': 7250}
{'loss': 0.5339453125, 'learning_rate': 5.2032019704433495e-06, 'epoch': 6.271551724137931, 'step': 7275}
{'loss': 0.45201171875, 'learning_rate': 5.049261083743843e-06, 'epoch': 6.293103448275862, 'step': 7300}
{'loss': 0.488017578125, 'learning_rate': 4.895320197044335e-06, 'epoch': 6.314655172413793, 'step': 7325}
{'loss': 0.370537109375, 'learning_rate': 4.741379310344828e-06, 'epoch': 6.336206896551724, 'step': 7350}
{'loss': 0.31990234375, 'learning_rate': 4.58743842364532e-06, 'epoch': 6.357758620689655, 'step': 7375}
{'loss': 0.424130859375, 'learning_rate': 4.4334975369458135e-06, 'epoch': 6.379310344827586, 'step': 7400}
{'loss': 0.63750976562

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.5060471477659626, 'eval_roc_auc': 0.8935001378549765, 'eval_threshold': 0.101682148873806, 'eval_pr_auc': 0.745817038830868, 'eval_recall': 0.8548387096774194, 'eval_precision': 0.7429906542056075, 'eval_f1': 0.795, 'eval_tn': 342, 'eval_fp': 48, 'eval_fn': 40, 'eval_tp': 146, 'epoch': 6.4655172413793105, 'step': 7500}
{'loss': 0.295908203125, 'learning_rate': 3.6637931034482757e-06, 'epoch': 6.487068965517241, 'step': 7525}
{'loss': 0.315927734375, 'learning_rate': 3.5098522167487686e-06, 'epoch': 6.508620689655173, 'step': 7550}
{'loss': 0.579580078125, 'learning_rate': 3.3559113300492615e-06, 'epoch': 6.530172413793103, 'step': 7575}
{'loss': 0.483779296875, 'learning_rate': 3.201970443349754e-06, 'epoch': 6.551724137931035, 'step': 7600}
{'loss': 0.259345703125, 'learning_rate': 3.0480295566502464e-06, 'epoch': 6.573275862068965, 'step': 7625}
{'loss': 0.44958984375, 'learning_rate': 2.894088669950739e-06, 'epoch': 6.594827586206897, 'step': 7650}
{'loss': 0.4141992

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.545825412729755, 'eval_roc_auc': 0.8860904328646263, 'eval_threshold': 0.04786735773086548, 'eval_pr_auc': 0.7358323935900215, 'eval_recall': 0.8655913978494624, 'eval_precision': 0.706140350877193, 'eval_f1': 0.7777777777777778, 'eval_tn': 341, 'eval_fp': 49, 'eval_fn': 44, 'eval_tp': 142, 'epoch': 6.681034482758621, 'step': 7750}
{'loss': 0.476474609375, 'learning_rate': 2.124384236453202e-06, 'epoch': 6.702586206896552, 'step': 7775}
{'loss': 0.49689453125, 'learning_rate': 1.970443349753695e-06, 'epoch': 6.724137931034483, 'step': 7800}
{'loss': 0.438505859375, 'learning_rate': 1.816502463054187e-06, 'epoch': 6.745689655172414, 'step': 7825}
{'loss': 0.677421875, 'learning_rate': 1.66256157635468e-06, 'epoch': 6.767241379310345, 'step': 7850}
{'loss': 0.5513671875, 'learning_rate': 1.5086206896551726e-06, 'epoch': 6.788793103448276, 'step': 7875}
{'loss': 0.315, 'learning_rate': 1.3546798029556653e-06, 'epoch': 6.810344827586206, 'step': 7900}
{'loss': 0.33255859375

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.5048482806370076, 'eval_roc_auc': 0.8961538461538461, 'eval_threshold': 0.0877491757273674, 'eval_pr_auc': 0.7455978384892411, 'eval_recall': 0.8763440860215054, 'eval_precision': 0.7244444444444444, 'eval_f1': 0.7931873479318734, 'eval_tn': 337, 'eval_fp': 53, 'eval_fn': 38, 'eval_tp': 148, 'epoch': 6.896551724137931, 'step': 8000}
{'loss': 0.45513671875, 'learning_rate': 5.849753694581281e-07, 'epoch': 6.918103448275862, 'step': 8025}
{'loss': 0.2897265625, 'learning_rate': 4.3103448275862073e-07, 'epoch': 6.939655172413794, 'step': 8050}
{'loss': 0.34822265625, 'learning_rate': 2.7709359605911334e-07, 'epoch': 6.961206896551724, 'step': 8075}
{'loss': 0.6094140625, 'learning_rate': 1.2315270935960593e-07, 'epoch': 6.982758620689655, 'step': 8100}
CPU times: user 1h 35min 9s, sys: 30min 58s, total: 2h 6min 8s
Wall time: 2h 7min 54s


TrainOutput(global_step=8120, training_loss=0.5218131831126848)

In [17]:
# Load TensorBoard notebook extension
%load_ext tensorboard

In [18]:
%tensorboard --logdir ./logs/runs

In [19]:
trainer.save_model("./multimodal_bert/multimodal_bert_v37")

### Inference

In [20]:
# Need to reload trainer because drop_last = True was set during training 
# Otherwise the inference would not work on how test sets
pred_trained = Trainer(
    model=model
)

In [21]:
result = pred_trained.predict(test_dataset=test_dataset).predictions

Prediction:   0%|          | 0/73 [00:00<?, ?it/s]

In [22]:
from scipy.special import softmax

pred_labels = np.argmax(result, axis=1)
pred_scores = softmax(result, axis=1)[:, 1]

In [23]:
predicted_labels = [dataloader.convert_prediction(pred) for pred in pred_labels]

In [24]:
output = pd.DataFrame({'id':test_df.id,'target':predicted_labels})
output

Unnamed: 0,id,target
0,544382249178001408,rumour
1,525027317551079424,rumour
2,544273220128739329,rumour
3,499571799764770816,non-rumour
4,552844104418091008,non-rumour
...,...,...
576,553581227165642752,non-rumour
577,552816302780579840,non-rumour
578,580350000074457088,rumour
579,498584409055174656,non-rumour


In [25]:
submission = pd.Series(output.target.values,index=output.id).to_dict()
with open('test-output.json', 'w') as f:
    json.dump(submission, f)