In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import math
from torch.distributions.categorical import Categorical

from typing import List

from algos.annealed_smc import AnnealedSMC
from utils import build_relaxed_single_token_prior, build_suffix_likelihood

In [2]:
def build_relaxed_prior_factory(
    model: GPT2LMHeadModel,
    tokenizer: GPT2Tokenizer,
    device: torch.device,
):
    """Return a callable that, for any σ, yields functional prior handles."""

    # Static objects shared across all σ
    with torch.no_grad():
        bos = torch.tensor([[tokenizer.eos_token_id]], device=device)
        logits = model(bos).logits[:, -1]                # (1,V)
        prior_probs = logits.softmax(-1).squeeze(0)      # (V,)
    E: torch.Tensor = model.transformer.wte.weight.detach()  # (V,d)
    V, d = E.shape

    # def factory():
    def logp(z: torch.Tensor, σ: float) -> torch.Tensor:
        const = -0.5 * d * math.log(2 * math.pi * σ * σ)
        inv_var = 1.0 / (σ * σ)
        diff = z.unsqueeze(1) - E.unsqueeze(0)        # (N,V,d)
        mahal = diff.square().sum(-1)                # (N,V)
        log_gauss = const - 0.5 * inv_var * mahal    # (N,V)
        log_weighted = log_gauss + prior_probs.log().unsqueeze(0)
        return torch.logsumexp(log_weighted, dim=-1) # (N,)

    def grad(z: torch.Tensor, σ: float) -> torch.Tensor:
        return torch.func.grad(lambda x: logp(x, σ).sum())(z)

    def sample(num: int, σ: float) -> torch.Tensor:
        cat = Categorical(prior_probs)
        tokens = cat.sample((num,))                  # (num,)
        base = E[tokens]                             # (num,d)
        noise = torch.randn_like(base) * σ
        return base + noise

    return logp, grad, sample
    # return factory

def build_suffix_likelihood(
    model: GPT2LMHeadModel,
    tokenizer: GPT2Tokenizer,
    # suffix: str,
    suffix_ids: List[int],
    device: torch.device,
):
    """Return (log‑likelihood, grad‑likelihood) functions depending only on *z*."""

    # suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)
    ids_tensor = torch.tensor(suffix_ids, device=device)
    bos_tensor = torch.tensor([tokenizer.bos_token_id], device=device)
    with torch.no_grad():
        suffix_embeds = model.transformer.wte(ids_tensor)  # (L,d)
        bos_embeds = model.transformer.wte(bos_tensor)    # (1,d)
    L = len(suffix_ids)

    def logp(z: torch.Tensor) -> torch.Tensor:
        N, d = z.shape
        z_seq = z.unsqueeze(1)                            # (N,1,d)
        suffix_expand = suffix_embeds.unsqueeze(0).expand(N, -1, -1)  # (N,L,d)
        bos_expand = bos_embeds.unsqueeze(0).expand(N, -1, -1)        # (N,1,d)
        inputs = torch.cat([bos_expand, z_seq, suffix_expand], dim=1)  # (N,2+L,d)
        logits = model(inputs_embeds=inputs).logits                   # (N,2+L,V)

        log_p = torch.zeros(N, device=device)
        for pos, tok_id in enumerate(suffix_ids):
            step_logits = logits[:, pos+1, :]
            log_p += torch.log_softmax(step_logits, dim=-1)[:, tok_id]
        return log_p                                      # (N,)

    def grad(z: torch.Tensor) -> torch.Tensor:
        return torch.func.grad(lambda x: logp(x).sum())(z)

    return logp, grad, suffix_ids

In [3]:
device = torch.device('cpu')

model = GPT2LMHeadModel.from_pretrained('gpt2').to(device, dtype=torch.float32)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# suffix = " went to the shop"
# suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)

# log_prior, grad_log_prior, sample_prior = build_relaxed_single_token_prior(model, tokenizer, device)
# log_like, grad_log_like = build_suffix_likelihood(model, tokenizer, suffix_ids, device)

# def log_target(x, sigma):
#     # with torch.no_grad():
#     #     return log_like(x) + log_prior(x, sigma)
#     with torch.no_grad():
#         return log_prior(x, sigma)

# def grad_log_target(x, sigma):
#     # with torch.no_grad():
#     #     return grad_log_like(x) + grad_log_prior(x, sigma)
#     with torch.no_grad():
#         return grad_log_prior(x, sigma)

In [4]:
suffix = " went to the shop"
suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)

log_prior, grad_log_prior, sample_prior = build_relaxed_prior_factory(model, tokenizer, device)
log_like, grad_log_like, suffix_ids = build_suffix_likelihood(model, tokenizer, suffix_ids, device)

def log_target(x, sigma):
    with torch.no_grad():
        return log_like(x) + log_prior(x, sigma)
        # return log_prior(x, sigma)

def grad_log_target(x, sigma):
    with torch.no_grad():
        return grad_log_like(x) + grad_log_prior(x, sigma)
        # return grad_log_prior(x, sigma)


In [None]:
N = 32
d = 768 # gpt2 embedding dimension
sigma0 = 10.0
sigma_target = 0.1

def sample_initial_particles(N, d):
    return sample_prior(N, sigma0)

sampler = AnnealedSMC(
    N=N,
    x_dim=d,
    sigma_0=sigma0,
    sigma_target=sigma_target,
    alpha=0.5,
    mala_step_size=0.2,
    mala_steps=3,
    ess_min_frac=0.5,
    device=device
)

final_particles = sampler.run(init_sampler=sample_initial_particles, log_target=log_target, grad_log_target=grad_log_target)