In [1]:
import torch


In [None]:
academic_words =  [
                      "however", "therefore", "consequently",
                      "analysis", "methodology", "significant",
                      "furthermore", "evidence", "demonstrates"
                  ]
casual_words = [
            "lol", "yeah", "gonna", "wanna", "basically",
            "like", "kinda", "sorta"
            ]

banned_words = ["shit"]

In [None]:
class StyleController:
    """
    Modular linguistic/style controller for academic auto-completion.
    Works with any HuggingFace-style tokenizer & model.
    """

    def __init__(
        self,
        tokenizer,
        academic_words=academic_words,
        casual_words=casual_words,
        banned_words=banned_words,
        boost=2.0,
        penalty=4.0,
        use_prefix=True,
    ):
        self.tokenizer = tokenizer
        self.boost = boost
        self.penalty = penalty
        self.use_prefix = use_prefix

        # Default word lists
        self.academic_words = academic_words

        self.casual_words = casual_words

        self.banned_words = banned_words

        # Token IDs
        self.academic_ids = self._words_to_token_ids(self.academic_words)
        self.casual_ids = self._words_to_token_ids(self.casual_words)
        self.banned_ids = self._words_to_token_ids(self.banned_words)

        # Prefix for academic tone
        self.prefix = (
            "Write the continuation below in a formal academic style, "
            "using precise vocabulary and objective reasoning.\n"
        )

    def _words_to_token_ids(self, words):
        ids = []
        for w in words:
            tokens = self.tokenizer.tokenize(w)
            if len(tokens) == 1:
                tid = self.tokenizer.convert_tokens_to_ids(tokens[0])
                ids.append(tid)
        return ids

    def apply_prefix(self, text: str):
        if self.use_prefix:
            return self.prefix + text
        return text

    def apply_logits_control(self, logits):
        """Apply style-based logits modification."""
        # Academic boost
        logits[:, self.academic_ids] += self.boost

        # Penalize informal language
        logits[:, self.casual_ids] -= self.penalty

        # Hard ban
        logits[:, self.banned_ids] = -1e4

        return logits
