In [1]:
from dataclasses import dataclass, field
import json
from pathlib import Path
from typing import *

import datasets
import numpy as np
import more_itertools
import queue
import rich
import torch
import torch.nn as nn
import transformers
import wandb


class LinearEpsilonScheduler:
    def __init__(self, epsilon, num_steps):
        self.epsilon = epsilon
        self.num_epochs = num_steps
        self.epoch = 0

    def __call__(self):
        self.epoch += 1
        epsilon = min(self.epsilon * (1 - self.epoch / self.num_epochs), 1)
        wandb.log({"epsilon": epsilon})
        return epsilon

class ConstantEpsilonScheduler:
    def __init__(self, epsilon):
        self.epsilon = epsilon

    def __call__(self):
        epsilon = self.epsilon
        wandb.log({"epsilon": epsilon})
        return epsilon

In [None]:


class StupidRetriever: 
    def __init__(self, *, model, tokenizer, device, train_vectors, train_samples, train_labels):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.train_vectors = train_vectors
        self.train_samples = train_samples
        self.train_labels = train_labels
        self.classification_ids_to_idx = {}

        for i, sample in enumerate(train_samples):
            encoded = tokenizer.encode(sample, truncation=True, padding=True, return_tensors="pt")
            self.classification_ids_to_idx[encoded] = i

    def retrieve(self, query_ids, query_index):
        # Get the representation
        representation = self.train_vectors[query_index]
        with torch.inference_mode():
            # Compute the inner products
            scores = torch.matmul(representation, self.train_vectors.t())
        # Get the top 2 results, to potentially exclude the sample itself.
        topk = torch.topk(scores, k=2, dim=-1)
        topk = topk.indices.cpu().numpy()
        
        for retrieved_idx in topk:
            if retrieved_idx != query_index:
                return self.train_samples[retrieved_idx], self.train_labels[retrieved_idx], retrieved_idx
        

# build train vectors
def make_retrival_model_and_vectors(retriever_name, path_to_vectors, device, dataset_name):
    """We expect the dir to have the following structure:
    - config.json
    - train_samples.json 
    - train_vectors.npy
    """    
    # Make some checks
    config =  json.loads((path_to_vectors / "config.json").read_text())
    assert dataset_name == config["dataset_name"], (dataset_name, config["dataset_name"])
    assert retriever_name == config["retriever_name"], (retriever_name, config["retriever_name"])

    retriever_model = transformers.AutoModel.from_pretrained(retriever_name)
    retriever_tokenizer = transformers.AutoTokenizer.from_pretrained(retriever_name)

    with open(path_to_vectors / "train_samples.json") as f:
        train_samples = json.load(f)
    vectors = torch.tensor(np.load(path_to_vectors / "train_vectors.npy")).to(device)
    retriever = StupidRetriever(
        model=retriever_model, 
        tokenizer=retriever_tokenizer, 
        device=device, 
        train_vectors=vectors, 
        train_samples=train_samples["inputs"],
        train_labels=train_samples["labels"],
    )
    
    return retriever


@dataclass(order=True)
class PrioritizedItem:
    priority: int
    item: Any=field(compare=False)


class BoostingIterator(torch.utils.data.IterableDataset):
    def __init__(
        self, *args, dataset, retriever_client, classifier, seed, 
        classification_device, classification_tokenizer, retriever_device, 
        epsilon_scheduler, **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.dataset = dataset.map(
            lambda example, idx:{"index": idx}, with_indices=True
        ).shuffle(seed=seed)
        self.priority_queue = queue.PriorityQueue()
        self.retriever_client = retriever_client
        self.epsilon_scheduler = epsilon_scheduler
        self.randomizer = np.random.RandomState(seed)
        self.seed = seed
        self.dataset_iter = None
        self.classifier = classifier
        self.classification_tokenizer = classification_tokenizer
        self.classification_device = classification_device
        self.retriever_device = retriever_device
        
        # assert mode in ["epsilon_priority_no_reset", "pure_sampled", "epsilon_sampled"], mode

    def push_score(self, inputs, loss):
        for input_, mask, loss_, index in (
            more_itertools.zip_equal(inputs["input_ids"], inputs["attention_mask"], loss, inputs["index"])
        ):
            assert loss_.shape == torch.Size([]), loss_.shape
            self.priority_queue.put(
                PrioritizedItem(
                    priority=-loss_.detach().cpu().numpy(), 
                    item=dict(input_ids=input_, attention_mask=mask, index=index)
                    )
                )

    def __len__(self):
        return len(self.dataset)

    def __iter__(self):
        rich.print("[bold green]ITER[/]")
        self.dataset = self.dataset.shuffle(seed=self.seed)
        self.dataset_iter = iter(self.dataset)
        return self
    
    def __next__(self):
        """ This is where the sampling happens.
        """

        # Test if we have a sample and if we pass the epsilon threshold
        empty = self.priority_queue.empty()
        rand = self.randomizer.rand()
        if not empty and rand < self.epsilon_scheduler():
            # pull a sample from the priority queue
            sample = self.priority_queue.get().item
            ids = sample["input_ids"]
            entry_indexes = sample["index"]

            # We retrieve the next sample.
            input_, next_label, index = self.retriever_client.retrieve(ids, entry_indexes)
            next_sample = dict(text=input_, label=next_label, index=index)
        else:
            next_sample = next(self.dataset_iter)  # We raise here if we have no more samples in the dataset

        tokenized = self.classification_tokenizer(
            next_sample["text"].strip(), 
            truncation=True, 
            padding=True,
        )

        del next_sample["text"]
        return dict(**tokenized, **next_sample)

class BoostingTrainer(transformers.Trainer):
    
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.
        Subclass and override to inject custom behavior.
        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.
                The dictionary will be unpacked before being fed to the model. 
                Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """
        model.train()
        inputs = self._prepare_inputs(inputs)
        index = inputs["index"]
        # Compute loss doesn't like this.
        del inputs["index"]

        with self.autocast_smart_context_manager():
            # Get the loss
            loss, outputs = self.compute_loss(model, inputs, return_outputs=True)

        if self.args.n_gpu > 1:
            # Mean over per gpu averages
            loss = loss.mean()  # mean() to average on multi-gpu parallel training

        # This is ignored in the priority queue computation
        if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
            # Deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
            loss = loss / self.args.gradient_accumulation_steps
        if self.do_grad_scaling:
            self.scaler.scale(loss).backward()
        elif self.use_apex:
            with torch.cuda.amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        elif self.deepspeed:
            # loss gets scaled under gradient_accumulation_steps in deepspeed
            loss = self.deepspeed.backward(loss)
        else:
            loss.backward()

        loss = loss.detach()

        # Addition for RetroBoost
        # Make sure the losses are similar, then push them to the priority queue
        # Put index back in
        inputs["index"] = index
        computed_loss = torch.nn.functional.cross_entropy(outputs.logits.detach(), inputs["labels"].detach(), reduction="mean")
        loss_per_sample = torch.nn.functional.cross_entropy(outputs.logits.detach(), inputs["labels"].detach(), reduction="none")
        assert torch.allclose(loss, computed_loss, atol=0.1)
        self.get_train_dataloader().dataset.push_score(inputs, loss_per_sample)

        return loss




In [None]:
RETRIEVER_NAME = "facebook/contriever"
PATH_TO_VECTORS = Path("./vectors_imdb_contriever/")
DATASET_NAME = "imdb"
CLASSIFIER_NAME = "roberta-base"
EPSILON_SCHEDULER_TYPE = "constant"
EPSILON_SCHEDULER_CONFIG = dict(
    epsilon=0.9,
)

REGULAR_TRAINER = False
CLASSIFIER_DEVICE = 1
RETRIEVER_DEVICE = 2
SEED = 0


###############################################################################
# Fast setup 
###############################################################################
config = dict(
        retriever_name=RETRIEVER_NAME,
        dataset_name=DATASET_NAME,
        classifier_name=CLASSIFIER_NAME,
        regular_trainer=REGULAR_TRAINER,
        epsilon=dict(
            scheduler_type=EPSILON_SCHEDULER_TYPE,
            scheduler_config=EPSILON_SCHEDULER_CONFIG,
        )
    )

wandb.init(
    config=config,
    project="RetroBoost", 
    entity="julesgm",
)

EPSILON_SCHEDULER_TYPE_MAP = dict(
    constant=ConstantEpsilonScheduler,
)

# Random seeds. 
np.random.seed(0)
torch.manual_seed(0)

dataset = datasets.load_dataset(DATASET_NAME)
ALL_LABELS = set(dataset["train"]["label"])
NUM_LABELS = len(ALL_LABELS)
assert ALL_LABELS == set(range(NUM_LABELS))

classifier_name = CLASSIFIER_NAME
dataset_name = DATASET_NAME
regular_trainer = REGULAR_TRAINER


classifier = transformers.AutoModelForSequenceClassification.from_pretrained(
    classifier_name, num_labels=NUM_LABELS
)
classifier_tokenizer = transformers.AutoTokenizer.from_pretrained(classifier_name)

def preprocess_function(examples, tokenizer):
    return tokenizer(examples["text"], truncation=True, padding=True)

def preprocess_function(examples, tokenizer):
    return tokenizer(examples["text"], truncation=True, padding=True)

dataset = datasets.load_dataset(dataset_name)
tokenized_training = dataset["train"].map(
    lambda examples: preprocess_function(examples, classifier_tokenizer), 
    batched=True
)

tokenized_validation = dataset["test"].map(
    lambda examples: preprocess_function(examples, classifier_tokenizer), 
    batched=True
)

training_args = transformers.TrainingArguments(
    output_dir="./results",
    learning_rate=1e-5,
    per_device_train_batch_size=20,
    per_device_eval_batch_size=20,
    num_train_epochs=5,
    weight_decay=0.01,
    report_to="wandb",
)

[34m[1mwandb[0m: [32m[41mERROR[0m Control-C detected -- Run data was not synced


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Reusing dataset imdb (/home/mila/g/gagnonju/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)


  0%|          | 0/3 [00:00<?, ?it/s]

loading configuration file https://huggingface.co/roberta-base/resolve/main/config.json from cache at /home/mila/g/gagnonju/.cache/huggingface/transformers/733bade19e5f0ce98e6531021dd5180994bb2f7b8bd7e80c7968805834ba351e.35205c6cfc956461d8515139f0f8dd5d207a2f336c0c3a83b4bc8dca3518e37b
Model config RobertaConfig {
  "_name_or_path": "roberta-base",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.16.2",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

loading weights file https://huggingface.c

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?ba/s]

  0%|          | 0/25 [00:00<?, ?ba/s]

In [None]:
retriever = make_retrival_model_and_vectors(
    retriever_name=RETRIEVER_NAME, 
    path_to_vectors=PATH_TO_VECTORS, 
    device=RETRIEVER_DEVICE, 
    dataset_name=DATASET_NAME,
)
retriever_client = retriever

if regular_trainer:
    trainer = transformers.Trainer(
        model=classifier.to(CLASSIFIER_DEVICE), 
        args=training_args, 
        tokenizer=classifier_tokenizer, 
        train_dataset=tokenized_training, 
        eval_dataset=tokenized_validation,
        data_collator=transformers.DataCollatorWithPadding(
            tokenizer=classifier_tokenizer
        ),
    )
else:
    tokenized_training = BoostingIterator(
        dataset=dataset["train"], 
        retriever_client=retriever_client, 
        classifier=classifier, 
        epsilon_scheduler=EPSILON_SCHEDULER_TYPE_MAP[EPSILON_SCHEDULER_TYPE](**EPSILON_SCHEDULER_CONFIG), 
        seed=SEED,
        retriever_device=RETRIEVER_DEVICE, 
        classification_device=CLASSIFIER_DEVICE,
        classification_tokenizer=classifier_tokenizer,
    )
    
    trainer = BoostingTrainer(
        model=classifier.to(CLASSIFIER_DEVICE),
        args=training_args, 
        tokenizer=classifier_tokenizer, 
        train_dataset=tokenized_training, 
        eval_dataset=tokenized_validation,
        data_collator=transformers.DataCollatorWithPadding(
            tokenizer=classifier_tokenizer
        ),
    )

Loading cached processed dataset at /home/mila/g/gagnonju/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-a645737e8e3ed793.arrow
Loading cached shuffled indices for dataset at /home/mila/g/gagnonju/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-42092c50c935d800.arrow


In [None]:
output = trainer.train()
wandb.finish()

***** Running training *****
  Num examples = 25000
  Num Epochs = 5
  Instantaneous batch size per device = 20
  Total train batch size (w. parallel, distributed & accumulation) = 60
  Gradient Accumulation steps = 1
  Total optimization steps = 6250
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Loading cached shuffled indices for dataset at /home/mila/g/gagnonju/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-4dc2830ea21e6a12.arrow


Step,Training Loss
500,0.2666
1000,0.1809
1500,0.161
2000,0.1332
2500,0.1335


Saving model checkpoint to ./results/checkpoint-500
Configuration saved in ./results/checkpoint-500/config.json
Model weights saved in ./results/checkpoint-500/pytorch_model.bin
tokenizer config file saved in ./results/checkpoint-500/tokenizer_config.json
Special tokens file saved in ./results/checkpoint-500/special_tokens_map.json
Saving model checkpoint to ./results/checkpoint-1000
Configuration saved in ./results/checkpoint-1000/config.json
Model weights saved in ./results/checkpoint-1000/pytorch_model.bin
tokenizer config file saved in ./results/checkpoint-1000/tokenizer_config.json
Special tokens file saved in ./results/checkpoint-1000/special_tokens_map.json


Loading cached shuffled indices for dataset at /home/mila/g/gagnonju/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-1172bcf5f8186130.arrow
Saving model checkpoint to ./results/checkpoint-1500
Configuration saved in ./results/checkpoint-1500/config.json
Model weights saved in ./results/checkpoint-1500/pytorch_model.bin
tokenizer config file saved in ./results/checkpoint-1500/tokenizer_config.json
Special tokens file saved in ./results/checkpoint-1500/special_tokens_map.json
Saving model checkpoint to ./results/checkpoint-2000
Configuration saved in ./results/checkpoint-2000/config.json
Model weights saved in ./results/checkpoint-2000/pytorch_model.bin
tokenizer config file saved in ./results/checkpoint-2000/tokenizer_config.json
Special tokens file saved in ./results/checkpoint-2000/special_tokens_map.json
Saving model checkpoint to ./results/checkpoint-2500
Configuration saved in ./results/checkpoint-2500/config.

Loading cached shuffled indices for dataset at /home/mila/g/gagnonju/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-f71e4df4d8809d6e.arrow


KeyboardInterrupt: 

In [None]:
retriever