<a href="https://www.kaggle.com/code/aisuko/generating-text-with-contrastive-search?scriptVersionId=165099587" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Overview

Natural language generation is one of the core tasks in natural language processing. However, we may see the genereted text contents with lots of repeat sentences, see the end of [Supervise Fine-tuning Casual Language Model](https://www.kaggle.com/code/aisuko/fine-tuning-llama2-with-qlora/notebook). So, here we introduce **Contrastive Search** which is a decoding method to solve this. The [A Contrastive Framework for Neural Text Generation](https://arxiv.org/abs/2202.06417). Decoding methods can be broadly categorized into two categories:


# Deterministic Methods

These methods follow a **fixed set of rules** to choose the next element in the output sequence. They are **computationally efficient** and often result in **highly accurate outputs**.


#### Greedy Search

Always selects the **most likely** next element based on the current state and the model's predictions.  


#### Beam Search

Maintains a fixed number of **candiate sequences(beams)** and expands the most promising ones at each step, considering factors like predicted probability and diversity.


# Stochastic Methods

These methods introduce **randomness** into the decoding process, leading to **more diverse outputs** compared to deterministic methods. However, they might be **less predictable and accurate** in some cases.

#### Top-k Sampling

Samples the next element from a **probability distribution** predicted by the model, considering only the **top k most likely** options.


#### Nucleus Sampling

Similar to top-k sampling, but it **downweights** the probabilities of higher-ranked tokens to encourage exploration of less frequent but potentially inveresting choices.


The most suitable decoding method for the specific task depends on several factors. For tasks requiring high accuracy and fluency(e.g. machine translation), deterministic methods like beam search might be preferred. For tasks prioritizing creativity and exploration(e.g. text generation), stochastic methods like top-k sampling or nucleus sampling might be more suitable. For the computational resources, Determinstic emthods are generally faster than stochastic methods.

In [1]:
%%capture
!pip install transformers==4.36.2
!pip install peft==0.7.1
!pip install bitsandbytes==0.41.3

In [2]:
import os

os.environ['MODEL_NAME']='gpt2-large'

# Deterministic Methods

Deterministic methods, e.g. **greedy search** and **beam search**, generate text by selecting the text continuation with the highest likelihood measured by the language model. However, as widely discussed in previous studies, deterministic methods often lead to the problem of model degeneration, i.e, the gernerated text is unnatural and contains undesirable repetitions.

In [3]:
import torch
from transformers import AutoTokenizer, GPT2LMHeadModel, BitsAndBytesConfig
from peft import get_peft_model

tokenizer=AutoTokenizer.from_pretrained(os.getenv('MODEL_NAME'))
input_ids=tokenizer("Melbourne is", return_tensors="pt").input_ids

bnb_config=BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    # https://github.com/huggingface/transformers/issues/21151#issuecomment-1398534410
    # Half only works on GPU and should not be used on cpu
    
    llm_int8_enable_fp32_cpu_offload=True
)

model=GPT2LMHeadModel.from_pretrained(
    os.getenv('MODEL_NAME'), 
    quantization_config=bnb_config, 
    device_map="auto", 
    torch_dtype=torch.bfloat16
)

print(model.device)

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.25G [00:00<?, ?B/s]



generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

cuda:0


From the result below generated by greedy search, we can see obvious pattern of repetitions.

In [4]:
output=model.generate(input_ids, max_length=128)
tokenizer.decode(output[0], skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


"Melbourne is a city of contrasts. It's a city of the rich and the poor, of the rich and the famous, of the rich and the famous, of the rich and the famous, of the rich and the famous, of the rich and the famous, of the rich and the famous, of the rich and the famous, of the rich and the famous, of the rich and the famous, of the rich and the famous, of the rich and the famous, of the rich and the famous, of the rich and the famous, of the rich and the famous, of the rich and the famous, of the rich and"

# Stochastic Methods

To address the issues posed by deterministic methods, stochastic methods generate text by introducing randomness during the decoding process. Two widely-used stochastic methods are **Top-k sampling** and **nucleus sampling(also called Top-p sampling)**. And we illustrate an example of generated text by nucleus sampling(p=0.95).

In [5]:
torch.manual_seed(0.)
output=model.generate(input_ids, do_sample=True, max_length=128, top_p=0.95, top_k=0)
tokenizer.decode(output[0], skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


"Melbourne is one of these listings — by the way, its relationship with these mass murderers is starting to look very national.\n\nHere's why.\n\nTHE MASS MURDERERS IN THE EXIT\n\nEveryone's been conditioned to believe they need gun control. But how many have been saved by confiscation? The forced confiscation of arms is used in Australia, Britain and New Zealand to imprison felons, and many here already participate in the semi-automatic ban which the Australian National Firearms Agreement is a remnant of the Rhodesian killer's eulogy. That means vast numbers of guns are being abolished to produce"

The nucelus sampling can generate text free of repetitions, the semantic coherence of the generated text is not well-maintained. This semantic inconsistency problem can partially be remedied by lowering the temperature. **However, reducing the temperature berings nucleus sampling closer to greedy search**, which can be seen as a trade-off between greedy search and nucleus sampling. Generally, it is challenging to find a prompt and model-independent temperature that avoids both pitfalls of greedy search and nucleus sampling.


# Contrastive Search

Let's be simple, it is a formula and has two parts below:


### Model Confidence

It is the probability of the candidata v(top_k prediction) predicted by the language model.


### Degeneration Penalty

It measures how discriminative of v with respect to the previous context and the function s(funciton in degeneration penalty part), computes the cosine similarity between the token representations. Here are going to generate the text with contrastive search(k=4 and $alpha$=0.6). To fully demonstrate the superior capability of contrastive search, we let the language model generate a long document with 512 tokens as

In [6]:
model=GPT2LMHeadModel.from_pretrained(
    os.getenv('MODEL_NAME'),
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    pad_token_id=tokenizer.eos_token_id
)

model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1280)
    (wpe): Embedding(1024, 1280)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-35): 36 x GPT2Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Linear4bit(in_features=1280, out_features=3840, bias=True)
          (c_proj): Linear4bit(in_features=1280, out_features=1280, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Linear4bit(in_features=1280, out_features=5120, bias=True)
          (c_proj): Linear4bit(in_features=5120, out_features=1280, bias=True)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1280,), eps=1e-05, element

The arguments are as follow:

* **--top_k**: The hyperparameter k in contrastive search.
* **--penalty_alpha**: The hyperparameter $alpha$ in contrastive search.

In [7]:
output=model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=512)
tokenizer.decode(output[0], skip_special_tokens=False)

"Melbourne is a city of contrasts.\n\nIt's an old-fashioned city with a modern, cosmopolitan feel, but the suburbs of St Kilda and Fitzroy are just as vibrant and vibrant as the city itself.\n\nIn this article we take a look at what's on offer in each of Melbourne's suburbs to make your trip to the city a memorable one.\n\nMelbourne's suburbs have plenty of attractions for everyone to enjoy, whether you're looking for a day out with friends or an overnight stay in one of the city's hotels.\n\nHere's a list of the best things to do in each of Melbourne's suburbs:\n\n1. The Docklands\n\nDocklands is the quintessential 'big city' suburb in Melbourne. With over 1.5 million people living in the area, it's home to some of the city's most iconic landmarks such as the Opera House, Melbourne Cricket Ground (MCG) and the Melbourne Cricket Ground (MCG) Oval.\n\nThe city's biggest shopping centre, the Domain, is located in the heart of the suburb, and there's a wide range of shopping options withi

We can see the output has a higher quality than the generated text with deterministic and stochastic methods.

# Reference
* https://huggingface.co/blog/introducing-csearch
* https://www.kaggle.com/code/aisuko/supervise-fine-tune-casual-language-model?scriptVersionId=159156374