In [1]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, EarlyStoppingCallback
from datasets import load_dataset, Dataset
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import os

from utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

In [2]:
def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [3]:
MODEL_NAMES = ['xlnet-base-cased'] # ['bert-base-uncased', 'xlnet-base-cased']

In [4]:
use_pretrain = False

for t in ['ORIG', 'INV', 'SIB']: 
    for MODEL_NAME in MODEL_NAMES:
        
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        
        dataset = load_dataset('glue', 'sst2')['train']
        dataset.rename_column_('sentence', 'text')
        dataset = dataset.train_test_split(test_size=0.1)
        train_dataset = dataset['train']
        test_dataset = dataset['test']
        
        if t == 'ORIG':
            checkpoint = 'pretrained/bert-base-uncased-sst2-ORIG'
            if use_pretrain and os.path.exists(checkpoint):
                MODEL_NAME = checkpoint
        if t == 'INV':
            text = npy_load("./assets/SST2/sentiment/INV/text.npy")
            label = npy_load("./assets/SST2/sentiment/INV/label.npy")
            df = pd.DataFrame({'text': text, 'label': label})
            df.text = df.text.astype(str)
            df.label = df.label.astype(int)
            train_dataset = Dataset.from_pandas(df)
            checkpoint = 'pretrained/bert-base-uncased-sst2-INV'
            if use_pretrain and os.path.exists(checkpoint):
                MODEL_NAME = checkpoint
        if t == 'SIB':
            text = npy_load("./assets/SST2/sentiment/SIB/text.npy")
            label = npy_load("./assets/SST2/sentiment/SIB/label.npy")
            df = pd.DataFrame({'text': text, 'label': label})
            df.text = df.text.astype(str)
            df.label = df.label.astype(int)
            train_dataset = Dataset.from_pandas(df)
            checkpoint = 'pretrained/bert-base-uncased-sst2-SIB'
            if use_pretrain and os.path.exists(checkpoint):
                MODEL_NAME = checkpoint
               
        train_dataset.shuffle()
        
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME).to(device)
            
        train_dataset = train_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))
        test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))
        train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
        test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
        
        train_batch_size = 8
        eval_batch_size = 8
        num_epoch = 10
        max_steps = int((len(train_dataset) * num_epoch) / train_batch_size)

        training_args = TrainingArguments(
            output_dir='./pretrained/' + MODEL_NAME + '-sst2-' + t,
            overwrite_output_dir=True,
            max_steps=max_steps,
            save_steps=int(max_steps / 10),
            save_total_limit=1,
            per_device_train_batch_size=train_batch_size,
            per_device_eval_batch_size=eval_batch_size,
            warmup_steps=int(max_steps / 100),
            weight_decay=0.01,
            logging_dir='./logs',
            logging_steps=int(max_steps / 10),
            load_best_model_at_end=True,
            metric_for_best_model="loss",
            greater_is_better=False,
            evaluation_strategy="steps"
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            compute_metrics=compute_metrics,
            train_dataset=train_dataset,
            eval_dataset=test_dataset,
            callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
        )

        trainer.train()
        out = trainer.evaluate()
        print(out)

Reusing dataset glue (C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)
Loading cached split indices for dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4\cache-934f1dcf99ce2ea2.arrow and C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4\cache-609e7de9225e8e34.arrow
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the c

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4\cache-95211b176b12226e.arrow





  return torch.tensor(x, **format_kwargs)


Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
8418,0.615646,0.863095,0.556496,0.715062,0.556496,1.0
16836,0.693281,0.690105,0.556496,0.715062,0.556496,1.0
25254,0.69049,0.688207,0.556496,0.715062,0.556496,1.0
33672,0.688121,0.687944,0.556496,0.715062,0.556496,1.0
42090,0.687353,0.687769,0.556496,0.715062,0.556496,1.0
50508,0.686856,0.689413,0.556496,0.715062,0.556496,1.0
58926,0.686216,0.68677,0.556496,0.715062,0.556496,1.0
67344,0.6861,0.694814,0.556496,0.715062,0.556496,1.0
75762,0.685653,0.693015,0.556496,0.715062,0.556496,1.0
84180,0.685217,0.692294,0.556496,0.715062,0.556496,1.0


{'eval_loss': 0.686770498752594, 'eval_accuracy': 0.5564959168522643, 'eval_f1': 0.7150624821138987, 'eval_precision': 0.5564959168522643, 'eval_recall': 1.0, 'epoch': 9.998812210476304}


Reusing dataset glue (C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)
Loading cached split indices for dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4\cache-e21c50a49f2c027a.arrow and C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4\cache-f230ca07febc706e.arrow
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the c

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4\cache-c2fe28a15bf0f3cd.arrow





Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
8418,0.638508,0.702481,0.451967,0.0,0.0,0.0
16836,0.700544,0.69533,0.451967,0.0,0.0,0.0
25254,0.697448,0.691391,0.548033,0.708038,0.548033,1.0
33672,0.695551,0.691478,0.548033,0.708038,0.548033,1.0
42090,0.694662,0.690305,0.548033,0.708038,0.548033,1.0
50508,0.694063,0.690998,0.548033,0.708038,0.548033,1.0
58926,0.693731,0.696601,0.451967,0.0,0.0,0.0
67344,0.693531,0.693609,0.451967,0.0,0.0,0.0


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.690304696559906, 'eval_accuracy': 0.5480326651818856, 'eval_f1': 0.7080375983119125, 'eval_precision': 0.5480326651818856, 'eval_recall': 1.0, 'epoch': 7.999049768381043}
