In [1]:
import pandas as pd
import numpy as np
import csv
import os
import logging
import random
import copy
import json
import argparse
import torch
import torch.nn as nn
from transformers import BertTokenizer,AdamW, BertConfig, get_linear_schedule_with_warmup
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertModel, BertPreTrainedModel
from tqdm import tqdm, trange
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

In [2]:
label_path = '../data/label.txt'

Utils

In [9]:
logger = logging.getLogger(__name__)

ADDITIONAL_SPECIAL_TOKENS = ["<e1>", "</e1>", "<e2>", "</e2>"]


def get_label(args):
    return [label.strip() for label in open(label_path, "r", encoding="utf-8")]


def load_tokenizer(args):
    tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
    tokenizer.add_special_tokens({"additional_special_tokens": ADDITIONAL_SPECIAL_TOKENS})
    return tokenizer


def write_prediction(args, output_file, preds):
    """
    For official evaluation script
    :param output_file: prediction_file_path (e.g. eval/proposed_answers.txt)
    :param preds: [0,1,0,2,18,...]
    """
    relation_labels = get_label(args)
    with open(output_file, "w", encoding="utf-8") as f:
        for idx, pred in enumerate(preds):
            f.write("{}\t{}\n".format(8001 + idx, relation_labels[pred]))


def init_logger():
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if not args.no_cuda and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)



def compute_metrics(preds, labels):
    assert len(preds) == len(labels)
    return acc(preds, labels)


def simple_accuracy(preds, labels):
    return (preds == labels).mean()


def acc(preds, labels, average="macro"):
    acc = simple_accuracy(preds, labels)
    return {
        "acc": acc}

Data_loader

In [24]:
#from utils import get_label
logger = logging.getLogger(__name__)

class InputExample(object):
    """
    A single training/test example for simple sequence classification.
    Args:
        guid: Unique id for the example.
        text_a: string. The untokenized text of the first sequence. For single
        sequence tasks, only this sequence must be specified.
        label: (Optional) string. The label of the example. This should be
        specified for train and dev examples, but not for test examples.
    """

    def __init__(self, guid, text_a, label):
        self.guid = guid
        self.text_a = text_a
        self.label = label

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"


class InputFeatures(object):
    """
    A single set of features of data.
    Args:
        input_ids: Indices of input sequence tokens in the vocabulary.
        attention_mask: Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            Usually  ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens.
        token_type_ids: Segment token indices to indicate first and second portions of the inputs.
    """

    def __init__(self, input_ids, attention_mask, token_type_ids, label_id, e1_mask, e2_mask):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.label_id = label_id
        self.e1_mask = e1_mask
        self.e2_mask = e2_mask

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"


class SemEvalProcessor(object):
    """Processor for the Semeval data set """

    def __init__(self, args):
        self.args = args
        self.relation_labels = get_label(args)

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r", encoding="utf-8") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for line in reader:
                lines.append(line)
            return lines

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            text_a = line[0]
            label = self.relation_labels.index(line[1])
            if i % 1000 == 0:
                logger.info(line)
            examples.append(InputExample(guid=guid, text_a=text_a, label=label))
        return examples


    def get_examples(self, mode):
        """
        Args:
            mode: train, dev, test
        """
        file_to_read = None
        if mode == "train_file":
            file_to_read = self.args.train_file
        elif mode == "eval_file":
            file_to_read = self.args.test_file
        elif mode == "test_file":
            file_to_read = self.args.test_file

        logger.info("LOOKING AT {}".format(os.path.join(self.args.data_dir, file_to_read)))
        return self._create_examples(self._read_tsv(os.path.join(self.args.data_dir, file_to_read)), mode)


processors = {"semeval": SemEvalProcessor}


def read_examples_from_file(data_dir, mode):
    file_path = os.path.join(data_dir, "{}.txt".format(mode))
    guid_index = 1
    examples = []
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f.readlines():
            line = line.strip().split("\t")
            if len(line) == 2:
                text_a = line[0]
                label = line[1]
            else:
                text_a = line[0]
                label = "NONE"
            examples.append(InputExample(guid=guid_index, text_a=text_a, label=label))
            guid_index += 1

    return examples

def convert_examples_to_features(
    examples,
    max_seq_len,
    tokenizer,
    cls_token="[CLS]",
    cls_token_segment_id=0,
    sep_token="[SEP]",
    pad_token=0,
    pad_token_segment_id=0,
    sequence_a_segment_id=0,
    add_sep_token=False,
    mask_padding_with_zero=True,
):
    features = []
    for (ex_index, example) in enumerate(examples):
        if ex_index % 5000 == 0:
            logger.info("Writing example %d of %d" % (ex_index, len(examples)))

        tokens_a = tokenizer.tokenize(example.text_a)
        print(tokens_a)

        e11_p = tokens_a.index("<e1>")  # the start position of entity1
        e12_p = tokens_a.index("</e1>")  # the end position of entity1
        e21_p = tokens_a.index("<e2>")  # the start position of entity2
        e22_p = tokens_a.index("</e2>")  # the end position of entity2

        # Replace the token
        tokens_a[e11_p] = "$"
        tokens_a[e12_p] = "$"
        tokens_a[e21_p] = "#"
        tokens_a[e22_p] = "#"

        # Add 1 because of the [CLS] token
        e11_p += 1
        e12_p += 1
        e21_p += 1
        e22_p += 1

        # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
        if add_sep_token:
            special_tokens_count = 2
        else:
            special_tokens_count = 1
        if len(tokens_a) > max_seq_len - special_tokens_count:
            tokens_a = tokens_a[: (max_seq_len - special_tokens_count)]

        tokens = tokens_a
        if add_sep_token:
            tokens += [sep_token]

        token_type_ids = [sequence_a_segment_id] * len(tokens)

        tokens = [cls_token] + tokens
        token_type_ids = [cls_token_segment_id] + token_type_ids

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
        attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding_length = max_seq_len - len(input_ids)
        input_ids = input_ids + ([pad_token] * padding_length)
        attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
        token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)

        # e1 mask, e2 mask
        e1_mask = [0] * len(attention_mask)
        e2_mask = [0] * len(attention_mask)

        for i in range(e11_p, e12_p + 1):
            e1_mask[i] = 1
        for i in range(e21_p, e22_p + 1):
            e2_mask[i] = 1

        assert len(input_ids) == max_seq_len, "Error with input length {} vs {}".format(len(input_ids), max_seq_len)
        assert len(attention_mask) == max_seq_len, "Error with attention mask length {} vs {}".format(
            len(attention_mask), max_seq_len
        )
        assert len(token_type_ids) == max_seq_len, "Error with token type length {} vs {}".format(
            len(token_type_ids), max_seq_len
        )

        label_id = int(example.label)

        if ex_index < 5:
            logger.info("*** Example ***")
            logger.info("guid: %s" % example.guid)
            logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
            logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
            logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
            logger.info("label: %s (id = %d)" % (example.label, label_id))
            logger.info("e1_mask: %s" % " ".join([str(x) for x in e1_mask]))
            logger.info("e2_mask: %s" % " ".join([str(x) for x in e2_mask]))

        features.append(
            InputFeatures(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                label_id=label_id,
                e1_mask=e1_mask,
                e2_mask=e2_mask,
            )
        )

    return features


def load_and_cache_examples(args, tokenizer, mode):
    processor = processors[args.task](args)

    # Load data features from cache or dataset file
    cached_features_file = os.path.join(
        args.data_dir,
        "cached_{}_{}_{}_{}".format(
            mode,
            args.task,
            list(filter(None, args.model_name_or_path.split("/"))).pop(),
            args.max_seq_len,
        ),
    )

    if os.path.exists(cached_features_file):
        logger.info("Loading features from cached file %s", cached_features_file)
        features = torch.load(cached_features_file)
    else:
        logger.info("Creating features from dataset file at %s", args.data_dir)
        if mode == "train_file":
            examples = processor.get_examples("train_file")
        elif mode == "eval_file":
            examples = processor.get_examples("eval_file")
        elif mode == "test_file":
            examples = processor.get_examples("test_file")
        else:
            raise Exception("For mode, Only train, dev, test is available")

        features = convert_examples_to_features(
            examples, args.max_seq_len, tokenizer
        )
        logger.info("Saving features into cached file %s", cached_features_file)
        torch.save(features, cached_features_file)

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    all_e1_mask = torch.tensor([f.e1_mask for f in features], dtype=torch.long)  # add e1 mask
    all_e2_mask = torch.tensor([f.e2_mask for f in features], dtype=torch.long)  # add e2 mask

    all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)

    dataset = TensorDataset(
        all_input_ids,
        all_attention_mask,
        all_token_type_ids,
        all_label_ids,
        all_e1_mask,
        all_e2_mask,
    )
    return dataset

Model

In [11]:
class FCLayer(nn.Module):
    def __init__(self, input_dim, output_dim, dropout_rate=0.0, use_activation=True):
        super(FCLayer, self).__init__()
        self.use_activation = use_activation
        self.dropout = nn.Dropout(dropout_rate)
        self.linear = nn.Linear(input_dim, output_dim)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.dropout(x)
        if self.use_activation:
            x = self.tanh(x)
        return self.linear(x)


class RBERT(BertPreTrainedModel):
    def __init__(self, config, args):
        super(RBERT, self).__init__(config)
        self.bert = BertModel(config=config)  # Load pretrained bert

        self.num_labels = config.num_labels

        self.cls_fc_layer = FCLayer(config.hidden_size, config.hidden_size, args.dropout_rate)
        self.entity_fc_layer = FCLayer(config.hidden_size, config.hidden_size, args.dropout_rate)
        self.label_classifier = FCLayer(
            config.hidden_size * 3,
            config.num_labels,
            args.dropout_rate,
            use_activation=False,
        )

    @staticmethod
    def entity_average(hidden_output, e_mask):
        """
        Average the entity hidden state vectors (H_i ~ H_j)
        :param hidden_output: [batch_size, j-i+1, dim]
        :param e_mask: [batch_size, max_seq_len]
                e.g. e_mask[0] == [0, 0, 0, 1, 1, 1, 0, 0, ... 0]
        :return: [batch_size, dim]
        """
        e_mask_unsqueeze = e_mask.unsqueeze(1)  # [b, 1, j-i+1]
        length_tensor = (e_mask != 0).sum(dim=1).unsqueeze(1)  # [batch_size, 1]

        # [b, 1, j-i+1] * [b, j-i+1, dim] = [b, 1, dim] -> [b, dim]
        sum_vector = torch.bmm(e_mask_unsqueeze.float(), hidden_output).squeeze(1)
        avg_vector = sum_vector.float() / length_tensor.float()  # broadcasting
        return avg_vector

    def forward(self, input_ids, attention_mask, token_type_ids, labels, e1_mask, e2_mask):
        outputs = self.bert(
            input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
        )  # sequence_output, pooled_output, (hidden_states), (attentions)
        sequence_output = outputs[0]
        pooled_output = outputs[1]  # [CLS]

        # Average
        e1_h = self.entity_average(sequence_output, e1_mask)
        e2_h = self.entity_average(sequence_output, e2_mask)

        # Dropout -> tanh -> fc_layer (Share FC layer for e1 and e2)
        pooled_output = self.cls_fc_layer(pooled_output)
        e1_h = self.entity_fc_layer(e1_h)
        e2_h = self.entity_fc_layer(e2_h)

        # Concat -> fc_layer
        concat_h = torch.cat([pooled_output, e1_h, e2_h], dim=-1)
        logits = self.label_classifier(concat_h)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        # Softmax
        if labels is not None:
            if self.num_labels == 1:
                loss_fct = nn.MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

            outputs = (loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)

Trainer

In [25]:
def get_device(pred_config):
    return "cuda" if torch.cuda.is_available() and not pred_config.no_cuda else "cpu"

def convert_input_file_to_tensor_dataset(
    args,
    cls_token_segment_id=0,
    pad_token_segment_id=0,
    sequence_a_segment_id=0,
    mask_padding_with_zero=True):
    tokenizer = load_tokenizer(args)

    # Setting based on the current model type
    cls_token = tokenizer.cls_token
    sep_token = tokenizer.sep_token
    pad_token_id = tokenizer.pad_token_id

    all_input_ids = []
    all_attention_mask = []
    all_token_type_ids = []
    all_e1_mask = []
    all_e2_mask = []

    with open(args.input_file, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            tokens = tokenizer.tokenize(line)

            e11_p = tokens.index("<e1>")  # the start position of entity1
            e12_p = tokens.index("</e1>")  # the end position of entity1
            e21_p = tokens.index("<e2>")  # the start position of entity2
            e22_p = tokens.index("</e2>")  # the end position of entity2

            # Replace the token
            tokens[e11_p] = "$"
            tokens[e12_p] = "$"
            tokens[e21_p] = "#"
            tokens[e22_p] = "#"

            # Add 1 because of the [CLS] token
            e11_p += 1
            e12_p += 1
            e21_p += 1
            e22_p += 1

            # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
            if args.add_sep_token:
                special_tokens_count = 2
            else:
                special_tokens_count = 1
            if len(tokens) > args.max_seq_len - special_tokens_count:
                tokens = tokens[: (args.max_seq_len - special_tokens_count)]

            # Add [SEP] token
            if args.add_sep_token:
                tokens += [sep_token]
            token_type_ids = [sequence_a_segment_id] * len(tokens)

            # Add [CLS] token
            tokens = [cls_token] + tokens
            token_type_ids = [cls_token_segment_id] + token_type_ids

            input_ids = tokenizer.convert_tokens_to_ids(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
            attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

            # Zero-pad up to the sequence length.
            padding_length = args.max_seq_len - len(input_ids)
            input_ids = input_ids + ([pad_token_id] * padding_length)
            attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
            token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)

            # e1 mask, e2 mask
            e1_mask = [0] * len(attention_mask)
            e2_mask = [0] * len(attention_mask)

            for i in range(e11_p, e12_p + 1):
                e1_mask[i] = 1
            for i in range(e21_p, e22_p + 1):
                e2_mask[i] = 1

            all_input_ids.append(input_ids)
            all_attention_mask.append(attention_mask)
            all_token_type_ids.append(token_type_ids)
            all_e1_mask.append(e1_mask)
            all_e2_mask.append(e2_mask)

    # Change to Tensor
    all_input_ids = torch.tensor(all_input_ids, dtype=torch.long)
    all_attention_mask = torch.tensor(all_attention_mask, dtype=torch.long)
    all_token_type_ids = torch.tensor(all_token_type_ids, dtype=torch.long)
    all_e1_mask = torch.tensor(all_e1_mask, dtype=torch.long)
    all_e2_mask = torch.tensor(all_e2_mask, dtype=torch.long)

    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_e1_mask, all_e2_mask)

    return dataset

In [13]:
logger = logging.getLogger(__name__)


class Trainer(object):
    def __init__(self, args, train_dataset=None, dev_dataset=None, test_dataset=None):
        self.args = args
        self.train_dataset = train_dataset
        self.dev_dataset = dev_dataset
        self.test_dataset = test_dataset

        self.label_lst = get_label(args)
        self.num_labels = len(self.label_lst)

        self.config = BertConfig.from_pretrained(
            args.model_name_or_path,
            num_labels=self.num_labels,
            finetuning_task=args.task,
            id2label={str(i): label for i, label in enumerate(self.label_lst)},
            label2id={label: i for i, label in enumerate(self.label_lst)},
        )
        self.model = RBERT.from_pretrained(args.model_name_or_path, config=self.config, args=args)

        # GPU or CPU
        self.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
        self.model.to(self.device)

        
    def evaluate(self, mode):
        # We use test dataset because semeval doesn't have dev dataset
        if mode == "test":
            dataset = self.test_dataset
        elif mode == "dev":
            dataset = self.dev_dataset
        else:
            raise Exception("Only dev and test dataset available")

        eval_sampler = SequentialSampler(dataset)
        eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=self.args.eval_batch_size)
        # Eval!

        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None
        out_label_ids = None

        self.model.eval()

        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            batch = tuple(t.to(self.device) for t in batch)
            with torch.no_grad():
                inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "token_type_ids": batch[2],
                    "labels": batch[3],
                    "e1_mask": batch[4],
                    "e2_mask": batch[5],
                }
                outputs = self.model(**inputs)
                tmp_eval_loss, logits = outputs[:2]

                eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1

            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = inputs["labels"].detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)

        eval_loss = eval_loss / nb_eval_steps
        #results = {"loss": eval_loss}
        preds = np.argmax(preds, axis=1)
        write_prediction(self.args, os.path.join(self.args.eval_dir, "proposed_answers_multilingual.txt"), preds)

        results = {"loss": eval_loss, 'accuracy' : accuracy_score(out_label_ids, preds), 
                   'f1_score': f1_score(out_label_ids, preds, average='weighted'),
                  'roc_auc': roc_auc_score(out_label_ids, preds)}

          #result = compute_metrics(preds, out_label_ids)
          #results.update(result)

        logger.info("***** Eval results *****")
        for key in sorted(results.keys()):
            logger.info("  {} = {:.4f}".format(key, results[key]))

        return results

    
    def train(self):
        train_sampler = RandomSampler(self.train_dataset)
        train_dataloader = DataLoader(
            self.train_dataset,
            sampler=train_sampler,
            batch_size=self.args.train_batch_size,
        )

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            self.args.num_train_epochs = (
                self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
            )
        else:
            t_total = len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs

        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.args.weight_decay,
            },
            {
                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=self.args.learning_rate,
            eps=self.args.adam_epsilon,
        )
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.args.warmup_steps,
            num_training_steps=t_total,
        )
        # Train!

        global_step = 0
        tr_loss = 0.0
        self.model.zero_grad()

        train_iterator = trange(int(self.args.num_train_epochs), desc="Epoch")

        for _ in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            for step, batch in enumerate(epoch_iterator):
                self.model.train()
                batch = tuple(t.to(self.device) for t in batch)  # GPU or CPU
                inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "token_type_ids": batch[2],
                    "labels": batch[3],
                    "e1_mask": batch[4],
                    "e2_mask": batch[5],
                }
                outputs = self.model(**inputs)
                loss = outputs[0]

                if self.args.gradient_accumulation_steps > 1:
                    loss = loss / self.args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                if (step + 1) % self.args.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)

                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    self.model.zero_grad()
                    global_step += 1

            print("\n====Evaluation====")
            print("\nEvaluation: ", self.evaluate("test"))
            
        self.save_model(self.model)

    def save_model(self, model):
        torch.save(model.state_dict(), 'model/model_multilingual_1105.bin')

In [37]:
def load_saved_model(args):
    config = BertConfig.from_pretrained(args.model_name_or_path, num_labels = args.num_labels)
    model = RBERT.from_pretrained('model/model_multilingual_1105.bin', config=config, args=args)
    model.to("cpu")
    return model


def predict(pred_config):
        device = "cpu"
        model = load_saved_model(pred_config)
        tokenizer = load_tokenizer(pred_config)

        # Convert input file to TensorDataset
        dataset = convert_input_file_to_tensor_dataset(pred_config)
        #dataset = pred_config.input_file
        # Predict
        sampler = SequentialSampler(dataset)
        data_loader = DataLoader(dataset, sampler=sampler, batch_size=pred_config.batch_size)

        preds = None
        
        
        for batch in tqdm(data_loader, desc="Predicting"):
            batch = tuple(t.to(device) for t in batch)
            with torch.no_grad():
                inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "token_type_ids": batch[2],
                    "labels": None,
                    "e1_mask": batch[3],
                    "e2_mask": batch[4],
                }
                outputs = model(**inputs)
                logits = outputs[0]

                if preds is None:
                    preds = logits.detach().cpu().numpy()
                else:
                    preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)

        preds = np.argmax(preds, axis=1)

        # Write to output file
        label_lst = get_label(pred_config)
        with open(pred_config.output_file, "w", encoding="utf-8") as f:
            for pred in preds:
                f.write("{}\n".format(label_lst[pred]))

        print('Prediction was done')
        return preds

In [51]:
def evaluate(pred_config):
    # We use test dataset because semeval doesn't have dev dataset
    device = "cpu"
    model = load_saved_model(pred_config)
    tokenizer = load_tokenizer(pred_config)
    dataset = convert_input_file_to_tensor_dataset(pred_config)

    eval_sampler = SequentialSampler(dataset)
    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=pred_config.batch_size)
    # Eval!

    eval_loss = 0.0
    nb_eval_steps = 0
    preds = None
    out_label_ids = None


    for batch in tqdm(eval_dataloader, desc="Predicting"):
        batch = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "labels": None,
                "e1_mask": batch[3],
                "e2_mask": batch[4],
            }
            outputs = model(**inputs)
            logits = outputs[0]


        nb_eval_steps += 1

        if preds is None:
            preds = logits.detach().cpu().numpy()
            #out_label_ids = inputs["labels"].detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            #out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)


    #results = {"loss": eval_loss}
    preds = np.argmax(preds, axis=1)
    
#     results = {'accuracy' : accuracy_score(out_label_ids, preds), 
#                'f1_score': f1_score(out_label_ids, preds, average='weighted'),
#               'roc_auc': roc_auc_score(out_label_ids, preds)}

      #result = compute_metrics(preds, out_label_ids)
      #results.update(result)

#     logger.info("***** Eval results *****")
#     for key in sorted(results.keys()):
#         logger.info("  {} = {:.4f}".format(key, results[key]))

    return preds


Main

In [15]:
def RBERT_re(args):
    set_seed(args)
    tokenizer = load_tokenizer(args)

    train_dataset = load_and_cache_examples(args, tokenizer, mode="train_file")
    test_dataset = load_and_cache_examples(args, tokenizer, mode="eval_file")

    trainer = Trainer(args, train_dataset=train_dataset, test_dataset=test_dataset)


    if args.do_train:
        trainer.train()

In [10]:
class Trainer_args(object):
    def __init__(self,
                model_name_or_path = 'bert-base-multilingual-cased',
                seed = 24,
                task = "semeval",
                train_file = 'train_balanced.csv', 
                test_file = 'eval_balanced.csv',
                label_file = 'label.txt',  
                dropout_rate = 0.1,
                num_labels = 2,
                learning_rate = 2e-5,
                num_train_epochs = 22,
                max_seq_len = 384,
                train_batch_size = 16,
                eval_batch_size = 16,
                adam_epsilon = 1e-8,
                gradient_accumulation_steps = 1,
                max_grad_norm = 1.0,
                logging_steps = 250,
                save_steps = 250,
                weight_decay = 0.0,
                add_sep_token = True,
                do_train = True,
                no_cuda = True,
                do_eval = True,
                max_steps = -1,
                warmup_steps = 0,
                model_dir = 'model/',
                data_dir = '../data/',
                eval_dir = '../data/'
                ):

        super(Trainer_args, self).__init__()

        self.train_file = train_file
        self.test_file = test_file
        self.dropout_rate = dropout_rate
        self.num_labels = num_labels
        self.learning_rate = learning_rate
        self.num_train_epochs = num_train_epochs
        self.max_seq_len = max_seq_len
        self.train_batch_size = train_batch_size
        self.adam_epsilon = adam_epsilon
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.max_grad_norm = max_grad_norm
        self.logging_steps = logging_steps
        self.save_steps = save_steps
        self.weight_decay = weight_decay
        self.data_dir = data_dir
        self.model_name_or_path = model_name_or_path
        self.seed = seed
        self.task = task
        self.add_sep_token = add_sep_token
        self.do_train = do_train
        self.no_cuda = no_cuda
        self.max_steps = max_steps
        self.warmup_steps = warmup_steps
        self.model_dir = model_dir
        self.label_file = label_file
        self.eval_batch_size = eval_batch_size
        self.do_eval = do_eval
        self.eval_dir = eval_dir
        return 
args = Trainer_args()

In [3]:
train_path = '../data/train_balanced.csv'

In [None]:
#11 epochs

In [12]:
main_model = RBERT_re(args)

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing RBERT: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing RBERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RBERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RBERT were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['cls_fc_layer.linear.weight


====Evaluation====



Evaluating:  12%|████                            | 1/8 [00:40<04:44, 40.68s/it][A
Evaluating:  25%|████████                        | 2/8 [01:21<04:04, 40.67s/it][A
Evaluating:  38%|████████████                    | 3/8 [02:02<03:23, 40.69s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:42<02:42, 40.67s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:23<02:01, 40.63s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [04:03<01:21, 40.67s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:44<00:40, 40.73s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [05:13<00:00, 39.17s/it][A
Epoch:   5%|█▍                              | 1/22 [23:44<8:18:31, 1424.35s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.6818543449044228, 'accuracy': 0.5853658536585366, 'f1_score': 0.6247049567269867, 'roc_auc': 0.6950757575757576}



Iteration:  11%|███▌                            | 1/9 [01:58<15:49, 118.72s/it][A
Iteration:  22%|███████                         | 2/9 [04:02<14:01, 120.27s/it][A
Iteration:  33%|██████████▋                     | 3/9 [06:09<12:13, 122.19s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [08:27<10:34, 126.86s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [10:38<08:32, 128.16s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [12:54<06:31, 130.47s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [15:07<04:22, 131.27s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [17:17<02:10, 130.84s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [18:42<00:00, 124.68s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:40<04:43, 40.45s/it][A
Evaluating:  25%|████████                        | 2/8 [01:21<04:03, 40.53s/it][A
Evaluating:  38%|████████████                    | 3/8 [02:01<03:22, 40.55s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:42<02:41, 40.46s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:20<01:59, 39.87s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [03:59<01:19, 39.74s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:40<00:39, 39.93s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [05:08<00:00, 38.57s/it][A
Epoch:   9%|██▉                             | 2/22 [47:35<7:55:30, 1426.52s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.6204676330089569, 'accuracy': 0.6747967479674797, 'f1_score': 0.7085068411659727, 'roc_auc': 0.7506313131313131}



Iteration:  11%|███▌                            | 1/9 [01:59<15:56, 119.58s/it][A
Iteration:  22%|███████                         | 2/9 [04:03<14:05, 120.80s/it][A
Iteration:  33%|██████████▋                     | 3/9 [06:11<12:17, 122.86s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [08:28<10:34, 126.97s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [10:26<08:18, 124.55s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [12:23<06:06, 122.17s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [14:29<04:06, 123.15s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [16:37<02:04, 124.68s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [18:01<00:00, 120.16s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:39<04:39, 39.88s/it][A
Evaluating:  25%|████████                        | 2/8 [01:20<03:59, 39.97s/it][A
Evaluating:  38%|████████████                    | 3/8 [02:00<03:21, 40.22s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:41<02:41, 40.27s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:21<02:01, 40.38s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [04:02<01:20, 40.49s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:43<00:40, 40.52s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [05:11<00:00, 38.97s/it][A
Epoch:  14%|████                          | 3/22 [1:10:50<7:28:40, 1416.89s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.5123746395111084, 'accuracy': 0.7886178861788617, 'f1_score': 0.8057491289198606, 'roc_auc': 0.7897727272727272}



Iteration:  11%|███▌                            | 1/9 [01:58<15:50, 118.79s/it][A
Iteration:  22%|███████                         | 2/9 [03:55<13:45, 117.99s/it][A
Iteration:  33%|██████████▋                     | 3/9 [06:02<12:03, 120.66s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [08:04<10:06, 121.26s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [10:05<08:04, 121.15s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:57<05:54, 118.25s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [13:50<03:53, 116.74s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [15:39<01:54, 114.42s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [17:04<00:00, 113.82s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:38<04:32, 38.87s/it][A
Evaluating:  25%|████████                        | 2/8 [01:18<03:53, 38.96s/it][A
Evaluating:  38%|████████████                    | 3/8 [01:56<03:14, 38.81s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:36<02:36, 39.18s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:19<02:00, 40.26s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [03:59<01:20, 40.24s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:39<00:40, 40.24s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [05:08<00:00, 38.53s/it][A
Epoch:  18%|█████▍                        | 4/22 [1:33:05<6:57:44, 1392.49s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.611476581543684, 'accuracy': 0.6747967479674797, 'f1_score': 0.7085068411659727, 'roc_auc': 0.7506313131313131}



Iteration:  11%|███▌                            | 1/9 [01:48<14:31, 108.89s/it][A
Iteration:  22%|███████                         | 2/9 [03:45<12:57, 111.07s/it][A
Iteration:  33%|██████████▋                     | 3/9 [05:42<11:17, 112.90s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [07:34<09:23, 112.65s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [09:26<07:29, 112.41s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:34<05:50, 116.98s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [13:25<03:50, 115.35s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [15:17<01:54, 114.44s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [16:39<00:00, 111.07s/it][A

Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A


====Evaluation====



Evaluating:  12%|████                            | 1/8 [00:40<04:43, 40.49s/it][A
Evaluating:  25%|████████                        | 2/8 [01:21<04:03, 40.62s/it][A
Evaluating:  38%|████████████                    | 3/8 [02:02<03:23, 40.65s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:40<02:39, 39.96s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:20<01:59, 39.98s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [04:00<01:20, 40.07s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:40<00:40, 40.06s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [05:08<00:00, 38.58s/it][A
Epoch:  23%|██████▊                       | 5/22 [1:54:55<6:27:27, 1367.48s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.4500096905976534, 'accuracy': 0.7804878048780488, 'f1_score': 0.7999625797975648, 'roc_auc': 0.8005050505050506}



Iteration:  11%|███▌                            | 1/9 [01:51<14:53, 111.71s/it][A
Iteration:  22%|███████                         | 2/9 [03:44<13:03, 111.99s/it][A
Iteration:  33%|██████████▋                     | 3/9 [05:37<11:14, 112.38s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [07:31<09:23, 112.69s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [09:30<07:38, 114.53s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:19<05:38, 112.83s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [13:10<03:44, 112.46s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [15:01<01:51, 111.85s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [16:20<00:00, 108.90s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:38<04:28, 38.37s/it][A
Evaluating:  25%|████████                        | 2/8 [01:16<03:50, 38.35s/it][A
Evaluating:  38%|████████████                    | 3/8 [01:55<03:12, 38.40s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:34<02:35, 38.81s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:13<01:56, 38.75s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [03:52<01:17, 38.77s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:30<00:38, 38.71s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [04:57<00:00, 37.18s/it][A
Epoch:  27%|████████▏                     | 6/22 [2:16:15<5:57:42, 1341.42s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.417155088391155, 'accuracy': 0.8617886178861789, 'f1_score': 0.8665177198257465, 'roc_auc': 0.8194444444444444}



Iteration:  11%|███▌                            | 1/9 [01:49<14:32, 109.05s/it][A
Iteration:  22%|███████                         | 2/9 [03:40<12:48, 109.80s/it][A
Iteration:  33%|██████████▋                     | 3/9 [05:31<11:00, 110.04s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [07:18<09:05, 109.13s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [09:12<07:23, 110.78s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:03<05:32, 110.73s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [12:54<03:41, 110.83s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [14:45<01:50, 110.73s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [16:11<00:00, 107.90s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:39<04:38, 39.81s/it][A
Evaluating:  25%|████████                        | 2/8 [01:20<03:59, 39.95s/it][A
Evaluating:  38%|████████████                    | 3/8 [02:00<03:19, 39.98s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:40<02:40, 40.12s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:20<02:00, 40.18s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [04:00<01:20, 40.12s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:41<00:40, 40.19s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [05:09<00:00, 38.70s/it][A
Epoch:  32%|█████████▌                    | 7/22 [2:37:39<5:31:01, 1324.11s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.5676400037482381, 'accuracy': 0.7886178861788617, 'f1_score': 0.8068736141906874, 'roc_auc': 0.8055555555555555}



Iteration:  11%|███▌                            | 1/9 [01:54<15:12, 114.10s/it][A
Iteration:  22%|███████                         | 2/9 [03:53<13:28, 115.56s/it][A
Iteration:  33%|██████████▋                     | 3/9 [05:42<11:22, 113.81s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [07:30<09:20, 112.09s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [09:21<07:26, 111.64s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:12<05:33, 111.30s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [13:06<03:44, 112.35s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [14:53<01:50, 110.75s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [16:12<00:00, 108.10s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:39<04:38, 39.76s/it][A
Evaluating:  25%|████████                        | 2/8 [01:17<03:55, 39.30s/it][A
Evaluating:  38%|████████████                    | 3/8 [01:56<03:15, 39.09s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:34<02:35, 38.87s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:13<01:56, 38.73s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [03:51<01:17, 38.66s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:32<00:39, 39.11s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [04:59<00:00, 37.48s/it][A
Epoch:  36%|██████████▉                   | 8/22 [2:58:56<5:05:40, 1310.07s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.5474206642247736, 'accuracy': 0.8455284552845529, 'f1_score': 0.852527832274991, 'roc_auc': 0.8093434343434344}



Iteration:  11%|███▌                            | 1/9 [01:47<14:19, 107.49s/it][A
Iteration:  22%|███████                         | 2/9 [03:39<12:41, 108.75s/it][A
Iteration:  33%|██████████▋                     | 3/9 [05:30<10:57, 109.53s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [07:22<09:11, 110.24s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [09:14<07:22, 110.74s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:06<05:33, 111.12s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [12:57<03:41, 110.88s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [14:48<01:51, 111.06s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [16:14<00:00, 108.32s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:40<04:40, 40.10s/it][A
Evaluating:  25%|████████                        | 2/8 [01:20<04:01, 40.23s/it][A
Evaluating:  38%|████████████                    | 3/8 [02:01<03:21, 40.34s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:41<02:41, 40.34s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:20<01:59, 39.94s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [04:01<01:20, 40.13s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:41<00:40, 40.30s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [05:10<00:00, 38.75s/it][A
Epoch:  41%|████████████▎                 | 9/22 [3:20:23<4:42:21, 1303.16s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.5264310580678284, 'accuracy': 0.8780487804878049, 'f1_score': 0.8806974947577528, 'roc_auc': 0.8295454545454546}



Iteration:  11%|███▌                            | 1/9 [01:52<15:02, 112.76s/it][A
Iteration:  22%|███████                         | 2/9 [03:42<13:02, 111.75s/it][A
Iteration:  33%|██████████▋                     | 3/9 [05:30<11:03, 110.61s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [07:21<09:13, 110.73s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [09:14<07:25, 111.48s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:05<05:34, 111.36s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [13:01<03:45, 112.74s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [14:56<01:53, 113.28s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [16:16<00:00, 108.53s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:38<04:29, 38.48s/it][A
Evaluating:  25%|████████                        | 2/8 [01:16<03:50, 38.47s/it][A
Evaluating:  38%|████████████                    | 3/8 [01:55<03:12, 38.42s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:33<02:33, 38.42s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:12<01:55, 38.44s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [03:50<01:16, 38.40s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:28<00:38, 38.36s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [04:55<00:00, 36.93s/it][A
Epoch:  45%|█████████████▏               | 10/22 [3:41:39<4:19:00, 1295.02s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.5962619823403656, 'accuracy': 0.8699186991869918, 'f1_score': 0.8751129177958449, 'roc_auc': 0.8402777777777777}



Iteration:  11%|███▌                            | 1/9 [01:47<14:13, 106.73s/it][A
Iteration:  22%|███████                         | 2/9 [03:57<12:58, 111.24s/it][A
Iteration:  33%|██████████▋                     | 3/9 [05:50<11:26, 114.35s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [07:36<09:20, 112.08s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [09:25<07:24, 111.11s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:17<05:33, 111.32s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [13:08<03:42, 111.28s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [15:02<01:52, 112.15s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [16:27<00:00, 109.71s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:39<04:37, 39.61s/it][A
Evaluating:  25%|████████                        | 2/8 [01:19<03:57, 39.61s/it][A
Evaluating:  38%|████████████                    | 3/8 [01:58<03:17, 39.57s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:38<02:38, 39.69s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:18<01:59, 39.70s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [03:57<01:18, 39.45s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:37<00:39, 39.72s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [05:05<00:00, 38.21s/it][A
Epoch:  50%|██████████████▌              | 11/22 [4:03:15<3:57:28, 1295.35s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.5117664849385619, 'accuracy': 0.8943089430894309, 'f1_score': 0.8951191718485091, 'roc_auc': 0.8396464646464646}



Iteration:  11%|███▌                            | 1/9 [01:48<14:26, 108.27s/it][A
Iteration:  22%|███████                         | 2/9 [03:40<12:45, 109.33s/it][A
Iteration:  33%|██████████▋                     | 3/9 [05:41<11:16, 112.82s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [07:30<09:19, 111.84s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [09:22<07:26, 111.72s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:14<05:36, 112.03s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [13:09<03:45, 112.82s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [15:02<01:52, 112.85s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [16:27<00:00, 109.78s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:40<04:40, 40.06s/it][A
Evaluating:  25%|████████                        | 2/8 [01:20<04:01, 40.25s/it][A
Evaluating:  38%|████████████                    | 3/8 [02:00<03:19, 39.97s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:38<02:37, 39.45s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:16<01:57, 39.07s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [03:55<01:17, 38.92s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:33<00:38, 38.75s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [04:59<00:00, 37.47s/it][A
Epoch:  55%|███████████████▊             | 12/22 [4:24:46<3:35:38, 1293.89s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.5787388291209936, 'accuracy': 0.8699186991869918, 'f1_score': 0.8751129177958449, 'roc_auc': 0.8402777777777777}



Iteration:  11%|███▌                            | 1/9 [01:46<14:12, 106.55s/it][A
Iteration:  22%|███████                         | 2/9 [03:37<12:34, 107.81s/it][A
Iteration:  33%|██████████▋                     | 3/9 [05:27<10:50, 108.43s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [07:18<09:06, 109.24s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [09:07<07:17, 109.28s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:06<05:34, 111.35s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [13:05<03:48, 114.33s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [14:55<01:53, 113.06s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [16:18<00:00, 108.67s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:40<04:43, 40.56s/it][A
Evaluating:  25%|████████                        | 2/8 [01:22<04:05, 40.84s/it][A
Evaluating:  38%|████████████                    | 3/8 [02:01<03:21, 40.34s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:41<02:40, 40.23s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:21<02:00, 40.18s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [04:01<01:20, 40.22s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:42<00:40, 40.28s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [05:10<00:00, 38.78s/it][A
Epoch:  59%|█████████████████▏           | 13/22 [4:46:18<3:14:00, 1293.36s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.7072811368852854, 'accuracy': 0.8292682926829268, 'f1_score': 0.8425962948154072, 'roc_auc': 0.8465909090909092}



Iteration:  11%|███▌                            | 1/9 [01:49<14:36, 109.59s/it][A
Iteration:  22%|███████                         | 2/9 [03:41<12:50, 110.14s/it][A
Iteration:  33%|██████████▋                     | 3/9 [05:32<11:02, 110.41s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [07:23<09:13, 110.72s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [09:12<07:21, 110.26s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:04<05:31, 110.57s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [12:55<03:41, 110.76s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [14:47<01:51, 111.19s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [16:12<00:00, 108.09s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:39<04:37, 39.64s/it][A
Evaluating:  25%|████████                        | 2/8 [01:19<03:58, 39.79s/it][A
Evaluating:  38%|████████████                    | 3/8 [01:59<03:19, 39.91s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:38<02:38, 39.58s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:18<01:59, 39.71s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [03:59<01:19, 39.94s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:38<00:39, 39.70s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [05:06<00:00, 38.34s/it][A
Epoch:  64%|██████████████████▍          | 14/22 [5:07:38<2:51:55, 1289.49s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.553994054440409, 'accuracy': 0.8699186991869918, 'f1_score': 0.8735807818030599, 'roc_auc': 0.8244949494949495}



Iteration:  11%|███▌                            | 1/9 [01:54<15:15, 114.39s/it][A
Iteration:  22%|███████                         | 2/9 [03:44<13:10, 112.99s/it][A
Iteration:  33%|██████████▋                     | 3/9 [05:35<11:15, 112.62s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [07:28<09:22, 112.56s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [09:15<07:23, 110.86s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:11<05:37, 112.60s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [12:58<03:41, 110.95s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [14:59<01:53, 113.84s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [16:19<00:00, 108.85s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:39<04:38, 39.80s/it][A
Evaluating:  25%|████████                        | 2/8 [01:19<03:58, 39.79s/it][A
Evaluating:  38%|████████████                    | 3/8 [01:57<03:16, 39.28s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:36<02:36, 39.01s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:15<01:57, 39.03s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [03:54<01:18, 39.26s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:33<00:39, 39.01s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [04:59<00:00, 37.47s/it][A
Epoch:  68%|███████████████████▊         | 15/22 [5:29:05<2:30:19, 1288.50s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.5448995255865157, 'accuracy': 0.8780487804878049, 'f1_score': 0.8806974947577528, 'roc_auc': 0.8295454545454546}



Iteration:  11%|███▌                            | 1/9 [01:46<14:14, 106.80s/it][A
Iteration:  22%|███████                         | 2/9 [03:37<12:36, 108.03s/it][A
Iteration:  33%|██████████▋                     | 3/9 [05:35<11:05, 110.98s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [07:29<09:18, 111.65s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [09:25<07:32, 113.19s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:12<05:34, 111.43s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [13:02<03:41, 110.92s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [14:52<01:50, 110.72s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [16:17<00:00, 108.65s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:39<04:38, 39.80s/it][A
Evaluating:  25%|████████                        | 2/8 [01:19<03:59, 39.88s/it][A
Evaluating:  38%|████████████                    | 3/8 [02:00<03:19, 40.00s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:38<02:38, 39.59s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:19<01:59, 39.84s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [03:58<01:19, 39.55s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:36<00:39, 39.16s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [05:02<00:00, 37.83s/it][A
Epoch:  73%|█████████████████████        | 16/22 [5:50:28<2:08:42, 1287.03s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.5458987597376108, 'accuracy': 0.8780487804878049, 'f1_score': 0.8806974947577528, 'roc_auc': 0.8295454545454546}



Iteration:  11%|███▌                            | 1/9 [01:49<14:36, 109.55s/it][A
Iteration:  22%|███████                         | 2/9 [03:41<12:52, 110.33s/it][A
Iteration:  33%|██████████▋                     | 3/9 [05:33<11:04, 110.68s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [07:25<09:16, 111.20s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [09:19<07:27, 111.92s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:10<05:35, 111.79s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [13:02<03:43, 111.82s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [15:03<01:54, 114.43s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [16:22<00:00, 109.22s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:38<04:26, 38.06s/it][A
Evaluating:  25%|████████                        | 2/8 [01:16<03:48, 38.06s/it][A
Evaluating:  38%|████████████                    | 3/8 [01:54<03:10, 38.11s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:32<02:32, 38.14s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:10<01:54, 38.16s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [03:49<01:16, 38.26s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:27<00:38, 38.25s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [04:53<00:00, 36.73s/it][A
Epoch:  77%|██████████████████████▍      | 17/22 [6:11:47<1:47:03, 1284.65s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.5674653183668852, 'accuracy': 0.8780487804878049, 'f1_score': 0.8822215174933057, 'roc_auc': 0.8453282828282828}



Iteration:  11%|███▌                            | 1/9 [01:50<14:46, 110.78s/it][A
Iteration:  22%|███████                         | 2/9 [03:42<12:55, 110.73s/it][A
Iteration:  33%|██████████▋                     | 3/9 [05:37<11:12, 112.14s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [07:24<09:13, 110.72s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [09:15<07:23, 110.77s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:08<05:34, 111.61s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [12:56<03:40, 110.48s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [14:44<01:49, 109.76s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [16:03<00:00, 107.08s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:38<04:27, 38.21s/it][A
Evaluating:  25%|████████                        | 2/8 [01:16<03:48, 38.10s/it][A
Evaluating:  38%|████████████                    | 3/8 [01:56<03:13, 38.69s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:34<02:34, 38.58s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:12<01:55, 38.53s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [03:54<01:19, 39.56s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:33<00:39, 39.19s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [04:59<00:00, 37.43s/it][A
Epoch:  82%|███████████████████████▋     | 18/22 [6:32:57<1:25:20, 1280.22s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.584462080616504, 'accuracy': 0.8780487804878049, 'f1_score': 0.8822215174933057, 'roc_auc': 0.8453282828282828}



Iteration:  11%|███▌                            | 1/9 [01:48<14:30, 108.86s/it][A
Iteration:  22%|███████                         | 2/9 [03:40<12:46, 109.56s/it][A
Iteration:  33%|██████████▋                     | 3/9 [05:41<11:17, 112.95s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [07:33<09:23, 112.79s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [09:25<07:30, 112.72s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:18<05:37, 112.57s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [13:08<03:43, 111.81s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [14:59<01:51, 111.73s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [16:22<00:00, 109.18s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:38<04:29, 38.48s/it][A
Evaluating:  25%|████████                        | 2/8 [01:18<03:53, 38.84s/it][A
Evaluating:  38%|████████████                    | 3/8 [01:57<03:14, 38.92s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:36<02:36, 39.06s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:17<01:58, 39.48s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [03:57<01:19, 39.75s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:37<00:39, 39.89s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [05:04<00:00, 38.07s/it][A
Epoch:  86%|█████████████████████████    | 19/22 [6:54:27<1:04:09, 1283.06s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.5873709497973323, 'accuracy': 0.8780487804878049, 'f1_score': 0.8822215174933057, 'roc_auc': 0.8453282828282828}



Iteration:  11%|███▌                            | 1/9 [01:51<14:51, 111.48s/it][A
Iteration:  22%|███████                         | 2/9 [03:43<13:02, 111.74s/it][A
Iteration:  33%|██████████▋                     | 3/9 [05:35<11:09, 111.59s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [07:38<09:34, 114.96s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [09:23<07:28, 112.15s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:19<05:38, 112.88s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [13:14<03:47, 113.90s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [15:03<01:52, 112.44s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [16:21<00:00, 109.08s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:37<04:24, 37.86s/it][A
Evaluating:  25%|████████                        | 2/8 [01:15<03:47, 37.92s/it][A
Evaluating:  38%|████████████                    | 3/8 [01:54<03:09, 37.97s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:32<02:32, 38.03s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:11<01:55, 38.51s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [03:50<01:17, 38.58s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:29<00:38, 38.61s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [04:55<00:00, 36.95s/it][A
Epoch:  91%|████████████████████████████▏  | 20/22 [7:15:48<42:44, 1282.50s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.5836742650717497, 'accuracy': 0.8780487804878049, 'f1_score': 0.8822215174933057, 'roc_auc': 0.8453282828282828}



Iteration:  11%|███▌                            | 1/9 [01:46<14:10, 106.35s/it][A
Iteration:  22%|███████                         | 2/9 [03:36<12:32, 107.48s/it][A
Iteration:  33%|██████████▋                     | 3/9 [05:38<11:09, 111.64s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [07:29<09:17, 111.57s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [09:33<07:41, 115.30s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:24<05:42, 114.11s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [13:17<03:47, 113.66s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [15:08<01:53, 113.04s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [16:36<00:00, 110.68s/it][A

Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A


====Evaluation====



Evaluating:  12%|████                            | 1/8 [00:38<04:29, 38.56s/it][A
Evaluating:  25%|████████                        | 2/8 [01:16<03:50, 38.39s/it][A
Evaluating:  38%|████████████                    | 3/8 [01:54<03:11, 38.31s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:32<02:32, 38.21s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:10<01:54, 38.24s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [03:48<01:16, 38.18s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:27<00:38, 38.20s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [04:53<00:00, 36.70s/it][A
Epoch:  95%|█████████████████████████████▌ | 21/22 [7:37:18<21:24, 1284.84s/it]
Iteration:   0%|                                         | 0/9 [00:00<?, ?it/s][A


Evaluation:  {'loss': 0.5862396312877536, 'accuracy': 0.8780487804878049, 'f1_score': 0.8822215174933057, 'roc_auc': 0.8453282828282828}



Iteration:  11%|███▌                            | 1/9 [01:51<14:51, 111.44s/it][A
Iteration:  22%|███████                         | 2/9 [03:43<13:00, 111.49s/it][A
Iteration:  33%|██████████▋                     | 3/9 [05:33<11:07, 111.30s/it][A
Iteration:  44%|██████████████▏                 | 4/9 [07:29<09:22, 112.58s/it][A
Iteration:  56%|█████████████████▊              | 5/9 [09:19<07:26, 111.73s/it][A
Iteration:  67%|█████████████████████▎          | 6/9 [11:08<05:32, 111.00s/it][A
Iteration:  78%|████████████████████████▉       | 7/9 [12:58<03:41, 110.62s/it][A
Iteration:  89%|████████████████████████████▍   | 8/9 [14:48<01:50, 110.57s/it][A
Iteration: 100%|████████████████████████████████| 9/9 [16:14<00:00, 108.29s/it][A



====Evaluation====



Evaluating:   0%|                                        | 0/8 [00:00<?, ?it/s][A
Evaluating:  12%|████                            | 1/8 [00:40<04:41, 40.26s/it][A
Evaluating:  25%|████████                        | 2/8 [01:20<04:00, 40.13s/it][A
Evaluating:  38%|████████████                    | 3/8 [01:59<03:19, 39.93s/it][A
Evaluating:  50%|████████████████                | 4/8 [02:37<02:37, 39.39s/it][A
Evaluating:  62%|████████████████████            | 5/8 [03:16<01:57, 39.09s/it][A
Evaluating:  75%|████████████████████████        | 6/8 [03:54<01:17, 38.79s/it][A
Evaluating:  88%|████████████████████████████    | 7/8 [04:32<00:38, 38.68s/it][A
Evaluating: 100%|████████████████████████████████| 8/8 [04:58<00:00, 37.36s/it][A
Epoch: 100%|███████████████████████████████| 22/22 [7:58:36<00:00, 1305.31s/it]


Evaluation:  {'loss': 0.5873853485099971, 'accuracy': 0.8780487804878049, 'f1_score': 0.8822215174933057, 'roc_auc': 0.8453282828282828}





Prediction

In [34]:
class Predict_args(object):
      def __init__(self,
                   input_file = '../data/test_file.csv',
                   output_file = 'sample_pred_out_multilingual.txt',
                   model_dir = 'model/',
                   model_name_or_path = 'bert-base-multilingual-cased',
                   num_labels = 2,
                   add_sep_token = True,
                   dropout_rate = 0.1,
                   max_seq_len = 384,
                   batch_size = 16,
                   no_cuda = True):
  
        super(Predict_args, self).__init__()
        self.input_file = input_file
        self.output_file = output_file
        self.model_dir = model_dir
        self.batch_size = batch_size
        self.no_cuda = no_cuda
        self.model_name_or_path = model_name_or_path
        self.add_sep_token = add_sep_token
        self.max_seq_len = max_seq_len
        self.dropout_rate = dropout_rate
        self.num_labels = num_labels
        return

pred_config = Predict_args()

In [35]:
examples = []
with open('../data/test_label.csv', "r") as f:
    for line in f.readlines():
        line = line.strip().split("\t")
        examples.append(line)

In [5]:
y_true = [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0,
       0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0]

In [44]:
prediction = predict(pred_config)

KeyboardInterrupt: 

In [52]:
prediction = evaluate(pred_config)

Predicting: 100%|████████████████████████████████| 5/5 [03:13<00:00, 38.67s/it]


In [36]:
prediction

array([1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
       1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1,
       1, 0, 1, 0, 0, 1, 1, 1], dtype=int64)

In [18]:
pred_tr = []
for i in prediction:
    if i==1:
        a = 0
    if i==0:
        a = 1
    pred_tr.append(a)        

In [20]:
pred_np = np.array(pred_tr)
pred_np = np.reshape(pred_np,(74,1))
y_true = np.array(y_true)
y_true = np.reshape(y_true,(74,1))

In [21]:
roc_auc_score(y_true, pred_np)

0.8852813852813852

In [22]:
f1_score(y_true, pred_np, average='weighted')

0.9336310223266745

In [25]:
# pred_np = np.array(prediction)
# y_true = np.array(test_label)

In [17]:
prediction

array([1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
       1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1,
       1, 0, 1, 0, 0, 0, 1, 1], dtype=int64)

In [18]:
pr = [0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
       0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
       0, 1, 0, 1, 1, 1, 0, 0]