In [1]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import math
from torch.distributions.categorical import Categorical

from typing import List

from algos.annealed_smc import AnnealedSMC
from utils import build_relaxed_single_token_prior, build_suffix_likelihood

In [2]:
device = torch.device('cuda')

model = GPT2LMHeadModel.from_pretrained('gpt2').to(device, dtype=torch.float32)
model.eval()
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

suffix = " quick dog jumped over the lazy fox"
suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)

log_prior, grad_log_prior, sample_prior = build_relaxed_single_token_prior(model, tokenizer, device)
log_like, grad_log_like = build_suffix_likelihood(model, tokenizer, suffix_ids, device)

def log_target(x, sigma):
    with torch.no_grad():
        return log_like(x) + log_prior(x, sigma)

def grad_log_target(x, sigma):
    with torch.no_grad():
        return grad_log_like(x) + grad_log_prior(x, sigma)

In [3]:
suffix = " quick dog jumped over the lazy fox"
suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)

solution = "The"
solution_ids = tokenizer.encode(solution, add_special_tokens=False)
solution_latents = model.transformer.wte(torch.tensor(solution_ids, device=device)).detach()

log_prior, grad_log_prior, sample_prior = build_relaxed_single_token_prior(model, tokenizer, device)
log_like, grad_log_like = build_suffix_likelihood(model, tokenizer, suffix_ids, device)

def log_target(x, sigma):
    with torch.no_grad():
        return log_like(x) + log_prior(x, sigma)
        # return log_prior(x, sigma)

def grad_log_target(x, sigma):
    with torch.no_grad():
        return grad_log_like(x) + grad_log_prior(x, sigma)
        # return grad_log_prior(x, sigma)


In [None]:
N = 16
d = 768 # gpt2 embedding dimension
sigma0 = 8.0
sigma_target = 1.0

def sample_initial_particles(N, d):
    return sample_prior(N, sigma0)

def initial_logp(x):
    return log_prior(x, sigma0)

def project_to_vocab(
    z: torch.Tensor,
    model: GPT2LMHeadModel,
    tokenizer: GPT2Tokenizer,
    device: torch.device,
) -> List[str]:
    embedding_matrix = model.transformer.wte.weight.detach()
    distances = torch.cdist(z, embedding_matrix)          # (N,V)
    token_idxs = distances.argmin(dim=1)                  # (N,)
    tokens = tokenizer.convert_ids_to_tokens(token_idxs)
    return tokens

def debug_log_tokens(z, sigma):
    solution_likelihood = log_like(solution_latents)
    solution_prior = log_prior(solution_latents, sigma)
    print(f"solution: {solution_likelihood.item():.4e}, {solution_prior.item():.4e}")

    vocabs = project_to_vocab(z, model, tokenizer, device)
    likelihoods = log_like(z)
    priors = log_prior(z, sigma)
    for vocab, likelihood, prior in zip(vocabs, likelihoods, priors):
        print(f"    {vocab[:10] + ' ' * (10 - len(vocab[:10]))}: {likelihood:.4e}, {prior:.4e}")

sampler = AnnealedSMC(
    N=N,
    x_dim=d,
    sigma_0=sigma0,
    sigma_target=sigma_target,
    alpha=0.5,
    mala_step_size=lambda sigma: min(0.1, 0.5 * sigma),
    mala_steps=128,
    ess_min_frac=0.8,
    device=device
)

final_particles = sampler.run(
    init_sampler=sample_initial_particles,
    init_logp=initial_logp,
    log_target=log_target,
    grad_log_target=grad_log_target,
    debug_logger=debug_log_tokens
)

Running SMC (ESS = 16):   0%|          | 0/100 [00:00<?, ?it/s]

sigma_0: 10.0
new_index: 0, curr_index: 0
solution: -4.5873e+01, -2.4742e+03
    soDelivery: -4.8051e+01, -2.8574e+03
    ãĤ´ãĥ³    : -5.3297e+01, -2.8350e+03
    soDelivery: -4.8220e+01, -2.8556e+03
    li        : -5.1999e+01, -2.8492e+03
    soDelivery: -4.8080e+01, -2.8617e+03
    english   : -4.8220e+01, -2.8569e+03
    soDelivery: -4.9497e+01, -2.8548e+03
    english   : -4.8578e+01, -2.8634e+03
    soDelivery: -4.8697e+01, -2.8584e+03
    soDelivery: -4.7277e+01, -2.8604e+03
    soDelivery: -4.8196e+01, -2.8591e+03
    english   : -4.7442e+01, -2.8619e+03
    soDelivery: -4.7986e+01, -2.8570e+03
    soDelivery: -4.9360e+01, -2.8627e+03
    soDelivery: -4.7202e+01, -2.8559e+03
    urous     : -5.2685e+01, -2.8555e+03
sigma_1: 9.75390625
new_index: 2, curr_index: 0
solution: -4.5873e+01, -2.4551e+03
    soDelivery: -4.8146e+01, -2.8588e+03
    ãĤ´ãĥ³    : -5.1677e+01, -2.8355e+03
    soDelivery: -4.7760e+01, -2.8562e+03
    li        : -5.1847e+01, -2.8495e+03
    Sham      : -4.8