In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
from utils import *

from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, EarlyStoppingCallback
from datasets import load_dataset, Dataset
import torch
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

In [3]:
def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True, max_length=250)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=None)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1.mean(),
        'precision': precision.mean(),
        'recall': recall.mean()
    }

def acc_at_k(y_true, y_pred, k=2):
    y_true = torch.tensor(y_true) if type(y_true) != torch.Tensor else y_true
    y_pred = torch.tensor(y_pred) if type(y_pred) != torch.Tensor else y_pred
    total = len(y_true)
    y_weights, y_idx = torch.topk(y_true, k=k, dim=-1)
    out_weights, out_idx = torch.topk(y_pred, k=k, dim=-1)
    correct = torch.sum(torch.eq(y_idx, out_idx) * y_weights)
    acc = correct / total
    return acc.item()

def CEwST_loss(logits, target, reduction='mean'):
    """
    Cross Entropy with Soft Target (CEwST) Loss
    :param logits: (batch, *)
    :param target: (batch, *) same shape as logits, each item must be a valid distribution: target[i, :].sum() == 1.
    """
    logprobs = torch.nn.functional.log_softmax(logits.view(logits.shape[0], -1), dim=1)
    batchloss = - torch.sum(target.view(target.shape[0], -1) * logprobs, dim=1)
    if reduction == 'none':
        return batchloss
    elif reduction == 'mean':
        return torch.mean(batchloss)
    elif reduction == 'sum':
        return torch.sum(batchloss)
    else:
        raise NotImplementedError('Unsupported reduction mode.')

def compute_metrics_w_soft_target(pred):
    labels = pred.label_ids
    preds = pred.predictions
    acc = acc_at_k(labels, preds, k=2)
    return {
        'accuracy': acc,
    }

class Trainer_w_soft_target(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs[0]
        loss = CEwST_loss(logits, labels)
        if return_outputs:
            return loss, outputs
        return loss
    
class DefaultCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        return torch.utils.data.dataloader.default_collate(batch)

In [4]:
MODEL_NAMES = ['bert-base-uncased'] #, 'xlnet-base-cased']

In [5]:
from sklearn.datasets import fetch_20newsgroups

def get_20NG_test_dataset():
    cats = [
        'talk.politics.mideast',                                # Wolrd 0
        'rec.sport.hockey', 'rec.sport.baseball',               # Sports 1
        # 'misc.forsale',                                       # Business 2
        'sci.crypt', 'sci.electronics', 'sci.med', 'sci.space', # Sci/Tech 3
    ]

    dataset = fetch_20newsgroups(
        subset='all',
        categories=cats,
        remove=('headers', 'footers', 'quotes')
    )

    df = pd.DataFrame([dataset.data, dataset.target]).T
    df.rename(columns={0:'text', 1: 'label'}, inplace=True)

    mapper = {
        0: 1,
        1: 1,
        2: 3,
        3: 3,
        4: 3,
        5: 3,
        6: 0,
    }

    df.label = df.label.map(mapper)
    df.text = df.text.replace('\n', ' ', regex=True).str.strip()

    test_dataset = Dataset.from_pandas(df)
    
    return test_dataset

In [6]:
use_pretrain = False
soft_target = False

for t in ['ORIG', 'INV', 'SIB-mix', 'INVSIB']:
    for MODEL_NAME in MODEL_NAMES:
        
        eval_only = False
        
        checkpoint = 'pretrained/' + MODEL_NAME
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)    
        
        if t == 'ORIG':
            train_dataset = load_dataset('ag_news')['train']
            checkpoint += '-ag_news-ORIG'
            if use_pretrain and os.path.exists(checkpoint):
                print('loading {}...'.format(checkpoint))
                MODEL_NAME = checkpoint
                eval_only = True
        if t == 'INV':
            text = npy_load("./assets/AG_NEWS/topic/INV/text.npy")
            label = npy_load("./assets/AG_NEWS/topic/INV/label.npy")
            df = pd.DataFrame({'text': text, 'label': label})
            df.text = df.text.astype(str)
            df.label = df.label.astype(int)
            train_dataset = Dataset.from_pandas(df)
            checkpoint += '-ag_news-INV'
            if use_pretrain and os.path.exists(checkpoint):
                print('loading {}...'.format(checkpoint))
                MODEL_NAME = checkpoint
                eval_only = True
        if t == 'SIB-mix':
            text = npy_load("./assets/AG_NEWS/topic/SIB-mix/text.npy")
            label = npy_load("./assets/AG_NEWS/topic/SIB-mix/label.npy")
            df = pd.DataFrame({'text': text, 'label': label.tolist()})
            df.text = df.text.astype(str)
            df.label = df.label.map(lambda y: np.array(y))
            train_dataset = Dataset.from_pandas(df)    
            checkpoint += '-ag_news-SIB-mix'
            if use_pretrain and os.path.exists(checkpoint):
                print('loading {}...'.format(checkpoint))
                MODEL_NAME = checkpoint  
                eval_only = True
        if t == 'INVSIB':
            text = npy_load("./assets/AG_NEWS/topic/INVSIB/text.npy")
            label = npy_load("./assets/AG_NEWS/topic/INVSIB/label.npy")
            df = pd.DataFrame({'text': text, 'label': label.tolist()})
            df.text = df.text.astype(str)
            df.label = df.label.map(lambda y: np.array(y))
            train_dataset = Dataset.from_pandas(df)    
            checkpoint += '-ag_news-INVSIB'
            if use_pretrain and os.path.exists(checkpoint):
                print('loading {}...'.format(checkpoint))
                MODEL_NAME = checkpoint 
                eval_only = True
                
        dataset_dict = train_dataset.train_test_split(
            test_size = 0.1,
            train_size = 0.9,
            shuffle = True
        )
        train_dataset = dataset_dict['train']
        eval_dataset = dataset_dict['test']
        test_dataset = load_dataset('ag_news')['test']
        test_dataset_20NG = get_20NG_test_dataset()
        
        # # reduce training time
        # n = 100
        # train_dataset = Dataset.from_dict(train_dataset[:n])
        # eval_dataset = Dataset.from_dict(eval_dataset[:n])
        # test_dataset = Dataset.from_dict(test_dataset[:n])

        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=4).to(device)
                
        train_dataset = train_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))
        eval_dataset = eval_dataset.map(tokenize, batched=True, batch_size=len(eval_dataset))
        test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(test_dataset))
        test_dataset_20NG = test_dataset_20NG.map(tokenize, batched=True, batch_size=len(test_dataset_20NG))
        train_dataset.rename_column_('label', 'labels')
        eval_dataset.rename_column_('label', 'labels')
        test_dataset.rename_column_('label', 'labels')
        test_dataset_20NG.rename_column_('label', 'labels')
        train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])
        eval_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])
        test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])
        test_dataset_20NG.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])
        
        if len(np.array(train_dataset['labels']).shape) > 1:
            soft_target = True

        train_batch_size = 3
        eval_batch_size = 32
        num_epoch = 3
        max_steps = int((len(train_dataset) * num_epoch) / train_batch_size)
        
        training_args = TrainingArguments(
            output_dir=checkpoint,
            overwrite_output_dir=True,
            max_steps=max_steps,
            save_steps=int(max_steps / 10),
            save_total_limit=1,
            per_device_train_batch_size=train_batch_size,
            per_device_eval_batch_size=eval_batch_size,
            warmup_steps=int(max_steps / 10),
            weight_decay=0.01,
            logging_dir='./logs',
            logging_steps=int(max_steps / 10),
            load_best_model_at_end=True,
            metric_for_best_model="loss",
            greater_is_better=False,
            evaluation_strategy="steps",
            label_names=['World', 'Sports', 'Business', 'Sci/Tech']
        )
        
        if soft_target:
            trainer = Trainer_w_soft_target(
                model=model,
                args=training_args,
                compute_metrics=compute_metrics_w_soft_target,
                train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                data_collator=DefaultCollator(),
                callbacks=[EarlyStoppingCallback(early_stopping_patience=10)]
            )
        else: 
            trainer = Trainer(
                model=model,
                args=training_args,
                compute_metrics=compute_metrics,
                train_dataset=train_dataset,
                eval_dataset=test_dataset,
                callbacks=[EarlyStoppingCallback(early_stopping_patience=10)]
            )

        if not eval_only:
            trainer.train()
            
        trainer.compute_metrics = compute_metrics
        
        # test with ORIG data
        trainer.eval_dataset = test_dataset
        out_orig = trainer.evaluate()
        print('ORIG \n', out_orig)
        
        # test with 20NG data
        trainer.eval_dataset = test_dataset_20NG
        out_20NG = trainer.evaluate()
        print('20NG \n', out_20NG)

Using custom data configuration default
Reusing dataset ag_news (C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Loading cached split indices for dataset at C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a\cache-d83824437fa96733.arrow and C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a\cache-19c3de4609a5743e.arrow
Using custom data configuration default
Reusing dataset ag_news (C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.p

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




  return torch.tensor(x, **format_kwargs)


Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall,Runtime,Samples Per Second
10800,0.517,0.408373,0.9075,0.907583,0.908603,0.9075,116.6641,65.144
21600,0.6899,1.391227,0.25,0.1,0.0625,0.25,113.7611,66.807
32400,1.405,1.39738,0.25,0.1,0.0625,0.25,114.0158,66.657
43200,1.4037,1.387325,0.25,0.1,0.0625,0.25,114.1798,66.562
54000,1.4006,1.39004,0.25,0.1,0.0625,0.25,114.4282,66.417
64800,1.3993,1.397552,0.25,0.1,0.0625,0.25,115.8551,65.599
75600,1.3978,1.387144,0.25,0.1,0.0625,0.25,115.2058,65.969
86400,1.3968,1.38922,0.25,0.1,0.0625,0.25,115.2679,65.933
97200,1.3947,1.388414,0.25,0.1,0.0625,0.25,115.1866,65.98
108000,1.3932,1.386345,0.25,0.1,0.0625,0.25,114.0903,66.614


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


ORIG 
 {'eval_loss': 0.40837329626083374, 'eval_accuracy': 0.9075, 'eval_f1': 0.9075834882269251, 'eval_precision': 0.9086028296035125, 'eval_recall': 0.9075, 'eval_runtime': 115.9768, 'eval_samples_per_second': 65.53, 'epoch': 3.0}


  _warn_prf(average, modifier, msg_start, len(result))


20NG 
 {'eval_loss': 0.8831258416175842, 'eval_accuracy': 0.8366013071895425, 'eval_f1': 0.5896020623758839, 'eval_precision': 0.614011064224578, 'eval_recall': 0.5733620441756762, 'eval_runtime': 105.1776, 'eval_samples_per_second': 65.461, 'epoch': 3.0}


Using custom data configuration default
Reusing dataset ag_news (C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly iden

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a\cache-4a398149fbe9330d.arrow





HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall,Runtime,Samples Per Second
10800,0.6412,0.528157,0.903684,0.903464,0.905755,0.903684,115.4384,65.836
21600,0.7082,0.52979,0.889474,0.889066,0.89262,0.889474,114.2746,66.506
32400,1.2517,1.388617,0.25,0.1,0.0625,0.25,113.5283,66.944
43200,1.3909,1.386972,0.25,0.1,0.0625,0.25,113.6785,66.855
54000,1.3894,1.388428,0.25,0.1,0.0625,0.25,113.6307,66.883
64800,1.3887,1.389081,0.25,0.1,0.0625,0.25,113.6401,66.878
75600,1.3882,1.387259,0.25,0.1,0.0625,0.25,113.3475,67.05
86400,1.3875,1.386879,0.25,0.1,0.0625,0.25,113.958,66.691
97200,1.3871,1.386346,0.25,0.1,0.0625,0.25,113.2582,67.103
108000,1.3867,1.386322,0.25,0.1,0.0625,0.25,113.888,66.732


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


ORIG 
 {'eval_loss': 0.5281568169593811, 'eval_accuracy': 0.9036842105263158, 'eval_f1': 0.9034637749475052, 'eval_precision': 0.9057550005298666, 'eval_recall': 0.9036842105263158, 'eval_runtime': 115.5963, 'eval_samples_per_second': 65.746, 'epoch': 3.0}


  _warn_prf(average, modifier, msg_start, len(result))


20NG 
 {'eval_loss': 0.6143170595169067, 'eval_accuracy': 0.8313725490196079, 'eval_f1': 0.632582006099234, 'eval_precision': 0.6509007977911079, 'eval_recall': 0.6198595426669283, 'eval_runtime': 105.0463, 'eval_samples_per_second': 65.543, 'epoch': 3.0}


Using custom data configuration default
Reusing dataset ag_news (C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly iden

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a\cache-4a398149fbe9330d.arrow





HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
10800,0.845,0.790434,0.59531,182.5256,65.744
21600,0.8044,0.762829,0.674607,182.5442,65.738
32400,0.8217,0.770077,0.663257,182.4258,65.78
43200,0.83,0.808683,0.665396,182.1084,65.895
54000,0.7524,0.783045,0.656759,182.525,65.744
64800,0.7329,0.746983,0.694861,181.5481,66.098
75600,0.7149,0.763103,0.692742,181.6046,66.078
86400,0.6924,0.744872,0.720389,181.61,66.076
97200,0.6752,0.736482,0.737491,182.0156,65.928
108000,0.6647,0.725972,0.743105,182.3042,65.824


ORIG 
 {'eval_loss': 26.19097328186035, 'eval_accuracy': 0.9248684210526316, 'eval_f1': 0.9247582743943782, 'eval_precision': 0.9246898441895176, 'eval_recall': 0.9248684210526316, 'eval_runtime': 115.0765, 'eval_samples_per_second': 66.043, 'epoch': 3.0}


  _warn_prf(average, modifier, msg_start, len(result))


20NG 
 {'eval_loss': 30.680707931518555, 'eval_accuracy': 0.818881626724764, 'eval_f1': 0.579928560471963, 'eval_precision': 0.6123676733268875, 'eval_recall': 0.5599780863384063, 'eval_runtime': 104.7146, 'eval_samples_per_second': 65.75, 'epoch': 3.0}


Using custom data configuration default
Reusing dataset ag_news (C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly iden

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Loading cached processed dataset at C:\Users\Fabrice\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a\cache-4a398149fbe9330d.arrow





HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
10800,0.7086,0.846066,0.740012,182.0185,65.927
21600,0.6669,0.617437,0.764425,181.5905,66.083
32400,0.6418,0.642007,0.752128,182.1848,65.867
43200,0.59,0.621222,0.774508,183.0753,65.547
54000,0.5571,0.592399,0.785612,182.8125,65.641
64800,0.536,0.614263,0.7882,182.8392,65.631
75600,0.5021,0.608628,0.797116,183.4277,65.421
86400,0.4513,0.571547,0.803165,182.5303,65.743
97200,0.4369,0.544704,0.82289,182.1391,65.884
108000,0.4313,0.538101,0.825614,185.1445,64.814


ORIG 
 {'eval_loss': 32.375675201416016, 'eval_accuracy': 0.9225, 'eval_f1': 0.9224934067850175, 'eval_precision': 0.9226543311502808, 'eval_recall': 0.9225000000000001, 'eval_runtime': 115.7574, 'eval_samples_per_second': 65.655, 'epoch': 3.0}
20NG 
 {'eval_loss': 35.82275390625, 'eval_accuracy': 0.7785039941902687, 'eval_f1': 0.5826307778180877, 'eval_precision': 0.5987674268764859, 'eval_recall': 0.5690547589694692, 'eval_runtime': 106.3557, 'eval_samples_per_second': 64.736, 'epoch': 3.0}


  _warn_prf(average, modifier, msg_start, len(result))
