# Import

In [None]:
# PyTorch
import torch
from torch import Tensor, no_grad, hstack
device = torch.device("cuda:0")

# Transformers
import transformers
from transformers import TextStreamer
from transformers.modeling_outputs import SequenceClassifierOutput

# Other Modules
from importlib import reload
from typing import List, Union
import time
from functools import cache

# Reloading CBF-LLM Modules
import torch_utils
import language_constraint_functions
import filter
import cbf_filter
import blacklist_filter
import just_topk_filter
import normalizers
import token_predictors
for module in [torch_utils, language_constraint_functions, filter, cbf_filter, blacklist_filter, just_topk_filter, normalizers, token_predictors]:
    reload(module)

In [None]:
# CBF-LLM Classes
from just_topk_filter import JustTopkFilter
from blacklist_filter import BlacklistFilter
from cbf_filter import CBFFilter
from filter import Filter, FilterResult
from language_constraint_functions import LanguageCF
from normalizers import Normalizer, MinJSDNormalizer, Min2NormNormalizer
from torch_utils import *
from token_predictors import distributionify

# Utilities

In [None]:
# 実行時間の計測
_tic_time = time.time()
def tic():
    global _tic_time
    _tic_time = time.time()
def toc(print_time: bool = True) -> Union[None, float]:
    t = time.time() - _tic_time
    if print_time:
        print(f"{t:.04f} 秒")
    else:
        return t

# Models

制約言語関数$h:\mathcal X \to \mathbb R$を作る．

In [None]:
# HuggingFace cardiffnlp/twitter-roberta-base-sentiment-latest
# https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment-latest
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# 当該モデルをダウンロードし，ダウンロード先のパスを指定してください．
name = input("Name for cardiffnlp/twitter-roberta-base-sentiment-latest?")


def _mapper(output: SequenceClassifierOutput) -> List[float]:
    """
    元の感情推定RoBERTaモデルからhの値を計算する．
    """
    scores = torch.softmax(output[0], dim=1)
    negatives, neutrals, positives = scores.T
    h_list = positives - torch.max(negatives, neutrals)
    return tolist(h_list)


lcf = LanguageCF(
    model=AutoModelForSequenceClassification.from_pretrained(name),
    tokenizer=AutoTokenizer.from_pretrained(name),
    mapper=_mapper,
    name="cardiffnlp/twitter-roberta-base-sentiment-latest"
)

LLMを起動する．

In [None]:
# HuggingFace meta-llama/Meta-Llama-3-8B
# https://huggingface.co/meta-llama/Meta-Llama-3-8B
from transformers import LlamaForCausalLM, AutoTokenizer
name = input("Name for meta-llama/Meta-Llama-3-8B?")

Gm = LlamaForCausalLM.from_pretrained(name, torch_dtype=torch.float16).to(device)
Gt = AutoTokenizer.from_pretrained(name)

In [None]:
streamer = TextStreamer(Gt)
vocab = Gt.get_vocab()
ivocab = {v: Gt.decode([v]).replace("\n", "") for k, v in vocab.items()}
VOCAB_SIZE = len(vocab)


@cache
def Gtokenize(xstr: str) -> Tensor:
    Ginputs = Gt(xstr, return_tensors="pt", add_special_tokens=False).to(device)
    x = Ginputs.input_ids[0]
    return x

# Models Check

In [None]:
test_texts = [
    "You have a good place",
    "I must be ill!"
]
# 効率化のため，`LanguageCF`には，テキストに対する値をキャッシュする機能があります．
print("#Cache:", len(lcf.cache))
tic()
h_list = lcf.get_for_texts(test_texts)
toc()
print("#Cache:", len(lcf.cache))
print(h_list)

# Generation Methods

In [None]:
@no_grad()
def generate(
        x0: Tensor,
        max_new_tokens: int,
        temperature: float,
        filter: Filter,
        normalizer: Normalizer,
        stream: bool = True
):
    R = {"disallowed_tokens_history": [], "clf_mapping_history": []}
    x = x0.clone()

    for _ in range(max_new_tokens):
        output = Gm(x[None])
        logit = output.logits[0][-1]
        P = distributionify(logit, temperature=temperature)
        filter_result = filter.scan(x, P)
        R["disallowed_tokens_history"].append(filter_result.disallowed)
        R["clf_mapping_history"].append(filter_result.clf_mapping)
        Q = normalizer(P, filter_result.allowed)
        iast = Q.multinomial(num_samples=1)
        if iast == Gt.eos_token_id:
            break
        x = hstack((x, iast))
        if stream:
            streamer.put(iast)

    if stream:
        streamer.end()

    R["xf"] = x
    return R

# Generation

In [None]:
x0str = "Everyone says you will be a good researcher in the future, but"
x0 = Gtokenize(x0str)
h0 = lcf.get_for_text(x0str)
print(f"{x0=}")
print(f"{h0=}")

TOPK = 30
TEMPERATURE = 1
normalizer = MinJSDNormalizer()
MAX_NEW_TOKENS = 30

In [None]:
# NoControl(Llama 3 Output)
R = generate(
    x0=x0,
    max_new_tokens=MAX_NEW_TOKENS,
    temperature=TEMPERATURE,
    # `JustTopkFilter`を指定することでTop-Kのみ行う，すなわち，トークンの取捨選択を行わないフィルタとなる．
    filter=JustTopkFilter(
        top_k=TOPK
    ),
    normalizer=normalizer
)

In [None]:
# CBF(alpha=0.3)
R = generate(
    x0=x0,
    max_new_tokens=MAX_NEW_TOKENS,
    temperature=TEMPERATURE,
    # `CBFFilter`を指定することでCBFフィルタを使用できる．
    filter=CBFFilter(
        top_k=TOPK,
        alpha=0.3,
        tokenizer=Gt,
        lcf=lcf  # CBFフィルタは制約言語関数の機能を使って駆動する．
    ),
    normalizer=normalizer
)