# Targeted SIB Training

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
from transformers import (
    AutoModelForSequenceClassification, 
    AutoTokenizer, 
    Trainer, 
    TrainingArguments, 
    TrainerCallback, 
    EarlyStoppingCallback
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers.trainer_callback import TrainerControl
from datasets import load_dataset
import torch
import pandas as pd
from torch.utils.data import DataLoader
from transforms import TextMix, SentMix, WordMix

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
def tokenize_fn(text):
    return tokenizer(text, padding=True, truncation=True, max_length=200, return_tensors='pt')

def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True, max_length=200)

def acc_at_k(y_true, y_pred, k=2):
    y_true = torch.tensor(y_true) if type(y_true) != torch.Tensor else y_true
    y_pred = torch.tensor(y_pred) if type(y_pred) != torch.Tensor else y_pred
    total = len(y_true)
    y_weights, y_idx = torch.topk(y_true, k=k, dim=-1)
    out_weights, out_idx = torch.topk(y_pred, k=k, dim=-1)
    correct = torch.sum(torch.eq(y_idx, out_idx) * y_weights)
    acc = correct / total
    return acc.item()

def CEwST_loss(logits, target, reduction='mean'):
    """
    Cross Entropy with Soft Target (CEwST) Loss
    :param logits: (batch, *)
    :param target: (batch, *) same shape as logits, each item must be a valid distribution: target[i, :].sum() == 1.
    """
    logprobs = torch.nn.functional.log_softmax(logits.view(logits.shape[0], -1), dim=1)
    batchloss = - torch.sum(target.view(target.shape[0], -1) * logprobs, dim=1)
    if reduction == 'none':
        return batchloss
    elif reduction == 'mean':
        return torch.mean(batchloss)
    elif reduction == 'sum':
        return torch.sum(batchloss)
    else:
        raise NotImplementedError('Unsupported reduction mode.')

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=None)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1.mean(),
        'precision': precision.mean(),
        'recall': recall.mean()
    }        
        
def compute_metrics_w_soft_target(pred):
    labels = pred.label_ids
    preds = pred.predictions
    acc = acc_at_k(labels, preds, k=2)
    return {
        'accuracy': acc,
    }

class TargetedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs[0]
        loss = CEwST_loss(logits, labels)
        if return_outputs:
            return loss, outputs
        return loss

class TargetedMixturesCallback(TrainerCallback):
    """
    A callback that calculates a confusion matrix on the validation
    data and returns the most confused class pairings.
    """
    def __init__(self, dataloader, device):
        self.dataloader = dataloader
        self.device = device
        
    def on_evaluate(self, args, state, control, model, tokenizer, **kwargs):
        cnf_mat = self.get_confusion_matrix(model, tokenizer, self.dataloader)
        new_targets = self.get_most_confused_per_class(cnf_mat)
        print("New targets:", new_targets)
        control = TrainerControl
        control.new_targets = new_targets
        if state.global_step < state.max_steps:
            control.should_training_stop = False
        else:
            control.should_training_stop = True
        return control
        
    def get_confusion_matrix(self, model, tokenizer, dataloader, normalize=True):
        n_classes = max(dataloader.dataset['label']) + 1
        confusion_matrix = torch.zeros(n_classes, n_classes)
        with torch.no_grad():
            for batch in iter(self.dataloader):
                data, targets = batch['text'], batch['label']
                data = tokenizer(data, padding=True, truncation=True, max_length=250, return_tensors='pt')
                input_ids = data['input_ids'].to(self.device)
                attention_mask = data['attention_mask'].to(self.device)
                targets = targets.to(self.device)
                outputs = model(input_ids, attention_mask=attention_mask).logits
                preds = torch.argmax(outputs, dim=1).cpu()
                for t, p in zip(targets.view(-1), preds.view(-1)):
                    confusion_matrix[t.long(), p.long()] += 1    
            if normalize:
                confusion_matrix = confusion_matrix / confusion_matrix.sum(dim=0)
        return confusion_matrix

    def get_most_confused_per_class(self, confusion_matrix):
        idx = torch.arange(len(confusion_matrix))
        cnf = confusion_matrix.fill_diagonal_(0).max(dim=1)[1]
        return torch.stack((idx, cnf)).T.tolist()

class TargetedMixturesCollator:
    def __init__(self, tokenize_fn, transform, target_pairs=[], target_prob=1.0, num_classes=4):
        self.tokenize_fn = tokenize_fn
        self.transform = transform
        self.target_pairs = target_pairs
        self.target_prob = target_prob
        self.num_classes = num_classes
        print("TargetedMixturesCollator initialized with {}".format(transform.__class__.__name__))
        
    def __call__(self, batch):
        text = [x['text'] for x in batch]
        labels = [x['label'] for x in batch]
        batch = (text, labels)
        batch = self.transform(
            batch, 
            self.target_pairs,   
            self.target_prob,
            self.num_classes
        )
        text, labels = batch
        batch = self.tokenize_fn(text)
        batch['labels'] = torch.tensor(labels)
        return batch
    
class DefaultCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        return torch.utils.data.dataloader.default_collate(batch)

In [4]:
MODEL_NAMES = ['bert-base-uncased', 'roberta-base', 'xlnet-base-cased']
ts = [TextMix(), SentMix(), WordMix()]

In [5]:
results = []

for MODEL_NAME in MODEL_NAMES:
        
    for t in ts: 
        
        t_str = t.__class__.__name__
        checkpoint = './results/' + MODEL_NAME + '-targeted-' + t_str
        
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=4).to(device)

        dataset = load_dataset('ag_news', split='train') 
        dataset_dict = dataset.train_test_split(
            test_size = 0.05,
            train_size = 0.95,
            shuffle = True
        )
        train_dataset = dataset_dict['train']
        eval_dataset = dataset_dict['test']

        test_dataset = load_dataset('ag_news', split='test') 
        test_dataset.rename_column_('label', 'labels')
        test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(test_dataset))
        test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])
        
        train_batch_size = 10
        eval_batch_size  = 32
        num_epoch = 10
        gradient_accumulation_steps = 1
        max_steps = int((len(train_dataset) * num_epoch / gradient_accumulation_steps) / train_batch_size)

        tmcb = TargetedMixturesCallback(
            dataloader=DataLoader(eval_dataset, batch_size=32),
            device=device
        )
        escb = EarlyStoppingCallback(
            early_stopping_patience=10
        )
        tmc = TargetedMixturesCollator(
            tokenize_fn=tokenize_fn, 
            transform=t,
            target_prob=0.5
        )

        training_args = TrainingArguments(\
            output_dir=checkpoint,
            overwrite_output_dir=True,
            max_steps=max_steps,
            save_steps=int(max_steps / 10),
            save_total_limit=1,
            per_device_train_batch_size=train_batch_size,
            per_device_eval_batch_size=eval_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps, 
            warmup_steps=int(max_steps / 10),
            weight_decay=0.01,
            logging_dir='./logs',
            logging_steps=1000,
            logging_first_step=True,
            load_best_model_at_end=True,
            metric_for_best_model="accuracy",
            greater_is_better=True,
            evaluation_strategy="steps",
            remove_unused_columns=False
        )

        trainer = TargetedTrainer(
            model=model, 
            tokenizer=tokenizer,
            args=training_args,
            compute_metrics=compute_metrics_w_soft_target,                  
            train_dataset=train_dataset,         
            eval_dataset=eval_dataset,
            data_collator=tmc,
            callbacks=[tmcb, escb]
        )

        trainer.train()

        # test with ORIG data
        trainer.eval_dataset = test_dataset
        trainer.compute_metrics = compute_metrics
        trainer.data_collator = DefaultCollator()
        trainer.remove_callback(tmcb)

        out_orig = trainer.evaluate()
        out_orig['run'] = checkpoint
        out_orig['test'] = "ORIG"
        print('ORIG for {}\n{}'.format(checkpoint, out_orig))

        results.append(out_orig)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


TargetedMixturesCollator initialized with TextMix


W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.123,0.818767,0.672051,41.7911,143.571
2000,0.8106,0.782954,0.66707,41.7731,143.633
3000,0.7665,0.783479,0.642045,41.9354,143.077
4000,0.7582,0.770588,0.658484,41.6916,143.914
5000,0.7659,0.763694,0.683334,41.7189,143.82
6000,0.7462,0.76827,0.73812,41.6282,144.133
7000,0.7639,0.735198,0.705449,41.824,143.458
8000,0.757,0.747112,0.68716,41.8891,143.236
9000,0.7499,0.75671,0.742128,41.748,143.72
10000,0.7412,0.789278,0.759883,41.5881,144.272


New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 

ORIG for ./results/bert-base-uncased-targeted-TextMix
{'eval_loss': 24.480539321899414, 'eval_accuracy': 0.9343421052631579, 'eval_f1': 0.9342633670695392, 'eval_precision': 0.9342082629365714, 'eval_recall': 0.934342105263158, 'eval_runtime': 61.8223, 'eval_samples_per_second': 122.933, 'epoch': 2.81, 'run': './results/bert-base-uncased-targeted-TextMix', 'test': 'ORIG'}


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.



TargetedMixturesCollator initialized with SentMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.1418,0.819791,0.674739,41.7843,143.595
2000,0.8194,0.804909,0.634521,41.1726,145.728
3000,0.7722,0.771695,0.671235,41.3916,144.957
4000,0.7646,0.769283,0.698796,41.136,145.858
5000,0.7523,0.806216,0.6578,41.1704,145.736
6000,0.7562,0.798659,0.720266,41.5179,144.516
7000,0.7627,0.782367,0.677025,40.9621,146.477
8000,0.7566,0.779678,0.682811,41.4263,144.836
9000,0.7717,0.775168,0.681278,41.4624,144.709
10000,0.753,0.767761,0.690635,41.5904,144.264


New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 0], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 0], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 

ORIG for ./results/bert-base-uncased-targeted-SentMix
{'eval_loss': 26.8177547454834, 'eval_accuracy': 0.9328947368421052, 'eval_f1': 0.9329163788471664, 'eval_precision': 0.9337170333495173, 'eval_recall': 0.9328947368421052, 'eval_runtime': 61.8966, 'eval_samples_per_second': 122.785, 'epoch': 5.26, 'run': './results/bert-base-uncased-targeted-SentMix', 'test': 'ORIG'}


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.



TargetedMixturesCollator initialized with WordMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.1034,0.864784,0.640724,41.0237,146.257
2000,0.8587,0.840726,0.605792,41.2965,145.291
3000,0.8302,0.817732,0.564329,40.7736,147.154
4000,0.7984,0.800068,0.624055,41.1956,145.647
5000,0.8016,0.819686,0.550635,41.339,145.141
6000,0.7975,0.8165,0.585087,41.2604,145.418
7000,0.7947,0.826997,0.556226,41.3579,145.075
8000,0.7808,0.819812,0.613131,41.1618,145.766
9000,0.7837,0.835647,0.54845,41.1876,145.675
10000,0.8023,0.830718,0.530839,41.1669,145.748


New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 0], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 1], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 

ORIG for ./results/bert-base-uncased-targeted-WordMix
{'eval_loss': 20.79989242553711, 'eval_accuracy': 0.91, 'eval_f1': 0.9096879772493497, 'eval_precision': 0.9113614704353348, 'eval_recall': 0.9100000000000001, 'eval_runtime': 61.8091, 'eval_samples_per_second': 122.959, 'epoch': 2.11, 'run': './results/bert-base-uncased-targeted-WordMix', 'test': 'ORIG'}


Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifie

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.



TargetedMixturesCollator initialized with TextMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.0408,0.755991,0.689623,41.1334,145.867
2000,0.7815,0.783156,0.679361,40.8634,146.831
3000,0.7655,0.784874,0.705539,40.8656,146.823
4000,0.7631,0.848992,0.698786,40.804,147.044
5000,0.7625,0.805542,0.689215,41.0932,146.01
6000,0.7711,0.787268,0.678828,40.8594,146.845
7000,0.769,0.755123,0.724273,40.8671,146.817
8000,0.7541,0.791531,0.71621,40.5646,147.912
9000,0.7561,0.76526,0.686166,40.8262,146.964
10000,0.7593,0.797582,0.618166,40.9123,146.655


New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 0]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 

ORIG for ./results/roberta-base-targeted-TextMix
{'eval_loss': 23.608760833740234, 'eval_accuracy': 0.93, 'eval_f1': 0.9302832264201846, 'eval_precision': 0.9312619021637497, 'eval_recall': 0.9299999999999999, 'eval_runtime': 60.8437, 'eval_samples_per_second': 124.91, 'epoch': 2.28, 'run': './results/roberta-base-targeted-TextMix', 'test': 'ORIG'}


Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifie

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.



TargetedMixturesCollator initialized with SentMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.059,0.743698,0.693887,41.0281,146.241
2000,0.7812,0.79704,0.69794,41.0994,145.988
3000,0.764,0.756848,0.678766,41.4003,144.927
4000,0.7847,0.768974,0.666649,40.8586,146.848
5000,0.7896,0.759688,0.711485,40.9844,146.397
6000,0.7602,0.75266,0.738283,41.3042,145.264
7000,0.7547,0.73892,0.734301,41.2597,145.42
8000,0.7515,0.773716,0.745563,41.0924,146.012
9000,0.7582,0.763181,0.714846,41.1865,145.679
10000,0.755,0.744711,0.678383,41.265,145.402


New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 0]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 

ORIG for ./results/roberta-base-targeted-SentMix
{'eval_loss': 23.41080665588379, 'eval_accuracy': 0.9351315789473684, 'eval_f1': 0.9352812099451323, 'eval_precision': 0.9360632134689295, 'eval_recall': 0.9351315789473684, 'eval_runtime': 60.5309, 'eval_samples_per_second': 125.556, 'epoch': 3.16, 'run': './results/roberta-base-targeted-SentMix', 'test': 'ORIG'}


Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifie

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.



TargetedMixturesCollator initialized with WordMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.0706,0.829852,0.623684,39.9854,150.055
2000,0.8491,0.828516,0.582391,40.4167,148.453
3000,0.8067,0.828665,0.590196,40.0901,149.663
4000,0.8114,0.824903,0.547974,40.5425,147.993
5000,0.8172,0.806744,0.591967,40.4965,148.161
6000,0.8158,0.805406,0.595121,40.5499,147.966
7000,0.8058,0.781646,0.614757,40.5468,147.977
8000,0.794,0.837159,0.586261,40.4141,148.463
9000,0.82,0.835895,0.570278,40.4717,148.252
10000,0.8001,0.789346,0.556848,40.3408,148.733


New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 0]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 0]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 0]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 0], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 0], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 

ORIG for ./results/roberta-base-targeted-WordMix
{'eval_loss': 21.083614349365234, 'eval_accuracy': 0.9173684210526316, 'eval_f1': 0.9174963589172196, 'eval_precision': 0.9179184898911081, 'eval_recall': 0.9173684210526316, 'eval_runtime': 61.4087, 'eval_samples_per_second': 123.761, 'epoch': 3.51, 'run': './results/roberta-base-targeted-WordMix', 'test': 'ORIG'}


Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence_summary.summary.weight', 'sequence_summary.summary.bias', 'logits_proj.weight', 'logits_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.



TargetedMixturesCollator initialized with TextMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.06,0.799566,0.679052,94.0247,63.813
2000,0.8224,0.787916,0.666406,93.3537,64.272
3000,0.7788,0.826211,0.632343,92.8339,64.632
4000,0.783,0.773624,0.645425,93.3615,64.266
5000,0.76,0.799342,0.633472,94.5114,63.484
6000,0.7628,0.783223,0.636232,94.0316,63.808
7000,0.7441,0.791242,0.666978,93.0672,64.47
8000,0.7445,0.804915,0.650863,92.6554,64.756
9000,0.756,0.768611,0.759018,93.4813,64.184
10000,0.7578,0.777972,0.623639,92.7553,64.686


New targets: [[0, 2], [1, 0], [2, 3], [3, 0]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 

ORIG for ./results/xlnet-base-cased-targeted-TextMix
{'eval_loss': 29.766590118408203, 'eval_accuracy': 0.9384210526315789, 'eval_f1': 0.9382906761700585, 'eval_precision': 0.9386953278978558, 'eval_recall': 0.9384210526315789, 'eval_runtime': 138.3363, 'eval_samples_per_second': 54.939, 'epoch': 8.07, 'run': './results/xlnet-base-cased-targeted-TextMix', 'test': 'ORIG'}


Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence_summary.summary.weight', 'sequence_summary.summary.bias', 'logits_proj.weight', 'logits_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.



TargetedMixturesCollator initialized with SentMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.06,0.814202,0.658599,94.9834,63.169
2000,0.8243,0.783878,0.653501,93.9369,63.873
3000,0.7774,0.789554,0.652811,94.666,63.381
4000,0.7628,0.775472,0.649541,95.1323,63.07
5000,0.7555,0.785967,0.690353,94.6922,63.363
6000,0.7595,0.774028,0.720262,94.132,63.74
7000,0.7607,0.749642,0.709219,93.5248,64.154
8000,0.761,0.791381,0.697819,94.5297,63.472
9000,0.759,0.752182,0.659591,93.8044,63.963
10000,0.7636,0.771678,0.653735,94.1296,63.742


New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 

ORIG for ./results/xlnet-base-cased-targeted-SentMix
{'eval_loss': 22.84684181213379, 'eval_accuracy': 0.9336842105263158, 'eval_f1': 0.9337134499138177, 'eval_precision': 0.9341871089752022, 'eval_recall': 0.9336842105263158, 'eval_runtime': 138.3086, 'eval_samples_per_second': 54.95, 'epoch': 2.46, 'run': './results/xlnet-base-cased-targeted-SentMix', 'test': 'ORIG'}


Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence_summary.summary.weight', 'sequence_summary.summary.bias', 'logits_proj.weight', 'logits_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.



TargetedMixturesCollator initialized with WordMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.0758,0.865502,0.628614,91.0053,65.93
2000,0.8585,0.839501,0.589337,91.0024,65.932
3000,0.8224,0.848991,0.554825,90.7884,66.088
4000,0.811,0.815506,0.597819,91.1643,65.815
5000,0.8,0.84352,0.603635,92.3822,64.948
6000,0.7875,0.80574,0.604512,91.6362,65.476
7000,0.7951,0.798661,0.56932,92.5207,64.85
8000,0.788,0.83862,0.586105,92.1141,65.137
9000,0.8094,0.804375,0.59434,91.7728,65.379
10000,0.8066,0.81299,0.592974,91.5195,65.56


New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 0], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 0], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 0]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 0]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 0], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 0]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 

ORIG for ./results/xlnet-base-cased-targeted-WordMix
{'eval_loss': 19.878807067871094, 'eval_accuracy': 0.9064473684210527, 'eval_f1': 0.9063317617010482, 'eval_precision': 0.9085388074682332, 'eval_recall': 0.9064473684210527, 'eval_runtime': 138.3666, 'eval_samples_per_second': 54.927, 'epoch': 2.28, 'run': './results/xlnet-base-cased-targeted-WordMix', 'test': 'ORIG'}


In [6]:
df = pd.DataFrame(results)
df                                                

Unnamed: 0,eval_loss,eval_accuracy,eval_f1,eval_precision,eval_recall,eval_runtime,eval_samples_per_second,epoch,run,test
0,24.480539,0.934342,0.934263,0.934208,0.934342,61.8223,122.933,2.81,./results/bert-base-uncased-targeted-TextMix,ORIG
1,26.817755,0.932895,0.932916,0.933717,0.932895,61.8966,122.785,5.26,./results/bert-base-uncased-targeted-SentMix,ORIG
2,20.799892,0.91,0.909688,0.911361,0.91,61.8091,122.959,2.11,./results/bert-base-uncased-targeted-WordMix,ORIG
3,23.608761,0.93,0.930283,0.931262,0.93,60.8437,124.91,2.28,./results/roberta-base-targeted-TextMix,ORIG
4,23.410807,0.935132,0.935281,0.936063,0.935132,60.5309,125.556,3.16,./results/roberta-base-targeted-SentMix,ORIG
5,21.083614,0.917368,0.917496,0.917918,0.917368,61.4087,123.761,3.51,./results/roberta-base-targeted-WordMix,ORIG
6,29.76659,0.938421,0.938291,0.938695,0.938421,138.3363,54.939,8.07,./results/xlnet-base-cased-targeted-TextMix,ORIG
7,22.846842,0.933684,0.933713,0.934187,0.933684,138.3086,54.95,2.46,./results/xlnet-base-cased-targeted-SentMix,ORIG
8,19.878807,0.906447,0.906332,0.908539,0.906447,138.3666,54.927,2.28,./results/xlnet-base-cased-targeted-WordMix,ORIG


In [7]:
df.to_csv('train_AG_NEWS_targeted_r1.csv')

In [8]:
df.to_clipboard(excel=True)

In [9]:
# ORIG for ./results/bert-base-uncased-targeted-TextMix
# {'eval_loss': 31.2364559173584, 'eval_accuracy': 0.9381578947368421, 'eval_f1': 0.9381945850526017, 'eval_precision': 0.938240633851668, 'eval_recall': 0.9381578947368421, 'eval_runtime': 117.2622, 'eval_samples_per_second': 64.812, 'epoch': 5.0, 'run': 'TextMix', 'test': 'ORIG'}