# 1. Retrieval Augmented In-context Learning

- Please choose "Runtime Type" = GPU in Colab for running this notebook (Runtime > Change Runtime Type > T4 GPU).
- You are free to choose to use Google Colab or Kaggle to complete this notebook.

## 1.1 Contextual Embedding

In [None]:
# Step 0. Prepare the environment
!pip install InstructorEmbedding sentence-transformers datasets scikit-learn

In [None]:
!mkdir -p data/classification
!wget -O data/classification/train.txt https://raw.githubusercontent.com/ranpox/comp3361-spring2024/main/assignments/A1/data/classification/train.txt
!wget -O data/classification/dev.txt https://raw.githubusercontent.com/ranpox/comp3361-spring2024/main/assignments/A1/data/classification/dev.txt
!wget -O data/classification/test-blind.txt https://raw.githubusercontent.com/ranpox/comp3361-spring2024/main/assignments/A1/data/classification/test-blind.txt

In [None]:
# Step 1. Declare the model & Example usage
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("hkunlp/instructor-base")
embeddings = model.encode(
    [
        "Dynamical Scalar Degree of Freedom in Horava-Lifshitz Gravity",
        "Comparison of Atmospheric Neutrino Flux Calculations at Low Energies",
        "Fermion Bags in the Massive Gross-Neveu Model",
        "QCD corrections to Associated t-tbar-H production at the Tevatron",
    ],
    prompt="Represent the Medicine sentence for clustering: ",
    show_progress_bar=True,
)

print(embeddings.shape)

In [None]:
# Step 2. Load data & Extract embeddings
# It may takes about two hours for CPU, or you can switch to GPU runtime for faster inference.
def load_data(file):
    r"""Load your custom data for training or evaluation.

    Args:
        file (str): the path of your data

    Returns:
        tuple: a tuple containing two elements:
           - embedding: The embedding of your data, should in shape (N, 768)
          - label: List of labels, should in shape (N,)
    """
    with open(file) as f:
        data = [line.strip().split('\t') for line in f]
    sentences = [line[1] for line in data]
    X = model.encode(sentences, prompt="Represent the sentence for sentiment analysis: ", show_progress_bar=True)
    y = [int(line[0]) for line in data]
    return X, y

X_train, y_train = load_data('data/classification/train.txt')
X_val, y_val = load_data('data/classification/dev.txt')

print(f"X_train shape: {X_train.shape}, y_train shape: {len(y_train)}")
print(f"X_val shape: {X_val.shape}, y_val shape: {len(y_val)}")

And train a logistic regression model

In [None]:
# Step 3: Train a logistic regression model



In [None]:
# Step 4: Evaluate the model



## 1.2 Retrieve Relevant Examples

In [None]:
import re
import json
from datasets import load_dataset
from sentence_transformers import SentenceTransformer

def load_train_data():
    """Loads the GSM8k train dataset.

    Returns:
        A list of dictionaries containing questions, cot answers, and short answers.
    """
    ds = load_dataset("gsm8k", "main", split="train")
    examples = [{"question": example["question"], "answer": example["answer"]} for example in ds]
    for example in examples:
        example["short_answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"].split("####")[1].strip())
        example["cot_answer"] = re.sub(r"\<\<.*?\>\>", "", example["answer"].split("####")[0].strip()) \
            + " So the answer is " + example["short_answer"] + "."
    return examples

def load_test_data():
    """Loads the first 30 examples of the GSM8k test dataset.

    Returns:
        A list of dictionaries containing questions and answers.
    """
    ds = load_dataset("gsm8k", "main", split="test")
    examples = [{"question": example["question"], "answer": example["answer"].split("####")[1].strip()} for example in ds]
    for example in examples:
        example["answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"])
    return examples[:30]

In [None]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
def retrieve_examples():
    """Retrieve top-20 in-context examples from GSM8K training set for each testing examples.

    Returns:
        A dictionary mapping testing questions to a list of top-20 training examples.
    """
    raise NotImplmentedError()

In [None]:
RETRIVED_EXAMPLES = retrieve_examples()

Note: The following retrieval augmented generation does not require a GPU. Please consider saving and downloading the examples you retrieve from the left file tab so that you will not be hindered by Colab GPU restrictions.

In [None]:
with open("retrieved_examples.json", "w") as fout:
    json.dump(RETRIVED_EXAMPLES, fout)

with open("retrieved_examples.json", "r") as fin:
    RETRIVED_EXAMPLES = json.load(fin)

## 1.3 Generation with Huggingface Inference API

We will use LLM by querying huggingface inference api so we do not need GPU for the following code. Please generate HF_TOKEN at [hf.co/settings/tokens](hf.co/settings/tokens) and set as environment varible.

In [None]:
from abc import ABC, abstractmethod
from typing import List, Dict, Any
import os
import json
import backoff
import evaluate
import re
import time
import requests


os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_TOKEN"] = "<YOUR TOKEN>"

from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

class LLM(object):
    def __init__(self, model_name="codellama/CodeLlama-7b-hf"):
        self.model_name = model_name
        self.api_url = f"https://api-inference.huggingface.co/models/{model_name}"
        self.headers = {"Authorization": f"Bearer {os.environ['HF_TOKEN']}"}

    @backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=60)
    def generate(self, prompts: List[str], **kwargs) -> List[str]:
        outputs = []

        def query(payload):
            response = requests.post(self.api_url, headers=self.headers, json=payload)
            if response.status_code != 200:
                raise ValueError(f"Request failed with code {response.status_code}, {response.text}")
            return response.json()

        for prompt in prompts:
            data = query(
                {
                    "inputs": prompt,
                    "parameters": { "max_new_tokens": 256, "stop": ["Question:"]},
                }
            )

            outputs.append(data[0]['generated_text'])

        return outputs

In [None]:
llm = LLM("codellama/CodeLlama-7b-hf")
print(llm.generate(["Explain the importance of low latency LLMs", "What is the capital of France?"]))

Please adapt your GSM8KCoTEvaluator for this API-based LLM. And report the performance of 8-shot chain-of-thought prompting on first 30 examples of GSM8K.

In [None]:
class Evaluator(ABC):
    def __init__(self, llm):
        self.llm = llm

    @abstractmethod
    def load_data(self):
        pass

    @abstractmethod
    def build_prompts(self):
        pass

    @abstractmethod
    def postprocess_output(self, output: str) -> str:
        pass

    def generate_completions(self, prompts: List[str], batch_size=2) -> List[str]:
        pass

    def evaluate(self, batch_size=4, n_shot=8, save_dir="outputs"):
        dataset = self.load_data()
        prompts = self.build_prompts(dataset, n_shot)
        outputs = self.generate_completions(prompts, batch_size=batch_size)

        predictions = []
        for i, (example, prompt, output) in enumerate(zip(dataset, prompts, outputs)):
            prediction = {
                "task_id": example.get("task_id", f"task_{i}"),
                "prompt": prompt,
                "completion": output,
                "prediction": self.postprocess_output(output),
            }
            predictions.append(prediction)

        # Save predictions to file
        os.makedirs(save_dir, exist_ok=True)
        prediction_save_path = os.path.join(save_dir, f"{type(self).__name__}_predictions.jsonl")
        with open(prediction_save_path, "w") as fout:
            for pred in predictions:
                fout.write(json.dumps(pred) + "\n")

        # Calculate metrics and print results
        results = self.calculate_metrics(predictions, dataset)
        print(f"Results for {type(self).__name__}: {results}")

    @abstractmethod
    def calculate_metrics(self):
        pass

GSM_EXAMPLARS = [
    {
        "question": "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?",
        "cot_answer": "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. So the answer is 6.",
        "pot_answer": "def solution():\n    \"\"\"There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\"\"\"\n    trees_initial = 15\n    trees_after = 21\n    trees_added = trees_after - trees_initial\n    result = trees_added\n    return result",
        "short_answer": "6"
    },
    {
        "question": "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?",
        "cot_answer": "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. So the answer is 5.",
        "pot_answer": "def solution():\n    \"\"\"If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\"\"\"\n    cars_initial = 3\n    cars_arrived = 2\n    total_cars = cars_initial + cars_arrived\n    result = total_cars\n    return result",
        "short_answer": "5"
    },
    {
        "question": "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?",
        "cot_answer": "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. So the answer is 39.",
        "pot_answer": "def solution():\n    \"\"\"Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?\"\"\"\n    leah_chocolates = 32\n    sister_chocolates = 42\n    total_chocolates = leah_chocolates + sister_chocolates\n    chocolates_eaten = 35\n    chocolates_left = total_chocolates - chocolates_eaten\n    result = chocolates_left\n    return result",
        "short_answer": "39"
    },
    {
        "question": "Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?",
        "cot_answer": "Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. So the answer is 8.",
        "pot_answer": "def solution():\n    \"\"\"Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?\"\"\"\n    jason_lollipops_initial = 20\n    jason_lollipops_after = 12\n    denny_lollipops = jason_lollipops_initial - jason_lollipops_after\n    result = denny_lollipops\n    return result",
        "short_answer": "8"
    },
    {
        "question": "Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?",
        "cot_answer": "Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. So the answer is 9.",
        "pot_answer": "def solution():\n    \"\"\"Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?\"\"\"\n    toys_initial = 5\n    mom_toys = 2\n    dad_toys = 2\n    total_received = mom_toys + dad_toys\n    total_toys = toys_initial + total_received\n    result = total_toys\n    return result",
        "short_answer": "9"
    },
    {
        "question": "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?",
        "cot_answer": "There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. So the answer is 29.",
        "pot_answer": "def solution():\n    \"\"\"Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?\"\"\"\n    toys_initial = 5\n    mom_toys = 2\n    dad_toys = 2\n    total_received = mom_toys + dad_toys\n    total_toys = toys_initial + total_received\n    result = total_toys\n    return result",
        "short_answer": "29"
    },
    {
        "question": "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?",
        "cot_answer": "Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. So the answer is 33.",
        "pot_answer": "def solution():\n    \"\"\"Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?\"\"\"\n    golf_balls_initial = 58\n    golf_balls_lost_tuesday = 23\n    golf_balls_lost_wednesday = 2\n    golf_balls_left = golf_balls_initial - golf_balls_lost_tuesday - golf_balls_lost_wednesday\n    result = golf_balls_left\n    return result",
        "short_answer": "33"
    },
    {
        "question": "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?",
        "cot_answer": "Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. So the answer is 8.",
        "pot_answer": "def solution():\n    \"\"\"Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\"\"\"\n    money_initial = 23\n    bagels = 5\n    bagel_cost = 3\n    money_spent = bagels * bagel_cost\n    money_left = money_initial - money_spent\n    result = money_left\n    return result",
        "short_answer": "8"
    }
]

In [None]:
class GSM8KEvaluator(Evaluator):
    def load_data(self):
        ds = load_dataset("gsm8k", "main", split="test")
        examples = [{"question": example["question"], "answer": example["answer"].split("####")[1].strip()} for example in ds]
        for example in examples:
            example["answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"])
        return examples[:30]

    def build_prompts(self, dataset, n_shot=8, demos=GSM_EXAMPLARS):
        pass

    def postprocess_output(self, output: str) -> str:
        pass

    def calculate_metrics(self, predictions, dataset):
        pass

In [None]:
class GSM8KCoTEvaluator(GSM8KEvaluator):
    def build_prompts(self, dataset, n_shot=8, demos=GSM_EXAMPLARS):
        pass

In [None]:
cot_evaluator = GSM8KCoTEvaluator(llm)
cot_evaluator.evaluate(n_shot=8)

## 1.4 Impact of Quantity on Few-shot Prompting

In [None]:
cot_evaluator.evaluate(n_shot=1)
cot_evaluator.evaluate(n_shot=2)
cot_evaluator.evaluate(n_shot=4)

## 1.5 Retrieval Augmented Few-shot Prompting

In [None]:
class GSM8KRetrievalICLEvaluator(GSM8KEvaluator):
    def build_prompts(self, dataset, n_shot=8, demos=GSM_EXAMPLARS):
        """Build prompts with RETRIVED_EXAMPLES we generated in 1.2.
        """

In [None]:
retrieval_icl_evaluator = GSM8KRetrievalICLEvaluator(llm)
retrieval_icl_evaluator.evaluate(n_shot=1)
retrieval_icl_evaluator.evaluate(n_shot=2)
retrieval_icl_evaluator.evaluate(n_shot=4)
retrieval_icl_evaluator.evaluate(n_shot=8)

In [None]:
# Rearrange the examples in ascending order of relevance.

## Discussion

- Q1.1: Regarding the contextual embedding in Section 1.1, how does sentiment classification performance compare to your word2vec results in A1? Discuss potential reasons.

- Q1.2: In Section 1.4, which discusses the impact of quantity, what trends do you notice when adding contextual examples? Do you think this trend will continue?

- Q1.3: In Section 1.5, which covers retrieval-augmented in-context learning, how does this differ from Section 1.4? Analyze the reasons.

- Q1.4: In Section 1.5, which arrangement yields better performance: in-context examples organized in descending or ascending order of relevance? Discuss the scenario.

# 2. Basic Decoding Algorithms

In [None]:
!pip install transformers
!pip install datasets
!pip install evaluate

In [None]:
"""set device and random seeds"""

######################################################
#  The following helper functions are given to you.
######################################################

from tqdm.notebook import tqdm
import torch
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'device: {device}')

def set_seed(seed=19260817):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

In [None]:
"""load datasets"""

######################################################
#  The following helper code is given to you.
######################################################

from datasets import load_dataset

dataset = load_dataset('Ximing/ROCStories')
train_data, dev_data, test_data = dataset['train'], dataset['validation'], dataset['test']

print(train_data[0])

In [None]:
"""prepare evaluation"""

######################################################
#  The following helper code is given to you.
######################################################

from evaluate import load
from transformers import RobertaForSequenceClassification, RobertaTokenizer

perplexity_scorer = load("perplexity", module_type="metric")
cola_model_name = "textattack/roberta-base-CoLA"
cola_tokenizer = RobertaTokenizer.from_pretrained(cola_model_name)
cola_model = RobertaForSequenceClassification.from_pretrained(cola_model_name).to(device)

def batchify(data, batch_size):
    assert batch_size > 0

    batch = []
    for item in data:
        # Yield next batch
        if len(batch) == batch_size:
            yield batch
            batch = []

        batch.append(item)

    # Yield last un-filled batch
    if len(batch) != 0:
        yield batch

In [None]:
"""set up evaluation metric"""

######################################################
#  The following helper code is given to you.
######################################################

def compute_perplexity(texts, model='gpt2', batch_size=8):
    score = perplexity_scorer.compute(predictions=texts, add_start_token=True, batch_size=batch_size, model_id=model)
    return score['mean_perplexity']


def compute_fluency(texts, batch_size=8):
  scores = []
  for b_texts in batchify(texts, batch_size):
    inputs = cola_tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(device)
    with torch.no_grad():
      logits = cola_model(**inputs).logits
      probs = logits.softmax(dim=-1)
      scores.extend(probs[:, 1].tolist())
  return sum(scores) / len(scores)


def compute_diversity(texts):
    unigrams, bigrams, trigrams = [], [], []
    total_words = 0
    for gen in texts:
        o = gen.split(' ')
        total_words += len(o)
        for i in range(len(o)):
            unigrams.append(o[i])
        for i in range(len(o) - 1):
            bigrams.append(o[i] + '_' + o[i + 1])
        for i in range(len(o) - 2):
            trigrams.append(o[i] + '_' + o[i + 1] + '_' + o[i + 2])
    return len(set(unigrams)) / len(unigrams), len(set(bigrams)) / len(bigrams), len(set(trigrams)) / len(trigrams)


def evaluate(generations, experiment):
  generations = [_ for _ in generations if _ != '']
  perplexity = compute_perplexity(generations)
  fluency = compute_fluency(generations)
  diversity = compute_diversity(generations)
  print(experiment)
  print(f'perplexity = {perplexity:.2f}')
  print(f'fluency = {fluency:.2f}')
  print(f'diversity = {diversity[0]:.2f}, {diversity[1]:.2f}, {diversity[2]:.2f}')
  print()

debug_sents = ["This restaurant is awesome", "My dog is cute and I love it.", "Today is sunny."]
evaluate(debug_sents, 'debugging run')

In [None]:
"""load model and tokenizer"""

######################################################
#  The following helper code is given to you.
######################################################

from transformers import GPT2LMHeadModel, GPT2Tokenizer

model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name, pad_token="<|endoftext|>")
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
model.eval()

In this section, you will implement a few basic decoding algorithms:
1. Greedy decoding
2. Vanilla sampling
3. Temperature sampling
5. Top-p sampling

We have provided a wrapper function `decode()` that takes care of batching, controlling max length, and handling the EOS token.
You will be asked to implement the core function of each method: *given the pre-softmax logits of the next token, decide what the next token is.*

**The wrapper calls the core function of each decoding algorithm, which you will implement in the subsections below.**

In [None]:
"""decode main wrapper function"""

######################################################
#  The following helper code is given to you.
######################################################

def decode(prompts, max_len, method, **kwargs):
  encodings_dict = tokenizer(prompts, return_tensors="pt", padding=True)
  input_ids = encodings_dict['input_ids'].to(device)
  attention_mask = encodings_dict['attention_mask'].to(device)

  model_kwargs = {'attention_mask': attention_mask}
  batch_size, input_seq_len = input_ids.shape

  unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)

  for step in range(max_len):
    model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
    with torch.no_grad():
      outputs = model(**model_inputs, return_dict=True, output_attentions=False, output_hidden_states=False)

    if step == 0:
      last_non_masked_idx = torch.sum(attention_mask, dim=1) - 1
      next_token_logits = outputs.logits[range(batch_size), last_non_masked_idx, :]
    else:
      next_token_logits = outputs.logits[:, -1, :]

    log_prob = F.log_softmax(next_token_logits, dim=-1)

    if method == 'greedy':
      next_tokens = greedy(next_token_logits)
    elif method == 'sample':
      next_tokens = sample(next_token_logits)
    elif method == 'temperature':
      next_tokens = temperature(next_token_logits, t=kwargs.get('t', 0.8))
    elif method == 'topk':
      next_tokens = topk(next_token_logits, k=kwargs.get('k', 20))
    elif method == 'topp':
      next_tokens = topp(next_token_logits, p=kwargs.get('p', 0.7))

    # finished sentences should have their next token be a padding token
    next_tokens = next_tokens * unfinished_sequences + tokenizer.pad_token_id * (1 - unfinished_sequences)

    input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
    model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder)

    # if eos_token was found in one sentence, set sentence to finished
    unfinished_sequences = unfinished_sequences.mul((next_tokens != tokenizer.eos_token_id).long())

    if unfinished_sequences.max() == 0:
      break

  response_ids = input_ids[:, input_seq_len:]
  response_text = [tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True) for output in response_ids]

  return response_text

In [None]:
"""debug helper code"""

######################################################
#  The following helper code is given to you.
######################################################

# For debugging, we duplicate a single prompt 10 times so that we obtain 10 generations for the same prompt
dev_prompts = [dev_data[0]['prompt']] * 10

def print_generations(prompts, generations):
  for prompt, generation in zip(prompts, generations):
    print(f'{[prompt]} ==> {[generation]}')

## 2.1 Greedy Decoding

In [None]:
def greedy(next_token_logits):
  '''
  inputs:
  - next_token_logits: Tensor(size = (B, V), dtype = float)
  outputs:
  - next_tokens: Tensor(size = (B), dtype = long)
  '''

  # TODO: compute `next_tokens` from `next_token_logits`.
  # Hint: use torch.argmax()
  next_tokens =

  return next_tokens


generations = decode(dev_prompts, max_len=20, method='greedy')
print_generations(dev_prompts, generations)

## 2.2 Vanilla Sampling and Temperature Sampling

In [None]:
def sample(next_token_logits):
  '''
  inputs:
  - next_token_logits: Tensor(size = (B, V), dtype = float)
  outputs:
  - next_tokens: Tensor(size = (B), dtype = long)
  '''

  # TODO: compute the probabilities `probs` from the logits.
  # Hint: `probs` should have size (B, V)
  probs =

  # TODO: compute `next_tokens` from `probs`.
  # Hint: use torch.multinomial()
  next_tokens =

  return next_tokens


set_seed()
generations = decode(dev_prompts, max_len=20, method='sample')
print_generations(dev_prompts, generations)

In [None]:
def temperature(next_token_logits, t):
  '''
  inputs:
  - next_token_logits: Tensor(size = (B, V), dtype = float)
  - t: float
  outputs:
  - next_tokens: Tensor(size = (B), dtype = long)
  '''

  # TODO: compute the probabilities `probs` from the logits, with temperature applied.
  probs =

  # TODO: compute `next_tokens` from `probs`.
  next_tokens =

  return next_tokens


set_seed()
generations = decode(dev_prompts, max_len=20, method='temperature', t=0.8)
print_generations(dev_prompts, generations)

## 2.3 Top-p Sampling

In [None]:
def topp(next_token_logits, p):
  '''
  inputs:
  - next_token_logits: Tensor(size = (B, V), dtype = float)
  - p: float
  outputs:
  - next_tokens: Tensor(size = (B), dtype = long)
  '''

  # TODO: Sort the logits in descending order, and compute
  # the cumulative probabilities `cum_probs` on the sorted logits
  sorted_logits, sorted_indices =
  sorted_probs =
  cum_probs =

  # Create a mask to zero out all logits not in top-p
  sorted_indices_to_remove = cum_probs > p
  sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
  sorted_indices_to_remove[:, 0] = 0
  # Restore mask to original indices
  indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)

  # Mask the logits
  next_token_logits[indices_to_remove] = float('-inf')

  # TODO: Sample from the masked logits
  probs =
  next_tokens =

  return next_tokens


set_seed()
generations = decode(dev_prompts, max_len=20, method='topp', p=0.7)
print_generations(dev_prompts, generations)

## 2.4: Evaluation

Run the following cell to obtain the evaluation results, which you should include in your writeup.
Also don't forget to answer the questions.

In [None]:
prompts = [item['prompt'] for item in test_data][:10]
GENERATIONS_PER_PROMPT = 10
MAX_LEN = 100

for experiment in ['greedy', 'sample', 'temperature', 'topp']:
  generations = []
  for prompt in tqdm(prompts):
    generations += decode([prompt] * GENERATIONS_PER_PROMPT, max_len=MAX_LEN, method=experiment)
  evaluate(generations, experiment)

## Discussion
- Q2.1: In greedy decoding, what do you observe when generating 10 times from the test prompt?

- Q2.2: In vanilla sampling, what do you observe when generating 10 times from the test prompt?

- Q2.3: In temperature sampling, play around with the value of temperature $t$. Which value of $t$ makes it equivalent to greedy decoding? Which value of $t$ makes it equivalent to vanilla sampling?

- Q2.4: In top-$p$ sampling, play around with the value of $p$. Which value of $p$ makes it equivalent to greedy decoding? Which value of $p$ makes it equivalent to vanilla sampling?

- Q2.5: Report the evaluation metrics (perplexity, fluency, diversity) of all 4 decoding methods. Which methods have the best and worst perplexity? Fluency? Diversity?