In [1]:
from utils import *

from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, EarlyStoppingCallback
from datasets import load_dataset, Dataset
import torch
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

In [2]:
def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True, max_length=250)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=None)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1.mean(),
        'precision': precision.mean(),
        'recall': recall.mean()
    }

def acc_at_k(y_true, y_pred, k=2):
    y_true = torch.tensor(y_true) if type(y_true) != torch.Tensor else y_true
    y_pred = torch.tensor(y_pred) if type(y_pred) != torch.Tensor else y_pred
    total = len(y_true)
    y_weights, y_idx = torch.topk(y_true, k=k, dim=-1)
    out_weights, out_idx = torch.topk(y_pred, k=k, dim=-1)
    correct = torch.sum(torch.eq(y_idx, out_idx) * y_weights)
    acc = correct / total
    return acc.item()

def CEwST_loss(logits, target, reduction='mean'):
    """
    Cross Entropy with Soft Target (CEwST) Loss
    :param logits: (batch, *)
    :param target: (batch, *) same shape as logits, each item must be a valid distribution: target[i, :].sum() == 1.
    """
    logprobs = torch.nn.functional.log_softmax(logits.view(logits.shape[0], -1), dim=1)
    batchloss = - torch.sum(target.view(target.shape[0], -1) * logprobs, dim=1)
    if reduction == 'none':
        return batchloss
    elif reduction == 'mean':
        return torch.mean(batchloss)
    elif reduction == 'sum':
        return torch.sum(batchloss)
    else:
        raise NotImplementedError('Unsupported reduction mode.')

def compute_metrics_w_soft_target(pred):
    labels = pred.label_ids
    preds = pred.predictions
    acc = acc_at_k(labels, preds, k=2)
    return {
        'accuracy': acc,
    }

class Trainer_w_soft_target(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs[0]
        loss = CEwST_loss(logits, labels)
        if return_outputs:
            return loss, outputs
        return loss
    
class DefaultCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        return torch.utils.data.dataloader.default_collate(batch)

In [3]:
MODEL_NAMES = ['bert-base-uncased', 'xlnet-base-cased']

In [4]:
use_pretrain = False
soft_target = False

for t in ['INVSIB']:
    for MODEL_NAME in MODEL_NAMES:
        
        checkpoint = 'pretrained/' + MODEL_NAME
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        
        train_dataset, test_dataset = load_dataset('ag_news', split=['train', 'test'])
        
        if t == 'ORIG':
            checkpoint += '-ag_news-ORIG'
            if use_pretrain and os.path.exists(checkpoint):
                MODEL_NAME = checkpoint
        if t == 'INV':
            text = npy_load("./assets/AG_NEWS/topic/INV/text.npy")
            label = npy_load("./assets/AG_NEWS/topic/INV/label.npy")
            df = pd.DataFrame({'text': text, 'label': label})
            df.text = df.text.astype(str)
            df.label = df.label.astype(int)
            train_dataset = Dataset.from_pandas(df)
            checkpoint += '-ag_news-INV'
            if use_pretrain and os.path.exists(checkpoint):
                MODEL_NAME = checkpoint
        if t == 'SIB-mix':
            text = npy_load("./assets/AG_NEWS/topic/SIB-mix/text.npy")
            label = npy_load("./assets/AG_NEWS/topic/SIB-mix/label.npy")
            df = pd.DataFrame({'text': text, 'label': label.tolist()})
            df.text = df.text.astype(str)
            df.label = df.label.map(lambda y: np.array(y))
            train_dataset = Dataset.from_pandas(df)    
            checkpoint += '-ag_news-SIB-mix'
            if use_pretrain and os.path.exists(checkpoint):
                MODEL_NAME = checkpoint  
        if t == 'INVSIB':
            text = npy_load("./assets/AG_NEWS/topic/INVSIB/text.npy")
            label = npy_load("./assets/AG_NEWS/topic/INVSIB/label.npy")
            df = pd.DataFrame({'text': text, 'label': label.tolist()})
            df.text = df.text.astype(str)
            df.label = df.label.map(lambda y: np.array(y))
            train_dataset = Dataset.from_pandas(df)    
            checkpoint += '-ag_news-INVSIB'
            if use_pretrain and os.path.exists(checkpoint):
                MODEL_NAME = checkpoint 
               
        train_dataset.shuffle()

        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=4).to(device)

        # # reduce training time
        # n = 10000
        # train_dataset = Dataset.from_dict(train_dataset[:n])
                
        train_dataset = train_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))
        test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))
        train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
        test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
        train_dataset.rename_column_('label', 'labels')
        test_dataset.rename_column_('label', 'labels')
        
        if len(np.array(train_dataset['labels']).shape) > 1:
            soft_target = True

        train_batch_size = 3
        eval_batch_size = 3
        num_epoch = 10
        max_steps = int((len(train_dataset) * num_epoch) / train_batch_size)
        
        training_args = TrainingArguments(
            output_dir='./pretrained/' + MODEL_NAME + '-ag_news-' + t,
            overwrite_output_dir=True,
            max_steps=max_steps,
            save_steps=int(max_steps / 10),
            save_total_limit=1,
            per_device_train_batch_size=train_batch_size,
            per_device_eval_batch_size=eval_batch_size,
            warmup_steps=int(max_steps / 100),
            weight_decay=0.01,
            logging_dir='./logs',
            logging_steps=int(max_steps / 10),
            load_best_model_at_end=True,
            metric_for_best_model="loss",
            greater_is_better=False,
            evaluation_strategy="steps"
        )
        
        if soft_target:
            trainer = Trainer_w_soft_target(
                model=model,
                args=training_args,
                compute_metrics=compute_metrics_w_soft_target,
                train_dataset=train_dataset,
                eval_dataset=test_dataset,
                data_collator=DefaultCollator(),
                callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
            )
        else: 
            trainer = Trainer(
                model=model,
                args=training_args,
                compute_metrics=compute_metrics,
                train_dataset=train_dataset,
                eval_dataset=test_dataset,
                callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
            )

        trainer.train()
        out = trainer.evaluate()
        print(out)

Using custom data configuration default
Reusing dataset ag_news (C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly iden

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a\cache-85de0f35a38cd764.arrow
  return torch.tensor(x, **format_kwargs)


Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
40000,0.685116,0.361895,0.000789,0.901296,0.901646,0.901184
80000,1.022712,1.389488,0.000395,0.1,0.0625,0.25
120000,1.390764,1.388991,0.000395,0.1,0.0625,0.25
160000,1.38697,1.387249,0.000395,0.1,0.0625,0.25


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.36189547181129456, 'eval_accuracy': 0.0007894737063907087, 'eval_f1': 0.9012957548842362, 'eval_precision': 0.9016463408425521, 'eval_recall': 0.9011842105263158, 'epoch': 4.0}


Using custom data configuration default
Reusing dataset ag_news (C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a\cache-573c1426115d7013.arrow


Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
40000,1.1514,1.397956,0.0,0.099958,0.062475,0.249868
80000,1.397819,1.389904,0.000395,0.1,0.0625,0.25
120000,1.39593,1.391146,0.000395,0.1,0.0625,0.25
160000,1.394065,1.389629,0.000395,0.1,0.0625,0.25
200000,1.39235,1.387429,0.000395,0.1,0.0625,0.25
240000,1.391227,1.38935,0.000395,0.1,0.0625,0.25
280000,1.38981,1.390773,0.000395,0.1,0.0625,0.25
320000,1.38905,1.386396,0.000395,0.1,0.0625,0.25
360000,1.387786,1.387384,0.000395,0.1,0.0625,0.25
400000,1.38712,1.387332,0.000395,0.1,0.0625,0.25


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 1.3863961696624756, 'eval_accuracy': 0.00039473685319535434, 'eval_f1': 0.1, 'eval_precision': 0.0625, 'eval_recall': 0.25, 'epoch': 10.0}


  _warn_prf(average, modifier, msg_start, len(result))
