The first steps of our code is to install the necessary dependencies for our code, mount our drive, and set our device to GPU if available, else CPU.

In [None]:
%pip install faiss-gpu
%pip install datasets

In [2]:
from transformers import pipeline, set_seed, GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer, AutoModel, GPT2LMHeadModel
from datasets import load_dataset
import faiss, torch
import numpy as np
from tqdm import tqdm

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# GPT2


We now load in our GPT model and tokenizer. We further define two fucntions below:

1. `get_logits`: which given a model, a tokenizer, and a set of text, will generate the models logits for next token prediction.
2. `get_probs`: given a set of logits, it will return the corresponding probability distribution across the logits by taking the softmax.

In [None]:
LM = GPT2LMHeadModel.from_pretrained("gpt2").eval()
LM_tokenizer = GPT2Tokenizer.from_pretrained("gpt2", padding_side = "left")
if LM_tokenizer.pad_token is None:
  LM_tokenizer.pad_token = LM_tokenizer.eos_token

In [6]:
# Returns the logits for a given model
def get_logits(text):
  input_ids = LM_tokenizer(text, return_tensors="pt").input_ids

  gen_tokens = LM.generate(
      input_ids,
      do_sample=True,
      temperature=0.9,
      max_length=100,
      return_dict_in_generate=True,
      output_logits=True,
      pad_token_id=LM_tokenizer.pad_token_id
  ).to_tuple()

  return gen_tokens[1]

In [7]:
# Returns the probability of every vocab word at each of the sequence positions
def get_probs(logits_tuple):
  out_pad = logits_tuple[0]
  out_pad_probs = out_pad.softmax(dim = -1)
  return out_pad_probs

# Contriever



We now load in our retriever model and tokenizer. As suggested in REPLUG, we use Facebook's Contriever model. This is a dense information retrieval model, and returns embeddings for given textual inputs

In [None]:
retriever_tokenizer = AutoTokenizer.from_pretrained('facebook/contriever')
retriever = AutoModel.from_pretrained('facebook/contriever').to(device)

In [9]:
retriever.save_pretrained("/content/drive/MyDrive/retriever")

# Data (C4)

We also load in our dataset. While the original REPLUG paper uses The Pile dataset, this has since been taken down due to copyright issues. As such, we utilized AllenAI's C4 dataset. Furthermore, the original paper uses a datastore of 36M documents of 128 tokens, and a train dataset of 800K sequences of length 256 tokens. In comparison, we use 1000 sequences for our datastore and 500 sequences for training. This was due to memory constraints, we could not have larger datasets.

In [None]:
# Generate document dataset. In paper, 36M sequences of length 128 tokens.
en = load_dataset("allenai/c4", data_files="en/c4-train.0000[1]-of-01024.json.gz")["train"].select(range(1000))

Below, we define two functions:
1. `proc_sample`: Takes in a sample text, and retruns the retriever embeddings for that given piece of text.
2. `make_ds`: Takes in a dataset, and makes a FAISS datastore.

In [None]:
max_token_length = 128
def proc_sample(sample):
    out = retriever_tokenizer(sample["text"], return_tensors="pt", truncation=True, max_length=max_token_length, padding="max_length")
    input_ids = out.input_ids.to(device)
    attention_mask = out.attention_mask.to(device)
    a = (retriever(input_ids, attention_mask=attention_mask).last_hidden_state)
    a[attention_mask==0] = 0.

    return {'embeddings':a.mean(dim=-2)}

def make_ds(en):
  data_with_embeddings = en.map(proc_sample, batched=True, batch_size=75)
  data_with_embeddings.set_format("pt", columns=["embeddings"], output_all_columns=True)
  data_with_embeddings.add_faiss_index(column='embeddings', device=0)
  return data_with_embeddings

data_with_embeddings = make_ds(en)

The below function, `get_top_responses`, takes in a set of queries, a datastore, a model, that model's tokenizer, and optionally a max token length and k value. With this, it returns the top-k retrieved pieces of context from the datastore, along with the similarity scores associated.

In [28]:
# Get query embedding
def get_top_responses(queries, datastore, model, tokenizer, max_token_length=128, k=20):
  question_toks = tokenizer(queries, return_tensors="pt", truncation=True, max_length=max_token_length, padding="max_length")
  input_ids = question_toks.input_ids.to(device)
  attention_mask = question_toks.attention_mask.to(device)
  a = (model(input_ids, attention_mask=attention_mask).last_hidden_state)
  a[attention_mask==0] = 0.
  a = a.mean(dim=-2).detach().cpu().numpy()
  scores, retrieved_examples = datastore.get_nearest_examples_batch('embeddings', a, k=k)
  return retrieved_examples, scores

Below, we load in our training data.

In [None]:
train_test = load_dataset("allenai/c4", data_files="en/c4-train.0010[1]-of-01024.json.gz")["train"]
train_seqs = train_test.select(range(5000))

In [14]:
test_seqs = train_test.select(range(50000, 55000))

# Putting Together

In this section, we define three functions that link together much of what we have defined above:

1. `prob_gen`: Given a piece of context, a piece of text, a model and its tokenizer, returns the probabilities for LM's next token generation. Input to LM is context concatenated with the text itself.
2. `get_LLM_outputs`: Given a query, datastore and k value, will return the ensemble probabilities across all top-k contexts appended to the query.
3. `generate_text`: Given probabilities and a model tokenizer, returns the next token for generation.

In [15]:
def prob_gen(context, text):
  out = LM_tokenizer(context, truncation=True, max_length=128, return_tensors='pt').input_ids.to(device)
  t = LM_tokenizer(text, truncation=True, max_length=256, return_tensors="pt").input_ids.to(device)
  tot = torch.cat((out, t), dim=-1)
  with torch.no_grad():
    outputs = LM(tot)
    logits = outputs.logits[0, len(out) + 1:]
  log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
  return torch.exp(-log_probs.gather(-1, t).mean()).item()

In [16]:
# Set K to be the amount of repsonses to include
def get_LLM_outputs(query, ds, k = 1):
  top_responses, scores = get_top_responses(query, ds, k = k)

  concat = lambda context: "Context: " + context + "Query: " + query

  llm_logits = get_logits(concat(top_responses[0]))
  llm_probs = get_probs(llm_logits)

  for response in top_responses[1:]:
    llm_logit = get_logits(concat(response))
    llm_prob = get_probs(llm_logit)
    llm_probs += llm_prob

  llm_probs /= k
  return llm_probs, scores

In [17]:
def generate_text(probs):
  text_indices = torch.argmax(probs, dim = 1).item()
  out_text = LM_tokenizer.convert_ids_to_tokens(text_indices)
  if out_text[0] == 'Ġ':
    out_text = out_text[1:]
  return out_text

# Fine-Tuning

Below we provide the code for fine-tuning. Sadly, we were unable to get past the memory constraints of Colab (without buying Colab pro), but through very small-scale tests, we believe our code to be correct. Individuals with enough memory on their Colab should be able to run this example with no issues.

We define the following functions to do so:

1. `loss`: Given a score computed in our `compute_score` function and LM generation probabilites, returns the loss as defined in the original paper.
2. `compute_score`: Given a query and doc, will return the "score" denoting their similarity. In the case of our paper, this is just the cosine similarity of the embeddings.
3. `train`: This possesses the main training loop of our fine-tuning. Our function takes in our model and tokenizer, as well as retriever and tokenizer and fine-tunes our retriever.


*Note: We explicitly hardcode in the hyperparameters specified in the paper. If you would like to change these, please do so in the `loss` and `train` functions.*

In [18]:
def loss(scores, probs, gamma=0.8, beta=0.8):
    num_contexts = len(probs)
    P = torch.nn.functional.softmax(scores/gamma, dim=-1)  # retrieval likelihood
    Q = torch.nn.functional.softmax(probs/beta, dim=-1)  # LM likelihood
    return torch.mean(Q * (torch.log(Q) - torch.log(P)))

In [19]:
def compute_score(query, doc):
  out = retriever(retriever_tokenizer(query, return_tensors="pt", truncation=True, max_length=256).input_ids.to(device)).last_hidden_state[0].mean(dim=-2)
  doc_out = retriever(retriever_tokenizer(doc, return_tensors="pt", truncation=True, max_length=128).input_ids.to(device)).last_hidden_state[0].mean(dim=-2)
  return torch.nn.functional.cosine_similarity(out,doc_out, dim=-1)

In [20]:
def train():
  retriever.train()
  num_epochs = 20
  recomp_epoch = 3000
  batch_size = 64
  optimizer = torch.optim.SGD(retriever.parameters(), lr = 2e-5)
  warmup_ratio = 0.1
  scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, 1./(warmup_ratio*num_epochs), 1.0)
  # do warmup ratio
  train_size = len(train_seqs)
  losses = []
  for epoch in range(num_epochs):
    for i in range(train_size//batch_size - 1):
      i = np.random.randint(0, train_size//batch_size)
      batch_size = len(range(i, max(i+batch_size, train_size)))
      batch = train_seqs.select(range(i, max(i+batch_size, train_size)))["text"]
      optimizer.zero_grad()
      batch_loss = torch.tensor(0.).to(device)
      for j in range(batch_size):
        examples, scores = get_top_responses(batch[j], data_with_embeddings)
        examples = examples[0]
        txt = batch[j]
        scores = torch.zeros(20)
        llm_probs = torch.zeros_like(scores)
        for j, doc in enumerate(examples["text"]):
          document = doc
          context = f"Context: {document} "
          llm_probs[j] = prob_gen(context, "<|endoftext|> "+txt)
          scores[j] = compute_score(txt, document)
        batch_loss += loss(scores, llm_probs)
      batch_loss.backward()
      scheduler.step()
      losses.append(batch_loss.item())
    if epoch % 10 == 0:
      print(losses[-1])
    if epoch % recomp_epoch == 0 and epoch:
      retriever.save_pretrained(f"/content/drive/MyDrive/retriever_{epoch}")
      ds = make_ds(ds)

In [21]:
# LM.to(device)
# train()

# Evaluation

Below, we provide the code to evaluate the performance of our models. We provide four functions to aid us:

1. `eval_batch`: We take in context and text and return the LM next token generation logits for the context appended to our text.
2. `eval_gpt`: We take text and return the LM next token generation logits for the given text.
3. `bpb_gpt_only`: Takes in a set of sequences, and calculates the bits per byte for GPT.
4. `bpb_replug`:  Takes in a set of sequences, and calculates the bits per byte for GPT + REPLUG (evaluating with whichever retriever is loaded at a given moment).

In [22]:
torch.cuda.empty_cache()

In [23]:
def eval_batch(context, text):
  out = LM_tokenizer(context, truncation=True, max_length=128, return_tensors='pt').input_ids.to(device)
  t = LM_tokenizer(text, return_tensors="pt").input_ids.to(device)
  tot = torch.cat((out, t), dim=-1)
  with torch.no_grad():
    outputs = LM(tot)
    logits = outputs.logits[0, len(out[0] + 2):]
  return logits

In [24]:
def eval_gpt(text):
  t = LM_tokenizer(text, return_tensors="pt").input_ids.to(device)
  with torch.no_grad():
    outputs = LM(t, labels=t).loss
  return outputs.item(), t.shape[-1]

In [25]:
def bpb_replug(seqs, ds = None):
  batch_size = 16
  with torch.no_grad():
    a = torch.tensor(0., device=device)
    cnt=0
    num_toks = 0
    num_utf_bytes = 0
    for i in tqdm(range(0, 5000, batch_size)):
      if i+batch_size > 500:
        break
      inp = seqs.select(range(i,i+batch_size))
      examples, scores = get_top_responses(inp["text"], data_with_embeddings, retriever, retriever_tokenizer)
      scores = torch.nn.functional.softmax(torch.tensor(scores, device=device), dim=-1)
      text = inp['text']
      for j,ex in enumerate(examples):
        txt = text[j]
        txt = " ".join(txt.split(" ")[:256])
        logits = eval_batch(ex["text"][0]+ " ", "<|endoftext|>  "+txt) *scores[j][0]
        input_toks = LM_tokenizer("<|endoftext|>  "+txt, return_tensors="pt").input_ids.to(device)
        for k, doc in enumerate(ex["text"][1:]):
          logits += eval_batch(doc+ " ", "<|endoftext|>  "+txt) *scores[j][k]
        shift_toks = input_toks[:, 1:].squeeze(0)
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)[:-1]
        a+= torch.nn.functional.nll_loss(log_probs, shift_toks, reduction="mean").item()
        num_toks += shift_toks.shape[-1]
        num_utf_bytes += len(txt.encode('utf-8'))
        cnt+= 1
    return num_toks/num_utf_bytes * a * np.log2(np.e)/cnt, num_toks/num_utf_bytes


In [26]:
def bpb_only_gpt(seqs, ds = None):
  batch_size = 16
  with torch.no_grad():
    a = torch.tensor(0., device=device)
    cnt=0
    num_toks = num_utf_bytes = 0
    for i in tqdm(range(0, 5000, batch_size)):
      if i+batch_size > len(seqs): break
      inp = seqs.select(range(i,i+batch_size))
      text = inp['text']
      for j in range(batch_size):
        txt = text[j]
        txt = " ".join(txt.split(" ")[:256])
        l, s = eval_gpt(txt)
        a+=l
        num_toks += s
        num_utf_bytes += len(txt.encode('utf-8'))
        cnt+=1

    return num_toks/num_utf_bytes * a * np.log2(np.e)/cnt, num_toks/num_utf_bytes

Below we actually evaluate our model's performances.

In [32]:
LM.to(device)

print("Just GPT")
final_bpb_gpt = bpb_only_gpt(test_seqs)
print("Final BPB: ", final_bpb_gpt[0].item())

print("\nRetriever")
retriever = AutoModel.from_pretrained("/content/drive/MyDrive/retriever").to(device) # REPLACE THIS PATH WITH THE RETRIEVER YOU WANT TO EVALUATE
final_bpb_replug = bpb_replug(test_seqs)
print("\nFinal BPB: ", final_bpb_replug[0].item())

Just GPT


100%|█████████▉| 312/313 [02:16<00:00,  2.28it/s]


Final BPB:  1.171903371810913
Retriever


 10%|▉         | 31/313 [05:02<45:51,  9.76s/it]


Final BPB:  1.1895300149917603



