In [1]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, EarlyStoppingCallback
from datasets import load_dataset, Dataset
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import os

from utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

In [2]:
def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [3]:
MODEL_NAMES = ['xlnet-base-cased'] # ['bert-base-uncased', 'xlnet-base-cased']

In [5]:
use_pretrain = True

for t in ['ORIG', 'INV', 'SIB']: 
    for MODEL_NAME in MODEL_NAMES:
        
        checkpoint = 'pretrained/' + MODEL_NAME
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        
        dataset = load_dataset('glue', 'sst2')['train']
        dataset.rename_column_('sentence', 'text')
        dataset = dataset.train_test_split(test_size=0.1)
        train_dataset = dataset['train']
        test_dataset = dataset['test']
        
        if t == 'ORIG':
            checkpoint += '-sst2-ORIG'
            if use_pretrain and os.path.exists(checkpoint):
                MODEL_NAME = checkpoint
        if t == 'INV':
            text = npy_load("./assets/SST2/sentiment/INV/text.npy")
            label = npy_load("./assets/SST2/sentiment/INV/label.npy")
            df = pd.DataFrame({'text': text, 'label': label})
            df.text = df.text.astype(str)
            df.label = df.label.astype(int)
            train_dataset = Dataset.from_pandas(df)
            checkpoint += '-sst2-INV'
            if use_pretrain and os.path.exists(checkpoint):
                MODEL_NAME = checkpoint
        if t == 'SIB':
            text = npy_load("./assets/SST2/sentiment/SIB/text.npy")
            label = npy_load("./assets/SST2/sentiment/SIB/label.npy")
            df = pd.DataFrame({'text': text, 'label': label})
            df.text = df.text.astype(str)
            df.label = df.label.astype(int)
            train_dataset = Dataset.from_pandas(df)
            checkpoint += '-sst2-SIB'
            if use_pretrain and os.path.exists(checkpoint):
                MODEL_NAME = checkpoint
               
        train_dataset.shuffle()
        
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME).to(device)
            
        train_dataset = train_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))
        test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))
        train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
        test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
        
        train_batch_size = 8
        eval_batch_size = 8
        num_epoch = 3
        max_steps = int((len(train_dataset) * num_epoch) / train_batch_size)

        training_args = TrainingArguments(
            output_dir='./pretrained/' + MODEL_NAME + '-sst2-' + t,
            overwrite_output_dir=True,
            max_steps=max_steps,
            save_steps=int(max_steps / 10),
            save_total_limit=1,
            per_device_train_batch_size=train_batch_size,
            per_device_eval_batch_size=eval_batch_size,
            warmup_steps=int(max_steps / 100),
            weight_decay=0.01,
            logging_dir='./logs',
            logging_steps=int(max_steps / 10),
            load_best_model_at_end=True,
            metric_for_best_model="loss",
            greater_is_better=False,
            evaluation_strategy="steps"
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            compute_metrics=compute_metrics,
            train_dataset=train_dataset,
            eval_dataset=test_dataset,
            callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
        )

        trainer.train()
        out = trainer.evaluate()
        print(out)

Reusing dataset glue (C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




  return torch.tensor(x, **format_kwargs)


{'eval_loss': 0.15437303483486176, 'eval_accuracy': 0.9536748329621381, 'eval_f1': 0.9589581689029203, 'eval_precision': 0.9551886792452831, 'eval_recall': 0.9627575277337559}


Reusing dataset glue (C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)
Loading cached split indices for dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4\cache-e21c50a49f2c027a.arrow and C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4\cache-f230ca07febc706e.arrow
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4\cache-c2fe28a15bf0f3cd.arrow





{'eval_loss': 0.6888005137443542, 'eval_accuracy': 0.5480326651818856, 'eval_f1': 0.7080375983119125, 'eval_precision': 0.5480326651818856, 'eval_recall': 1.0}


Reusing dataset glue (C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)
Loading cached split indices for dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4\cache-e21c50a49f2c027a.arrow and C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4\cache-f230ca07febc706e.arrow
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\glue\sst2\1.0.0\7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4\cache-c2fe28a15bf0f3cd.arrow





{'eval_loss': 0.690304696559906, 'eval_accuracy': 0.5480326651818856, 'eval_f1': 0.7080375983119125, 'eval_precision': 0.5480326651818856, 'eval_recall': 1.0}
