# Targeted SIB Training

In [1]:
%reload_ext autoreload
%autoreload 2

In [15]:
from transformers import (
    AutoModelForSequenceClassification, 
    AutoTokenizer, 
    Trainer, 
    TrainingArguments, 
    TrainerCallback, 
    EarlyStoppingCallback
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers.trainer_callback import TrainerControl
from datasets import load_dataset
import torch
import pandas as pd
from torch.utils.data import DataLoader
from transforms import TextMix, SentMix, WordMix

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
def tokenize_fn(text):
    return tokenizer(text, padding=True, truncation=True, max_length=200, return_tensors='pt')

def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True, max_length=200)

def acc_at_k(y_true, y_pred, k=2):
    y_true = torch.tensor(y_true) if type(y_true) != torch.Tensor else y_true
    y_pred = torch.tensor(y_pred) if type(y_pred) != torch.Tensor else y_pred
    total = len(y_true)
    y_weights, y_idx = torch.topk(y_true, k=k, dim=-1)
    out_weights, out_idx = torch.topk(y_pred, k=k, dim=-1)
    correct = torch.sum(torch.eq(y_idx, out_idx) * y_weights)
    acc = correct / total
    return acc.item()

def CEwST_loss(logits, target, reduction='mean'):
    """
    Cross Entropy with Soft Target (CEwST) Loss
    :param logits: (batch, *)
    :param target: (batch, *) same shape as logits, each item must be a valid distribution: target[i, :].sum() == 1.
    """
    logprobs = torch.nn.functional.log_softmax(logits.view(logits.shape[0], -1), dim=1)
    batchloss = - torch.sum(target.view(target.shape[0], -1) * logprobs, dim=1)
    if reduction == 'none':
        return batchloss
    elif reduction == 'mean':
        return torch.mean(batchloss)
    elif reduction == 'sum':
        return torch.sum(batchloss)
    else:
        raise NotImplementedError('Unsupported reduction mode.')

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=None)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1.mean(),
        'precision': precision.mean(),
        'recall': recall.mean()
    }        
        
def compute_metrics_w_soft_target(pred):
    labels = pred.label_ids
    preds = pred.predictions
    acc = acc_at_k(labels, preds, k=2)
    return {
        'accuracy': acc,
    }

class TargetedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs[0]
        loss = CEwST_loss(logits, labels)
        if return_outputs:
            return loss, outputs
        return loss

class TargetedMixturesCallback(TrainerCallback):
    """
    A callback that calculates a confusion matrix on the validation
    data and returns the most confused class pairings.
    """
    def __init__(self, dataloader, device):
        self.dataloader = dataloader
        self.device = device
        
    def on_evaluate(self, args, state, control, model, tokenizer, **kwargs):
        cnf_mat = self.get_confusion_matrix(model, tokenizer, self.dataloader)
        new_targets = self.get_most_confused_per_class(cnf_mat)
        print("New targets:", new_targets)
        control = TrainerControl
        control.new_targets = new_targets
        if state.global_step < state.max_steps:
            control.should_training_stop = False
        else:
            control.should_training_stop = True
        return control
        
    def get_confusion_matrix(self, model, tokenizer, dataloader, normalize=True):
        n_classes = max(dataloader.dataset['label']) + 1
        confusion_matrix = torch.zeros(n_classes, n_classes)
        with torch.no_grad():
            for batch in iter(self.dataloader):
                data, targets = batch['text'], batch['label']
                data = tokenizer(data, padding=True, truncation=True, max_length=250, return_tensors='pt')
                input_ids = data['input_ids'].to(self.device)
                attention_mask = data['attention_mask'].to(self.device)
                targets = targets.to(self.device)
                outputs = model(input_ids, attention_mask=attention_mask).logits
                preds = torch.argmax(outputs, dim=1).cpu()
                for t, p in zip(targets.view(-1), preds.view(-1)):
                    confusion_matrix[t.long(), p.long()] += 1    
            if normalize:
                confusion_matrix = confusion_matrix / confusion_matrix.sum(dim=0)
        return confusion_matrix

    def get_most_confused_per_class(self, confusion_matrix):
        idx = torch.arange(len(confusion_matrix))
        cnf = confusion_matrix.fill_diagonal_(0).max(dim=1)[1]
        return torch.stack((idx, cnf)).T.tolist()

class TargetedMixturesCollator:
    def __init__(self, tokenize_fn, transform, target_pairs=[], target_prob=1.0, num_classes=4):
        self.tokenize_fn = tokenize_fn
        self.transform = transform
        self.target_pairs = target_pairs
        self.target_prob = target_prob
        self.num_classes = num_classes
        print("TargetedMixturesCollator initialized with {}".format(transform.__class__.__name__))
        
    def __call__(self, batch):
        text = [x['text'] for x in batch]
        labels = [x['label'] for x in batch]
        batch = (text, labels)
        batch = self.transform(
            batch, 
            self.target_pairs,   
            self.target_prob,
            self.num_classes
        )
        text, labels = batch
        batch = self.tokenize_fn(text)
        batch['labels'] = torch.tensor(labels)
        return batch
    
class DefaultCollator:
    def __init__(self):
        pass
    def __call__(self, batch):
        return torch.utils.data.dataloader.default_collate(batch)

In [4]:
MODEL_NAMES = ['bert-base-uncased', 'roberta-base', 'xlnet-base-cased']
ts = [TextMix(), SentMix(), WordMix()]

In [5]:
results = []

for MODEL_NAME in MODEL_NAMES:
        
    for t in ts: 
        
        t_str = t.__class__.__name__
        checkpoint = './results/' + MODEL_NAME + '-targeted-' + t_str
        
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=4).to(device)

        dataset = load_dataset('ag_news', split='train') 
        dataset_dict = dataset.train_test_split(
            test_size = 0.05,
            train_size = 0.95,
            shuffle = True
        )
        train_dataset = dataset_dict['train']
        eval_dataset = dataset_dict['test']

        test_dataset = load_dataset('ag_news', split='test') 
        test_dataset.rename_column_('label', 'labels')
        test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(test_dataset))
        test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])
        
        train_batch_size = 10
        eval_batch_size  = 32
        num_epoch = 10
        gradient_accumulation_steps = 1
        max_steps = int((len(train_dataset) * num_epoch / gradient_accumulation_steps) / train_batch_size)

        tmcb = TargetedMixturesCallback(
            dataloader=DataLoader(eval_dataset, batch_size=32),
            device=device
        )
        escb = EarlyStoppingCallback(
            early_stopping_patience=10
        )
        tmc = TargetedMixturesCollator(
            tokenize_fn=tokenize_fn, 
            transform=t,
            target_prob=0.5
        )

        training_args = TrainingArguments(\
            output_dir=checkpoint,
            overwrite_output_dir=True,
            max_steps=max_steps,
            save_steps=int(max_steps / 10),
            save_total_limit=1,
            per_device_train_batch_size=train_batch_size,
            per_device_eval_batch_size=eval_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps, 
            warmup_steps=int(max_steps / 10),
            weight_decay=0.01,
            logging_dir='./logs',
            logging_steps=1000,
            logging_first_step=True,
            load_best_model_at_end=True,
            metric_for_best_model="accuracy",
            greater_is_better=True,
            evaluation_strategy="steps",
            remove_unused_columns=False
        )

        trainer = TargetedTrainer(
            model=model, 
            tokenizer=tokenizer,
            args=training_args,
            compute_metrics=compute_metrics_w_soft_target,                  
            train_dataset=train_dataset,         
            eval_dataset=eval_dataset,
            data_collator=tmc,
            callbacks=[tmcb, escb]
        )

        trainer.train()

        # test with ORIG data
        trainer.eval_dataset = test_dataset
        trainer.compute_metrics = compute_metrics
        trainer.data_collator = DefaultCollator()
        trainer.remove_callback(tmcb)

        out_orig = trainer.evaluate()
        out_orig['run'] = checkpoint
        out_orig['test'] = "ORIG"
        print('ORIG for {}\n{}'.format(checkpoint, out_orig))

        results.append(out_orig)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


TargetedMixturesCollator initialized with TextMix


W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.1396,0.808837,0.672657,41.8092,143.509
2000,0.8156,0.789588,0.679618,41.3296,145.175
3000,0.7665,0.772603,0.704306,41.7923,143.567
4000,0.7611,0.774343,0.661262,41.7152,143.832
5000,0.7626,0.776974,0.7222,41.9252,143.112
6000,0.7562,0.760593,0.715502,41.8154,143.488
7000,0.7502,0.749505,0.706589,41.6941,143.905
8000,0.7592,0.749131,0.683196,41.6247,144.145
9000,0.7551,0.763973,0.725802,41.8193,143.475
10000,0.7367,0.834224,0.701059,41.859,143.338


New targets: [[0, 3], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 

ORIG for ./results/bert-base-uncased-targeted-TextMix
{'eval_loss': 27.071096420288086, 'eval_accuracy': 0.9306578947368421, 'eval_f1': 0.930685905086747, 'eval_precision': 0.9309863143174393, 'eval_recall': 0.9306578947368422, 'eval_runtime': 61.7877, 'eval_samples_per_second': 123.002, 'epoch': 5.26, 'run': './results/bert-base-uncased-targeted-TextMix', 'test': 'ORIG'}


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.



TargetedMixturesCollator initialized with SentMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.1114,0.811199,0.675503,41.4916,144.607
2000,0.8077,0.787932,0.674835,41.9932,142.88
3000,0.7843,0.820006,0.67677,41.4439,144.774
4000,0.7746,0.788888,0.648939,41.9008,143.195
5000,0.7734,0.790702,0.652741,41.7423,143.739
6000,0.7476,0.779163,0.672822,41.5969,144.242
7000,0.7523,0.765362,0.67389,41.5349,144.457
8000,0.7529,0.800721,0.648623,41.7676,143.652
9000,0.7522,0.784195,0.697498,41.676,143.968
10000,0.7356,0.789944,0.652224,41.8095,143.508


New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 

ORIG for ./results/bert-base-uncased-targeted-SentMix
{'eval_loss': 24.861095428466797, 'eval_accuracy': 0.9298684210526316, 'eval_f1': 0.9302086197215405, 'eval_precision': 0.9314223893874708, 'eval_recall': 0.9298684210526316, 'eval_runtime': 61.8363, 'eval_samples_per_second': 122.905, 'epoch': 3.68, 'run': './results/bert-base-uncased-targeted-SentMix', 'test': 'ORIG'}


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.



TargetedMixturesCollator initialized with WordMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.1321,0.856758,0.630224,41.9902,142.89
2000,0.8598,0.843506,0.604995,41.8456,143.384
3000,0.8311,0.800334,0.618334,41.6706,143.986
4000,0.8078,0.801563,0.602642,41.9982,142.863
5000,0.7827,0.785071,0.606398,41.8304,143.436
6000,0.7853,0.810397,0.548979,41.7991,143.544
7000,0.8055,0.819937,0.597269,41.7296,143.783
8000,0.7901,0.775075,0.634653,41.8514,143.364
9000,0.8043,0.822465,0.554362,41.7101,143.85
10000,0.8046,0.844121,0.569214,42.0731,142.609


New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 1], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 1], [1, 2], [2, 0], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 1], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 

ORIG for ./results/bert-base-uncased-targeted-WordMix
{'eval_loss': 21.50132179260254, 'eval_accuracy': 0.9085526315789474, 'eval_f1': 0.9085825387020183, 'eval_precision': 0.9086716755290601, 'eval_recall': 0.9085526315789474, 'eval_runtime': 61.6814, 'eval_samples_per_second': 123.214, 'epoch': 1.93, 'run': './results/bert-base-uncased-targeted-WordMix', 'test': 'ORIG'}


Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifie

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.



TargetedMixturesCollator initialized with TextMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.0703,0.728975,0.733501,41.2023,145.623
2000,0.785,0.766644,0.66437,41.2764,145.362
3000,0.7739,0.756887,0.719969,41.1009,145.982
4000,0.7683,0.792227,0.696743,41.4647,144.701
5000,0.7624,0.775964,0.715147,41.3206,145.206
6000,0.7623,0.726631,0.649587,41.274,145.37
7000,0.7533,0.766364,0.766213,41.4208,144.855
8000,0.7582,0.754376,0.743606,40.9838,146.399
9000,0.7419,0.761356,0.737333,41.5331,144.463
10000,0.7548,0.770594,0.702068,41.4152,144.874


New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 0]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 

ORIG for ./results/roberta-base-targeted-TextMix
{'eval_loss': 25.243261337280273, 'eval_accuracy': 0.9360526315789474, 'eval_f1': 0.9359662658623298, 'eval_precision': 0.9361662118240563, 'eval_recall': 0.9360526315789472, 'eval_runtime': 61.1577, 'eval_samples_per_second': 124.269, 'epoch': 5.18, 'run': './results/roberta-base-targeted-TextMix', 'test': 'ORIG'}


Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifie

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.



TargetedMixturesCollator initialized with SentMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.0407,0.751379,0.702937,41.2679,145.391
2000,0.7811,0.76934,0.704486,41.1307,145.876
3000,0.7711,0.796822,0.606337,41.1965,145.643
4000,0.7791,0.785128,0.6426,41.3639,145.054
5000,0.7888,0.844962,0.642662,41.083,146.046
6000,0.7752,0.797398,0.65418,41.2901,145.313
7000,0.7652,0.771777,0.736489,41.2595,145.421
8000,0.7505,0.776426,0.707972,41.197,145.642
9000,0.7473,0.77354,0.663444,41.3974,144.937
10000,0.7565,0.762717,0.644465,41.5558,144.384


New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 0]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 0]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 0]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 

ORIG for ./results/roberta-base-targeted-SentMix
{'eval_loss': 25.53358268737793, 'eval_accuracy': 0.9375, 'eval_f1': 0.9376240867980787, 'eval_precision': 0.9383416662662304, 'eval_recall': 0.9375, 'eval_runtime': 61.2393, 'eval_samples_per_second': 124.103, 'epoch': 4.3, 'run': './results/roberta-base-targeted-SentMix', 'test': 'ORIG'}


Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifie

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.



TargetedMixturesCollator initialized with WordMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.0498,0.809248,0.649325,40.5421,147.994
2000,0.8373,0.869969,0.543037,40.4886,148.19
3000,0.8154,0.80961,0.586436,40.4849,148.203
4000,0.7992,0.840861,0.599084,40.148,149.447
5000,0.8004,0.834741,0.572262,40.3315,148.767
6000,0.803,0.815939,0.5952,40.4436,148.355
7000,0.801,0.825424,0.591453,40.0703,149.737
8000,0.8069,0.839419,0.581111,40.1008,149.623
9000,0.8332,0.849409,0.570176,40.2725,148.985
10000,0.8235,0.839813,0.604064,40.1887,149.296


New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 1], [1, 0], [2, 0], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 0]]
New targets: [[0, 2], [1, 0], [2, 0], [3, 2]]
New targets: [[0, 1], [1, 0], [2, 3], [3, 0]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 0]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 1], [1, 0], [2, 0], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 1], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 

ORIG for ./results/roberta-base-targeted-WordMix
{'eval_loss': 21.70138931274414, 'eval_accuracy': 0.9088157894736842, 'eval_f1': 0.9087583705618749, 'eval_precision': 0.9105993282375267, 'eval_recall': 0.9088157894736842, 'eval_runtime': 61.5478, 'eval_samples_per_second': 123.481, 'epoch': 3.25, 'run': './results/roberta-base-targeted-WordMix', 'test': 'ORIG'}


Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence_summary.summary.weight', 'sequence_summary.summary.bias', 'logits_proj.weight', 'logits_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.



TargetedMixturesCollator initialized with TextMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.0745,0.792972,0.669984,94.707,63.353
2000,0.8049,0.796526,0.654298,94.8519,63.257
3000,0.7838,0.775121,0.682193,93.9159,63.887
4000,0.7639,0.764956,0.704449,94.2961,63.629
5000,0.7516,0.74801,0.693294,94.6262,63.407
6000,0.7525,0.739099,0.70189,94.4175,63.548
7000,0.7522,0.812666,0.719719,94.485,63.502
8000,0.7468,0.775083,0.622422,94.6393,63.399
9000,0.7471,0.745748,0.729307,94.2597,63.654
10000,0.7573,0.715481,0.740632,94.4735,63.51


New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 0]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 2], [2, 

ORIG for ./results/xlnet-base-cased-targeted-TextMix
{'eval_loss': 27.15379524230957, 'eval_accuracy': 0.9328947368421052, 'eval_f1': 0.9329292708245343, 'eval_precision': 0.9333931286860029, 'eval_recall': 0.9328947368421052, 'eval_runtime': 138.3675, 'eval_samples_per_second': 54.926, 'epoch': 6.58, 'run': './results/xlnet-base-cased-targeted-TextMix', 'test': 'ORIG'}


Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence_summary.summary.weight', 'sequence_summary.summary.bias', 'logits_proj.weight', 'logits_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.



TargetedMixturesCollator initialized with SentMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.0682,0.800256,0.669279,94.8811,63.237
2000,0.8146,0.793195,0.609697,93.5357,64.147
3000,0.7802,0.810916,0.604511,94.6252,63.408
4000,0.7679,0.81531,0.698414,93.7145,64.024
5000,0.7635,0.790875,0.669827,93.0423,64.487
6000,0.7584,0.782297,0.553055,93.6268,64.084
7000,0.7559,0.777668,0.579981,93.7384,64.008
8000,0.7418,0.751674,0.680772,93.4228,64.224
9000,0.7514,0.782886,0.723713,93.5709,64.122
10000,0.7675,0.81181,0.661829,93.73,64.014


New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 0], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 

ORIG for ./results/xlnet-base-cased-targeted-SentMix
{'eval_loss': 24.25064468383789, 'eval_accuracy': 0.935921052631579, 'eval_f1': 0.9360067471173967, 'eval_precision': 0.936320568584647, 'eval_recall': 0.935921052631579, 'eval_runtime': 138.5276, 'eval_samples_per_second': 54.863, 'epoch': 4.12, 'run': './results/xlnet-base-cased-targeted-SentMix', 'test': 'ORIG'}


Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['sequence_summary.summary.weight', 'sequence_summary.summary.bias', 'logits_proj.weight', 'logits_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.



TargetedMixturesCollator initialized with WordMix


Step,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1000,1.0694,0.839141,0.656574,93.4444,64.209
2000,0.8547,0.847419,0.555724,93.3804,64.253
3000,0.8222,0.811653,0.611738,92.9555,64.547
4000,0.7964,0.822057,0.609598,93.6295,64.082
5000,0.8022,0.781116,0.589005,92.9446,64.555
6000,0.7909,0.803952,0.584095,93.7036,64.032
7000,0.7843,0.85145,0.554407,93.0067,64.511
8000,0.7881,0.891525,0.539363,92.9684,64.538
9000,0.7842,0.796675,0.605546,93.657,64.064
10000,0.7931,0.804632,0.613533,92.5264,64.846


New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 0], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 0], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 0], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 3], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 3], [3, 2]]
New targets: [[0, 3], [1, 0], [2, 3], [3, 2]]
New targets: [[0, 2], [1, 2], [2, 

ORIG for ./results/xlnet-base-cased-targeted-WordMix
{'eval_loss': 20.25948143005371, 'eval_accuracy': 0.9161842105263158, 'eval_f1': 0.9163314377461362, 'eval_precision': 0.9169453092152762, 'eval_recall': 0.9161842105263158, 'eval_runtime': 138.4677, 'eval_samples_per_second': 54.886, 'epoch': 2.81, 'run': './results/xlnet-base-cased-targeted-WordMix', 'test': 'ORIG'}


In [16]:
pd.DataFrame(results)

ValueError: Only callable can be used as callback

ValueError: Only callable can be used as callback

In [12]:
df = pd.DataFrame(results)
                                                

In [7]:
df.to_csv('train_AG_NEWS_targeted_r2.csv')

In [8]:
df.to_clipboard(excel=True)

In [9]:
# ORIG for ./results/bert-base-uncased-targeted-TextMix
# {'eval_loss': 31.2364559173584, 'eval_accuracy': 0.9381578947368421, 'eval_f1': 0.9381945850526017, 'eval_precision': 0.938240633851668, 'eval_recall': 0.9381578947368421, 'eval_runtime': 117.2622, 'eval_samples_per_second': 64.812, 'epoch': 5.0, 'run': 'TextMix', 'test': 'ORIG'}