# A Hybrid Approach using PlanSearch and Contrastive Decoding

X: https://x.com/_aloobun

> Project 2: Generate the funniest joke imaginable with LLMs.
> Implement PlanSearch but for generating jokes, then use LLM as a judge to rate generated jokes and show top ones. Your pipeline should input a word or a context (like “penguins” or “nodejs locked in a VM”) and output top funniest jokes.Please note that doing LLM-as-a-judge correctly is a tricky thing as they have all sorts of biases in rating or position preferences. So the expectation is that you will read literature on it beforehand.Optional follow up: think and answer how would you test whether generated jokes are truly novel? What’s novelty in any case? What if LLMs are just memorizing? How would you test it?


**Prelude**: Whenever I think of language models, I think of something like open fields rather than closed roads. I named the file as The Stig’s AI cousin because he had one of the funniest introductions on the show, and one of them was, “His favourite flower is the potato”.

**Why did you pick the particular project?**

One of the main reasons I wanted to work on this is, if we’re not solving the creativity/diversity part, we are not allowing language models to progress. It’s interesting to see, when prompted directly, llm tend to produce derivative content (good memorization i guess).

In [1]:
##o
#!pip install transformer_lens
#!huggingface-cli login

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
The token `allow` has been saved to /teamspace/studios/this_studio/.cache/huggingface/stored_tokens
Your token has been saved to /teamspace/studios/this_studio/.cache/huggingface/token
Login successful.
The current active token is: `allow`


# Finding the 'problem' before proposing a solution

I **hypothesize** that for any given topic, the model has a default internal pathway for simple, factual statements and a distinct, separate pathway for creative outputs (lol, i was wrong and how!!)**. We can identify and map these pathways using causal tracing. 

**And sometimes there will be overlapping circuits (i can deep dive into what overlapping circuits mean, do they memorize ??) - you'd see this in the problem finding interp experiment.

It seems like we have to somehow steer llm to be creative.

In [2]:
import torch
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained(
    "meta-llama/Llama-3.2-1B-Instruct",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cuda" if torch.cuda.is_available() else "cpu"
)



Loaded pretrained model meta-llama/Llama-3.2-1B-Instruct into HookedTransformer


In [3]:
import torch
import torch.nn.functional as F
import pandas as pd
from transformer_lens import HookedTransformer, utils
from functools import partial
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
device = "cuda" if torch.cuda.is_available() else "cpu"

class CircuitFinder:
    def __init__(self, model: HookedTransformer):
        self.model = model
        self.model.eval()
        if self.model.tokenizer.pad_token is None:
            self.model.tokenizer.pad_token = self.model.tokenizer.eos_token

    def _tokenize_with_chat_template(self, text, system_prompt=None):
        if system_prompt:
            messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": text}]
        else:
            messages = [{"role": "user", "content": text}]
        # Using the model's tokenizer directly
        return self.model.tokenizer.apply_chat_template(
            messages, add_generation_prompt=False, return_tensors="pt"
        ).to(device)

    def _patch_full_activation_hook(self, activation, hook, clean_activation):
        activation[:, :, :] = clean_activation[:, :, :]
        return activation

    def _patch_head_hook(self, activation, hook, clean_head_activation, head_index):
        activation[:, :, head_index, :] = clean_head_activation
        return activation

    def _calculate_target_log_prob(self, logits, target_tokens, seq_len):
        punchline_len = target_tokens.shape[1]
        start_index = seq_len - punchline_len
        end_index = seq_len
        relevant_logits = logits[0, start_index-1:end_index-1, :]
        log_probs = torch.nn.functional.log_softmax(relevant_logits, dim=-1)
        # squeeze target_tokens to be 1D
        target_indices = target_tokens.squeeze(0)
        return log_probs[torch.arange(punchline_len), target_indices].sum()

    def run_patching_experiment(self, clean_text, corrupted_text, target_punchline_text, system_prompt=None):
        clean_tokens = self._tokenize_with_chat_template(clean_text, system_prompt)
        corrupted_tokens = self._tokenize_with_chat_template(corrupted_text, system_prompt)
        target_tokens = self.model.to_tokens(target_punchline_text, prepend_bos=False).to(device)
        
        len_clean, len_corr = clean_tokens.shape[1], corrupted_tokens.shape[1]
        max_len = max(len_clean, len_corr)
        
        if len_clean < max_len:
            padding = torch.full((1, max_len - len_clean), self.model.tokenizer.pad_token_id, device=device)
            clean_tokens = torch.cat([clean_tokens, padding], dim=1)
        if len_corr < max_len:
            padding = torch.full((1, max_len - len_corr), self.model.tokenizer.pad_token_id, device=device)
            corrupted_tokens = torch.cat([corrupted_tokens, padding], dim=1)
        
        seq_len = max_len

        # get baseline and cache
        _, clean_cache = self.model.run_with_cache(clean_tokens)
        corrupted_logits = self.model(corrupted_tokens)
        baseline_log_prob = self._calculate_target_log_prob(corrupted_logits, target_tokens, seq_len)
        print(f"  Baseline log prob for target '{target_punchline_text}': {baseline_log_prob.item():.2f}")

        # patching experiments
        results = []
        for layer in range(self.model.cfg.n_layers):
            # Patch MLP Layers
            mlp_hook_name = utils.get_act_name("post", layer, "mlp")
            clean_mlp_activation = clean_cache[mlp_hook_name]
            hook_fn = partial(self._patch_full_activation_hook, clean_activation=clean_mlp_activation)
            with self.model.hooks(fwd_hooks=[(mlp_hook_name, hook_fn)]):
                patched_logits_mlp = self.model(corrupted_tokens)
            improvement_mlp = self._calculate_target_log_prob(patched_logits_mlp, target_tokens, seq_len).item() - baseline_log_prob.item()
            results.append({"component": f"MLP_{layer}", "improvement": improvement_mlp})
            
            # patch attention heads
            for head in range(self.model.cfg.n_heads):
                attn_hook_name = utils.get_act_name("z", layer)
                clean_attn_activation = clean_cache[attn_hook_name][:, :, head, :]
                hook_fn_attn = partial(self._patch_head_hook, clean_head_activation=clean_attn_activation, head_index=head)
                with self.model.hooks(fwd_hooks=[(attn_hook_name, hook_fn_attn)]):
                    patched_logits_attn = self.model(corrupted_tokens)
                improvement_attn = self._calculate_target_log_prob(patched_logits_attn, target_tokens, seq_len).item() - baseline_log_prob.item()
                results.append({"component": f"L{layer}H{head}", "improvement": improvement_attn})

        return pd.DataFrame(results)

if __name__ == "__main__":
    
    print("Finding circuits")

    circuit_finder = CircuitFinder(model=model)

    my_system_prompt = "You are a helpful assistant and you will give funny replies like a standup comedian."
    
    context_only = "What is 2+2?"
    descriptive_completion_text = "What is 2+2? 4."
    creative_completion_text = "What is 2+2? 4, but it’s a lot more complicated when you’re trying to split the bill."
    descriptive_punchline_text = " 4."
    creative_punchline_text = " 4, but it’s a lot more complicated when you’re trying to split the bill."


    print("\nFinding the 'creative' circuit")
    df_creative_circuit_full = circuit_finder.run_patching_experiment(
        clean_text=creative_completion_text, 
        corrupted_text=descriptive_completion_text, 
        target_punchline_text=creative_punchline_text,
        system_prompt=my_system_prompt
    )
    

    print("\nFinding the 'descriptive' circuit")
    df_descriptive_circuit_full = circuit_finder.run_patching_experiment(
        clean_text=descriptive_completion_text, 
        corrupted_text=creative_completion_text, 
        target_punchline_text=descriptive_punchline_text,
        system_prompt=my_system_prompt
    )

    print("\n'Creative' Circuit")
    print("Top 10 important components overall:")
    print(df_creative_circuit_full.sort_values(by="improvement", ascending=False).head(10).to_string())
    creative_circuit_components = df_creative_circuit_full[df_creative_circuit_full.component.str.startswith('MLP')].sort_values(
        by="improvement", ascending=False
    ).head(5)['component'].tolist()
    print(f"\nMLP components: {creative_circuit_components}")

    print("\n'Descriptive' Circuit")
    print("Top 10 important components overall:")
    print(df_descriptive_circuit_full.sort_values(by="improvement", ascending=False).head(10).to_string())
    descriptive_circuit_components = df_descriptive_circuit_full[df_descriptive_circuit_full.component.str.startswith('MLP')].sort_values(
        by="improvement", ascending=False
    ).head(5)['component'].tolist()
    print(f"\nMLP components: {descriptive_circuit_components}")

Finding circuits

Finding the 'creative' circuit
  Baseline log prob for target ' 4, but it’s a lot more complicated when you’re trying to split the bill.': -447.70

Finding the 'descriptive' circuit
  Baseline log prob for target ' 4.': -39.37

'Creative' Circuit
Top 10 important components overall:
    component  improvement
0       MLP_0   225.454651
495    MLP_15   196.012558
396    MLP_12   182.134460
330    MLP_10   174.880646
264     MLP_8   164.607544
363    MLP_11   158.541931
297     MLP_9   158.233582
231     MLP_7   155.855927
429    MLP_13   124.825500
462    MLP_14   124.810913

MLP components: ['MLP_0', 'MLP_15', 'MLP_12', 'MLP_10', 'MLP_8']

'Descriptive' Circuit
Top 10 important components overall:
    component  improvement
0       MLP_0    12.455952
330    MLP_10     8.141054
264     MLP_8     7.143742
363    MLP_11     6.051708
429    MLP_13     5.829998
396    MLP_12     5.126740
297     MLP_9     5.102928
198     MLP_6     3.095867
132     MLP_4     3.022987
231  

# 'Solution' 
## PlanSearch was boring so i extended it's idea with contrastive sampling.

## Explanation

We feed a detailed prompt (as explained in Appendix [here](https://arxiv.org/pdf/2409.03733)) into the model, model processes the prompt and calculates the probability for every single word, instead of using all the observations at once, `PlanSearch` creates many small, separate branches to explore different combinations of ideas, iterates each branch and the `PlanSearch` algo takes those logits (flattens it slightly) which makes less likely words a bit more likely to be chosen.

### My Extension to PlanSearch
Here the core of my extension to PlanSearch comes from paper [Dola](https://arxiv.org/abs/2309.03883), where the core idea is to take two models with same vocab and feed same context/question and instead of predicting the next token one at a time, we look difference of logits from from bigger and smaller models respectively and accotdingly we interpolate (to steer the generation process - as said above that we need to steer). The core of the algo is to calculate this:
`final_logits = (1 - alpha) * logits1 + alpha * logits2` 

Here `alpha` is a scalar, 0 means take logits from smaller model, 1 means take logits from bigger model. `logits1` is logits from smaller model and `logits2` is logits from bigger model.

As per my **hypothesis** that: for any given topic, the model has a default internal pathway for simple, factual statements and a distinct, separate pathway for creative outputs. We can identify and map these pathways using causal tracing, this approach felt better and useful than writing detailed prompts. However, initial results are okaish, and i can work on this again to make it better (i guess).

In [4]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging
import json
import re
from typing import List, Dict, Tuple, Optional

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class ContrastiveJokeSearch:
    def __init__(self, model1: AutoModelForCausalLM, model2: AutoModelForCausalLM, tokenizer: AutoTokenizer, system_prompt: str, alpha: float):
        self.model1 = model1
        self.model2 = model2
        self.tokenizer = tokenizer
        self.system_prompt = system_prompt
        self.alpha = alpha #fixed alpha value that will be used for all generation steps
        
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
            
    def _get_next_token_logits(self, model, input_ids, past_key_values): # probabilities (logits) for the very next token.
        with torch.no_grad():
            outputs = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True, return_dict=True)
        return outputs.logits[:, -1, :], outputs.past_key_values

    # generates a sequence of text by combining the "thoughts" of two different models.
    def _contrastive_generate(self, messages: List[Dict], max_new_tokens: int = 128, temperature: float = 0.7) -> str: 
        input_ids = self.tokenizer.apply_chat_template(
            messages, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
        
        generated_ids = []
        _, past_key_values1 = self._get_next_token_logits(self.model1, input_ids, None)
        _, past_key_values2 = self._get_next_token_logits(self.model2, input_ids, None)
        current_input_ids = input_ids[:, -1:]

        for _ in range(max_new_tokens):
            #in each loop, get the next-token logits for smaller model
            logits1, past_key_values1 = self._get_next_token_logits(self.model1, current_input_ids, past_key_values=past_key_values1)
            #do the same for larger model
            logits2, past_key_values2 = self._get_next_token_logits(self.model2, current_input_ids, past_key_values=past_key_values2)

            ## this is the core of the solution here
            # it combines the logits from both models using the class's single `alpha` value.
            final_logits = (1.0 - self.alpha) * logits1 + self.alpha * logits2

            if temperature > 0:
                next_token_id = torch.multinomial(torch.softmax(final_logits / temperature, dim=-1), num_samples=1)
            else:
                next_token_id = torch.argmax(final_logits, dim=-1, keepdim=True)
            
            if next_token_id.item() == self.tokenizer.eos_token_id: break
            generated_ids.append(next_token_id.item())
            current_input_ids = next_token_id

        return self.tokenizer.decode(generated_ids, skip_special_tokens=True)

    def _parse_list_from_response(self, response_text: str) -> List[str]:
        items = response_text.strip().split('\n')
        cleaned_items = [re.sub(r'^\s*\d+[\.\)]\s*|\s*-\s*', '', item).strip() for item in items]
        return [item for item in cleaned_items if item]

    #first step of the PlanSearch
    def generate_observations(self, topic: str, num_observations: int = 5) -> List[str]:
        prompt = f"""
        Analyze the topic: "{topic}". Generate a numbered list of {num_observations} humorous observations.
        Each observation should be on a new line.
        """
        messages = [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt}]
        
        logging.info(f"Generating observations with contrastive decoding (alpha={self.alpha})...")
        response_text = self._contrastive_generate(messages)
        return self._parse_list_from_response(response_text)

        #second step of the PlanSearch
    def generate_comedic_angles(self, topic: str, observations: List[str], num_angles: int = 3) -> List[str]:
        prompt = f"""
        Given these observations: {json.dumps(observations)}.
        Generate a numbered list of {num_angles} unique comedic angles for the topic "{topic}".
        Each angle should be on a new line.
        """
        messages = [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt}]
        
        logging.info(f"Generating comedic angles with contrastive decoding (alpha={self.alpha})...")
        response_text = self._contrastive_generate(messages)
        return self._parse_list_from_response(response_text)

    #final step of PlanSeach
    def write_final_joke(self, topic: str, angle: str) -> str:
        prompt = f"""
        Write a single, concise joke for the topic "{topic}" based on this angle: "{angle}".
        """
        messages = [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt}]
        
        logging.info(f"Writing final joke with contrastive decoding (alpha={self.alpha}) for angle: '{angle}'")
        return self._contrastive_generate(messages, max_new_tokens=64)

    def run_full_pipeline(self, topic: str, max_results: int = 5) -> List[Dict[str, str]]:
        logging.info(f"Starting Contrastive Pipeline for topic: '{topic}' (Strict max results: {max_results}) ---")
        
        observations = self.generate_observations(topic, num_observations=5)
        if not observations:
            return []
        
        angles = self.generate_comedic_angles(topic, observations, num_angles=max_results + 2)
        if not angles:
            return []
        
        angles_to_use = angles[:max_results]
        logging.info(f"Model generated {len(angles)} angles, but we will use the top {len(angles_to_use)} to meet the max_results limit.")
        
        results = [{"angle": angle, "joke": self.write_final_joke(topic, angle)} for angle in angles_to_use]
        return [res for res in results if res["joke"]]


if __name__ == "__main__":
    MODEL1_ID = "meta-llama/Llama-3.2-1B-Instruct" 
    MODEL2_ID = "meta-llama/Llama-3.2-3B-Instruct"

    logging.info(f"Loading smaller model: {MODEL1_ID}")
    model1 = AutoModelForCausalLM.from_pretrained(MODEL1_ID, torch_dtype=torch.bfloat16).to(DEVICE).eval()
    
    logging.info(f"Loading bigger model: {MODEL2_ID}")
    model2 = AutoModelForCausalLM.from_pretrained(MODEL2_ID, torch_dtype=torch.bfloat16).to(DEVICE).eval()
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL1_ID)

    SYSTEM_PROMPT = "You are a witty and creative comedian."
    JOKE_TOPIC = "the gym"

    ALPHA_SETTING = 1.2 # 0 is for smaller model, 1 for larger, interestgint to see what comes of when >1 is used
    MAX_JOKES_TO_GENERATE = 5

    pipeline = ContrastiveJokeSearch(
        model1=model1, 
        model2=model2, 
        tokenizer=tokenizer, 
        system_prompt=SYSTEM_PROMPT,
        alpha=ALPHA_SETTING
    )
    
    generated_jokes = pipeline.run_full_pipeline(topic=JOKE_TOPIC, max_results=MAX_JOKES_TO_GENERATE)

    print(f"\n\n--- Contrastive Generated Jokes (Using Fixed Alpha = {ALPHA_SETTING}, Count = {MAX_JOKES_TO_GENERATE}) ---")
    if generated_jokes:
        for i, result in enumerate(generated_jokes):
            print(f"\nJoke #{i+1}\n  Angle: {result['angle']}\n  Joke: {result['joke']}")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



--- Contrastive Generated Jokes (Using Fixed Alpha = 1.2, Count = 5) ---

Joke #1
  Angle: "Gym Profiling: Why they know exactly what brand of socks you wear and what type of cardio machine you're too intimidated to use."
  Joke: They must have a'sole'- search algorithm to figure out why you're still on the Treadmill of Shame.

Joke #2
  Angle: "Bulkedup bullies: The gym is full of selfproclaimed 'tough guys' who can bench press a small car, but still manage to ask where the 'long' toilet seat is."
  Joke: Why did the gym-goer who could bench 500 pounds walk into a bathroom with a long toilet seat? Because even the strongest guys need to answer nature's biggest question.

Joke #3
  Angle: "Gym Etiquette 101: Don't even think about pulling out that smartphone in the squat rack – unless you're documenting your own midworkout existential crisis."
  Joke: Here's a joke based on the angle:

"I was trying to squat in peace, but then I saw someone live-streaming their inner turmoil on Insta

### Why did you pick the particular project?

One of the main reasons I wanted to work on this is, if we’re not solving the creativity/diversity part, we are not allowing language models to progress. It’s interesting to see, when prompted directly, llm tend to produce derivative content (good memorization i guess).

### What I Learned From This Project

The approach I've used(not perfect imo) isn't just a fancy way to generate textmusing glorified prompts. It's a direct method for changing the model's decision making at the most fundamental level which is the logits. Something like "creativity" as a mathematical direction in the latent space. My takeaway is that both the small model and the large model have likely learned many of the same fundamental circuits for basic grammar, concepts, and common sense, mainly becuase of the polysemantic neurons capture and compress all different type of features leading to more generic outputs.

### What surprised you the most?

There's a joke in here somewhere! 

### If you had more compute, what would you have done?
Right now, our alpha parameter amplifies everything that the larger model does differently. With more time or compute probably i would train a suite of linear probes. Each probe would be a simple classifier trained on the model's internal activations, having said that, i realise after writing that probe only shows a correlation so may be a natural progression from there would be to prove causation. Still open-ended but it would be an interesting thing to do.

### If you had to write a paper on the project, what else needs to be done?
the first step would be to prove our jokes are actually funny, which means moving beyond just my own biased opinion (human vs llm as a judge may be?? but it has itws own pros and cons w.r.t internal biases) and to make sure we're not just creating a more elaborate parrot, i'd need to provide mechanistic evidence, using causal tracing to show our jokes rely less on simple pattern-matching circuits and more on... well, that's it for now.

This was fun, thank you for making me think around diversity/creativity issue with llm. GLHF.