In [2]:
import abc
from dataclasses import dataclass, field
import json
from pathlib import Path
from typing import *

from beartype import beartype
import datasets
import numpy as np
import more_itertools
import queue
import rich
import torch
import torch.nn as nn
import transformers
import wandb

TokenizerType = transformers.tokenization_utils_fast.PreTrainedTokenizerFast

class BaseEpsilonScheduler(abc.ABC):
    @abc.abstractmethod
    def __call__(self):
        pass

class LinearEpsilonScheduler(BaseEpsilonScheduler):
    def __init__(self, epsilon, num_steps):
        self.epsilon = epsilon
        self.num_steps = num_steps
        self.epoch = 0

    def __call__(self):
        self.epoch += 1
        epsilon = min(self.epsilon * (1 - self.epoch / self.num_epochs), 1)
        wandb.log({"epsilon": epsilon})
        wandb.log({"epsilon_num_steps": self.num_steps})
        return epsilon

class ConstantEpsilonScheduler(BaseEpsilonScheduler):
    def __init__(self, epsilon):
        self.epsilon = epsilon

    def __call__(self):
        epsilon = self.epsilon
        wandb.log({"epsilon": epsilon})
        return epsilon

In [3]:
class BaseRetriever(abc.ABC):
    @abc.abstractmethod
    def retrieve(self, query_ids, query_index):
        pass


class StupidRetriever(BaseRetriever): 
    @beartype
    def __init__(
        self, *, model, tokenizer: TokenizerType, device: Union[int, str], 
        train_vectors: torch.Tensor, train_samples: List[str], train_labels: List[int]
    ):
    
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.train_vectors = train_vectors
        self.train_samples = train_samples
        self.train_labels = train_labels
        self.classification_ids_to_idx = {}

        for i, sample in enumerate(train_samples):
            encoded = tokenizer.encode(sample, truncation=True, padding=True, return_tensors="pt")
            self.classification_ids_to_idx[encoded] = i

    def retrieve(self, query_ids, query_index):
        # Get the representation
        representation = self.train_vectors[query_index]
        with torch.inference_mode():
            # Compute the inner products
            scores = torch.matmul(representation, self.train_vectors.t())
            # Get the top 2 results, to potentially exclude the sample itself.
            topk = torch.topk(scores, k=2, dim=-1)
        topk = topk.indices.cpu().numpy()
        
        for retrieved_idx in topk:
            if retrieved_idx != query_index:
                return self.train_samples[retrieved_idx], self.train_labels[retrieved_idx], retrieved_idx
        

# build train vectors
@beartype
def make_retrival_model_and_vectors(
    retriever_name: str, path_to_vectors: Union[str, Path], device: int, dataset_name: str
):
    """We expect the dir to have the following structure:
    - config.json
    - train_samples.json 
    - train_vectors.npy
    """    
    # Make some checks
    config =  json.loads((path_to_vectors / "config.json").read_text())
    assert dataset_name == config["dataset_name"], (dataset_name, config["dataset_name"])
    assert retriever_name == config["retriever_name"], (retriever_name, config["retriever_name"])

    retriever_model = transformers.AutoModel.from_pretrained(retriever_name)
    retriever_tokenizer = transformers.AutoTokenizer.from_pretrained(retriever_name)

    with open(path_to_vectors / "train_samples.json") as f:
        train_samples = json.load(f)
        
    vectors = torch.tensor(np.load(path_to_vectors / "train_vectors.npy")).to(device)
    retriever = StupidRetriever(
        model=retriever_model, 
        tokenizer=retriever_tokenizer, 
        device=device, 
        train_vectors=vectors, 
        train_samples=train_samples["inputs"],
        train_labels=train_samples["labels"],
    )
    
    return retriever


@dataclass(order=True)
class PrioritizedItem:
    priority: int
    item: Any=field(compare=False)


class BoostingIterator(torch.utils.data.IterableDataset):
    @beartype
    def __init__(
        self, 
        *args, 
        dataset: torch.utils.data.Dataset, 
        retriever_client: BaseRetriever, 
        classifier: nn.Module, seed: int, 
        classification_device: Union[int, str], 
        classification_tokenizer: TokenizerType, 
        retriever_device: Union[int, str],
        epsilon_scheduler: BaseEpsilonScheduler, 
        loss_ema_alpha: float, 
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.dataset = dataset.map(
            lambda example, idx:{"index": idx}, with_indices=True
        ).shuffle(seed=seed)
        self.priority_queue = queue.PriorityQueue()
        self.retriever_client = retriever_client
        self.epsilon_scheduler = epsilon_scheduler
        self.randomizer = np.random.RandomState(seed)
        self.seed = seed
        self.dataset_iter = None
        self.classifier = classifier
        self.classification_tokenizer = classification_tokenizer
        self.classification_device = classification_device
        self.retriever_device = retriever_device
        self.loss_moving_average = None
        self.loss_ema_alpha = loss_ema_alpha

        # assert mode in ["epsilon_priority_no_reset", "pure_sampled", "epsilon_sampled"], mode

    def push_score(self, inputs, loss):
        average_loss = loss.mean()
        if self.loss_moving_average is None:
            self.loss_moving_average = average_loss
        else:
            self.loss_moving_average = (
                self.loss_ema_alpha * self.loss_moving_average + (1 - self.loss_ema_alpha) * average_loss
            )

        for input_, mask, loss_, index in (
            more_itertools.zip_equal(inputs["input_ids"], inputs["attention_mask"], loss, inputs["index"])
        ):
            assert loss_.shape == torch.Size([]), loss_.shape
            self.priority_queue.put(
                PrioritizedItem(
                    priority= -loss_.detach().cpu().numpy() / self.loss_moving_average, 
                    item=dict(input_ids=input_, attention_mask=mask, index=index)
                    )
                )

    def __len__(self):
        return len(self.dataset)

    def __iter__(self):
        rich.print("[bold green]ITER[/]")
        self.dataset = self.dataset.shuffle(seed=self.seed)
        self.dataset_iter = iter(self.dataset)
        return self
    
    def __next__(self):
        """ This is where the sampling happens.
        """

        # Test if we have a sample and if we pass the epsilon threshold
        empty = self.priority_queue.empty()
        rand = self.randomizer.rand()
        if not empty and rand < self.epsilon_scheduler():
            # pull a sample from the priority queue
            sample = self.priority_queue.get().item

            # We retrieve the next sample.
            input_, next_label, index = self.retriever_client.retrieve(
                sample["input_ids"], sample["index"]
            )
            next_sample = dict(text=input_, label=next_label, index=index)
        else:
            next_sample = next(self.dataset_iter)  # We raise here if we have no more samples in the dataset
            assert next_sample.keys() == {"text", "label", "index"}, next_sample.keys()

        tokenized = self.classification_tokenizer.encode_plus(
            next_sample["text"].strip(), 
            truncation=True, 
            padding=True,
        )

        # text is not needed anymore
        del next_sample["text"]
        assert len(tokenized.keys() & next_sample.keys()) == 0, (tokenized.keys(), next_sample.keys()) 
        return dict(**tokenized, **next_sample)


class BoostingTrainer(transformers.Trainer):
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.
        Subclass and override to inject custom behavior.
        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.
                The dictionary will be unpacked before being fed to the model. 
                Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """
        model.train()
        inputs = self._prepare_inputs(inputs)
        index = inputs["index"]
        # Compute loss doesn't work with extra arguments.
        del inputs["index"]

        with self.autocast_smart_context_manager():
            # Get the loss
            loss, outputs = self.compute_loss(model, inputs, return_outputs=True)

        if self.args.n_gpu > 1:
            # Mean over per gpu averages
            loss = loss.mean()  # mean() to average on multi-gpu parallel training

        # This is ignored in the priority queue computation
        if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
            # Deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
            loss = loss / self.args.gradient_accumulation_steps
        if self.do_grad_scaling:
            self.scaler.scale(loss).backward()
        elif self.use_apex:
            with torch.cuda.amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        elif self.deepspeed:
            # loss gets scaled under gradient_accumulation_steps in deepspeed
            loss = self.deepspeed.backward(loss)
        else:
            loss.backward()

        loss = loss.detach()

        # Addition for RetroBoost
        # Make sure the losses are similar, then push them to the priority queue
        # Put index back in

        inputs["index"] = index
        computed_loss = torch.nn.functional.cross_entropy(outputs.logits.detach(), inputs["labels"].detach(), reduction="mean")
        loss_per_sample = torch.nn.functional.cross_entropy(outputs.logits.detach(), inputs["labels"].detach(), reduction="none")
        
        assert torch.allclose(loss, computed_loss, atol=0.1)
        self.get_train_dataloader().dataset.push_score(inputs, loss_per_sample)

        return loss




    https://www.python.org/dev/peps/pep-0585
    https://www.python.org/dev/peps/pep-0585


In [4]:
RETRIEVER_NAME = "facebook/contriever"
DATASET_NAME = "ag_news"
PATH_TO_VECTORS = Path(f"./vectors_{DATASET_NAME}_{RETRIEVER_NAME.split('/')[-1]}/")
CLASSIFIER_NAME = "roberta-base"
CLASSIFIER_BATCH_SIZE = 40
EPSILON_SCHEDULER_TYPE = "constant"
EPSILON_SCHEDULER_CONFIG = dict(
    epsilon=0.5,
)
LOSS_EMA_ALPHA = 0.5

REGULAR_TRAINER = True
CLASSIFIER_DEVICE = 1
RETRIEVER_DEVICE = 2
SEED = 0

###############################################################################
# Fast setup 
###############################################################################
config = dict(
        retriever_name=RETRIEVER_NAME,
        dataset_name=DATASET_NAME,
        classifier_name=CLASSIFIER_NAME,
        regular_trainer=REGULAR_TRAINER,
        loss_ema_alpha=LOSS_EMA_ALPHA,
        epsilon=dict(
            scheduler_type=EPSILON_SCHEDULER_TYPE,
            scheduler_config=EPSILON_SCHEDULER_CONFIG,
        )
    )

wandb.init(
    config=config,
    project="RetroBoost", 
    entity="retroboost",
)

EPSILON_SCHEDULER_TYPE_MAP = dict(
    constant=ConstantEpsilonScheduler,
)

# Random seeds. 
np.random.seed(0)
torch.manual_seed(0)

dataset = datasets.load_dataset(DATASET_NAME)
ALL_LABELS = set(dataset["train"]["label"])
NUM_LABELS = len(ALL_LABELS)
assert ALL_LABELS == set(range(NUM_LABELS))

classifier_name = CLASSIFIER_NAME
dataset_name = DATASET_NAME
regular_trainer = REGULAR_TRAINER


classifier = transformers.AutoModelForSequenceClassification.from_pretrained(
    classifier_name, num_labels=NUM_LABELS
)
classifier_tokenizer = transformers.AutoTokenizer.from_pretrained(classifier_name)

def preprocess_function(examples, tokenizer):
    return tokenizer(examples["text"], truncation=True, padding=True)

def preprocess_function(examples, tokenizer):
    return tokenizer(examples["text"], truncation=True, padding=True)

dataset = datasets.load_dataset(dataset_name)
tokenized_training = dataset["train"].map(
    lambda examples: preprocess_function(examples, classifier_tokenizer), 
    batched=True
)

tokenized_validation = dataset["test"].map(
    lambda examples: preprocess_function(examples, classifier_tokenizer), 
    batched=True
)

training_args = transformers.TrainingArguments(
    eval_steps=499,
    evaluation_strategy="steps",
    output_dir="./results",
    learning_rate=1e-5,
    per_device_train_batch_size=CLASSIFIER_BATCH_SIZE,
    per_device_eval_batch_size=int(CLASSIFIER_BATCH_SIZE * 1.5),
    num_train_epochs=5,
    weight_decay=0.01,
    report_to="wandb",
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mretroboost[0m (use `wandb login --relogin` to force relogin)


Using custom data configuration default
Reusing dataset ag_news (/home/mila/g/gagnonju/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)


  0%|          | 0/2 [00:00<?, ?it/s]

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.layer_norm.bias', 'roberta.pooler.dense.bias', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifie

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/120 [00:00<?, ?ba/s]

  0%|          | 0/8 [00:00<?, ?ba/s]

In [5]:
retriever = make_retrival_model_and_vectors(
    retriever_name=RETRIEVER_NAME, 
    path_to_vectors=PATH_TO_VECTORS, 
    device=RETRIEVER_DEVICE, 
    dataset_name=DATASET_NAME,
)
retriever_client = retriever

if regular_trainer:
    trainer = transformers.Trainer(
        model=classifier.to(CLASSIFIER_DEVICE), 
        args=training_args, 
        tokenizer=classifier_tokenizer, 
        train_dataset=tokenized_training, 
        eval_dataset=tokenized_validation,
        data_collator=transformers.DataCollatorWithPadding(
            tokenizer=classifier_tokenizer
        ),
    )
else:
    tokenized_training = BoostingIterator(
        dataset=dataset["train"], 
        retriever_client=retriever_client, 
        classifier=classifier, 
        epsilon_scheduler=EPSILON_SCHEDULER_TYPE_MAP[EPSILON_SCHEDULER_TYPE](**EPSILON_SCHEDULER_CONFIG), 
        seed=SEED,
        retriever_device=RETRIEVER_DEVICE, 
        classification_device=CLASSIFIER_DEVICE,
        classification_tokenizer=classifier_tokenizer,
        loss_ema_alpha=LOSS_EMA_ALPHA,
    )
    
    trainer = BoostingTrainer(
        model=classifier.to(CLASSIFIER_DEVICE),
        args=training_args, 
        tokenizer=classifier_tokenizer, 
        train_dataset=tokenized_training, 
        eval_dataset=tokenized_validation,
        data_collator=transformers.DataCollatorWithPadding(
            tokenizer=classifier_tokenizer
        ),
    )

In [None]:
output = trainer.train()
wandb.finish()

The following columns in the training set  don't have a corresponding argument in `RobertaForSequenceClassification.forward` and have been ignored: text.
***** Running training *****
  Num examples = 25000
  Num Epochs = 5
  Instantaneous batch size per device = 40
  Total train batch size (w. parallel, distributed & accumulation) = 120
  Gradient Accumulation steps = 1
  Total optimization steps = 1045
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Step,Training Loss,Validation Loss
499,No log,0.138188
998,0.189000,0.141031


The following columns in the evaluation set  don't have a corresponding argument in `RobertaForSequenceClassification.forward` and have been ignored: text.
***** Running Evaluation *****
  Num examples = 25000
  Batch size = 180
Saving model checkpoint to ./results/checkpoint-500
Configuration saved in ./results/checkpoint-500/config.json
Model weights saved in ./results/checkpoint-500/pytorch_model.bin
tokenizer config file saved in ./results/checkpoint-500/tokenizer_config.json
Special tokens file saved in ./results/checkpoint-500/special_tokens_map.json
The following columns in the evaluation set  don't have a corresponding argument in `RobertaForSequenceClassification.forward` and have been ignored: text.
***** Running Evaluation *****
  Num examples = 25000
  Batch size = 180
Saving model checkpoint to ./results/checkpoint-1000
Configuration saved in ./results/checkpoint-1000/config.json
Model weights saved in ./results/checkpoint-1000/pytorch_model.bin
tokenizer config file saved

VBox(children=(Label(value=' 0.04MB of 0.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
eval/loss,▁█
eval/runtime,█▁
eval/samples_per_second,▁█
eval/steps_per_second,▁█
train/epoch,▁▁▇▇█
train/global_step,▁▁▇▇█
train/learning_rate,█▁
train/loss,█▁
train/total_flos,▁
train/train_loss,▁

0,1
eval/loss,0.14103
eval/runtime,23.4115
eval/samples_per_second,1067.851
eval/steps_per_second,5.937
train/epoch,5.0
train/global_step,1045.0
train/learning_rate,0.0
train/loss,0.0903
train/total_flos,3.288888192e+16
train/train_loss,0.13665


In [None]:
retriever

<__main__.StupidRetriever at 0x7fdfcd28fac0>