In [1]:
from utils import *

from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, EarlyStoppingCallback
from datasets import load_dataset, concatenate_datasets, Dataset
import torch
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

In [2]:
def one_hot_encode(y, nb_classes=4):
    if not isinstance(y, np.ndarray):
        y = np.expand_dims(np.array(y), 0)
    res = np.eye(nb_classes)[np.array(y).reshape(-1)]
    return res.reshape(list(y.shape)+[nb_classes])[0]

def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True, max_length=250)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=None)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1.mean(),
        'precision': precision.mean(),
        'recall': recall.mean()
    }

def acc_at_k(y_true, y_pred, k=2):
    y_true = torch.tensor(y_true) if type(y_true) != torch.Tensor else y_true
    y_pred = torch.tensor(y_pred) if type(y_pred) != torch.Tensor else y_pred
    total = len(y_true)
    y_weights, y_idx = torch.topk(y_true, k=k, dim=-1)
    out_weights, out_idx = torch.topk(y_pred, k=k, dim=-1)
    correct = torch.sum(torch.eq(y_idx, out_idx) * y_weights)
    acc = correct / total
    return acc.item()

def CEwST_loss(logits, target, reduction='mean'):
    """
    Cross Entropy with Soft Target (CEwST) Loss
    :param logits: (batch, *)
    :param target: (batch, *) same shape as logits, each item must be a valid distribution: target[i, :].sum() == 1.
    """
    logprobs = torch.nn.functional.log_softmax(logits.view(logits.shape[0], -1), dim=1)
    batchloss = - torch.sum(target.view(target.shape[0], -1) * logprobs, dim=1)
    if reduction == 'none':
        return batchloss
    elif reduction == 'mean':
        return torch.mean(batchloss)
    elif reduction == 'sum':
        return torch.sum(batchloss)
    else:
        raise NotImplementedError('Unsupported reduction mode.')

def compute_metrics_w_soft_target(pred):
    labels = pred.label_ids
    preds = pred.predictions
    acc = acc_at_k(labels, preds, k=2)
    return {
        'accuracy': acc,
    }

class Trainer_w_soft_target(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs[0]
        loss = CEwST_loss(logits, labels)
        if return_outputs:
            return loss, outputs
        return loss
    
class DefaultCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        return torch.utils.data.dataloader.default_collate(batch)

In [3]:
from sklearn.datasets import fetch_20newsgroups

def get_20NG_test_dataset():
    cats = [
        'talk.politics.mideast',                                # Wolrd 0
        'rec.sport.hockey', 'rec.sport.baseball',               # Sports 1
        # 'misc.forsale',                                       # Business 2
        'sci.crypt', 'sci.electronics', 'sci.med', 'sci.space', # Sci/Tech 3
    ]

    dataset = fetch_20newsgroups(
        subset='all',
        categories=cats,
        remove=('headers', 'footers', 'quotes')
    )

    df = pd.DataFrame([dataset.data, dataset.target]).T
    df.rename(columns={0:'text', 1: 'label'}, inplace=True)

    mapper = {
        0: 1,
        1: 1,
        2: 3,
        3: 3,
        4: 3,
        5: 3,
        6: 0,
    }

    df.label = df.label.map(mapper)
    df.text = df.text.replace('\n', ' ', regex=True).str.strip()

    test_dataset = Dataset.from_pandas(df)
    
    return test_dataset

In [4]:
# ['ORIG', 'INV', 'SIB', 'INVSIB', 'TextMix', 'SentMix', 'WordMix']
# ['bert-base-uncased', 'roberta-base', 'xlnet-base-cased']

In [5]:
MODEL_NAMES = ['xlnet-base-cased']

In [6]:
use_pretrain = False

results = []
for MODEL_NAME in MODEL_NAMES:
    for t in ['SIB', 'INVSIB', 'TextMix', 'SentMix', 'WordMix']: 
                        
        soft_target = False
        eval_only = False
        
        checkpoint = 'pretrained/' + MODEL_NAME + "-ag_news-ORIG+" + t 
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        
        if t == 'ORIG':
            train_dataset = load_dataset('ag_news', split='train')
        else:
            # load custom data    
            text = npy_load("./assets/AG_NEWS/" + t + "/text.npy")
            label = npy_load("./assets/AG_NEWS/" + t + "/label.npy")
            if len(label.shape) > 1:
                df = pd.DataFrame({'text': text, 'label': label.tolist()})
                df.text = df.text.astype(str)
                df.label = df.label.map(lambda y: np.array(y))
            else:
                df = pd.DataFrame({'text': text, 'label': label})
                df.text = df.text.astype(str)
                df.label = df.label.astype(object)
            train_dataset = Dataset.from_pandas(df)  
            
            # load orig data
            orig_dataset = load_dataset('ag_news', split='train')
            df = orig_dataset.to_pandas()
            df = df[df.columns[::-1]]
            df.text = df.text.astype(str)
            if len(label.shape) > 1:
                df.label = df.label.map(one_hot_encode)
            else:
                df.label = df.label.astype(object)
            orig_dataset = Dataset.from_pandas(df)
            
            # merge orig + custom data
            train_dataset = concatenate_datasets([orig_dataset, train_dataset])
            train_dataset.shuffle()
            
        if use_pretrain and os.path.exists(checkpoint):
            print('loading {}...'.format(checkpoint))
            MODEL_NAME = checkpoint
            eval_only = True
                
        dataset_dict = train_dataset.train_test_split(
            test_size = 0.05,
            train_size = 0.95,
            shuffle = True
        )
        train_dataset = dataset_dict['train']
        eval_dataset = dataset_dict['test']
        test_dataset = load_dataset('ag_news')['test']
        test_dataset_20NG = get_20NG_test_dataset()
        
        # # reduce training time
        # n = 10000
        # train_dataset = Dataset.from_dict(train_dataset[:n])
        # eval_dataset = Dataset.from_dict(eval_dataset[:n])
        # test_dataset = Dataset.from_dict(test_dataset[:n])

        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=4).to(device)
                
        train_dataset = train_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))
        eval_dataset = eval_dataset.map(tokenize, batched=True, batch_size=len(eval_dataset))
        test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(test_dataset))
        test_dataset_20NG = test_dataset_20NG.map(tokenize, batched=True, batch_size=len(test_dataset_20NG))
        train_dataset.rename_column_('label', 'labels')
        eval_dataset.rename_column_('label', 'labels')
        test_dataset.rename_column_('label', 'labels')
        test_dataset_20NG.rename_column_('label', 'labels')
        train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])
        eval_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])
        test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])
        test_dataset_20NG.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])
        
        if len(np.array(train_dataset['labels']).shape) > 1:
            soft_target = True

        train_batch_size = 3
        eval_batch_size = 32
        num_epoch = 3
        gradient_accumulation_steps=1
        max_steps = int((len(train_dataset) * num_epoch / gradient_accumulation_steps) / train_batch_size)
        
        training_args = TrainingArguments(
            seed=1,
            # adafactor=True,
            output_dir=checkpoint,
            overwrite_output_dir=True,
            max_steps=max_steps,
            save_steps=int(max_steps / 10),
            save_total_limit=1,
            per_device_train_batch_size=train_batch_size,
            per_device_eval_batch_size=eval_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps, 
            warmup_steps=int(max_steps / 10),
            weight_decay=0.01,
            logging_dir='./logs',
            logging_steps=1000,
            logging_first_step=True,
            load_best_model_at_end=True,
            metric_for_best_model="accuracy",
            greater_is_better=True,
            evaluation_strategy="steps",
            run_name=checkpoint,
            label_names=['World', 'Sports', 'Business', 'Sci/Tech']
        )
        
        if soft_target:
            trainer = Trainer_w_soft_target(
                model=model,
                args=training_args,
                compute_metrics=compute_metrics_w_soft_target,
                train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                data_collator=DefaultCollator(),
                callbacks=[EarlyStoppingCallback(early_stopping_patience=10)]
            )
        else: 
            trainer = Trainer(
                model=model,
                args=training_args,
                compute_metrics=compute_metrics,
                train_dataset=train_dataset,
                eval_dataset=test_dataset,
                callbacks=[EarlyStoppingCallback(early_stopping_patience=10)]
            )

        if not eval_only:
            trainer.train()
            
        trainer.compute_metrics = compute_metrics
        
        # test with ORIG data
        trainer.eval_dataset = test_dataset
        out_orig = trainer.evaluate()
        out_orig['run'] = checkpoint
        out_orig['test'] = "ORIG"
        print('ORIG for {}\n{}'.format(checkpoint, out_orig))
        
        # test with 20NG data
        trainer.eval_dataset = test_dataset_20NG
        out_20NG = trainer.evaluate()
        out_20NG['run'] = checkpoint
        out_orig['test'] = "20NG"
        print('20NG for {}\n{}'.format(checkpoint, out_20NG))
        
        results.append(out_orig)
        results.append(out_20NG)
        
        # run.finish()

Using custom data configuration default
Reusing dataset ag_news (C:\Users\sleev\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Using custom data configuration default
Reusing dataset ag_news (C:\Users\sleev\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model fr

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Loading cached processed dataset at C:\Users\sleev\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a\cache-6943bef4b721d7aa.arrow





HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

rename_column_ is deprecated and will be removed in the next major version of datasets. Use the dataset.rename_column method instead.





W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.1807,0.711139,0.721985,375.0217,31.998
2000,0.6802,0.625677,0.777106,374.9872,32.001
3000,0.6283,0.658409,0.767222,374.8969,32.009
4000,0.5924,0.581758,0.780381,374.905,32.008
5000,0.5978,0.584971,0.777788,375.0688,31.994
6000,0.5885,0.57143,0.763454,374.964,32.003
7000,0.5974,0.596963,0.770752,374.697,32.026
8000,0.5598,0.627166,0.76546,374.7233,32.024
9000,0.5912,0.64238,0.767717,374.7238,32.024
10000,0.5812,0.653354,0.78219,374.7313,32.023


ORIG for pretrained/xlnet-base-cased-ag_news-ORIG+SIB
{'eval_loss': 32.779727935791016, 'eval_accuracy': 0.9113157894736842, 'eval_f1': 0.9102133583293787, 'eval_precision': 0.9143852896173227, 'eval_recall': 0.9113157894736842, 'eval_runtime': 221.629, 'eval_samples_per_second': 34.292, 'epoch': 0.26, 'run': 'pretrained/xlnet-base-cased-ag_news-ORIG+SIB', 'test': 'ORIG'}


  _warn_prf(average, modifier, msg_start, len(result))


20NG for pretrained/xlnet-base-cased-ag_news-ORIG+SIB
{'eval_loss': 39.73031997680664, 'eval_accuracy': 0.8778503994190269, 'eval_f1': 0.6392209736635851, 'eval_precision': 0.6340855840685231, 'eval_recall': 0.6524435347938217, 'eval_runtime': 200.8539, 'eval_samples_per_second': 34.279, 'epoch': 0.26, 'run': 'pretrained/xlnet-base-cased-ag_news-ORIG+SIB'}


Using custom data configuration default
Reusing dataset ag_news (C:\Users\sleev\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Using custom data configuration default
Reusing dataset ag_news (C:\Users\sleev\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model fr

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Loading cached processed dataset at C:\Users\sleev\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a\cache-6943bef4b721d7aa.arrow





HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.1282,0.56722,0.79339,373.4146,32.136
2000,0.5436,0.514997,0.830276,373.465,32.132
3000,0.5935,0.495248,0.834003,373.5338,32.126
4000,0.5678,0.652225,0.828357,373.3721,32.14
5000,0.5678,0.533003,0.840727,373.4276,32.135
6000,0.5371,0.550586,0.824994,373.369,32.14
7000,0.5554,0.53412,0.825187,373.3881,32.138
8000,0.5195,0.529481,0.83936,373.3034,32.145
9000,0.5172,0.570029,0.83959,373.2869,32.147
10000,0.5345,0.550825,0.838199,373.4297,32.135


ORIG for pretrained/xlnet-base-cased-ag_news-ORIG+INVSIB
{'eval_loss': 30.559518814086914, 'eval_accuracy': 0.9113157894736842, 'eval_f1': 0.9112269335861601, 'eval_precision': 0.9114737761531938, 'eval_recall': 0.9113157894736843, 'eval_runtime': 221.1852, 'eval_samples_per_second': 34.36, 'epoch': 0.33, 'run': 'pretrained/xlnet-base-cased-ag_news-ORIG+INVSIB', 'test': 'ORIG'}


  _warn_prf(average, modifier, msg_start, len(result))


20NG for pretrained/xlnet-base-cased-ag_news-ORIG+INVSIB
{'eval_loss': 36.06204605102539, 'eval_accuracy': 0.8171387073347858, 'eval_f1': 0.5921682404189916, 'eval_precision': 0.5995609616059674, 'eval_recall': 0.596491127130416, 'eval_runtime': 200.4663, 'eval_samples_per_second': 34.345, 'epoch': 0.33, 'run': 'pretrained/xlnet-base-cased-ag_news-ORIG+INVSIB'}


Using custom data configuration default
Reusing dataset ag_news (C:\Users\sleev\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Using custom data configuration default
Reusing dataset ag_news (C:\Users\sleev\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model fr

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Loading cached processed dataset at C:\Users\sleev\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a\cache-6943bef4b721d7aa.arrow





HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.1227,0.68318,0.729552,374.9884,32.001
2000,0.6449,0.581783,0.7821,374.9414,32.005
3000,0.61,0.626079,0.749903,374.8992,32.009
4000,0.5972,0.642624,0.727962,374.7294,32.023
5000,0.5951,0.556868,0.750251,374.8372,32.014
6000,0.5646,0.603674,0.799863,374.9371,32.005
7000,0.5996,0.57066,0.749336,374.9,32.009
8000,0.5724,0.60359,0.783597,374.8962,32.009
9000,0.5711,0.596051,0.772786,374.8143,32.016
10000,0.5366,0.667459,0.76723,374.7332,32.023


ORIG for pretrained/xlnet-base-cased-ag_news-ORIG+TextMix
{'eval_loss': 30.514461517333984, 'eval_accuracy': 0.9190789473684211, 'eval_f1': 0.9191902430325183, 'eval_precision': 0.9205755163746754, 'eval_recall': 0.9190789473684211, 'eval_runtime': 221.8023, 'eval_samples_per_second': 34.265, 'epoch': 0.28, 'run': 'pretrained/xlnet-base-cased-ag_news-ORIG+TextMix', 'test': 'ORIG'}


  _warn_prf(average, modifier, msg_start, len(result))


20NG for pretrained/xlnet-base-cased-ag_news-ORIG+TextMix
{'eval_loss': 35.80183410644531, 'eval_accuracy': 0.8262890341321714, 'eval_f1': 0.6344004919505909, 'eval_precision': 0.6920195362546452, 'eval_recall': 0.5888897184378247, 'eval_runtime': 201.0678, 'eval_samples_per_second': 34.242, 'epoch': 0.28, 'run': 'pretrained/xlnet-base-cased-ag_news-ORIG+TextMix'}


Using custom data configuration default
Reusing dataset ag_news (C:\Users\sleev\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Using custom data configuration default
Reusing dataset ag_news (C:\Users\sleev\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model fr

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Loading cached processed dataset at C:\Users\sleev\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a\cache-6943bef4b721d7aa.arrow





HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.0359,0.706221,0.56293,560.4001,32.12
2000,0.6205,0.539962,0.622516,560.3599,32.122
3000,0.5611,0.572588,0.617082,560.2891,32.126
4000,0.5151,0.498127,0.630502,560.3408,32.123
5000,0.5187,0.495807,0.662226,560.2858,32.126
6000,0.5047,0.478219,0.638549,560.3405,32.123
7000,0.5058,0.469551,0.63219,560.261,32.128
8000,0.4992,0.445969,0.652073,560.3443,32.123
9000,0.4906,0.4793,0.665047,560.306,32.125
10000,0.4654,0.456679,0.648731,560.2668,32.128


ORIG for pretrained/xlnet-base-cased-ag_news-ORIG+SentMix
{'eval_loss': 29.355998992919922, 'eval_accuracy': 0.9214473684210527, 'eval_f1': 0.921278984511142, 'eval_precision': 0.921881965752843, 'eval_recall': 0.9214473684210527, 'eval_runtime': 221.7186, 'eval_samples_per_second': 34.278, 'epoch': 0.25, 'run': 'pretrained/xlnet-base-cased-ag_news-ORIG+SentMix', 'test': 'ORIG'}


  _warn_prf(average, modifier, msg_start, len(result))


20NG for pretrained/xlnet-base-cased-ag_news-ORIG+SentMix
{'eval_loss': 35.578189849853516, 'eval_accuracy': 0.820479302832244, 'eval_f1': 0.5908476252377701, 'eval_precision': 0.5861215760026348, 'eval_recall': 0.605713668969397, 'eval_runtime': 200.9597, 'eval_samples_per_second': 34.261, 'epoch': 0.25, 'run': 'pretrained/xlnet-base-cased-ag_news-ORIG+SentMix'}


Using custom data configuration default
Reusing dataset ag_news (C:\Users\sleev\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Using custom data configuration default
Reusing dataset ag_news (C:\Users\sleev\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a)
Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model fr

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Loading cached processed dataset at C:\Users\sleev\.cache\huggingface\datasets\ag_news\default\0.0.0\fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a\cache-6943bef4b721d7aa.arrow


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,0.9165,0.720237,0.426913,751.0076,31.957
2000,0.5782,0.479328,0.511681,750.4582,31.98
3000,0.4706,0.467978,0.516797,750.7578,31.968
4000,0.4475,0.442142,0.526822,750.4897,31.979
5000,0.4259,0.425813,0.52568,750.6241,31.973
6000,0.4335,0.422985,0.527042,751.1154,31.952
7000,0.4245,0.43548,0.520216,750.9167,31.961
8000,0.4463,0.429101,0.525395,750.4673,31.98
9000,0.421,0.432605,0.516039,750.7799,31.967
10000,0.4026,0.429545,0.535664,750.9212,31.961


ORIG for pretrained/xlnet-base-cased-ag_news-ORIG+WordMix
{'eval_loss': 34.021446228027344, 'eval_accuracy': 0.9228947368421052, 'eval_f1': 0.9230065178085569, 'eval_precision': 0.9235508008785167, 'eval_recall': 0.9228947368421052, 'eval_runtime': 222.0923, 'eval_samples_per_second': 34.22, 'epoch': 0.18, 'run': 'pretrained/xlnet-base-cased-ag_news-ORIG+WordMix', 'test': 'ORIG'}
20NG for pretrained/xlnet-base-cased-ag_news-ORIG+WordMix
{'eval_loss': 38.64665222167969, 'eval_accuracy': 0.8171387073347858, 'eval_f1': 0.6137541683533673, 'eval_precision': 0.6271330516988347, 'eval_recall': 0.6219071690781202, 'eval_runtime': 201.3232, 'eval_samples_per_second': 34.199, 'epoch': 0.18, 'run': 'pretrained/xlnet-base-cased-ag_news-ORIG+WordMix'}


  _warn_prf(average, modifier, msg_start, len(result))


In [7]:
df = pd.DataFrame(results)
df

Unnamed: 0,eval_loss,eval_accuracy,eval_f1,eval_precision,eval_recall,eval_runtime,eval_samples_per_second,epoch,run,test
0,32.779728,0.911316,0.910213,0.914385,0.911316,221.629,34.292,0.26,pretrained/xlnet-base-cased-ag_news-ORIG+SIB,20NG
1,39.73032,0.87785,0.639221,0.634086,0.652444,200.8539,34.279,0.26,pretrained/xlnet-base-cased-ag_news-ORIG+SIB,
2,30.559519,0.911316,0.911227,0.911474,0.911316,221.1852,34.36,0.33,pretrained/xlnet-base-cased-ag_news-ORIG+INVSIB,20NG
3,36.062046,0.817139,0.592168,0.599561,0.596491,200.4663,34.345,0.33,pretrained/xlnet-base-cased-ag_news-ORIG+INVSIB,
4,30.514462,0.919079,0.91919,0.920576,0.919079,221.8023,34.265,0.28,pretrained/xlnet-base-cased-ag_news-ORIG+TextMix,20NG
5,35.801834,0.826289,0.6344,0.69202,0.58889,201.0678,34.242,0.28,pretrained/xlnet-base-cased-ag_news-ORIG+TextMix,
6,29.355999,0.921447,0.921279,0.921882,0.921447,221.7186,34.278,0.25,pretrained/xlnet-base-cased-ag_news-ORIG+SentMix,20NG
7,35.57819,0.820479,0.590848,0.586122,0.605714,200.9597,34.261,0.25,pretrained/xlnet-base-cased-ag_news-ORIG+SentMix,
8,34.021446,0.922895,0.923007,0.923551,0.922895,222.0923,34.22,0.18,pretrained/xlnet-base-cased-ag_news-ORIG+WordMix,20NG
9,38.646652,0.817139,0.613754,0.627133,0.621907,201.3232,34.199,0.18,pretrained/xlnet-base-cased-ag_news-ORIG+WordMix,


In [8]:
df.to_csv('train_AG_NEWS_r2.csv')

In [9]:
df.to_clipboard(excel=True)