In [37]:
import argparse
import ast
import sys
import json
import re
import os
from from_root import from_root
from transformers import AutoTokenizer, AutoModelForMaskedLM

import pandas as pd

# Ensure we can import from src/
sys.path.insert(0, str(from_root("src")))

from read_and_write_docs import read_jsonl, read_rds, read_txt
from model_loading import load_model, distinct_special_chars
from utils import apply_temp_doc_id, build_metadata_df
from n_gram_functions import (
    common_ngrams,
    filter_ngrams,
    pretty_print_common_ngrams,
    get_scored_df,
    get_scored_df_no_context
)

In [2]:
base_read_loc = "/Volumes/BCross/datasets/author_verification"
data_type = "test"
corpus = "Wiki"

known_loc = f"{base_read_loc}/{data_type}/{corpus}/known_raw.jsonl"
unknown_loc = f"{base_read_loc}/{data_type}/{corpus}/unknown_raw.jsonl"

metadata_loc = f"{base_read_loc}/{data_type}/metadata.rds"

print("Loading data")
known = read_jsonl(known_loc)
known = apply_temp_doc_id(known)
    
unknown = read_jsonl(unknown_loc)
unknown = apply_temp_doc_id(unknown)

# NOTE - Is this used?
metadata = read_rds(metadata_loc)
filtered_metadata = metadata[metadata['corpus'] == corpus]
agg_metadata = build_metadata_df(filtered_metadata, known, unknown)
    
print("Data loaded")

Loading data
Data loaded


In [None]:
model_loc = "/Volumes/BCross/models/ModernBERT/ModernBERT-base"

print("Loading model")

tokenizer = AutoTokenizer.from_pretrained(model_loc)
model = AutoModelForMaskedLM.from_pretrained(model_loc)
special_tokens = distinct_special_chars(tokenizer=tokenizer)

Loading model


In [12]:
known_doc_list = (
    read_txt("/Volumes/BCross/datasets/author_verification/test/Wiki/known_doc_list.txt")
    .strip()
    .split("\n")
)
unknown_doc_list = (
    read_txt("/Volumes/BCross/datasets/author_verification/test/Wiki/unknown_doc_list.txt")
    .strip()
    .split("\n")
)

known_doc = known_doc_list[0]
unknown_doc = unknown_doc_list[0]

In [13]:
# -----
# Get the chosen text & metadata
# -----

known_text = known[known['doc_id'] == known_doc].reset_index().loc[0, 'text'].lower()
unknown_text = unknown[unknown['doc_id'] == unknown_doc].reset_index().loc[0, 'text'].lower()
    
problem_metadata = agg_metadata[(agg_metadata['known_doc_id'] == known_doc)
                                & (agg_metadata['unknown_doc_id'] == unknown_doc)].reset_index()
problem_metadata['target'] = problem_metadata['known_author'] == problem_metadata['unknown_author']

In [15]:
# -----
# Create document dataframe
# -----
    
# This is used to display the text
docs_df = pd.DataFrame(
{
    "known":   [corpus, data_type, known_doc, known_text],
    "unknown": [corpus, data_type, unknown_doc, unknown_text],
},
index=["corpus", "data type", "doc", "text"],
)

In [24]:
# -----
# Get common n-grams
# -----
    
print("Getting common n-grams")
common = common_ngrams(known_text, unknown_text, n=2, model=model, tokenizer=tokenizer, lowercase=True)
    
# Filter to remove smaller n-grams which don't satisfy the rules
common = filter_ngrams(common, special_tokens=special_tokens)
n_gram_list = pretty_print_common_ngrams(common, tokenizer=tokenizer, return_format='flat', show_raw=True)
print(f"There are {len(n_gram_list)} n-grams in common!")

Getting common n-grams
There are 12 n-grams in common!


{2: {('Ġabout', 'Ġthis'),
  ('Ġarticles', 'Ġon'),
  ('Ġbecause', 'Ġthe'),
  ('Ġthe', 'Ġsubject'),
  ('Ġthis', 'Ġarticle')},
 3: {(',', 'Ġbut', 'Ġthis'),
  (',', 'Ġyou', 'Ġare'),
  ('Ġdo', 'Ġnot', 'Ġhave'),
  ('Ġone', 'Ġof', 'Ġthe'),
  ('Ġwelcome', 'Ġto', 'Ġimprove'),
  ('Ġyou', 'Ġdo', 'Ġnot')},
 4: {(',', 'Ġthis', 'Ġis', 'Ġnot')}}

In [25]:
n_gram_list

[(', but this', "(',', 'Ġbut', 'Ġthis')"),
 (', you are', "(',', 'Ġyou', 'Ġare')"),
 (' do not have', "('Ġdo', 'Ġnot', 'Ġhave')"),
 (' one of the', "('Ġone', 'Ġof', 'Ġthe')"),
 (' welcome to improve', "('Ġwelcome', 'Ġto', 'Ġimprove')"),
 (' you do not', "('Ġyou', 'Ġdo', 'Ġnot')"),
 (' about this', "('Ġabout', 'Ġthis')"),
 (' articles on', "('Ġarticles', 'Ġon')"),
 (' because the', "('Ġbecause', 'Ġthe')"),
 (' the subject', "('Ġthe', 'Ġsubject')"),
 (' this article', "('Ġthis', 'Ġarticle')"),
 (', this is not', "(',', 'Ġthis', 'Ġis', 'Ġnot')")]

## Phrase Masking Funcitons

In [108]:
import re
from typing import List

def mask_phrase(
    text: str,
    phrase: str,
    mask_token: str = "[MASK]"
) -> List[str]:
    """
    For each occurrence of `phrase` in `text`, create a version of
    the text where ONLY that occurrence is replaced with `mask_token`.

    Returns:
        List of masked texts.
        - If the phrase appears N times, the list has length N.
        - If it appears 0 times, the list is empty.
    """
    if not phrase:
        raise ValueError("phrase must be a non-empty string")

    pattern = re.escape(phrase)
    matches = list(re.finditer(pattern, text))

    masked_variants: List[str] = []

    for m in matches:
        start, end = m.span()
        masked_text = text[:start] + mask_token + text[end:]
        masked_variants.append(masked_text)

    return masked_variants

def mask_phrase_by_tokens(
    text: str,
    phrase_tokens: List[str],
    tokenizer: tokenizer,
    mask_token: str | None = '[MASK]',
) -> List[str]:
    """
    For each occurrence of `phrase_tokens` (a token sequence) in the tokenized
    version of `text`, create a version of the text where ONLY that occurrence
    is replaced by mask tokens (one [MASK] per token in the phrase).

    Args:
        text: Original text.
        phrase_tokens: List of token strings for the phrase (e.g., tokenizer.tokenize(phrase)).
        tokenizer: HF tokenizer used to tokenize and detokenize.
        mask_token: Mask token string to use. Defaults to tokenizer.mask_token.

    Returns:
        List of masked texts:
        - If the phrase appears N times (as a token subsequence), the list has length N.
        - If it appears 0 times, the list is empty.
    """
    if not phrase_tokens:
        raise ValueError("phrase_tokens must be a non-empty list of tokens")

    if mask_token is None:
        if tokenizer.mask_token is None:
            raise ValueError("mask_token not provided and tokenizer.mask_token is None")
        mask_token = tokenizer.mask_token

    # Tokenize text without special tokens
    tokens = tokenizer.tokenize(text)
    n = len(phrase_tokens)

    masked_variants: List[str] = []

    # Find all subsequence matches of phrase_tokens in tokens
    for start_idx in range(len(tokens) - n + 1):
        if tokens[start_idx:start_idx + n] == phrase_tokens:
            # Create a masked copy for this particular occurrence
            masked_tokens = tokens.copy()
            for j in range(n):
                masked_tokens[start_idx + j] = mask_token

            masked_text = tokenizer.convert_tokens_to_string(masked_tokens)
            masked_variants.append(masked_text)

    return masked_variants

from typing import List, Dict, Optional
from transformers import PreTrainedTokenizerBase

def mask_phrase(
    text: str,
    phrase_tokens: List[str],
    tokenizer: PreTrainedTokenizerBase,
    mask_token: Optional[str] = "[MASK]",
) -> Dict[str, List[str]]:
    """
    For each occurrence of `phrase_tokens` (a token sequence) in the tokenized
    version of `text`, create two versions of the text:

    1. single_mask: the whole phrase replaced by a **single** mask_token.
    2. multi_mask: each token in the phrase replaced by its own mask_token.

    Returns a dict:
        {
            "single_mask": [ ... ],
            "multi_mask":  [ ... ],
        }
    where each list has length == number of occurrences (can be 0).
    """
    if not phrase_tokens:
        raise ValueError("phrase_tokens must be a non-empty list of tokens")

    if mask_token is None:
        if tokenizer.mask_token is None:
            raise ValueError("mask_token not provided and tokenizer.mask_token is None")
        mask_token = tokenizer.mask_token

    # Tokenize the text (no special tokens)
    tokens = tokenizer.tokenize(text)
    n = len(phrase_tokens)

    single_mask_variants: List[str] = []
    multi_mask_variants: List[str] = []

    # Find all subsequence matches of phrase_tokens in tokens
    for start_idx in range(len(tokens) - n + 1):
        if tokens[start_idx:start_idx + n] == phrase_tokens:
            # --- multi-mask: replace each token in the phrase by mask_token ---
            multi_tokens = tokens.copy()
            for j in range(n):
                multi_tokens[start_idx + j] = mask_token
            multi_text = tokenizer.convert_tokens_to_string(multi_tokens)
            multi_mask_variants.append(multi_text)

            # --- single-mask: collapse the whole phrase into a single mask_token ---
            single_tokens = tokens[:start_idx] + [mask_token] + tokens[start_idx + n:]
            single_text = tokenizer.convert_tokens_to_string(single_tokens)
            single_mask_variants.append(single_text)

    return {
        "single_mask": single_mask_variants,
        "multi_mask": multi_mask_variants,
    }


### Paraphrasing

In [96]:
# B1 — Joint multi-mask fill with beam search (ModernBERT or any MLM)
import math
from typing import List, Dict, Tuple, Optional, Iterable
import torch
from transformers import PreTrainedTokenizerBase, PreTrainedModel, AutoTokenizer, AutoModelForMaskedLM

@torch.no_grad()
def fill_masks_beam(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    masked_text: str,
    top_k_per_mask: int = 25,   # widen/limit local options at each mask
    beam_size: int = 100,       # how many partial hypotheses to keep
    max_candidates: int = 50,   # truncate final list
    banned_token_ids: Optional[Iterable[int]] = None,
) -> List[Dict]:
    """
    Jointly fill one or more [MASK] tokens using a simple beam search.

    Returns a list of dicts: {text, score, tokens, token_ids, mask_positions}
    where 'score' is the sum of log-probs across all masks (higher is better).
    """
    if tokenizer.mask_token_id is None:
        raise ValueError("Tokenizer must define mask_token_id (MLM required).")

    enc = tokenizer(masked_text, return_tensors="pt", add_special_tokens=True, truncation=True)
    input_ids = enc["input_ids"]
    mask_id = tokenizer.mask_token_id
    mask_positions = (input_ids[0] == mask_id).nonzero(as_tuple=False).flatten().tolist()
    if not mask_positions:
        return []

    outputs = model(**enc)
    log_probs = outputs.logits.log_softmax(-1)  # [1, seq_len, vocab]

    specials = set(getattr(tokenizer, "all_special_ids", []) or [])
    banned = set(banned_token_ids or []) | specials

    # Precompute top-k candidates at each mask position
    per_mask_cands = []
    for pos in mask_positions:
        lp = log_probs[0, pos].clone()
        if banned:
            idx = torch.tensor(sorted(banned), dtype=torch.long)
            lp.index_fill_(0, idx, float("-inf"))
        topk = torch.topk(lp, k=min(top_k_per_mask, lp.numel()))
        per_mask_cands.append((topk.indices.tolist(), topk.values.tolist()))

    # Beam over masks (left-to-right in token order)
    beam = [(0.0, [])]  # (cum_logprob, chosen_token_ids_so_far)
    for cand_ids, cand_lps in per_mask_cands:
        new_beam = []
        for cum_lp, chosen in beam:
            for tid, lp in zip(cand_ids, cand_lps):
                new_beam.append((cum_lp + float(lp), chosen + [tid]))
        new_beam.sort(key=lambda x: x[0], reverse=True)
        beam = new_beam[:beam_size]

    # Materialize and deduplicate by decoded text
    best = {}
    for cum_lp, choice_ids in beam:
        filled = input_ids.clone()
        for pos, tid in zip(mask_positions, choice_ids):
            filled[0, pos] = tid
        text_out = tokenizer.decode(filled[0], skip_special_tokens=True)
        prev = best.get(text_out)
        if (prev is None) or (cum_lp > prev["score"]):
            best[text_out] = {
                "text": text_out,
                "score": cum_lp,
                "tokens": tokenizer.convert_ids_to_tokens(choice_ids),
                "token_ids": choice_ids,
                "mask_positions": mask_positions,
            }
    return sorted(best.values(), key=lambda r: r["score"], reverse=True)[:max_candidates]

# C1 — Try multiple lengths at one masked span and rank globally
from typing import Iterable, List, Dict

def variable_length_infill(
    model, tokenizer,
    masked_template: str,              # contains ONE [MASK] span
    length_options: Iterable[int] = (1, 2, 3, 4),
    per_length_topk: int = 10,
    normalize_by_masks: bool = True,   # de-bias longer spans
    **beam_kwargs
) -> List[Dict]:
    mask_tok = tokenizer.mask_token
    assert mask_tok in masked_template, "Template must contain one [MASK] span."
    rows = []
    for L in length_options:
        expanded = masked_template.replace(mask_tok, " ".join([mask_tok]*L), 1)
        outs = fill_masks_beam(model, tokenizer, expanded, **beam_kwargs)
        for o in outs[:per_length_topk]:
            score = o["score"] / L if normalize_by_masks else o["score"]
            rows.append({"text": o["text"], "length": L, "score": score, "raw_score": o["score"], "tokens": o["tokens"]})
    return sorted(rows, key=lambda r: r["score"], reverse=True)

from typing import List, Dict, Optional
from transformers import PreTrainedTokenizerBase

def beam_outputs_to_phrases(
    outputs: List[Dict],
    tokenizer: PreTrainedTokenizerBase,
    original_phrase: Optional[str] = None,
    lowercase: bool = True,
    unique: bool = True,
) -> List[str]:
    """
    Convert MLM beam outputs (from fill_masks_beam / variable_length_infill)
    into a list of phrases, similar to parse_paraphrases.

    Assumes that `o["tokens"]` are the tokens for the masked span
    (which is true for variable_length_infill, and for single-span masks).
    """
    phrases: List[str] = []

    # Prepare comparison baseline (like parse_paraphrases)
    if original_phrase is not None and lowercase:
        original_cmp = original_phrase.lower()
    else:
        original_cmp = original_phrase

    for o in outputs:
        # Decode just the predicted tokens for the masked span
        candidate = tokenizer.convert_tokens_to_string(o["tokens"])

        # Compare and optionally lowercase
        cand_cmp = candidate.lower() if lowercase else candidate

        # Drop suggestions that are the same as the original phrase
        if original_cmp is not None and cand_cmp == original_cmp:
            continue

        phrases.append(cand_cmp if lowercase else candidate)

    # Deduplicate (like your set(), but preserve order)
    if unique:
        seen = set()
        deduped = []
        for p in phrases:
            if p not in seen:
                seen.add(p)
                deduped.append(p)
        return deduped

    return phrases


@torch.no_grad()
def fill_masks_beam_phrases(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    masked_text: str,
    original_phrase: Optional[str] = None,
    lowercase: bool = True,
    top_k_per_mask: int = 25,
    beam_size: int = 100,
    max_candidates: int = 50,
    banned_token_ids: Optional[Iterable[int]] = None,
) -> List[str]:
    """
    Convenience wrapper around fill_masks_beam that returns only the
    predicted phrases (like parse_paraphrases).
    """
    beam_outputs = fill_masks_beam(
        model=model,
        tokenizer=tokenizer,
        masked_text=masked_text,
        top_k_per_mask=top_k_per_mask,
        beam_size=beam_size,
        max_candidates=max_candidates,
        banned_token_ids=banned_token_ids,
    )

    return beam_outputs_to_phrases(
        outputs=beam_outputs,
        tokenizer=tokenizer,
        original_phrase=original_phrase,
        lowercase=lowercase,
        unique=True,
    )

def variable_length_infill_phrases(
    model,
    tokenizer,
    masked_template: str,               # contains ONE [MASK] span
    original_phrase: Optional[str] = None,
    lowercase: bool = True,
    length_options: Iterable[int] = (1, 2, 3, 4),
    per_length_topk: int = 10,
    normalize_by_masks: bool = True,
    **beam_kwargs,
) -> List[str]:
    """
    Run variable_length_infill and return only the span phrases as a list,
    similar to parse_paraphrases.
    """

    rows = variable_length_infill(
        model=model,
        tokenizer=tokenizer,
        masked_template=masked_template,
        length_options=length_options,
        per_length_topk=per_length_topk,
        normalize_by_masks=normalize_by_masks,
        **beam_kwargs,
    )
    # rows are dicts with at least: {"text", "length", "score", "raw_score", "tokens"}

    return beam_outputs_to_phrases(
        outputs=rows,        # compatible: each row has "tokens"
        tokenizer=tokenizer,
        original_phrase=original_phrase,
        lowercase=lowercase,
        unique=True,
    )


In [109]:
n_gram_dict = {}
width = len(str(len(n_gram_list)))  # e.g., 10 -> 2, 100 -> 3

for idx, (phrase_pretty, phrase_raw) in enumerate(n_gram_list, start=1):
    phrase_list = list(ast.literal_eval(phrase_raw))
    
    # Mask the phrase in the text
    masked_data = mask_phrase(known_text, phrase_list, tokenizer, "[MASK]")
    string_based_masked_list = masked_data['single_mask']
    token_based_masked_list = masked_data['multi_mask']
    
    paraphrases = []
    for i in range(0, len(string_based_masked_list)):
        
        token_paraphrases = fill_masks_beam_phrases(
            model=model,
            tokenizer=tokenizer,
            masked_text=token_based_masked_list[i],
            original_phrase=phrase_pretty,
            lowercase=True,
            top_k_per_mask=25,
            beam_size=100,
            max_candidates=50
        )
        
        paraphrases.extend(token_paraphrases)
        
        string_paraphrases = variable_length_infill_phrases(
            model,
            tokenizer,
            masked_template=string_based_masked_list[i],
            original_phrase = phrase_pretty,
            lowercase= True,
            length_options=tuple(range(2, len(phrase_list) + 1)),
            per_length_topk=10,
            normalize_by_masks=True
        )
        
        paraphrases.extend(string_paraphrases)
        
    paraphrases = list(set(paraphrases))
    key = f"phrase_{idx:0{width}d}"  # -> phrase_01, phrase_002, etc.
    n_gram_dict[key] = {"phrase": phrase_pretty, "paraphrases": paraphrases}

In [116]:
n_gram_dict

{'phrase_01': {'phrase': ', but this',
  'paraphrases': [' together it',
   ',\nthe',
   ', it',
   ' together\n it',
   '. it',
   ' together the it',
   ', thethis',
   ' in but it',
   ' because\nthis',
   ',\nhis',
   ', the the',
   ' with\n it',
   ' together but it',
   ' into but it',
   '.\n this',
   ' with this it',
   ' into it',
   '. the this',
   ' together this it',
   ' but it',
   ' because\n this',
   '.\n the',
   ', butthis',
   ', becausethe',
   ',\n this',
   ', the',
   ' and this',
   ' but this',
   ' book\n this',
   ' because this this',
   ' into the it',
   '. because this',
   ',\nthis',
   ', this this',
   '. this this',
   '. but that',
   ', this',
   ', thatthis',
   ', this it',
   ' in the it',
   ' into that it',
   ' in it',
   ', the this',
   ' with this that',
   ' with it it',
   ' with but that',
   ', but it',
   ' with something it',
   ' into because it',
   ' with that it',
   '.\n it',
   ' because the',
   ', thisthis',
   ',, it',
  

In [114]:
import torch
from typing import List, Tuple, Dict
from transformers import PreTrainedTokenizerBase, PreTrainedModel

@torch.no_grad()
def score_phrase_in_context(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    tokens: List[str],
    span: Tuple[int, int],
    candidate_tokens: List[str],
) -> Dict[str, object]:
    """
    Approximate log P(candidate_tokens | context) using pseudo log-likelihood:
    for each candidate token position, mask it, get log P(true token | context),
    and accumulate.

    Returns a dict with:
        {
          "token_logprobs": [float, ...],  # one per token in candidate_tokens
          "sum_logprob": float,
          "avg_logprob": float,
        }
    """
    start_idx, end_idx = span
    assert (end_idx - start_idx + 1) == len(candidate_tokens), "Span length mismatch."

    # Insert candidate tokens into a copy of the tokens
    base_tokens = tokens.copy()
    base_tokens[start_idx:end_idx + 1] = candidate_tokens

    # Encode once with specials
    enc = tokenizer(
        base_tokens,
        is_split_into_words=False,
        return_tensors="pt",
        add_special_tokens=True,
    )
    input_ids = enc["input_ids"]  # [1, seq_len]
    mask_id = tokenizer.mask_token_id
    if mask_id is None:
        raise ValueError("Tokenizer must have a mask_token_id.")

    token_logprobs: List[float] = []

    # Assume [CLS] at position 0 and tokens aligned 1:1 after that
    seq_start = 1  # offset of first "real" token in input_ids

    for pos_offset, tok_str in enumerate(candidate_tokens):
        token_id = tokenizer.convert_tokens_to_ids(tok_str)

        pos = seq_start + start_idx + pos_offset

        masked_ids = input_ids.clone()
        masked_ids[0, pos] = mask_id

        outputs = model(input_ids=masked_ids)
        logits = outputs.logits  # [1, seq_len, vocab]
        log_probs = torch.log_softmax(logits[0, pos], dim=-1)

        tok_log_prob = float(log_probs[token_id].item())
        token_logprobs.append(tok_log_prob)

    sum_logprob = float(sum(token_logprobs))
    avg_logprob = sum_logprob / len(token_logprobs)

    return {
        "token_logprobs": token_logprobs,
        "sum_logprob": sum_logprob,
        "avg_logprob": avg_logprob,
    }


In [115]:
from typing import List, Tuple, Dict
from transformers import PreTrainedTokenizerBase, PreTrainedModel

def find_phrase_token_spans(
    text: str,
    phrase_tokens: List[str],
    tokenizer: PreTrainedTokenizerBase,
) -> Tuple[List[Tuple[int, int]], List[str]]:
    """
    Find all contiguous occurrences of phrase_tokens in tokenized text.

    Returns:
        spans: list of (start_idx, end_idx_inclusive) in token indices.
        tokens: the tokenized text.
    """
    tokens = tokenizer.tokenize(text)
    spans: List[Tuple[int, int]] = []
    n = len(phrase_tokens)

    for start_idx in range(len(tokens) - n + 1):
        if tokens[start_idx:start_idx + n] == phrase_tokens:
            spans.append((start_idx, start_idx + n - 1))

    return spans, tokens


def score_candidates_in_unknown(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    unknown_text: str,
    phrase_tokens: List[str],
    candidates: List[str],
) -> Dict[str, object]:
    """
    For each occurrence of phrase_tokens in unknown_text, compute MLM scores for
    each candidate phrase.

    Returns a dict:

        {
          "occurrences": [
             {
               "span": (start_idx, end_idx),
               "scores": [
                  {
                    "candidate": str,
                    "token_logprobs": [float, ...],
                    "sum_logprob": float,
                    "avg_logprob": float,
                  },
                  ...
               ]
             },
             ...
          ]
        }
    """
    spans, base_tokens = find_phrase_token_spans(unknown_text, phrase_tokens, tokenizer)

    results: Dict[str, object] = {"occurrences": []}

    # Pre-tokenize candidates
    cand_token_map: Dict[str, List[str]] = {
        cand: tokenizer.tokenize(cand) for cand in candidates
    }

    for span in spans:
        start_idx, end_idx = span
        span_len = end_idx - start_idx + 1

        scores_for_occurrence = []

        for cand, cand_tokens in cand_token_map.items():
            # Easiest: only compare candidates of same token length as the span
            if len(cand_tokens) != span_len:
                continue

            score_dict = score_phrase_in_context(
                model=model,
                tokenizer=tokenizer,
                tokens=base_tokens,
                span=span,
                candidate_tokens=cand_tokens,
            )

            scores_for_occurrence.append({
                "candidate": cand,
                "token_logprobs": score_dict["token_logprobs"],
                "sum_logprob": score_dict["sum_logprob"],
                "avg_logprob": score_dict["avg_logprob"],
            })

        # Sort candidates by avg_logprob (higher = more likely)
        scores_for_occurrence.sort(
            key=lambda x: x["avg_logprob"],
            reverse=True
        )

        results["occurrences"].append({
            "span": span,
            "scores": scores_for_occurrence,
        })

    return results


In [None]:
ref_tokens = n_gram_list[0][1]
ref_tokens = list(ast.literal_eval(ref_tokens))


[',', 'Ġbut', 'Ġthis']

In [127]:
ref_tokens = n_gram_list[0][1]
phrase_tokens = list(ast.literal_eval(ref_tokens))
phrases = n_gram_dict['phrase_01']
candidates = [phrases['phrase']] + phrases['paraphrases']

In [128]:
scores = score_candidates_in_unknown(
    model=model,
    tokenizer=tokenizer,
    unknown_text=unknown_text,
    phrase_tokens=phrase_tokens,
    candidates=candidates,
)

# Example: inspect first occurrence
occ = scores["occurrences"][0]
print("Span:", occ["span"])
for s in occ["scores"][:5]:
    print(
        f"Candidate: {s['candidate']!r}\n"
        f"  token_logprobs: {s['token_logprobs']}\n"
        f"  sum_logprob: {s['sum_logprob']:.3f}\n"
        f"  avg_logprob: {s['avg_logprob']:.3f}\n"
    )


ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`input_ids` in this case) have excessive nesting (inputs type `list` where type `int` is expected).

In [None]:
#!/usr/bin/env python3

import argparse
import sys
import json
import re
import os
from from_root import from_root

import pandas as pd

# Ensure we can import from src/
sys.path.insert(0, str(from_root("src")))

from read_and_write_docs import read_jsonl, read_rds
from model_loading import load_model, distinct_special_chars
from utils import apply_temp_doc_id, build_metadata_df
from n_gram_functions import (
    common_ngrams,
    filter_ngrams,
    pretty_print_common_ngrams,
    get_scored_df,
    get_scored_df_no_context
)
from open_ai import initialise_client, llm
from excel_functions import create_excel_template

# --------------------
# Helpers
# --------------------

# remove illegal control chars (keep \t, \n, \r)
_ILLEGAL_RE = re.compile(r"[\x00-\x08\x0B-\x0C\x0E-\x1F]")

def _clean_cell(x):
    if isinstance(x, str):
        return _ILLEGAL_RE.sub("", x)
    return x

def clean_for_excel(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    obj_cols = df.select_dtypes(include=["object"]).columns
    df[obj_cols] = df[obj_cols].applymap(_clean_cell)
    return df

def create_system_prompt(prompt_loc):
    """Reads the prompt as a .txt file for better versioning"""
    with open(prompt_loc, "r", encoding="utf-8") as f:
        return f.read()
    
def create_user_prompt(known_text, phrase, raw_phrase):
    """The method of input to the LLM as described in the system prompt"""
    user_prompt = f"""
<DOC>
{known_text}
</DOC>
<RAW NGRAM>
"{raw_phrase}"
</RAW NGRAM>
<NGRAM>
"{phrase}"
</NGRAM>
"""
    
    return user_prompt

def parse_paraphrases(response, phrase, lowercase=True):
    """Extract paraphrases from OpenAI response (JSON mode)."""
    paraphrase_list = []
    for i in range(1, len(response.choices)):
        content = response.choices[i].message.content
        
        try:
            content_json = json.loads(content)
            for para in content_json['paraphrases']:
                if para != phrase:
                    if (lowercase) & (para.lower() != phrase):
                        paraphrase_list.append(para.lower())
                    else:
                        paraphrase_list.append(para)  
        except Exception:
            continue
        
    unique_list = list(set(paraphrase_list))
    
    return unique_list

def parse_args():
    ap = argparse.ArgumentParser(description="OpenAI N-gram paraphrase pipeline")
    # Paths
    ap.add_argument("--known_loc")
    ap.add_argument("--unknown_loc")
    ap.add_argument("--metadata_loc")
    ap.add_argument("--model_loc")
    ap.add_argument("--save_loc")
    ap.add_argument("--completed_loc", default=None)
    # Dataset hinting
    ap.add_argument("--corpus", default="Wiki")
    ap.add_argument("--data_type", default="training")
    ap.add_argument("--known_doc")
    ap.add_argument("--unknown_doc")
    # Env
    ap.add_argument("--credentials_loc", default=str(from_root("credentials.json")))
    ap.add_argument("--prompt_loc", default=str(from_root("prompts", "exhaustive_constrained_ngram_paraphraser_prompt_JSON_new.txt")))
    # N-gram
    ap.add_argument("--ngram_n", type=int, default=2)
    ap.add_argument("--lowercase", action="store_true")
    ap.add_argument("--no_lowercase", dest="lowercase", action="store_false")
    ap.set_defaults(lowercase=True)
    ap.add_argument("--order", default="len_desc", help="Order for pretty_print_common_ngrams")
    ap.add_argument("--score_texts", action="store_true")
    # OpenAI
    ap.add_argument("--openai_model", default="gpt-4.1")
    ap.add_argument("--max_tokens", type=int, default=5000)
    ap.add_argument("--temperature", type=float, default=0.7)
    ap.add_argument("--n", type=int, default=10)

    return ap.parse_args()

def main():
    
    args=parse_args()
    
    # Ensure the directory exists before beginning
    os.makedirs(args.save_loc, exist_ok=True)
    
    # -----
    # LOAD DATA & LOCAL MODEL
    # -----
    specific_problem = f"{args.known_doc} vs {args.unknown_doc}"
    save_loc = f"{args.save_loc}/{specific_problem}.xlsx"
    
    if args.completed_loc:
        completed_loc = f"{args.completed_loc}/{specific_problem}.xlsx"
        if os.path.exists(completed_loc):
            print(f"Result for {specific_problem} already exists in the completed folder. Exiting.")
            sys.exit()
    
    # Skip the problem if already exists
    if os.path.exists(save_loc):
        print(f"Path {save_loc} already exists. Exiting.")
        sys.exit()
        
    print(f"Working on problem: {specific_problem}")
    
    print("Loading model")
    tokenizer, model = load_model(args.model_loc)
    special_tokens = distinct_special_chars(tokenizer=tokenizer)
    
    print("Loading data")
    known = read_jsonl(args.known_loc)
    known = apply_temp_doc_id(known)
    
    unknown = read_jsonl(args.unknown_loc)
    unknown = apply_temp_doc_id(unknown)

    print("Data loaded")
    
    # NOTE - Is this used?
    metadata = read_rds(args.metadata_loc)
    filtered_metadata = metadata[metadata['corpus'] == args.corpus]
    agg_metadata = build_metadata_df(filtered_metadata, known, unknown)

    # -----
    # Get the chosen text & metadata
    # -----
    
    known_text = known[known['doc_id'] == args.known_doc].reset_index().loc[0, 'text'].lower()
    unknown_text = unknown[unknown['doc_id'] == args.unknown_doc].reset_index().loc[0, 'text'].lower()
    
    problem_metadata = agg_metadata[(agg_metadata['known_doc_id'] == args.known_doc)
                                    & (agg_metadata['unknown_doc_id'] == args.unknown_doc)].reset_index()
    problem_metadata['target'] = problem_metadata['known_author'] == problem_metadata['unknown_author']

    # -----
    # Create document dataframe
    # -----
    
    # This is used to display the text
    docs_df = pd.DataFrame(
    {
        "known":   [args.corpus, args.data_type, args.known_doc, known_text],
        "unknown": [args.corpus, args.data_type, args.unknown_doc, unknown_text],
    },
    index=["corpus", "data type", "doc", "text"],
    )
    
    # -----
    # Get common n-grams
    # -----
    
    print("Getting common n-grams")
    common = common_ngrams(known_text, unknown_text, args.ngram_n, model, tokenizer, lowercase=args.lowercase)
    
    # Filter to remove smaller n-grams which don't satisfy the rules
    common = filter_ngrams(common, special_tokens=special_tokens)
    n_gram_list = pretty_print_common_ngrams(common, tokenizer=tokenizer, order=args.order, return_format='flat', show_raw=True)
    print(f"There are {len(n_gram_list)} n-grams in common!")
    
    # -----
    # OpenAI bits
    # -----
    
    print("Generating paraphrases")
    client = initialise_client(args.credentials_loc)
    
    n_gram_dict = {}
    width = len(str(len(n_gram_list)))  # e.g., 10 -> 2, 100 -> 3

    for idx, (phrase_pretty, phrase_raw) in enumerate(n_gram_list, start=1):
        user_prompt = create_user_prompt(known_text, phrase_pretty, raw_phrase=phrase_raw)
        response = llm(
            create_system_prompt(args.prompt_loc),
            user_prompt,
            client,
            model=args.openai_model,
            max_tokens=args.max_tokens,
            temperature=args.temperature,
            n=args.n,
            response_format={"type": "json_object"},
        )
        paraphrases = parse_paraphrases(response, phrase_pretty)
        key = f"phrase_{idx:0{width}d}"  # -> phrase_01, phrase_002, etc.
        n_gram_dict[key] = {"phrase": phrase_pretty, "paraphrases": paraphrases}
        
    # -----
    # Score phrases
    # -----
    if args.score_texts:
        print("Scoring phrases")
        print("    Scoring known text")
        known_scored = get_scored_df(n_gram_dict, known_text, tokenizer, model)
        
        print("    Scoring unknown text")
        unknown_scored = get_scored_df(n_gram_dict, unknown_text, tokenizer, model)
        
        print("    Scoring phrases with no context")
        score_df_no_context = get_scored_df_no_context(n_gram_dict, tokenizer, model)
        
        # -----
        # Final cleaning and saving
        # -----
        
        print(f"Writing file: {specific_problem}")
        
        # Run the new Excel function
        create_excel_template(
            known=known_scored,
            unknown=unknown_scored,
            no_context=score_df_no_context,
            metadata=problem_metadata,
            docs=docs_df,
            path=save_loc,
            known_sheet="known",
            unknown_sheet="unknown",
            nc_sheet="no context",
            metadata_sheet="metadata",
            docs_sheet="docs",
            llr_sheet="LLR",
            use_xlookup=False,
            highlight_phrases=False
        )
    
    else:
        print("Not scoring texts")
        print("<<<RESULT_JSON_START>>>")
        print(json.dumps(n_gram_dict))
        print("<<<RESULT_JSON_END>>>")
    
if __name__ == "__main__":
    main()