In [1]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import math
from torch.distributions.categorical import Categorical

from typing import List

from algos.annealed_smc import AnnealedSMC
from utils import build_relaxed_single_token_prior, build_suffix_likelihood

In [2]:
device = torch.device('cuda')

model = GPT2LMHeadModel.from_pretrained('gpt2').to(device, dtype=torch.float32)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

suffix = " went to the shop"
suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)

log_prior, grad_log_prior, sample_prior = build_relaxed_single_token_prior(model, tokenizer, device)
log_like, grad_log_like = build_suffix_likelihood(model, tokenizer, suffix_ids, device)

def log_target(x, sigma):
    with torch.no_grad():
        return log_like(x) + log_prior(x, sigma)

def grad_log_target(x, sigma):
    with torch.no_grad():
        return grad_log_like(x) + grad_log_prior(x, sigma)

In [3]:
suffix = " went to the shop"
suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)

log_prior, grad_log_prior, sample_prior = build_relaxed_single_token_prior(model, tokenizer, device)
log_like, grad_log_like = build_suffix_likelihood(model, tokenizer, suffix_ids, device)

def log_target(x, sigma):
    with torch.no_grad():
        return log_like(x) + log_prior(x, sigma)
        # return log_prior(x, sigma)

def grad_log_target(x, sigma):
    with torch.no_grad():
        return grad_log_like(x) + grad_log_prior(x, sigma)
        # return grad_log_prior(x, sigma)


In [4]:
N = 16
d = 768 # gpt2 embedding dimension
sigma0 = 10.0
sigma_target = 0.1

def sample_initial_particles(N, d):
    return sample_prior(N, sigma0)

def initial_logp(x):
    return log_prior(x, sigma0)

sampler = AnnealedSMC(
    N=N,
    x_dim=d,
    sigma_0=sigma0,
    sigma_target=sigma_target,
    alpha=0.5,
    mala_step_size=0.2,
    mala_steps=100,
    ess_min_frac=0.5,
    device=device
)

final_particles = sampler.run(
    init_sampler=sample_initial_particles,
    init_logp=initial_logp,
    log_target=log_target,
    grad_log_target=grad_log_target
)

Running SMC:   0%|          | 0/100 [00:00<?, ?it/s]

TypeError: unsupported operand type(s) for -: 'NoneType' and 'float'