In [1]:
from argparse import ArgumentParser
from pathlib import Path

import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch.nn as nn
from transformers import AutoModel

In [3]:
from dataclasses import dataclass

@dataclass
class Args:
    train_file: str = "../public_data/train/track_a/sun.csv"
    test_file: str = "../public_data/dev/track_a/sun_a.csv"
    model_checkpoint: str = "LazarusNLP/NusaBERT-base"
    output_dir: str = "models"
    num_train_epochs: int = 50
    optim: str = "adamw_torch"
    early_stopping_patience: int = 5
    early_stopping_threshold: float = 0.0
    learning_rate: float = 1e-5
    warmup_ratio: float = 0.1
    weight_decay: float = 0.01
    per_device_train_batch_size: int = 8
    per_device_eval_batch_size: int = 16
    alpha: float = 0.2
    fp16: bool = False
    bf16: bool = False
    hub_model_id: str = "LazarusNLP/NusaBERT-base-CASA"

args = Args()

In [4]:
train_df = pd.read_csv(args.train_file)
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)
test_df = pd.read_csv(args.test_file)

labels = sorted(set(train_df.columns) - set(["id", "text"]))
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for i, l in enumerate(labels)}
test_df.loc[:, labels] = 0

dataset = DatasetDict(
    {
        "train": Dataset.from_pandas(train_df.reset_index(drop=True)),
        "validation": Dataset.from_pandas(val_df.reset_index(drop=True)),
        "test": Dataset.from_pandas(test_df),
    }
)

In [16]:
class SpanEmo(nn.Module):
    def __init__(self, model_checkpoint, output_dropout=0.1, alpha=0.2):
        super().__init__()
        self.alpha = alpha
        self.bert = AutoModel.from_pretrained(model_checkpoint)
        self.ffn = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size),
            nn.Tanh(),
            nn.Dropout(p=output_dropout),
            nn.Linear(self.bert.config.hidden_size, 1),
        )

    def forward(self, input_ids, token_type_ids, attention_mask, label_idxs, labels=None):
        label_idxs = label_idxs[0].long()

        last_hidden_state = self.bert(
            input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
        ).last_hidden_state
        logits = self.ffn(last_hidden_state).squeeze(-1).index_select(dim=1, index=label_idxs)

        if labels is not None:
            bce_loss = F.binary_cross_entropy_with_logits(logits, labels.to(torch.float32))
            corr_loss = self.corr_loss(logits, labels)
            loss = ((1 - self.alpha) * bce_loss) + (self.alpha * corr_loss)
            return {"loss": loss, "logits": logits}
        else:
            return {"logits": logits}

    @staticmethod
    def corr_loss(y_hat, y_true, reduction="mean"):
        """
        :param y_hat: model predictions, shape(batch, classes)
        :param y_true: target labels (batch, classes)
        :param reduction: whether to avg or sum loss
        :return: loss
        """
        loss = torch.zeros(y_true.size(0)).to(y_true.device)
        for idx, (y, y_h) in enumerate(zip(y_true, y_hat.sigmoid())):
            y_z, y_o = (y == 0).nonzero(), y.nonzero()
            if y_o.nelement() != 0:
                output = torch.exp(torch.sub(y_h[y_z], y_h[y_o][:, None]).squeeze(-1)).sum()
                num_comparisons = y_z.size(0) * y_o.size(0)
                loss[idx] = output.div(num_comparisons)
        return loss.mean() if reduction == "mean" else loss.sum()

In [17]:
model = SpanEmo(args.model_checkpoint, alpha=args.alpha)

Some weights of BertModel were not initialized from the model checkpoint at LazarusNLP/NusaBERT-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint)

In [9]:
def preprocess_function(example):
    label_names = "marah jijik takut gembira sedih terkejut"
    tokenized_input = tokenizer(
        label_names, example["text"], truncation=True, max_length=model.bert.config.max_position_embeddings
    )
    tokenized_input["labels"] = [float(example[label]) for label in labels]
    tokenized_input["label_idxs"] = [
        tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"]).index(l) for l in label_names.split()
    ]
    return tokenized_input

tokenized_dataset = dataset.map(preprocess_function, remove_columns=dataset["train"].column_names)

Map: 100%|██████████| 831/831 [00:00<00:00, 1707.40 examples/s]
Map: 100%|██████████| 93/93 [00:00<00:00, 1602.41 examples/s]
Map: 100%|██████████| 199/199 [00:00<00:00, 1642.50 examples/s]


In [59]:
# inputs = preprocess_function(dataset["train"][0])
# inputs = {k: torch.tensor(v).unsqueeze(0) for k, v in inputs.items()}
# labels = inputs.pop("labels")
# model(**inputs)

In [19]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [35]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = sigmoid(predictions)
    predictions = (predictions > 0.5).astype(int)
    labels = labels.astype(int)
    return {"f1": f1_score(y_true=labels, y_pred=predictions, average="macro")}

callbacks = [EarlyStoppingCallback(args.early_stopping_patience, args.early_stopping_threshold)]

In [36]:
training_args = TrainingArguments(
    output_dir="tmp/",
    eval_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=args.per_device_train_batch_size,
    per_device_eval_batch_size=args.per_device_eval_batch_size,
    learning_rate=args.learning_rate,
    warmup_ratio=args.warmup_ratio,
    optim=args.optim,
    weight_decay=args.weight_decay,
    num_train_epochs=args.num_train_epochs,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    fp16=args.fp16,
    bf16=args.bf16,
    report_to="tensorboard",
    label_names=["labels"],
    # push_to_hub=True,
    # hub_model_id=args.hub_model_id,
    # hub_private_repo=True,
)

In [37]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=callbacks,
)

In [None]:
trainer.train()

In [40]:
val_scores, *_ = trainer.predict(tokenized_dataset["validation"])
val_scores = sigmoid(val_scores)

In [42]:
val_labels = val_df[labels].astype(int).to_numpy()

# find best threshold via validation set, apply best threshold to test set
thresholds = np.arange(0.0, 1.0, 0.01)
scores = []

for threshold in thresholds:
    val_prediction = (val_scores > threshold).astype(int)
    metrics_result = f1_score(y_true=val_labels, y_pred=val_prediction, average="macro")
    scores.append((threshold, metrics_result))

In [43]:
scores

[(0.0, 0.2986935391117492),
 (0.01, 0.2986935391117492),
 (0.02, 0.2986935391117492),
 (0.03, 0.2986935391117492),
 (0.04, 0.2986935391117492),
 (0.05, 0.2986935391117492),
 (0.06, 0.2986935391117492),
 (0.07, 0.2986935391117492),
 (0.08, 0.2986935391117492),
 (0.09, 0.2986935391117492),
 (0.1, 0.2986935391117492),
 (0.11, 0.2986935391117492),
 (0.12, 0.2986935391117492),
 (0.13, 0.2986935391117492),
 (0.14, 0.2986935391117492),
 (0.15, 0.2986935391117492),
 (0.16, 0.2986935391117492),
 (0.17, 0.2986935391117492),
 (0.18, 0.2986935391117492),
 (0.19, 0.2986935391117492),
 (0.2, 0.2986935391117492),
 (0.21, 0.2986935391117492),
 (0.22, 0.2986935391117492),
 (0.23, 0.2986935391117492),
 (0.24, 0.2986935391117492),
 (0.25, 0.2986935391117492),
 (0.26, 0.2986935391117492),
 (0.27, 0.2986935391117492),
 (0.28, 0.2986935391117492),
 (0.29, 0.2986935391117492),
 (0.3, 0.2986935391117492),
 (0.31, 0.2986935391117492),
 (0.32, 0.2986935391117492),
 (0.33, 0.2990828772369364),
 (0.34, 0.29947970