In [1]:
from utils import *

from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, EarlyStoppingCallback
from datasets import load_dataset, Dataset
import torch
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

In [2]:
def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=None)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1.mean(),
        'precision': precision.mean(),
        'recall': recall.mean()
    }

def acc_at_k(y_true, y_pred, k=2):
    y_true = torch.tensor(y_true) if type(y_true) != torch.Tensor else y_true
    y_pred = torch.tensor(y_pred) if type(y_pred) != torch.Tensor else y_pred
    total = len(y_true)
    y_weights, y_idx = torch.topk(y_true, k=k, dim=-1)
    out_weights, out_idx = torch.topk(y_pred, k=k, dim=-1)
    correct = torch.sum(torch.eq(y_idx, out_idx) * y_weights)
    acc = correct / total
    return acc.item()

def CEwST_loss(logits, target, reduction='mean'):
    """
    Cross Entropy with Soft Target (CEwST) Loss
    :param logits: (batch, *)
    :param target: (batch, *) same shape as logits, each item must be a valid distribution: target[i, :].sum() == 1.
    """
    logprobs = torch.nn.functional.log_softmax(logits.view(logits.shape[0], -1), dim=1)
    batchloss = - torch.sum(target.view(target.shape[0], -1) * logprobs, dim=1)
    if reduction == 'none':
        return batchloss
    elif reduction == 'mean':
        return torch.mean(batchloss)
    elif reduction == 'sum':
        return torch.sum(batchloss)
    else:
        raise NotImplementedError('Unsupported reduction mode.')

def compute_metrics_w_soft_target(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=None)
    acc = acc_at_k(labels, preds, k=2)
    return {
        'accuracy': acc,
        'f1': f1.mean(),
        'precision': precision.mean(),
        'recall': recall.mean()
    }

class Trainer_w_soft_target(Trainer):
    def compute_loss(self, model, inputs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs[0]
        return CEwST_loss(logits, labels)
    
class DefaultCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        return torch.utils.data.dataloader.default_collate(batch)

In [3]:
MODEL_NAMES = ['bert-base-uncased'] # ['bert-base-uncased', 'xlnet-base-cased']

In [4]:
use_pretrain = False

for t in ['ORIG', 'INV', 'SIB-mix']:
    for MODEL_NAME in MODEL_NAMES:
        
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        
        train_dataset, test_dataset = load_dataset('ag_news', split=['train', 'test'])
        
        if t == 'ORIG':
            checkpoint = 'pretrained/bert-base-uncased-ag_news-ORIG'
            if use_pretrain and os.path.exists(checkpoint):
                MODEL_NAME = checkpoint
        if t == 'INV':
            text = npy_load("./assets/AG_NEWS/topic/INV/text.npy")
            label = npy_load("./assets/AG_NEWS/topic/INV/label.npy")
            df = pd.DataFrame({'text': text, 'label': label})
            df.text = df.text.astype(str)
            df.label = df.label.astype(int)
            train_dataset = Dataset.from_pandas(df)
            checkpoint = 'pretrained/bert-base-uncased-ag_news-INV'
            if use_pretrain and os.path.exists(checkpoint):
                MODEL_NAME = checkpoint
        if t == 'SIB-mix':
            text = npy_load("./assets/AG_NEWS/topic/SIB-mix/text.npy")
            label = npy_load("./assets/AG_NEWS/topic/SIB-mix/label.npy")
            df = pd.DataFrame({'text': text, 'label': label.tolist()})
            df.text = df.text.astype(str)
            df.label = df.label.map(lambda y: np.array(y))
            train_dataset = Dataset.from_pandas(df)    
            checkpoint = 'pretrained/bert-base-uncased-ag_news-SIB-mix'
            if use_pretrain and os.path.exists(checkpoint):
                MODEL_NAME = checkpoint  
               
        train_dataset.shuffle()

        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=4).to(device)

        # # reduce training time
        # n = 10000
        # train_dataset = Dataset.from_dict(train_dataset[:n])
                
        train_dataset = train_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))
        test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))
        train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
        test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
        train_dataset.rename_column_('label', 'labels')
        test_dataset.rename_column_('label', 'labels')

        train_batch_size = 3
        eval_batch_size = 3
        num_epoch = 10
        max_steps = int((len(train_dataset) * num_epoch) / train_batch_size)
        training_args = TrainingArguments(
            output_dir='./pretrained/' + MODEL_NAME + '-ag_news-' + t,
            overwrite_output_dir=True,
            max_steps=max_steps,
            save_steps=int(max_steps / 10),
            save_total_limit=1,
            per_device_train_batch_size=train_batch_size,
            per_device_eval_batch_size=eval_batch_size,
            warmup_steps=int(max_steps / 100),
            weight_decay=0.01,
            logging_dir='./logs',
            logging_steps=int(max_steps / 10),
            load_best_model_at_end=True,
            metric_for_best_model="loss",
            greater_is_better=False,
            evaluation_strategy="steps"
        )
        
        if t == 'SIB-mix':
            trainer = Trainer_w_soft_target(
                model=model,
                args=training_args,
                compute_metrics=compute_metrics_w_soft_target,
                train_dataset=train_dataset,
                eval_dataset=test_dataset,
                data_collator=DefaultCollator(),
                callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
            )
        else: 
            trainer = Trainer(
                model=model,
                args=training_args,
                compute_metrics=compute_metrics,
                train_dataset=train_dataset,
                eval_dataset=test_dataset,
                callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
            )

        trainer.train()
        out = trainer.evaluate()
        print(out)

Using custom data configuration default
Reusing dataset ag_news (C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Loading cached shuffled indices for dataset at C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a\cache-170dae68215320af.arrow
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a Bert

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a\cache-f862c39b76ce2e8a.arrow





  return torch.tensor(x, **format_kwargs)


Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1000,0.687948,0.669146,0.851053,0.850284,0.86771,0.851053
2000,0.570044,0.548023,0.889737,0.89013,0.89629,0.889737
3000,0.53156,0.742537,0.862368,0.861088,0.870154,0.862368
4000,0.411257,0.557874,0.896447,0.895525,0.89808,0.896447
5000,0.339575,0.499875,0.906579,0.906591,0.908043,0.906579
6000,0.298661,0.527324,0.906447,0.906561,0.908845,0.906447
7000,0.24212,0.577542,0.903158,0.902551,0.905102,0.903158
8000,0.186214,0.543915,0.906447,0.906622,0.907528,0.906447
9000,0.144056,0.552779,0.907105,0.907226,0.907668,0.907105
10000,0.138216,0.561721,0.909079,0.909139,0.909998,0.909079


{'eval_loss': 0.4998745322227478, 'eval_accuracy': 0.906578947368421, 'eval_f1': 0.9065911926250698, 'eval_precision': 0.908043109084492, 'eval_recall': 0.906578947368421, 'epoch': 2.9994001199760048}


Using custom data configuration default
Reusing dataset ag_news (C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly iden

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a\cache-f862c39b76ce2e8a.arrow





Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1000,1.040473,0.956195,0.751184,0.75178,0.783072,0.751184
2000,0.933427,0.887286,0.6375,0.585735,0.781799,0.6375
3000,0.911823,1.527582,0.478816,0.403525,0.653639,0.478816
4000,0.884285,0.708453,0.807105,0.805916,0.821787,0.807105
5000,0.838658,0.664825,0.8175,0.816738,0.821093,0.8175
6000,0.801313,0.640657,0.825395,0.825461,0.844251,0.825395
7000,0.796616,0.713386,0.805132,0.803347,0.81758,0.805132
8000,0.717131,0.707552,0.817368,0.814853,0.827083,0.817368
9000,0.721618,0.813004,0.782368,0.781602,0.797863,0.782368
10000,0.716183,0.86974,0.766842,0.764873,0.798292,0.766842


{'eval_loss': 0.640656590461731, 'eval_accuracy': 0.8253947368421053, 'eval_f1': 0.8254610708793054, 'eval_precision': 0.8442514758331714, 'eval_recall': 0.8253947368421053, 'epoch': 2.9994001199760048}


Using custom data configuration default
Reusing dataset ag_news (C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly iden

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a\cache-f862c39b76ce2e8a.arrow





Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1000,0.946886,0.391845,0.000789,0.871247,0.87819,0.870789
2000,0.833896,0.363272,0.000789,0.880112,0.881446,0.880526
3000,0.817627,0.334284,0.000789,0.889834,0.89039,0.890395
4000,0.738126,0.351827,0.000789,0.894295,0.896684,0.893947
5000,0.715665,0.341011,0.000789,0.8985,0.89911,0.898816
6000,0.708556,0.359509,0.000789,0.884847,0.893859,0.885395
7000,0.659699,0.369993,0.000789,0.895905,0.896009,0.896316
8000,0.613615,0.378916,0.000789,0.896459,0.897417,0.896184
9000,0.617523,0.380661,0.000789,0.899301,0.9,0.899079
10000,0.603245,0.387376,0.000789,0.900952,0.901673,0.900921


{'eval_loss': 0.33428433537483215, 'eval_accuracy': 0.0007894737063907087, 'eval_f1': 0.8898335147173879, 'eval_precision': 0.8903898603624751, 'eval_recall': 0.8903947368421053, 'epoch': 2.9994001199760048}


In [5]:
EarlyStoppingCallback(early_stopping_patience=3)

transformers.trainer_callback.EarlyStoppingCallback