In [1]:
from dataclasses import dataclass, field
import json
from pathlib import Path
from typing import *
import warnings

from beartype import beartype
import datasets
import numpy as np
import matplotlib.pyplot as plt
import more_itertools
import queue
import rich
import torch
import torch.nn as nn
import transformers
import tqdm

In [2]:
class StupidRetriever:
    def __init__(self, model, tokenizer, device, train_vectors, train_samples, train_labels):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.train_vectors = train_vectors
        self.train_samples = train_samples
        self.train_labels = train_labels
        self.classification_ids_to_idx = {}

        for i, sample in enumerate(train_samples):
            encoded = tokenizer.encode(sample, truncation=True, padding=True, return_tensors="pt")
            self.classification_ids_to_idx[encoded] = i

    def retrieve(self, query_ids, query_index):
        query_indices_ids = self.tokenizer.encode_plus(
            self.train_samples[query_index], truncation=True, padding=True
        )
        assert torch.all(query_ids == query_ids), (query_ids, query_indices_ids)

        # Get the representation
        representation = self.train_vectors[query_index]
        with torch.inference_mode():
            # Compute the inner products
            scores = torch.matmul(representation, self.train_vectors.t())
        # Get the top 2 results, to potentially exclude the sample itself.
        topk = torch.topk(scores, k=2, dim=-1)
        topk = topk.indices.cpu().numpy()
        
        for retrieved_idx in topk:
            if retrieved_idx != query_index:
                return self.train_samples[retrieved_idx], self.train_labels[retrieved_idx], retrieved_idx
        

# build train vectors
def make_retrival_model_and_vectors(retriever_name, path_to_vectors, device, dataset_name):
    """We expect the dir to have the following structure:
    - config.json
    - train_samples.json 
    - train_vectors.npy
    """    
    # Make some checks
    config =  json.loads((path_to_vectors / "config.json").read_text())
    assert dataset_name == config["dataset_name"], (dataset_name, config["dataset_name"])
    assert retriever_name == config["retriever_name"], (retriever_name, config["retriever_name"])

    retriever_model = transformers.AutoModel.from_pretrained(retriever_name)
    retriever_tokenizer = transformers.AutoTokenizer.from_pretrained(retriever_name)

    with open(path_to_vectors / "train_samples.json") as f:
        train_samples = json.load(f)
    vectors = torch.tensor(np.load(path_to_vectors / "train_vectors.npy")).to(device)
    retriever = StupidRetriever(
        model=retriever_model, 
        tokenizer=retriever_tokenizer, 
        device=device, 
        train_vectors=vectors, 
        train_samples=train_samples["inputs"],
        train_labels=train_samples["labels"],
    )

    
    return retriever

In [3]:
@dataclass(order=True)
class PrioritizedItem:
    priority: int
    item: Any=field(compare=False)


class BoostingIterator(torch.utils.data.IterableDataset):
    def __init__(
        self, *args, dataset, retriever_client, classifier, epsilon, seed, classification_device, classification_tokenizer, retriever_device, **kwargs):
        super().__init__(*args, **kwargs)
        self.dataset = dataset.map(lambda example, idx:{"index": idx}, with_indices=True).shuffle(seed=seed)
        self.priority_queue = queue.PriorityQueue()
        self.retriever_client = retriever_client
        self.epsilon = epsilon
        self.randomizer = np.random.RandomState(seed)
        self.seed = seed
        self.dataset_iter = None
        self.classifier = classifier
        self.classification_tokenizer = classification_tokenizer
        self.classification_device = classification_device
        self.retriever_device = retriever_device
        
        # assert mode in ["epsilon_priority_no_reset", "pure_sampled", "epsilon_sampled"], mode

    def push_score(self, inputs, loss):
        for input_, mask, loss_, index in more_itertools.zip_equal(inputs["input_ids"], inputs["attention_mask"], loss, inputs["index"]):
            assert loss_.shape == torch.Size([]), loss_.shape
            self.priority_queue.put(PrioritizedItem(loss_.detach().cpu().numpy(), dict(input_ids=input_, attention_mask=mask, index=index)))

    def __len__(self):
        return len(self.dataset)

    def __iter__(self):
        rich.print("[bold green]ITER[/]")
        self.dataset = self.dataset.shuffle(seed=self.seed)
        self.dataset_iter = iter(self.dataset)
        return self
    
    def __next__(self):
        """ This is where the sampling happens.
        """

        # Test if we have a sample and if we pass the epsilon threshold
        empty = self.priority_queue.empty()
        rand = self.randomizer.rand()
        if not self.priority_queue.empty() and rand < self.epsilon:
            # pull a sample from the priority queue
            sample = self.priority_queue.get().item
            ids = sample["input_ids"]
            entry_indexes = sample["index"]

            # We retrieve the next sample.
            input_, next_label, index = self.retriever_client.retrieve(ids, entry_indexes)
            next_sample = dict(text=input_, label=next_label, index=index)
        else:
            next_sample = next(self.dataset_iter)  # We raise here if we have no more samples in the dataset

        tokenized = self.classification_tokenizer(
            next_sample["text"].strip(), 
            truncation=True, 
            padding=True,
        )

        del next_sample["text"]
        return dict(**tokenized, **next_sample)

class BoostingTrainer(transformers.Trainer):
    
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.
        Subclass and override to inject custom behavior.
        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.
                The dictionary will be unpacked before being fed to the model. 
                Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """
        model.train()
        inputs = self._prepare_inputs(inputs)
        index = inputs["index"]
        # Compute loss doesn't like this.
        del inputs["index"]

        with self.autocast_smart_context_manager():
            # Get the loss
            loss, outputs = self.compute_loss(model, inputs, return_outputs=True)

        if self.args.n_gpu > 1:
            # Mean over per gpu averages
            loss = loss.mean()  # mean() to average on multi-gpu parallel training

        # This is ignored in the priority queue computation
        if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
            # Deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
            loss = loss / self.args.gradient_accumulation_steps
        if self.do_grad_scaling:
            self.scaler.scale(loss).backward()
        elif self.use_apex:
            with torch.cuda.amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        elif self.deepspeed:
            # loss gets scaled under gradient_accumulation_steps in deepspeed
            loss = self.deepspeed.backward(loss)
        else:
            loss.backward()

        loss = loss.detach()

        # Addition for RetroBoost
        # Make sure the losses are similar, then push them to the priority queue
        # Put index back in
        inputs["index"] = index
        computed_loss = torch.nn.functional.cross_entropy(outputs.logits.detach(), inputs["labels"].detach(), reduction="mean")
        loss_per_sample = torch.nn.functional.cross_entropy(outputs.logits.detach(), inputs["labels"].detach(), reduction="none")
        assert torch.allclose(loss, computed_loss, atol=0.01
        self.get_train_dataloader().dataset.push_score(inputs, loss_per_sample)

        return loss




In [4]:
RETRIEVER_NAME = "facebook/contriever"
PATH_TO_VECTORS = Path("./vectors_imdb_contriever/")
DATASET_NAME = "imdb"
CLASSIFIER_NAME = "roberta-base"

dataset = datasets.load_dataset(DATASET_NAME)
ALL_LABELS = set(dataset["train"]["label"])
NUM_LABELS = len(ALL_LABELS)
assert ALL_LABELS == set(range(NUM_LABELS))
retriever = make_retrival_model_and_vectors(RETRIEVER_NAME, PATH_TO_VECTORS, 0, DATASET_NAME)


classifier_name = CLASSIFIER_NAME
dataset_name = DATASET_NAME
regular_trainer = False
classifier_device = 1
retriever_client = retriever
retriever_device = 2

classifier = transformers.AutoModelForSequenceClassification.from_pretrained(
    classifier_name, num_labels=NUM_LABELS
)
classifier_tokenizer = transformers.AutoTokenizer.from_pretrained(classifier_name)

def preprocess_function(examples, tokenizer):
    return tokenizer(examples["text"], truncation=True, padding=True)

def preprocess_function(examples, tokenizer):
    return tokenizer(examples["text"], truncation=True, padding=True)

dataset = datasets.load_dataset(dataset_name)
tokenized_training = dataset["train"].map(
    lambda examples: preprocess_function(examples, classifier_tokenizer), 
    batched=True
)

tokenized_validation = dataset["test"].map(
    lambda examples: preprocess_function(examples, classifier_tokenizer), 
    batched=True
)

training_args = transformers.TrainingArguments(
    output_dir="./results",
    learning_rate=1e-5,
    per_device_train_batch_size=20,
    per_device_eval_batch_size=20,
    num_train_epochs=5,
    weight_decay=0.01,
)

Reusing dataset imdb (/home/mila/g/gagnonju/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)


  0%|          | 0/3 [00:00<?, ?it/s]

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifie

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?ba/s]

  0%|          | 0/25 [00:00<?, ?ba/s]

In [5]:

if regular_trainer:
    trainer = transformers.Trainer(
        model=classifier.to(classifier_device), 
        args=training_args, 
        tokenizer=classifier_tokenizer, 
        train_dataset=tokenized_training, 
        eval_dataset=tokenized_validation,
        data_collator=transformers.DataCollatorWithPadding(
            tokenizer=classifier_tokenizer
        ),
    )
else:
    tokenized_training = BoostingIterator(
        dataset=dataset["train"], 
        retriever_client=retriever_client, 
        classifier=classifier, 
        epsilon=0.5, 
        seed=0,
        retriever_device=retriever_device, 
        classification_device=classifier_device,
        classification_tokenizer=classifier_tokenizer,
    )
    
    trainer = BoostingTrainer(
        model=classifier.to(classifier_device), 
        args=training_args, 
        tokenizer=classifier_tokenizer, 
        train_dataset=tokenized_training, 
        eval_dataset=tokenized_validation,
        data_collator=transformers.DataCollatorWithPadding(
            tokenizer=classifier_tokenizer
        ),
    )

Loading cached processed dataset at /home/mila/g/gagnonju/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-a645737e8e3ed793.arrow
Loading cached shuffled indices for dataset at /home/mila/g/gagnonju/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-42092c50c935d800.arrow


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [6]:
trainer.train()

***** Running training *****
  Num examples = 25000
  Num Epochs = 5
  Instantaneous batch size per device = 20
  Total train batch size (w. parallel, distributed & accumulation) = 60
  Gradient Accumulation steps = 1
  Total optimization steps = 6250
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjulesgm[0m (use `wandb login --relogin` to force relogin)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[34m[1mwandb[0m: wandb version 0.12.10 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Loading cached shuffled indices for dataset at /home/mila/g/gagnonju/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-4dc2830ea21e6a12.arrow


Step,Training Loss


KeyboardInterrupt: 