In [1]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, EarlyStoppingCallback
from datasets import load_dataset, Dataset
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import os

from utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

In [2]:
def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True, max_length=250)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=None)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1.mean(),
        'precision': precision.mean(),
        'recall': recall.mean()
    }

def acc_at_k(y_true, y_pred, k=2):
    y_true = torch.tensor(y_true) if type(y_true) != torch.Tensor else y_true
    y_pred = torch.tensor(y_pred) if type(y_pred) != torch.Tensor else y_pred
    total = len(y_true)
    y_weights, y_idx = torch.topk(y_true, k=k, dim=-1)
    out_weights, out_idx = torch.topk(y_pred, k=k, dim=-1)
    correct = torch.sum(torch.eq(y_idx, out_idx) * y_weights)
    acc = correct / total
    return acc.item()

def CEwST_loss(logits, target, reduction='mean'):
    """
    Cross Entropy with Soft Target (CEwST) Loss
    :param logits: (batch, *)
    :param target: (batch, *) same shape as logits, each item must be a valid distribution: target[i, :].sum() == 1.
    """
    logprobs = torch.nn.functional.log_softmax(logits.view(logits.shape[0], -1), dim=1)
    batchloss = - torch.sum(target.view(target.shape[0], -1) * logprobs, dim=1)
    if reduction == 'none':
        return batchloss
    elif reduction == 'mean':
        return torch.mean(batchloss)
    elif reduction == 'sum':
        return torch.sum(batchloss)
    else:
        raise NotImplementedError('Unsupported reduction mode.')

def compute_metrics_w_soft_target(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=None)
    acc = acc_at_k(labels, preds, k=2)
    return {
        'accuracy': acc,
        'f1': f1.mean(),
        'precision': precision.mean(),
        'recall': recall.mean()
    }

class Trainer_w_soft_target(Trainer):
    def compute_loss(self, model, inputs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs[0]
        return CEwST_loss(logits, labels)
    
class DefaultCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        return torch.utils.data.dataloader.default_collate(batch)

In [3]:
MODEL_NAMES = ['bert-base-uncased', 'xlnet-base-cased']

In [10]:
use_pretrain = False
soft_target = False

for t in ['INVSIB']: # ['ORIG', 'INV', 'SIB', 'INVSIB']: 
    for MODEL_NAME in MODEL_NAMES:
        
        checkpoint = 'pretrained/' + MODEL_NAME
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        
        dataset = load_dataset('glue', 'sst2')['train']
        dataset.rename_column_('sentence', 'text')
        dataset = dataset.train_test_split(test_size=0.1)
        train_dataset = dataset['train']
        test_dataset = dataset['test']
        
        if t == 'ORIG':
            checkpoint += '-sst2-ORIG'
            if use_pretrain and os.path.exists(checkpoint):
                MODEL_NAME = checkpoint
        if t == 'INV':
            text = npy_load("./assets/SST2/sentiment/INV/text.npy")
            label = npy_load("./assets/SST2/sentiment/INV/label.npy")
            df = pd.DataFrame({'text': text, 'label': label})
            df.text = df.text.astype(str)
            df.label = df.label.astype(int)
            train_dataset = Dataset.from_pandas(df)
            checkpoint += '-sst2-INV'
            if use_pretrain and os.path.exists(checkpoint):
                MODEL_NAME = checkpoint
        if t == 'SIB':
            text = npy_load("./assets/SST2/sentiment/SIB/text.npy")
            label = npy_load("./assets/SST2/sentiment/SIB/label.npy")
            df = pd.DataFrame({'text': text, 'label': label})
            df.text = df.text.astype(str)
            df.label = df.label.astype(int)
            train_dataset = Dataset.from_pandas(df)
            checkpoint += '-sst2-SIB'
            if use_pretrain and os.path.exists(checkpoint):
                MODEL_NAME = checkpoint
        if t == 'INVSIB':
            text = npy_load("./assets/SST2/sentiment/INVSIB/text.npy")
            label = npy_load("./assets/SST2/sentiment/INVSIB/label.npy")
            df = pd.DataFrame({'text': text, 'label': label.tolist()})
            df.text = df.text.astype(str)
            df.label = df.label.map(lambda y: np.array(y))
            train_dataset = Dataset.from_pandas(df)
            checkpoint += '-sst2-INVSIB'
            if use_pretrain and os.path.exists(checkpoint):
                MODEL_NAME = checkpoint
               
        train_dataset.shuffle()
        
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME).to(device)
            
        train_dataset = train_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))
        test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))
        train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
        test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
        train_dataset.rename_column_('label', 'labels')
        test_dataset.rename_column_('label', 'labels')
        
        if len(np.array(train_dataset['labels']).shape) > 1:
            soft_target = True
        
        train_batch_size = 8
        eval_batch_size = 8
        num_epoch = 10
        max_steps = int((len(train_dataset) * num_epoch) / train_batch_size)

        training_args = TrainingArguments(
            output_dir='./pretrained/' + MODEL_NAME + '-sst2-' + t,
            overwrite_output_dir=True,
            max_steps=max_steps,
            save_steps=int(max_steps / 10),
            save_total_limit=1,
            per_device_train_batch_size=train_batch_size,
            per_device_eval_batch_size=eval_batch_size,
            warmup_steps=int(max_steps / 100),
            weight_decay=0.01,
            logging_dir='./logs',
            logging_steps=int(max_steps / 10),
            load_best_model_at_end=True,
            metric_for_best_model="loss",
            greater_is_better=False,
            evaluation_strategy="steps"
        )

        if soft_target:
            trainer = Trainer_w_soft_target(
                model=model,
                args=training_args,
                compute_metrics=compute_metrics_w_soft_target,
                train_dataset=train_dataset,
                eval_dataset=test_dataset,
                data_collator=DefaultCollator(),
                callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
            )
        else: 
            trainer = Trainer(
                model=model,
                args=training_args,
                compute_metrics=compute_metrics,
                train_dataset=train_dataset,
                eval_dataset=test_dataset,
                callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
            )

        trainer.train()
        out = trainer.evaluate()
        print(out)

Reusing dataset glue (C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)
Loading cached split indices for dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4\cache-3ed104176ee81a30.arrow and C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4\cache-827cf6e879ea305a.arrow
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4\cache-447e97564ad63924.arrow


Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
8418,0.629923,0.391305,0.000297,0.898779,0.897847,0.900168
16836,0.595918,0.365656,0.000297,0.889543,0.896606,0.886073
25254,0.577574,0.418366,0.000148,0.893991,0.892926,0.896098
33672,0.561308,0.396412,0.000297,0.900453,0.909196,0.89641
42090,0.540606,0.383765,0.000297,0.899164,0.89836,0.900248


{'eval_loss': 0.3656557500362396, 'eval_accuracy': 0.0002969561901409179, 'eval_f1': 0.8895427624938106, 'eval_precision': 0.8966060790239015, 'eval_recall': 0.8860733882195384, 'epoch': 4.999406105238152}


Reusing dataset glue (C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)
Loading cached split indices for dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4\cache-e21c50a49f2c027a.arrow and C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4\cache-f230ca07febc706e.arrow
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the c

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
8418,0.647453,0.493449,0.000297,0.881269,0.880883,0.884289
16836,0.620117,0.597321,0.000297,0.620493,0.754766,0.644529
25254,0.661803,0.691824,0.000297,0.354019,0.274016,0.5
33672,0.693773,0.689583,0.000297,0.354019,0.274016,0.5


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.4934491217136383, 'eval_accuracy': 0.0002969561901409179, 'eval_f1': 0.8812689034457023, 'eval_precision': 0.8808826845270987, 'eval_recall': 0.8842893855886268, 'epoch': 3.9995248841905213}
