In [1]:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
__author__ = 'Author'
__email__ = 'Email'

# Detecting Contradiction at the Lexical Level
## Word Probability

In [2]:
# dependency
# # built-in
import os, math, random, string
# # public
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
import nltk
# nltk.download('words')
from nltk.corpus import words
from transformers import AutoModelForCausalLM, AutoTokenizer
from config import Config

%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [3]:
def get_device():
    if "DEVICE" in os.environ:
        return os.environ["DEVICE"]
    if torch.cuda.is_available():
        return "cuda"
    elif torch.xpu.is_available():
        return "xpu"
    return "cpu"

def set_random_seed(seed: int = 42):
    """Fix random seeds for reproducibility across Python, NumPy, and PyTorch."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if using multi-GPU

    # Ensures deterministic behavior where possible
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Init

In [4]:
config = Config()
for k,v in config.__dict__.items():
    print(f'{k}: {v}')

seed: 0
llm: meta-llama/Llama-3.2-3B
CURR_PATH: ./
RESOURCE_PATH: ./res
DATA_PATH: ./res/data
RESULTS_PATH: ./res/results
LLMS_PATH: ./res/llms
LLM_PATH: ./res/llms/meta-llama/Llama-3.2-3B


# Helper

In [5]:
# Global dictionary mapping model IDs (or model families) to BOW prefixes
BOW_PREFIX_MAP = {
    "meta-llama/Llama-3.2-3B": "Ġ", # LLaMA 2, 3
    "deepseek-ai": "Ġ",             # DeepSeek
    "EleutherAI/gpt-neo": "Ġ",      # GPT-Neo
    "openai-community/gpt2": "Ġ",   # GPT-2
    "facebook/opt": "Ġ",            # OPT family
    "bigscience/bloom": "▁",        # BLOOM uses SentencePiece
    "google/pegasus": "▁",          # SentencePiece
    "google-t5": "▁",               # T5 models
    "google/mt5": "▁",
    "Salesforce/codegen": "Ġ",      # CodeGen uses GPT-style
}

def get_device():
    if "DEVICE" in os.environ:
        return os.environ["DEVICE"]
    if torch.cuda.is_available():
        return "cuda"
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return "mps"
    elif hasattr(torch, "xpu") and torch.xpu.is_available():
        return "xpu"
    return "cpu"

# I/O

In [6]:
# load tsv
raw_df = pd.read_csv('res/data/wiki/capital50.tsv', sep='\t')
raw_df.head()

Unnamed: 0,wikidata_id,country,capital,source
0,Q233,Malta,Valletta,The capital of Malta is
1,Q262,Algeria,Algiers,The capital of Algeria is
2,Q889,Afghanistan,Kabul,The capital of Afghanistan is
3,Q33,Finland,Helsinki,The capital of Finland is
4,Q736,Ecuador,Quito,The capital of Ecuador is


In [7]:
prompts = raw_df['source'].tolist()

## Llama 3

In [8]:
print(config.LLM_PATH)

./res/llms/meta-llama/Llama-3.2-3B


In [9]:
# init model and tokenizer
config.device = get_device()
print(f"Using device: {config.device}")

tokenizer = AutoTokenizer.from_pretrained(
    config.LLM_PATH
    , device_map=config.device
    )

model = AutoModelForCausalLM.from_pretrained(
    config.LLM_PATH
    , device_map=config.device
    )

Using device: mps


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### Init

In [127]:
# init parameters
top_k = 10
beam_width = 5

config.bow_prefix = BOW_PREFIX_MAP.get(config.LLM_PATH, "Ġ")
config.bow_prefix_id = tokenizer.convert_tokens_to_ids(config.bow_prefix)
print(f'bow prefix: {config.bow_prefix}')
print(f'bow prefix id: {config.bow_prefix_id}')

bow prefix: Ġ
bow prefix id: 220


In [79]:
# get the max beam depth

def get_english_words():
    """
    Returns a list of English words sorted by character length (descending).
    """
    word_list = set(words.words())
    return sorted(word_list, key=len, reverse=True)


def estimate_beam_depth(vocab, tokenizer, bow_prefix):
    beam_depth = 0
    for word in vocab[:100]:
        token = tokenizer.encode(word, add_special_tokens=True)
        if len(token) > beam_depth:
            print(word, tokenizer.decode(token), len(token))
            beam_depth = max(beam_depth, len(token))
    return beam_depth


vocab = get_english_words()
print(f"Vocabulary size: {len(vocab)}")
beam_depth = estimate_beam_depth(vocab, tokenizer, config.bow_prefix)
print(f"Estimated beam depth: {beam_depth}")

Vocabulary size: 235892
pathologicopsychological <|begin_of_text|>pathologicopsychological 6
thyroparathyroidectomize <|begin_of_text|>thyroparathyroidectomize 9
formaldehydesulphoxylate <|begin_of_text|>formaldehydesulphoxylate 10
Estimated beam depth: 10


In [80]:
def lm(text_or_ids, model, tokenizer, config):
    """
    Compute the next-token probability distribution for a given input.

    Args:
        text_or_ids (str or List[int]): Input string or list of token IDs
        model: Hugging Face AutoModelForCausalLM
        tokenizer: Corresponding tokenizer
        config: Should contain `.device`

    Returns:
        probs (Tensor): Softmax probability distribution over vocabulary, shape [vocab_size]
    """
    # Decode token IDs to text if input is not a string
    if not isinstance(text_or_ids, str):
        text = tokenizer.decode(text_or_ids, skip_special_tokens=True)
    else:
        text = text_or_ids

    # Tokenize input
    x = tokenizer(text, return_tensors="pt")
    input_ids = x["input_ids"].to(config.device)
    attention_mask = x["attention_mask"].to(config.device)

    # Run model and get logits for the last token
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

    logits = outputs.logits[:, -1, :]  # Last token position
    probs = F.softmax(logits, dim=-1).squeeze()  # [vocab_size]

    return probs

def is_valid_token(token_id, bow_prefix, tokenizer) -> bool:
    """
    Returns True if the token is a valid token:
    - start with the BOW prefix
    - contains only alphanumeric characters
    """
    token = tokenizer.convert_ids_to_tokens(token_id)
    token_str = tokenizer.decode(token_id)
    return token.startswith(bow_prefix) or token_str.isalpha()

def is_bow_token(token_id, bow_prefix, tokenizer) -> bool:
    """
    Returns True if the token is a valid token:
    - start with the BOW prefix
    - contains only alphanumeric characters
    """
    return tokenizer.convert_ids_to_tokens(token_id).startswith(bow_prefix)

def get_bow_token_ids(bow_prefix, tokenizer) -> list:
    """
    Returns a list of token IDs that:
    - start with the BOW prefix (e.g., Ġ), AND
    - decode to alphabetic strings (i.e., isalpha())
    """
    vocab_size = tokenizer.vocab_size
    return [i for i in range(vocab_size) if is_bow_token(i, bow_prefix, tokenizer)]

def is_mid_token(token_id, bow_prefix, tokenizer) -> bool:
    """
    Returns True if the token is a valid continuation of a word:
    - does NOT start with the BOW prefix
    - contains only English letters
    """
    token = tokenizer.convert_ids_to_tokens(token_id)
    token_str = tokenizer.decode(token_id)
    # does NOT start with the BOW prefix
    if token.startswith(bow_prefix):
        return False
    # contains only English letters
    if not token_str.isalpha():
        return False
    return True


def get_mid_token_ids(bow_prefix, tokenizer) -> list:
    """
    Returns a list of token IDs that:
    - start with the BOW prefix (e.g., Ġ), AND
    - decode to alphabetic strings (i.e., isalpha())
    """
    vocab_size = tokenizer.vocab_size
    return [i for i in range(vocab_size) if is_mid_token(i, bow_prefix, tokenizer)]

In [81]:
# get valid token ids
bow_token_ids = get_bow_token_ids(config.bow_prefix, tokenizer)
print(f"Bow token IDs: {len(bow_token_ids)}")

mid_token_ids = get_mid_token_ids(config.bow_prefix, tokenizer)
print(f"Mid token IDs: {len(mid_token_ids)}")

valid_token_ids = list(set(bow_token_ids + mid_token_ids))
print(f"Valid token IDs: {len(valid_token_ids)}")

Bow token IDs: 57875
Mid token IDs: 41943
Valid token IDs: 99818


In [116]:
def norm_probs(probs, valid_token_ids):
    """
    Normalize probs over a precomputed set of valid token IDs.

    Args:
        probs (Tensor): Raw probability distribution over vocabulary, shape [vocab_size]
        valid_token_ids (List[int]): Token IDs that are considered valid for normalization

    Returns:
        norm_probs (Tensor): New probability distribution normalized over valid_token_ids
    """
    masked_probs = torch.zeros_like(probs)
    masked_probs[valid_token_ids] = probs[valid_token_ids]
    total = masked_probs.sum()

    return masked_probs / total

def inject_eow_prob(probs, bow_token_ids, bow_prefix_id):
    """
    Inject EOW probability into the bow_prefix_id slot by reallocating 
    the total BOW mass there, and zeroing out the original BOW tokens.
    
    Returns a new probability tensor (not in-place).
    """
    probs = probs.clone()
    bow_mass = probs[bow_token_ids].sum()
    probs[bow_token_ids] = 0.0
    probs[bow_prefix_id] = bow_mass
    return probs

In [None]:
class Beam:
    # BOW prefix ID for the model
    bow_prefix_id = None
    def __init__(self, token_ids, token_probs, input_ids, parent=None):
        """
        Args:
            token_ids (List[int]): List of token IDs generated so far
            token_probs (List[float]): List of probabilities for each token in the beam
            input_ids (List[int]): Full input IDs (prompt + generated tokens)
            parent (Optional[Beam]): Parent beam for backtracking
        """
        self.token_ids = token_ids
        self.token_probs = token_probs
        self.token_log_probs = [math.log(p) if p > 0 else float('-inf') for p in token_probs]
        self.input_ids = input_ids
        self.parent = parent

    def extend(self, next_token_id, next_token_prob):
        """Return a new Beam with one more token added."""
        return Beam(
            self.token_ids + [next_token_id],
            self.token_probs + [next_token_prob],
            self.input_ids + [next_token_id],
            parent=self
        )

    def prob(self):
        """Return product of token probabilities (pseudo-probability)."""
        return math.prod(self.token_probs) if self.token_ids else .0

    def log_prob(self):
        """Return sum of log probabilities (more stable for ranking)."""
        return sum(self.token_log_probs)

    @property
    def done(self):
        """A beam is done if the last token is a BOW token (end of word)."""
        return self.token_ids and self.token_ids[-1] == Beam.bow_prefix_id

    def path(self):
        """Return a list of beam nodes from root to this beam."""
        beam, result = self, []
        while beam:
            result.append(beam)
            beam = beam.parent
        return list(reversed(result))

    def decoded(self, tokenizer):
        """Decode the beam's token sequence using a tokenizer."""
        return tokenizer.decode(self.token_ids)

    def tokens(self, tokenizer):
        """Return a list of token strings."""
        return [tokenizer.decode([t]) for t in self.token_ids]

    def __eq__(self, other):
        return isinstance(other, Beam) and self.token_ids == other.token_ids

    def __hash__(self):
        return hash(tuple(self.token_ids))

    def __repr__(self):
        return f"Beam(tokens={self.token_ids}, prob={self.prob():.8f}, log_prob={self.log_prob():.8f})"

Beam.bow_prefix_id = config.bow_prefix_id
    


In [227]:
beam_depth = 10
beam_width = 10

In [233]:
raw_candidates = []
for p in tqdm(prompts):
    # Step 1: Initialize input and beam
    input_ids = tokenizer(p, return_tensors="pt").input_ids.tolist()[0]
    beams = [Beam([], [], input_ids)]
    # Step 2: Beam search decoding
    for depth in range(beam_depth):
        new_beams = []
        for beam in beams:
            if beam.done:
                new_beams.append(beam)
                continue
            # Step 2.1: Get next-token probability distribution
            next_probs = lm(beam.input_ids, model, tokenizer, config)
            vocab_ids = bow_token_ids if depth == 0 else valid_token_ids
            next_probs = norm_probs(next_probs, vocab_ids)
            # Step 2.2: Inject end-of-word (EOW) probability at depths > 0
            if depth:
                next_probs = inject_eow_prob(next_probs, bow_token_ids, config.bow_prefix_id)
            # Step 2.3: Top-k expansion
            topk_probs, topk_ids = torch.topk(next_probs, k=beam_width)
            for topk_id, topk_prob in zip(topk_ids.tolist(), topk_probs.tolist()):
                new_beams.append(beam.extend(topk_id, topk_prob))
                
        # Step 3: Keep top-scoring beams
        beams = sorted(new_beams, key=lambda beam: -beam.prob())[:beam_width]
        if all(b.done for b in beams):
            break
    raw_candidates.append(beams)

100%|██████████| 50/50 [00:55<00:00,  1.12s/it]


In [None]:
# post-process the beams
def post_process_beams(beams, tokenizer):
    """
    Post-process the beams to extract the best candidates.
    """
    candidates = []
    for beam in beams:
        decoded = beam.decoded(tokenizer)
        prob = beam.prob()
        tokens = tokenizer.convert_ids_to_tokens(beam.token_ids)
        candidates.append((decoded, prob, path))
    return candidates

candidates = [post_process_beams(cs, tokenizer) for cs in raw_candidates]

In [244]:
candidates[0]

[(' Valletta ',
  0.3554948042315573,
  [Beam(tokens=[], prob=0.00000000, log_prob=0.00000000),
   Beam(tokens=[4196], prob=0.36004543, log_prob=-1.02152505),
   Beam(tokens=[4196, 1169], prob=0.35597174, log_prob=-1.03290394),
   Beam(tokens=[4196, 1169, 2629], prob=0.35590202, log_prob=-1.03309982),
   Beam(tokens=[4196, 1169, 2629, 220], prob=0.35549480, log_prob=-1.03424465)]),
 (' the ',
  0.12685003320938826,
  [Beam(tokens=[], prob=0.00000000, log_prob=0.00000000),
   Beam(tokens=[279], prob=0.12690452, log_prob=-2.06432031),
   Beam(tokens=[279, 220], prob=0.12685003, log_prob=-2.06474973)]),
 (' Valetta ',
  0.09286365397251106,
  [Beam(tokens=[], prob=0.00000000, log_prob=0.00000000),
   Beam(tokens=[27713], prob=0.09439309, log_prob=-2.36028741),
   Beam(tokens=[27713, 1169], prob=0.09297046, log_prob=-2.37547347),
   Beam(tokens=[27713, 1169, 2629], prob=0.09294434, log_prob=-2.37575443),
   Beam(tokens=[27713, 1169, 2629, 220], prob=0.09286365, log_prob=-2.37662295)]),
 ('