In [9]:
# NOTE: Some modifications to multimodal-transformer source code have been performed but are not tracked by version control.

In [5]:
! pip install multimodal-transformers

Collecting transformers==3.1
  Using cached transformers-3.1.0-py3-none-any.whl (884 kB)
Collecting tokenizers==0.8.1.rc2
  Using cached tokenizers-0.8.1rc2-cp37-cp37m-manylinux1_x86_64.whl (3.0 MB)
Installing collected packages: tokenizers, transformers
  Attempting uninstall: tokenizers
    Found existing installation: tokenizers 0.10.2
    Uninstalling tokenizers-0.10.2:
      Successfully uninstalled tokenizers-0.10.2
  Attempting uninstall: transformers
    Found existing installation: transformers 4.6.0.dev0
    Uninstalling transformers-4.6.0.dev0:
      Successfully uninstalled transformers-4.6.0.dev0
Successfully installed tokenizers-0.8.1rc2 transformers-3.1.0


In [15]:
import json
import re
import sys
import pandas as pd
import numpy as np

## Preprocessing

In [1]:
import dataloader

In [2]:
train_df = dataloader.load_data(data_file = '../data/train.data.jsonl', label_file = '../data/train.label.json', perform_stemming = False)
dev_df = dataloader.load_data(data_file = '../data/dev.data.jsonl', label_file = '../data/dev.label.json', perform_stemming = False)
test_df = dataloader.load_data(data_file = '../data/test.data.jsonl', label_file = None, perform_stemming = False)

In [7]:
test_df['text'][0]

'5 people have been able to get out of sydney cafe during hostage situation: fucking terrorists ...   ,      ,     ,    .  5 people have been able to get out of sydney cafe during hostage situation: her fingers look broken. 5 otages libres 5 people have been able to get out of sydney cafe : definitely adding that tweet to my 5 people have been able to get out of sydney cafe during hostage situation:   cuba usa stop showing their faces, they are going to have enough trouble without forever being linked with this event i sure hope the hostage taker gets out. feet first and room temp. 5 people have been able to get out of sydney cafe during hostage situation: dear mr. swat team guy, you may want to choose a long rifle for a more precise shot. this is not csi or miami vice. radical islam again spreading terror.  muslims are fanatical zealots p2 isis syndeysiege anyone see the connection yet? how long r free and democratic nations going to allow terrorists and state sponsors of terror dicta

## Multimodal BERT

In [8]:
from dataclasses import dataclass, field
import json
import logging
import os
from typing import Optional

from transformers import (
    AutoTokenizer,
    AutoConfig,
    Trainer,
    EvalPrediction,
    set_seed
)
from transformers.training_args import TrainingArguments

from multimodal_transformers.data import load_data_from_folder
from multimodal_transformers.model import TabularConfig
from multimodal_transformers.model import AutoModelWithTabular

logging.basicConfig(level=logging.INFO)
os.environ['COMET_MODE'] = 'DISABLED'

In [9]:
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
    model_name_or_path: str = field(
      metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
      default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
      default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
      default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )


@dataclass
class MultimodalDataTrainingArguments:
    """
    Arguments pertaining to how we combine tabular features
    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.
    """

    data_path: str = field(metadata={
                            'help': 'the path to the csv file containing the dataset'
                        })
    column_info_path: str = field(
      default=None,
      metadata={
          'help': 'the path to the json file detailing which columns are text, categorical, numerical, and the label'
    })

    column_info: dict = field(
      default=None,
      metadata={
          'help': 'a dict referencing the text, categorical, numerical, and label columns'
                  'its keys are text_cols, num_cols, cat_cols, and label_col'
    })

    categorical_encode_type: str = field(default='ohe',
                                        metadata={
                                            'help': 'sklearn encoder to use for categorical data',
                                            'choices': ['ohe', 'binary', 'label', 'none']
                                        })
    numerical_transformer_method: str = field(default='yeo_johnson',
                                            metadata={
                                                'help': 'sklearn numerical transformer to preprocess numerical data',
                                                'choices': ['yeo_johnson', 'box_cox', 'quantile_normal', 'none']
                                            })
    task: str = field(default="classification",
                    metadata={
                        "help": "The downstream training task",
                        "choices": ["classification", "regression"]
                    })

    mlp_division: int = field(default=4,
                            metadata={
                                'help': 'the ratio of the number of '
                                        'hidden dims in a current layer to the next MLP layer'
                            })
    combine_feat_method: str = field(default='individual_mlps_on_cat_and_numerical_feats_then_concat',
                                    metadata={
                                        'help': 'method to combine categorical and numerical features, '
                                                'see README for all the method'
                                    })
    mlp_dropout: float = field(default=0.1,
                              metadata={
                                'help': 'dropout ratio used for MLP layers'
                              })
    numerical_bn: bool = field(default=True,
                              metadata={
                                  'help': 'whether to use batchnorm on numerical features'
                              })
    use_simple_classifier: str = field(default=True,
                                      metadata={
                                          'help': 'whether to use single layer or MLP as final classifier'
                                      })
    mlp_act: str = field(default='relu',
                        metadata={
                            'help': 'the activation function to use for finetuning layers',
                            'choices': ['relu', 'prelu', 'sigmoid', 'tanh', 'linear']
                        })
    gating_beta: float = field(default=0.2,
                              metadata={
                                  'help': "the beta hyperparameters used for gating tabular data "
                                          "see https://www.aclweb.org/anthology/2020.acl-main.214.pdf"
                              })

    def __post_init__(self):
        assert self.column_info != self.column_info_path
        if self.column_info is None and self.column_info_path:
            with open(self.column_info_path, 'r') as f:
                self.column_info = json.load(f)

In [26]:
text_cols = ['text']
cat_cols = ['question_mark', 'contains_url', 'contains_media', 'contains_profile_background_image', 'verified', 'geo_enabled', 'has_description']
numerical_cols = ['retweet_count', 'favorite_count', 'number_urls','statuses_count', 'listed_count', 'reputation_score_1', 'reputation_score_2', 'favourites_count','length_description','follow_tweets']

column_info_dict = {
    'text_cols': text_cols,
    'num_cols': numerical_cols,
    'cat_cols': cat_cols,
    'label_col': 'label',
    'label_list': [0, 1]
}


model_args = ModelArguments(
    model_name_or_path='bert-base-uncased'
)

data_args = MultimodalDataTrainingArguments(
    data_path='.',
    combine_feat_method='gating_on_cat_and_num_feats_then_sum',
    column_info=column_info_dict,
    task='classification'
)

training_args = TrainingArguments(
    output_dir="./logs/model_name",
    logging_dir="./logs/runs",
    overwrite_output_dir=True,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=4,
    num_train_epochs=3,
    evaluate_during_training=True,
    logging_steps=25,
    eval_steps=250,
    dataloader_drop_last=True
)

set_seed(training_args.seed)

In [27]:
tokenizer_path_or_name = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
print('Specified tokenizer: ', tokenizer_path_or_name)

# Tokens automatically converted to lower_case
tokenizer = AutoTokenizer.from_pretrained(
    tokenizer_path_or_name,
    cache_dir=model_args.cache_dir
)

Specified tokenizer:  bert-base-uncased


In [28]:
# Get Datasets
train_dataset, dev_dataset, test_dataset = load_data_from_folder(train_df, dev_df, test_df,
    data_args.column_info['text_cols'],
    tokenizer,
    label_col=data_args.column_info['label_col'],
    label_list=data_args.column_info['label_list'],
    categorical_cols=data_args.column_info['cat_cols'],
    numerical_cols=data_args.column_info['num_cols'],
    sep_text_token_str=tokenizer.sep_token
)

INFO:multimodal_transformers.data.data_utils:9 numerical columns
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self[k1] = value[k2]
INFO:multimodal_transformers.data.data_utils:20 categorical columns
INFO:multimodal_transformers.data.data_utils:9 numerical columns
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self[k1] = value[k2]
INFO:multimodal_transformers.data.load_data:Text columns: ['text']
INFO:multimodal_transformers.data.load_data:Raw text example: how to respond to the murderous attack on charlie hebdo? every newspaper in the free world should pri

In [29]:
num_labels = len(np.unique(train_dataset.labels))
num_labels

2

In [30]:
config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
tabular_config = TabularConfig(num_labels=num_labels,
                               cat_feat_dim=train_dataset.cat_feats.shape[1],
                               numerical_feat_dim=train_dataset.numerical_feats.shape[1],
                               **vars(data_args))
config.tabular_config = tabular_config

In [31]:
model = AutoModelWithTabular.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        config=config,
        cache_dir=model_args.cache_dir
    )

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertWithTabular: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertWithTabular from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertWithTabular from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertWithTabular were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifi

In [32]:
import numpy as np
from scipy.special import softmax
from sklearn.metrics import (
    auc,
    precision_recall_curve,
    roc_auc_score,
    f1_score,
    confusion_matrix,
    matthews_corrcoef,
)

def calc_classification_metrics(p: EvalPrediction):
    pred_labels = np.argmax(p.predictions, axis=1)
    pred_scores = softmax(p.predictions, axis=1)[:, 1]
    labels = p.label_ids
    if len(np.unique(labels)) == 2:  # binary classification
        roc_auc_pred_score = roc_auc_score(labels, pred_scores)
        precisions, recalls, thresholds = precision_recall_curve(labels,
                                                                pred_scores)
        fscore = (2 * precisions * recalls) / (precisions + recalls)
        fscore[np.isnan(fscore)] = 0
        ix = np.argmax(fscore)
        threshold = thresholds[ix].item()
        pr_auc = auc(recalls, precisions)
        tn, fp, fn, tp = confusion_matrix(labels, pred_labels, labels=[0, 1]).ravel()
        result = {'roc_auc': roc_auc_pred_score,
                'threshold': threshold,
                'pr_auc': pr_auc,
                'recall': recalls[ix].item(),
                'precision': precisions[ix].item(), 'f1': fscore[ix].item(),
                'tn': tn.item(), 'fp': fp.item(), 'fn': fn.item(), 'tp': tp.item()
                }
    else:
        acc = (pred_labels == labels).mean()
        f1 = f1_score(y_true=labels, y_pred=pred_labels)
        result = {
          "acc": acc,
          "f1": f1,
          "acc_and_f1": (acc + f1) / 2,
          "mcc": matthews_corrcoef(labels, pred_labels)
        }

    return result

In [20]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,
    compute_metrics=calc_classification_metrics
)

In [33]:
%%time
trainer.train()

Epoch:   0%|          | 0/6 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1160 [00:00<?, ?it/s]

{'loss': 0.4349016571044922, 'learning_rate': 3.472413793103448e-05, 'epoch': 0.021551724137931036, 'step': 25}
{'loss': 0.3521604156494141, 'learning_rate': 3.450862068965517e-05, 'epoch': 0.04310344827586207, 'step': 50}
{'loss': 0.49650787353515624, 'learning_rate': 3.4293103448275864e-05, 'epoch': 0.06465517241379311, 'step': 75}
{'loss': 0.4685893249511719, 'learning_rate': 3.4077586206896555e-05, 'epoch': 0.08620689655172414, 'step': 100}
{'loss': 0.26140182495117187, 'learning_rate': 3.386206896551724e-05, 'epoch': 0.10775862068965517, 'step': 125}
{'loss': 0.6234982299804688, 'learning_rate': 3.364655172413793e-05, 'epoch': 0.12931034482758622, 'step': 150}
{'loss': 0.41747283935546875, 'learning_rate': 3.343103448275862e-05, 'epoch': 0.15086206896551724, 'step': 175}
{'loss': 0.50378662109375, 'learning_rate': 3.321551724137931e-05, 'epoch': 0.1724137931034483, 'step': 200}
{'loss': 0.710081787109375, 'learning_rate': 3.3e-05, 'epoch': 0.1939655172413793, 'step': 225}
{'loss':

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.5579873754953345, 'eval_roc_auc': 0.9078163771712159, 'eval_threshold': 0.013931883499026299, 'eval_pr_auc': 0.7808453934714581, 'eval_recall': 0.8548387096774194, 'eval_precision': 0.7464788732394366, 'eval_f1': 0.7969924812030075, 'eval_tn': 355, 'eval_fp': 35, 'eval_fn': 48, 'eval_tp': 138, 'epoch': 0.21551724137931033, 'step': 250}
{'loss': 0.5334515380859375, 'learning_rate': 3.256896551724138e-05, 'epoch': 0.23706896551724138, 'step': 275}
{'loss': 0.5242626953125, 'learning_rate': 3.235344827586207e-05, 'epoch': 0.25862068965517243, 'step': 300}
{'loss': 0.562830810546875, 'learning_rate': 3.213793103448276e-05, 'epoch': 0.2801724137931034, 'step': 325}
{'loss': 0.423704833984375, 'learning_rate': 3.192241379310345e-05, 'epoch': 0.3017241379310345, 'step': 350}
{'loss': 0.4186285400390625, 'learning_rate': 3.170689655172414e-05, 'epoch': 0.3232758620689655, 'step': 375}
{'loss': 0.477371826171875, 'learning_rate': 3.149137931034483e-05, 'epoch': 0.344827586206896

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.6527061170717288, 'eval_roc_auc': 0.904618141714916, 'eval_threshold': 0.016183823347091675, 'eval_pr_auc': 0.812868634686498, 'eval_recall': 0.7526881720430108, 'eval_precision': 0.813953488372093, 'eval_f1': 0.782122905027933, 'eval_tn': 363, 'eval_fp': 27, 'eval_fn': 58, 'eval_tp': 128, 'epoch': 0.43103448275862066, 'step': 500}
{'loss': 0.6098358154296875, 'learning_rate': 3.041379310344828e-05, 'epoch': 0.4525862068965517, 'step': 525}
{'loss': 0.5145611572265625, 'learning_rate': 3.0198275862068965e-05, 'epoch': 0.47413793103448276, 'step': 550}
{'loss': 0.5330712890625, 'learning_rate': 2.9982758620689656e-05, 'epoch': 0.4956896551724138, 'step': 575}
{'loss': 0.56660888671875, 'learning_rate': 2.9767241379310347e-05, 'epoch': 0.5172413793103449, 'step': 600}
{'loss': 0.432568359375, 'learning_rate': 2.9551724137931038e-05, 'epoch': 0.5387931034482759, 'step': 625}
{'loss': 0.954951171875, 'learning_rate': 2.9336206896551725e-05, 'epoch': 0.5603448275862069, 'ste

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.5144457803237148, 'eval_roc_auc': 0.9287979046043563, 'eval_threshold': 0.04807457700371742, 'eval_pr_auc': 0.8058351410936725, 'eval_recall': 0.8763440860215054, 'eval_precision': 0.7442922374429224, 'eval_f1': 0.8049382716049382, 'eval_tn': 344, 'eval_fp': 46, 'eval_fn': 32, 'eval_tp': 154, 'epoch': 0.646551724137931, 'step': 750}
{'loss': 0.442808837890625, 'learning_rate': 2.8258620689655173e-05, 'epoch': 0.6681034482758621, 'step': 775}
{'loss': 0.663048095703125, 'learning_rate': 2.8043103448275864e-05, 'epoch': 0.6896551724137931, 'step': 800}
{'loss': 0.370579833984375, 'learning_rate': 2.7827586206896555e-05, 'epoch': 0.7112068965517241, 'step': 825}
{'loss': 0.596197509765625, 'learning_rate': 2.761206896551724e-05, 'epoch': 0.7327586206896551, 'step': 850}
{'loss': 0.330992431640625, 'learning_rate': 2.739655172413793e-05, 'epoch': 0.7543103448275862, 'step': 875}
{'loss': 0.474881591796875, 'learning_rate': 2.718103448275862e-05, 'epoch': 0.7758620689655172,

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.48635242219703895, 'eval_roc_auc': 0.9232837055417701, 'eval_threshold': 0.07372444123029709, 'eval_pr_auc': 0.8051426925044557, 'eval_recall': 0.8602150537634409, 'eval_precision': 0.7920792079207921, 'eval_f1': 0.8247422680412373, 'eval_tn': 350, 'eval_fp': 40, 'eval_fn': 32, 'eval_tp': 154, 'epoch': 0.8620689655172413, 'step': 1000}
{'loss': 0.57782958984375, 'learning_rate': 2.610344827586207e-05, 'epoch': 0.8836206896551724, 'step': 1025}
{'loss': 0.692193603515625, 'learning_rate': 2.588793103448276e-05, 'epoch': 0.9051724137931034, 'step': 1050}
{'loss': 0.36170654296875, 'learning_rate': 2.567241379310345e-05, 'epoch': 0.9267241379310345, 'step': 1075}
{'loss': 0.3060400390625, 'learning_rate': 2.5456896551724142e-05, 'epoch': 0.9482758620689655, 'step': 1100}
{'loss': 0.4232568359375, 'learning_rate': 2.5241379310344833e-05, 'epoch': 0.9698275862068966, 'step': 1125}
{'loss': 0.4652880859375, 'learning_rate': 2.5025862068965517e-05, 'epoch': 0.9913793103448276,

Iteration:   0%|          | 0/1160 [00:00<?, ?it/s]

{'loss': 0.4086474609375, 'learning_rate': 2.4810344827586208e-05, 'epoch': 1.0129310344827587, 'step': 1175}
{'loss': 0.456787109375, 'learning_rate': 2.45948275862069e-05, 'epoch': 1.0344827586206897, 'step': 1200}
{'loss': 0.5631103515625, 'learning_rate': 2.4379310344827587e-05, 'epoch': 1.0560344827586208, 'step': 1225}
{'loss': 0.3422412109375, 'learning_rate': 2.4163793103448278e-05, 'epoch': 1.0775862068965518, 'step': 1250}


Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.45161445427220315, 'eval_roc_auc': 0.9343948166528812, 'eval_threshold': 0.2393077313899994, 'eval_pr_auc': 0.8120607551461987, 'eval_recall': 0.8709677419354839, 'eval_precision': 0.8059701492537313, 'eval_f1': 0.8372093023255813, 'eval_tn': 352, 'eval_fp': 38, 'eval_fn': 26, 'eval_tp': 160, 'epoch': 1.0775862068965518, 'step': 1250}
{'loss': 0.389072265625, 'learning_rate': 2.3948275862068965e-05, 'epoch': 1.0991379310344827, 'step': 1275}
{'loss': 0.34833740234375, 'learning_rate': 2.3732758620689656e-05, 'epoch': 1.1206896551724137, 'step': 1300}
{'loss': 0.5803759765625, 'learning_rate': 2.3517241379310344e-05, 'epoch': 1.1422413793103448, 'step': 1325}
{'loss': 0.35283935546875, 'learning_rate': 2.3301724137931035e-05, 'epoch': 1.1637931034482758, 'step': 1350}
{'loss': 0.26526123046875, 'learning_rate': 2.3086206896551726e-05, 'epoch': 1.1853448275862069, 'step': 1375}
{'loss': 0.40535400390625, 'learning_rate': 2.2870689655172413e-05, 'epoch': 1.206896551724138,

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]



{'eval_loss': 0.5291692831572922, 'eval_roc_auc': 0.9332230493520817, 'eval_threshold': 0.026159482076764107, 'eval_pr_auc': 0.8201575150901894, 'eval_recall': 0.8279569892473119, 'eval_precision': 0.806282722513089, 'eval_f1': 0.8169761273209549, 'eval_tn': 361, 'eval_fp': 29, 'eval_fn': 41, 'eval_tp': 145, 'epoch': 1.293103448275862, 'step': 1500}
{'loss': 0.47438232421875, 'learning_rate': 2.1793103448275865e-05, 'epoch': 1.3146551724137931, 'step': 1525}
{'loss': 0.491162109375, 'learning_rate': 2.1577586206896552e-05, 'epoch': 1.3362068965517242, 'step': 1550}
{'loss': 0.1805615234375, 'learning_rate': 2.1362068965517243e-05, 'epoch': 1.3577586206896552, 'step': 1575}
{'loss': 0.1772265625, 'learning_rate': 2.114655172413793e-05, 'epoch': 1.3793103448275863, 'step': 1600}
{'loss': 0.43464111328125, 'learning_rate': 2.0931034482758622e-05, 'epoch': 1.4008620689655173, 'step': 1625}
{'loss': 0.40931884765625, 'learning_rate': 2.0715517241379313e-05, 'epoch': 1.4224137931034484, 'ste

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.504755151613305, 'eval_roc_auc': 0.9370829886958919, 'eval_threshold': 0.02284669503569603, 'eval_pr_auc': 0.832242815201769, 'eval_recall': 0.8440860215053764, 'eval_precision': 0.8177083333333334, 'eval_f1': 0.8306878306878307, 'eval_tn': 360, 'eval_fp': 30, 'eval_fn': 37, 'eval_tp': 149, 'epoch': 1.5086206896551724, 'step': 1750}
{'loss': 0.5189111328125, 'learning_rate': 1.963793103448276e-05, 'epoch': 1.5301724137931034, 'step': 1775}
{'loss': 0.23588623046875, 'learning_rate': 1.942241379310345e-05, 'epoch': 1.5517241379310345, 'step': 1800}
{'loss': 0.3969677734375, 'learning_rate': 1.920689655172414e-05, 'epoch': 1.5732758620689655, 'step': 1825}
{'loss': 0.3555859375, 'learning_rate': 1.8991379310344827e-05, 'epoch': 1.5948275862068966, 'step': 1850}
{'loss': 0.5004150390625, 'learning_rate': 1.8775862068965518e-05, 'epoch': 1.6163793103448276, 'step': 1875}
{'loss': 0.38471923828125, 'learning_rate': 1.8560344827586205e-05, 'epoch': 1.6379310344827587, 'step':

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.4795048693453686, 'eval_roc_auc': 0.9363937138130686, 'eval_threshold': 0.6312541365623474, 'eval_pr_auc': 0.824366321198478, 'eval_recall': 0.8817204301075269, 'eval_precision': 0.8118811881188119, 'eval_f1': 0.845360824742268, 'eval_tn': 350, 'eval_fp': 40, 'eval_fn': 22, 'eval_tp': 164, 'epoch': 1.7241379310344827, 'step': 2000}
{'loss': 0.38375244140625, 'learning_rate': 1.7482758620689657e-05, 'epoch': 1.7456896551724137, 'step': 2025}
{'loss': 0.237978515625, 'learning_rate': 1.7267241379310344e-05, 'epoch': 1.7672413793103448, 'step': 2050}
{'loss': 0.20902587890625, 'learning_rate': 1.7051724137931035e-05, 'epoch': 1.7887931034482758, 'step': 2075}
{'loss': 0.4835546875, 'learning_rate': 1.6836206896551726e-05, 'epoch': 1.8103448275862069, 'step': 2100}
{'loss': 0.277939453125, 'learning_rate': 1.6620689655172414e-05, 'epoch': 1.831896551724138, 'step': 2125}
{'loss': 0.27896728515625, 'learning_rate': 1.6405172413793105e-05, 'epoch': 1.853448275862069, 'step': 

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.5306890398238061, 'eval_roc_auc': 0.9374000551419907, 'eval_threshold': 0.9805845618247986, 'eval_pr_auc': 0.8255278508746254, 'eval_recall': 0.8709677419354839, 'eval_precision': 0.8140703517587939, 'eval_f1': 0.8415584415584416, 'eval_tn': 335, 'eval_fp': 55, 'eval_fn': 17, 'eval_tp': 169, 'epoch': 1.9396551724137931, 'step': 2250}
{'loss': 0.28814208984375, 'learning_rate': 1.5327586206896553e-05, 'epoch': 1.9612068965517242, 'step': 2275}
{'loss': 0.365625, 'learning_rate': 1.5112068965517242e-05, 'epoch': 1.9827586206896552, 'step': 2300}


Iteration:   0%|          | 0/1160 [00:00<?, ?it/s]

{'loss': 0.28384033203125, 'learning_rate': 1.489655172413793e-05, 'epoch': 2.0043103448275863, 'step': 2325}
{'loss': 0.15362548828125, 'learning_rate': 1.468103448275862e-05, 'epoch': 2.0258620689655173, 'step': 2350}
{'loss': 0.0581884765625, 'learning_rate': 1.4465517241379312e-05, 'epoch': 2.0474137931034484, 'step': 2375}
{'loss': 0.30900146484375, 'learning_rate': 1.4249999999999999e-05, 'epoch': 2.0689655172413794, 'step': 2400}
{'loss': 0.26433837890625, 'learning_rate': 1.403448275862069e-05, 'epoch': 2.0905172413793105, 'step': 2425}
{'loss': 0.2073779296875, 'learning_rate': 1.3818965517241381e-05, 'epoch': 2.1120689655172415, 'step': 2450}
{'loss': 0.29172119140625, 'learning_rate': 1.3603448275862069e-05, 'epoch': 2.1336206896551726, 'step': 2475}
{'loss': 0.28594482421875, 'learning_rate': 1.338793103448276e-05, 'epoch': 2.1551724137931036, 'step': 2500}


Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.5039967370507333, 'eval_roc_auc': 0.9463606286186931, 'eval_threshold': 0.011066182516515255, 'eval_pr_auc': 0.854077613897508, 'eval_recall': 0.9032258064516129, 'eval_precision': 0.7924528301886793, 'eval_f1': 0.8442211055276382, 'eval_tn': 362, 'eval_fp': 28, 'eval_fn': 38, 'eval_tp': 148, 'epoch': 2.1551724137931036, 'step': 2500}
{'loss': 0.3893701171875, 'learning_rate': 1.317241379310345e-05, 'epoch': 2.1767241379310347, 'step': 2525}
{'loss': 0.396142578125, 'learning_rate': 1.2956896551724138e-05, 'epoch': 2.1982758620689653, 'step': 2550}
{'loss': 0.1151904296875, 'learning_rate': 1.2741379310344827e-05, 'epoch': 2.2198275862068964, 'step': 2575}
{'loss': 0.149921875, 'learning_rate': 1.2525862068965518e-05, 'epoch': 2.2413793103448274, 'step': 2600}
{'loss': 0.367587890625, 'learning_rate': 1.2310344827586208e-05, 'epoch': 2.2629310344827585, 'step': 2625}
{'loss': 0.35111328125, 'learning_rate': 1.2094827586206897e-05, 'epoch': 2.2844827586206895, 'step': 26

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.5306517001558354, 'eval_roc_auc': 0.9427626137303556, 'eval_threshold': 0.031779978424310684, 'eval_pr_auc': 0.8460264699835617, 'eval_recall': 0.8494623655913979, 'eval_precision': 0.8272251308900523, 'eval_f1': 0.8381962864721485, 'eval_tn': 360, 'eval_fp': 30, 'eval_fn': 35, 'eval_tp': 151, 'epoch': 2.3706896551724137, 'step': 2750}
{'loss': 0.0982080078125, 'learning_rate': 1.1017241379310347e-05, 'epoch': 2.3922413793103448, 'step': 2775}
{'loss': 0.189638671875, 'learning_rate': 1.0801724137931036e-05, 'epoch': 2.413793103448276, 'step': 2800}
{'loss': 0.436884765625, 'learning_rate': 1.0586206896551725e-05, 'epoch': 2.435344827586207, 'step': 2825}
{'loss': 0.1744775390625, 'learning_rate': 1.0370689655172414e-05, 'epoch': 2.456896551724138, 'step': 2850}
{'loss': 0.3594677734375, 'learning_rate': 1.0155172413793104e-05, 'epoch': 2.478448275862069, 'step': 2875}
{'loss': 0.19728515625, 'learning_rate': 9.939655172413793e-06, 'epoch': 2.5, 'step': 2900}
{'loss': 0

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.4998810120741837, 'eval_roc_auc': 0.9412875654811139, 'eval_threshold': 0.023344211280345917, 'eval_pr_auc': 0.8401101348919473, 'eval_recall': 0.8763440860215054, 'eval_precision': 0.815, 'eval_f1': 0.8445595854922279, 'eval_tn': 358, 'eval_fp': 32, 'eval_fn': 30, 'eval_tp': 156, 'epoch': 2.586206896551724, 'step': 3000}
{'loss': 0.4512890625, 'learning_rate': 8.862068965517243e-06, 'epoch': 2.6077586206896552, 'step': 3025}
{'loss': 0.2943701171875, 'learning_rate': 8.646551724137932e-06, 'epoch': 2.6293103448275863, 'step': 3050}
{'loss': 0.2205517578125, 'learning_rate': 8.431034482758621e-06, 'epoch': 2.6508620689655173, 'step': 3075}
{'loss': 0.332314453125, 'learning_rate': 8.215517241379312e-06, 'epoch': 2.6724137931034484, 'step': 3100}
{'loss': 0.40576171875, 'learning_rate': 8.000000000000001e-06, 'epoch': 2.6939655172413794, 'step': 3125}
{'loss': 0.1208642578125, 'learning_rate': 7.78448275862069e-06, 'epoch': 2.7155172413793105, 'step': 3150}
{'loss': 0.16

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]



{'eval_loss': 0.5393092060694471, 'eval_roc_auc': 0.9479183898538737, 'eval_threshold': 0.02693619206547737, 'eval_pr_auc': 0.8642727418185782, 'eval_recall': 0.8333333333333334, 'eval_precision': 0.842391304347826, 'eval_f1': 0.8378378378378378, 'eval_tn': 362, 'eval_fp': 28, 'eval_fn': 39, 'eval_tp': 147, 'epoch': 2.8017241379310347, 'step': 3250}
{'loss': 0.3656591796875, 'learning_rate': 6.706896551724139e-06, 'epoch': 2.8232758620689653, 'step': 3275}
{'loss': 0.26818359375, 'learning_rate': 6.491379310344828e-06, 'epoch': 2.844827586206897, 'step': 3300}
{'loss': 0.328828125, 'learning_rate': 6.275862068965517e-06, 'epoch': 2.8663793103448274, 'step': 3325}
{'loss': 0.1263232421875, 'learning_rate': 6.060344827586207e-06, 'epoch': 2.887931034482759, 'step': 3350}
{'loss': 0.2273095703125, 'learning_rate': 5.8448275862068965e-06, 'epoch': 2.9094827586206895, 'step': 3375}
{'loss': 0.2521923828125, 'learning_rate': 5.629310344827587e-06, 'epoch': 2.9310344827586206, 'step': 3400}
{

Iteration:   0%|          | 0/1160 [00:00<?, ?it/s]

{'loss': 0.130126953125, 'learning_rate': 4.767241379310345e-06, 'epoch': 3.0172413793103448, 'step': 3500}


Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.48710770395377445, 'eval_roc_auc': 0.9514750482492418, 'eval_threshold': 0.04505738243460655, 'eval_pr_auc': 0.8612163646980553, 'eval_recall': 0.8817204301075269, 'eval_precision': 0.82, 'eval_f1': 0.849740932642487, 'eval_tn': 355, 'eval_fp': 35, 'eval_fn': 27, 'eval_tp': 159, 'epoch': 3.0172413793103448, 'step': 3500}
{'loss': 0.2289013671875, 'learning_rate': 4.551724137931035e-06, 'epoch': 3.038793103448276, 'step': 3525}
{'loss': 0.2280419921875, 'learning_rate': 4.336206896551724e-06, 'epoch': 3.060344827586207, 'step': 3550}
{'loss': 0.1532470703125, 'learning_rate': 4.120689655172414e-06, 'epoch': 3.081896551724138, 'step': 3575}
{'loss': 0.089111328125, 'learning_rate': 3.905172413793104e-06, 'epoch': 3.103448275862069, 'step': 3600}
{'loss': 0.1015576171875, 'learning_rate': 3.689655172413793e-06, 'epoch': 3.125, 'step': 3625}
{'loss': 0.1599072265625, 'learning_rate': 3.474137931034483e-06, 'epoch': 3.146551724137931, 'step': 3650}
{'loss': 0.19798828125, 'l

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.48731243582410066, 'eval_roc_auc': 0.9531017369727047, 'eval_threshold': 0.08701246976852417, 'eval_pr_auc': 0.8657234465180382, 'eval_recall': 0.9032258064516129, 'eval_precision': 0.8115942028985508, 'eval_f1': 0.8549618320610687, 'eval_tn': 352, 'eval_fp': 38, 'eval_fn': 20, 'eval_tp': 166, 'epoch': 3.2327586206896552, 'step': 3750}
{'loss': 0.21380859375, 'learning_rate': 2.396551724137931e-06, 'epoch': 3.2543103448275863, 'step': 3775}
{'loss': 0.1598779296875, 'learning_rate': 2.181034482758621e-06, 'epoch': 3.2758620689655173, 'step': 3800}
{'loss': 0.334951171875, 'learning_rate': 1.9655172413793105e-06, 'epoch': 3.2974137931034484, 'step': 3825}
{'loss': 0.23546875, 'learning_rate': 1.7500000000000002e-06, 'epoch': 3.3189655172413794, 'step': 3850}
{'loss': 0.2044775390625, 'learning_rate': 1.5344827586206899e-06, 'epoch': 3.3405172413793105, 'step': 3875}
{'loss': 0.249189453125, 'learning_rate': 1.3189655172413794e-06, 'epoch': 3.3620689655172415, 'step': 390

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]



{'eval_loss': 0.4822951387549337, 'eval_roc_auc': 0.9515026192445546, 'eval_threshold': 0.03403972461819649, 'eval_pr_auc': 0.8544667160516239, 'eval_recall': 0.8924731182795699, 'eval_precision': 0.8217821782178217, 'eval_f1': 0.8556701030927835, 'eval_tn': 356, 'eval_fp': 34, 'eval_fn': 23, 'eval_tp': 163, 'epoch': 3.4482758620689653, 'step': 4000}
{'loss': 0.18541015625, 'learning_rate': 2.4137931034482764e-07, 'epoch': 3.469827586206897, 'step': 4025}
{'loss': 0.0494970703125, 'learning_rate': 2.5862068965517245e-08, 'epoch': 3.4913793103448274, 'step': 4050}
{'loss': 0.2114208984375, 'learning_rate': 0.0, 'epoch': 3.512931034482759, 'step': 4075}
{'loss': 0.1764697265625, 'learning_rate': 0.0, 'epoch': 3.5344827586206895, 'step': 4100}
{'loss': 0.102919921875, 'learning_rate': 0.0, 'epoch': 3.5560344827586206, 'step': 4125}
{'loss': 0.2780322265625, 'learning_rate': 0.0, 'epoch': 3.5775862068965516, 'step': 4150}
{'loss': 0.2889453125, 'learning_rate': 0.0, 'epoch': 3.599137931034

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]



{'eval_loss': 0.48646868881446104, 'eval_roc_auc': 0.9512544802867383, 'eval_threshold': 0.9637233018875122, 'eval_pr_auc': 0.8528667856842004, 'eval_recall': 0.8709677419354839, 'eval_precision': 0.8393782383419689, 'eval_f1': 0.8548812664907651, 'eval_tn': 355, 'eval_fp': 35, 'eval_fn': 23, 'eval_tp': 163, 'epoch': 3.663793103448276, 'step': 4250}
{'loss': 0.14865234375, 'learning_rate': 0.0, 'epoch': 3.685344827586207, 'step': 4275}
{'loss': 0.124541015625, 'learning_rate': 0.0, 'epoch': 3.706896551724138, 'step': 4300}
{'loss': 0.2051953125, 'learning_rate': 0.0, 'epoch': 3.728448275862069, 'step': 4325}
{'loss': 0.1883349609375, 'learning_rate': 0.0, 'epoch': 3.75, 'step': 4350}
{'loss': 0.3046337890625, 'learning_rate': 0.0, 'epoch': 3.771551724137931, 'step': 4375}
{'loss': 0.143017578125, 'learning_rate': 0.0, 'epoch': 3.793103448275862, 'step': 4400}
{'loss': 0.100966796875, 'learning_rate': 0.0, 'epoch': 3.814655172413793, 'step': 4425}
{'loss': 0.1111865234375, 'learning_rat

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.4858557675438028, 'eval_roc_auc': 0.9511028398125172, 'eval_threshold': 0.9628520011901855, 'eval_pr_auc': 0.8572300653752538, 'eval_recall': 0.8709677419354839, 'eval_precision': 0.8393782383419689, 'eval_f1': 0.8548812664907651, 'eval_tn': 355, 'eval_fp': 35, 'eval_fn': 23, 'eval_tp': 163, 'epoch': 3.8793103448275863, 'step': 4500}
{'loss': 0.035517578125, 'learning_rate': 0.0, 'epoch': 3.9008620689655173, 'step': 4525}
{'loss': 0.155380859375, 'learning_rate': 0.0, 'epoch': 3.9224137931034484, 'step': 4550}
{'loss': 0.4227685546875, 'learning_rate': 0.0, 'epoch': 3.9439655172413794, 'step': 4575}
{'loss': 0.1843359375, 'learning_rate': 0.0, 'epoch': 3.9655172413793105, 'step': 4600}
{'loss': 0.1357177734375, 'learning_rate': 0.0, 'epoch': 3.987068965517241, 'step': 4625}


Iteration:   0%|          | 0/1160 [00:00<?, ?it/s]

{'loss': 0.07671875, 'learning_rate': 0.0, 'epoch': 4.008620689655173, 'step': 4650}
{'loss': 0.138974609375, 'learning_rate': 0.0, 'epoch': 4.030172413793103, 'step': 4675}
{'loss': 0.08009765625, 'learning_rate': 0.0, 'epoch': 4.051724137931035, 'step': 4700}
{'loss': 0.15888671875, 'learning_rate': 0.0, 'epoch': 4.073275862068965, 'step': 4725}
{'loss': 0.2207958984375, 'learning_rate': 0.0, 'epoch': 4.094827586206897, 'step': 4750}


Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]



{'eval_loss': 0.4840331194411394, 'eval_roc_auc': 0.9509649848359526, 'eval_threshold': 0.9612193703651428, 'eval_pr_auc': 0.8520877281199644, 'eval_recall': 0.8709677419354839, 'eval_precision': 0.8393782383419689, 'eval_f1': 0.8548812664907651, 'eval_tn': 355, 'eval_fp': 35, 'eval_fn': 23, 'eval_tp': 163, 'epoch': 4.094827586206897, 'step': 4750}
{'loss': 0.3485888671875, 'learning_rate': 0.0, 'epoch': 4.116379310344827, 'step': 4775}
{'loss': 0.1459619140625, 'learning_rate': 0.0, 'epoch': 4.137931034482759, 'step': 4800}
{'loss': 0.377587890625, 'learning_rate': 0.0, 'epoch': 4.1594827586206895, 'step': 4825}
{'loss': 0.1661376953125, 'learning_rate': 0.0, 'epoch': 4.181034482758621, 'step': 4850}
{'loss': 0.078037109375, 'learning_rate': 0.0, 'epoch': 4.202586206896552, 'step': 4875}
{'loss': 0.221357421875, 'learning_rate': 0.0, 'epoch': 4.224137931034483, 'step': 4900}
{'loss': 0.0836474609375, 'learning_rate': 0.0, 'epoch': 4.245689655172414, 'step': 4925}
{'loss': 0.3264404296

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]



{'eval_loss': 0.48668964382765506, 'eval_roc_auc': 0.9511717673007996, 'eval_threshold': 0.9598184823989868, 'eval_pr_auc': 0.8539930875755289, 'eval_recall': 0.8709677419354839, 'eval_precision': 0.8393782383419689, 'eval_f1': 0.8548812664907651, 'eval_tn': 355, 'eval_fp': 35, 'eval_fn': 23, 'eval_tp': 163, 'epoch': 4.310344827586207, 'step': 5000}
{'loss': 0.1801220703125, 'learning_rate': 0.0, 'epoch': 4.331896551724138, 'step': 5025}
{'loss': 0.29666015625, 'learning_rate': 0.0, 'epoch': 4.353448275862069, 'step': 5050}
{'loss': 0.0996044921875, 'learning_rate': 0.0, 'epoch': 4.375, 'step': 5075}
{'loss': 0.1072021484375, 'learning_rate': 0.0, 'epoch': 4.396551724137931, 'step': 5100}
{'loss': 0.1482421875, 'learning_rate': 0.0, 'epoch': 4.418103448275862, 'step': 5125}
{'loss': 0.2877197265625, 'learning_rate': 0.0, 'epoch': 4.439655172413793, 'step': 5150}
{'loss': 0.269189453125, 'learning_rate': 0.0, 'epoch': 4.461206896551724, 'step': 5175}
{'loss': 0.052705078125, 'learning_r

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.48333052382804453, 'eval_roc_auc': 0.951392335263303, 'eval_threshold': 0.9614226222038269, 'eval_pr_auc': 0.8608799882444906, 'eval_recall': 0.8709677419354839, 'eval_precision': 0.8393782383419689, 'eval_f1': 0.8548812664907651, 'eval_tn': 355, 'eval_fp': 35, 'eval_fn': 23, 'eval_tp': 163, 'epoch': 4.525862068965517, 'step': 5250}
{'loss': 0.2689306640625, 'learning_rate': 0.0, 'epoch': 4.547413793103448, 'step': 5275}
{'loss': 0.1110498046875, 'learning_rate': 0.0, 'epoch': 4.568965517241379, 'step': 5300}
{'loss': 0.191845703125, 'learning_rate': 0.0, 'epoch': 4.5905172413793105, 'step': 5325}
{'loss': 0.15060546875, 'learning_rate': 0.0, 'epoch': 4.612068965517241, 'step': 5350}
{'loss': 0.14306640625, 'learning_rate': 0.0, 'epoch': 4.633620689655173, 'step': 5375}
{'loss': 0.055322265625, 'learning_rate': 0.0, 'epoch': 4.655172413793103, 'step': 5400}
{'loss': 0.2903662109375, 'learning_rate': 0.0, 'epoch': 4.676724137931035, 'step': 5425}
{'loss': 0.09947265625, 

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]

{'eval_loss': 0.4838169237870413, 'eval_roc_auc': 0.9511579818031431, 'eval_threshold': 0.9618346095085144, 'eval_pr_auc': 0.8592406871803954, 'eval_recall': 0.8709677419354839, 'eval_precision': 0.8393782383419689, 'eval_f1': 0.8548812664907651, 'eval_tn': 355, 'eval_fp': 35, 'eval_fn': 23, 'eval_tp': 163, 'epoch': 4.741379310344827, 'step': 5500}
{'loss': 0.0588916015625, 'learning_rate': 0.0, 'epoch': 4.762931034482759, 'step': 5525}
{'loss': 0.328466796875, 'learning_rate': 0.0, 'epoch': 4.7844827586206895, 'step': 5550}
{'loss': 0.2914208984375, 'learning_rate': 0.0, 'epoch': 4.806034482758621, 'step': 5575}
{'loss': 0.2051708984375, 'learning_rate': 0.0, 'epoch': 4.827586206896552, 'step': 5600}
{'loss': 0.116826171875, 'learning_rate': 0.0, 'epoch': 4.849137931034483, 'step': 5625}
{'loss': 0.14458984375, 'learning_rate': 0.0, 'epoch': 4.870689655172414, 'step': 5650}
{'loss': 0.2078515625, 'learning_rate': 0.0, 'epoch': 4.892241379310345, 'step': 5675}
{'loss': 0.1364453125, 'l

Evaluation:   0%|          | 0/72 [00:00<?, ?it/s]



{'eval_loss': 0.48394241781271474, 'eval_roc_auc': 0.9511993382961125, 'eval_threshold': 0.9596925377845764, 'eval_pr_auc': 0.8552557144186422, 'eval_recall': 0.8709677419354839, 'eval_precision': 0.8393782383419689, 'eval_f1': 0.8548812664907651, 'eval_tn': 355, 'eval_fp': 35, 'eval_fn': 23, 'eval_tp': 163, 'epoch': 4.956896551724138, 'step': 5750}
{'loss': 0.12205078125, 'learning_rate': 0.0, 'epoch': 4.978448275862069, 'step': 5775}
{'loss': 0.297763671875, 'learning_rate': 0.0, 'epoch': 5.0, 'step': 5800}
CPU times: user 1h 7min 22s, sys: 21min 55s, total: 1h 29min 17s
Wall time: 1h 29min 44s


TrainOutput(global_step=5800, training_loss=0.29739531418372844)

In [21]:
# Load TensorBoard notebook extension
%load_ext tensorboard

In [22]:
%tensorboard --logdir ./logs/runs

In [34]:
trainer.save_model("./multimodal_bert/multimodal_bert_v34")

### Inference

In [35]:
# Need to reload trainer because drop_last = True was set during training.
pred_trained = Trainer(
    model=model
)

In [36]:
result = pred_trained.predict(test_dataset=test_dataset).predictions

Prediction:   0%|          | 0/73 [00:00<?, ?it/s]

In [37]:
from scipy.special import softmax

pred_labels = np.argmax(result, axis=1)
pred_scores = softmax(result, axis=1)[:, 1]

In [38]:
predicted_labels = [dataloader.convert_prediction(pred) for pred in pred_labels]

In [39]:
output = pd.DataFrame({'id':test_df.id,'target':predicted_labels})
output

Unnamed: 0,id,target
0,544382249178001408,rumour
1,525027317551079424,rumour
2,544273220128739329,rumour
3,499571799764770816,rumour
4,552844104418091008,rumour
...,...,...
576,553581227165642752,rumour
577,552816302780579840,rumour
578,580350000074457088,rumour
579,498584409055174656,rumour


In [40]:
submission = pd.Series(output.target.values,index=output.id).to_dict()
with open('test-output.json', 'w') as f:
    json.dump(submission, f)