In [1]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

from algos.annealed_smc import AnnealedSMC
from utils import build_relaxed_single_token_prior, build_suffix_likelihood

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cpu')

model = GPT2LMHeadModel.from_pretrained('gpt2').to(device, dtype=torch.float32)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

suffix = " went to the shop"
suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)

log_prior, grad_log_prior, sample_prior = build_relaxed_single_token_prior(model, tokenizer, device)
log_like, grad_log_like = build_suffix_likelihood(model, tokenizer, suffix_ids, device)

def log_target(x, sigma):
    # with torch.no_grad():
    #     return log_like(x) + log_prior(x, sigma)
    with torch.no_grad():
        return log_prior(x, sigma)

def grad_log_target(x, sigma):
    # with torch.no_grad():
    #     return grad_log_like(x) + grad_log_prior(x, sigma)
    with torch.no_grad():
        return grad_log_prior(x, sigma)

In [None]:
N = 32
d = 768 # gpt2 embedding dimension
sigma0 = 10.0
sigma_target = 0.1

def sample_initial_particles(N, d):
    return sample_prior(N, sigma0)

sampler = AnnealedSMC(
    N=N,
    x_dim=d,
    sigma_0=sigma0,
    sigma_target=sigma_target,
    alpha=0.5,
    mala_step_size=0.2,
    mala_steps=100,
    ess_min_frac=0.5,
    device=device
)

final_particles = sampler.run(init_sampler=sample_prior, log_target=log_target, grad_log_target=grad_log_target)

Running SMC:   0%|          | 0/100 [00:00<?, ?it/s]

w: tensor([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
ess at 0.1: 1.0
w: tensor([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
ess at 5.05: 1.0
w: tensor([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
ess at 7.525: 1.0
