In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load the expert and amateur models and tokenizer
expert_model = AutoModelForCausalLM.from_pretrained("gpt2-large")
amateur_model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")



In [2]:
alpha = 0.1  # Masking threshold
beta = 0.5   # Contrastive weighting factor

In [3]:
def get_logits(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.logits[0, -1]  # Get logits for the last token

In [None]:
import torch.nn.functional as F

def contrastive_decoding(expert_logits, amateur_logits, alpha, beta):
    # Algorithm 2 in the original paper
    cutoff = torch.log(torch.tensor(alpha)) + expert_logits.max(dim=-1, keepdim=True).values
    
    # Compute the contrastive difference with weighted logits
    diffs = (1 + beta) * expert_logits - beta * amateur_logits
    
    # Apply masking to filter out tokens below the cutoff
    cd_logits = diffs.masked_fill(expert_logits < cutoff, float("-inf"))
    
    return cd_logits

In [5]:
prompt = "Welcome to the Data Science Institute at Vanderbilt"
inputs = tokenizer(prompt, return_tensors="pt")

# Generate logits from both expert and amateur models
expert_logits = get_logits(prompt, expert_model, tokenizer)
amateur_logits = get_logits(prompt, amateur_model, tokenizer)

# Apply Contrastive Decoding
cd_logits = contrastive_decoding(expert_logits, amateur_logits, alpha, beta)

# Sample the token with the highest contrastive logit
predicted_token_id = torch.argmax(cd_logits).item()
predicted_token = tokenizer.decode(predicted_token_id)

print("Predicted token:", predicted_token)

Predicted token:  University
