In [4]:
%load_ext autoreload
%autoreload 2

In [59]:
import os
import argparse
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Generator
import re
import toml

Useful link: https://github.com/lean-dojo/LeanDojo/blob/main/scripts/generate-benchmark-lean4.ipynb

### High-level pre-processing logic

In [17]:
def main():
    parser_config = argparse.ArgumentParser()
    parser_config.add_argument(
        "--config_path", required=True, type=str, help="Path to the config file"
    )
    args_config = parser_config.parse_args()
    config = toml.load(args_config.config_path)
    all_path_ori = config['dataset']['parent_path']
    all_path_save = config['dataset']['parent_save_path_our']
    if not os.path.exists(all_path_save):
        os.makedirs(all_path_save)

    train_path = all_path_ori + "train.jsonl"
    test_path = all_path_ori + "test.jsonl"
    val_path = all_path_ori + "valid.jsonl"
    train_path_save = all_path_save + 'train.jsonl'
    test_path_save = all_path_save + 'test.jsonl'
    val_path_save = all_path_save + 'val.jsonl'
    test_path_save_increase = all_path_save + 'test_increase.jsonl'
    val_path_save_increase = all_path_save + 'val_increase.jsonl'
    train_expand_path = all_path_save + "train_expand_premise.jsonl"

    modeul_id = _load_corpus_update(
        config['dataset']['all_data_path'] + "corpus.jsonl",
        config['dataset']['all_data_path'] + 'statement.jsonl',
        config['dataset']['all_data_path'] + 'module_id_mapping.json',
    )

    _get_dataset_path(train_path, train_path_save, modeul_id)
    _get_dataset_path(test_path, test_path_save, modeul_id)
    _get_dataset_path(val_path, val_path_save, modeul_id)

    expand_data(train_path_save, train_expand_path)

    _get_dataset_path_test(test_path, test_path_save_increase, is_train=False, modeul_id=modeul_id)
    _get_dataset_path_test(val_path, val_path_save_increase, is_train=False, modeul_id=modeul_id)

In [12]:
@dataclass(eq=True, unsafe_hash=True)
class Pos:
    """Position in source files.

    We use 1-index to keep it consistent with code editors such as Visual Studio Code.
    """

    line_nb: int
    """Line number
    """

    column_nb: int
    """Column number
    """

    @classmethod
    def from_str(cls, s: str) -> "Pos":
        """Construct a :class:`Pos` object from its string representation, e.g., :code:`"(323, 1109)"`."""
        assert s.startswith("(") and s.endswith(
            ")"
        ), f"Invalid string representation of a position: {s}"
        line, column = s[1:-1].split(",")
        line_nb = int(line)
        column_nb = int(column)
        return cls(line_nb, column_nb)

    def __iter__(self) -> Generator[int, None, None]:
        yield self.line_nb
        yield self.column_nb

    def __repr__(self) -> str:
        return repr(tuple(self))

    def __lt__(self, other):
        return self.line_nb < other.line_nb or (
                self.line_nb == other.line_nb and self.column_nb < other.column_nb
        )

    def __le__(self, other):
        return self < other or self == other

@dataclass(unsafe_hash=True)
class Premise:
    """Premises are "documents" in our retrieval setup."""

    path: str
    """The ``*.lean`` file this premise comes from.
    """

    full_name: str
    """Fully qualified name.
    """

    start: Pos = field(repr=False, compare=False)
    """Start position of the premise's definition in the ``*.lean`` file.
    """

    end: Pos = field(repr=False, compare=False)
    """End position of the premise's definition in the ``*.lean`` file.
    """

    # code: str = field(compare=False)
    def __post_init__(self) -> None:
        assert isinstance(self.path, str)
        assert isinstance(self.full_name, str)
        assert (
                isinstance(self.start, Pos)
                and isinstance(self.end, Pos)
        )
        assert (
                self.start <= self.end
        )
        # assert isinstance(self.code, str) and self.code != ""


@dataclass(unsafe_hash=True)
class Premise_corpus:
    """Premises are "documents" in our retrieval setup."""

    path: str
    """The ``*.lean`` file this premise comes from.
    """

    full_name: str
    """Fully qualified name.
    """

    start: Pos = field(repr=False, compare=False)
    """Start position of the premise's definition in the ``*.lean`` file.
    """

    end: Pos = field(repr=False, compare=False)
    """End position of the premise's definition in the ``*.lean`` file.
    """
    code: str = field(compare=False)

    def __post_init__(self) -> None:
        assert isinstance(self.path, str)
        assert isinstance(self.full_name, str)
        assert (
                isinstance(self.start, Pos)
                and isinstance(self.end, Pos)
        )
        assert (
                self.start <= self.end
        )
        assert isinstance(self.code, str) and self.code != ""


### Exploring the corpus structure

The goal of this part is to iterate over the full corpus of .lean files. Per file, extract all used premises, store into a Premise object and then store in a mapping from Premise object to an id. 

In [9]:
def get_code_without_doc_string(code):
    pattern = r"/--(.*?)-/\n"
    cleaned_code = re.sub(pattern, "", code, flags=re.DOTALL)
    return cleaned_code

In [31]:
DST_DIR = Path("/home/ex-anastasia/Premise-Retrieval/mathlib_handler_benchmark_410/")
corpus_path = DST_DIR / "corpus.jsonl"
lines = list(corpus_path.open())

In [32]:
one_line = json.loads(lines[2000])

def_path = one_line['path']
print(f'def_path: {def_path}\n')

# these are defined as globals over all lines
corpus = [] # this will get written into config['dataset']['all_data_path'] + 'statement.jsonl',
moudel_id = {} # this will get returned
path_premise = {} # this is written into config['dataset']['all_data_path'] + 'module_id_mapping.json',
count = 4 # this is simply the id of the premise as we loop over corpus

# loop over premises in this one line
premise_num = []
premise = one_line['premises'][count]
print(f"is_thm: {premise['is_thm']}\n") # if this is false, we continue

temp_dict = {}
premise_state = {}
premise_state['context'] = [item[1:-1] for item in premise['args']]
premise_state['goal'] = premise['goal']
premise['state'] = premise_state
premise['code'] = get_code_without_doc_string(premise['code'])
print(f"premise: \n") # has code and state; state consists of goal and context
print(f"premise_code: {premise['code']}\n")
print(f"premise_goal: {premise_state['goal']}\n")
print(f"premise_context: {premise_state['context']}\n")
temp_dict['id'] = count
temp_dict['premise'] = premise
temp_dict['def_path'] = def_path

now_premise = Premise(def_path, premise['full_name'],
                        Pos(premise['start'][0], premise['start'][1]),
                        Pos(premise['end'][0], premise['end'][1]))
p = now_premise
full_name = p.full_name
if full_name is None or "user__.n" in full_name or premise['code'] == "" or (full_name.startswith("[") and full_name.endswith("]")):
    print('Error with full_name; skipping this one')

if now_premise in moudel_id:
    raise ValueError
    # continue
moudel_id[now_premise] = count # maps the now_premise object to a count
print(f"moudel_id: {moudel_id}\n")
premise_num.append(count)
corpus.append(temp_dict)
path_premise[def_path] = premise_num

def_path: Mathlib/Analysis/SpecialFunctions/Trigonometric/ArctanDeriv.lean

is_thm: True

premise: 

premise_code: theorem continuousAt_tan {x : ℝ} : ContinuousAt tan x ↔ cos x ≠ 0 

premise_goal: ContinuousAt Real.tan x ↔ Real.cos x ≠ 0

premise_context: ['x : ℝ']

moudel_id: {Premise(path='Mathlib/Analysis/SpecialFunctions/Trigonometric/ArctanDeriv.lean', full_name='Real.continuousAt_tan'): 4}



### Exploring the pre-processing

The goal of this part is to iterate over the train dataset (or val, test) from Lean-Dojo. Each train element has a list of tactics with state_before, tactics, state_after. We extract these first and clean them up a bit into a dictionary with 'context' and 'goal', where the context is a list.  

Then, we break the relation between (state, tactic, state), and just create a dictionary: 

`data_list[state]: {'context': [...], 'goal': '...'}`

`data_list[premise]: [1,2,3]`

Finally, the context and goal is what will get combined with `<VAR>` and `<GOAL>` tokens and this is what's processed by Bert. 

In [None]:
# open a line
DST_DIR = Path("/home/ex-anastasia/Premise-Retrieval/mathlib_handler_benchmark_410/random/")
train_path = DST_DIR / "train.jsonl"
train_lines = list(train_path.open())

# this replicates _get_dataset_path 
data_list = [] 
not_in = set() 

# iterate over all train elements
one_line_train = json.loads(train_lines[2000])

traced_tactics = one_line_train['tactics'] # a list of tactics
def_path = one_line_train['file_path']
print(f"def_path: {def_path}")

# each train element has a list of tactics, with state_before and state_after and tactic applied
i = 0
now_tactic = traced_tactics[i]
annotated_tactic = now_tactic['premises']
state_before = now_tactic['state_before']
state_after = now_tactic['state_after']


print(f"now_tactic: {now_tactic}\n")
print(f"annotated_tactic: {annotated_tactic}\n")

def_path: Mathlib/Topology/Compactness/Compact.lean
now_tactic: {'state_before': 'X : Type u\nY : Type v\nι : Type u_1\ninst✝¹ : TopologicalSpace X\ninst✝ : TopologicalSpace Y\ns t : Set X\nl : Filter X\nhs : IsCompact s\n⊢ Disjoint (𝓝ˢ s) l ↔ ∀ x ∈ s, Disjoint (𝓝 x) l', 'state_after': 'X : Type u\nY : Type v\nι : Type u_1\ninst✝¹ : TopologicalSpace X\ninst✝ : TopologicalSpace Y\ns t : Set X\nl : Filter X\nhs : IsCompact s\nH : ∀ x ∈ s, Disjoint (𝓝 x) l\n⊢ Disjoint (𝓝ˢ s) l', 'tactic': 'refine ⟨fun h x hx => h.mono_left <| nhds_le_nhdsSet hx, fun H => ?_⟩', 'premises': [{'full_name': 'Disjoint.mono_left', 'def_path': 'Mathlib/Order/Disjoint.lean', 'def_pos': [66, 8], 'def_end_pos': [66, 26]}, {'full_name': 'Iff.intro', 'def_path': '.lake/packages/lean4/src/lean/Init/Core.lean', 'def_pos': [116, 2], 'def_end_pos': [116, 7]}, {'full_name': 'nhds_le_nhdsSet', 'def_path': 'Mathlib/Topology/NhdsSet.lean', 'def_pos': [130, 8], 'def_end_pos': [130, 23]}]}

annotated_tactic: [{'full_name': 'Di

In [43]:
def process_state(state):
    '''
    Cleans up the state string, splits double new lines into list elements, and splits context and goal
    '''
    # removes full line that start with "case" -- WHY?
    processed_state = re.sub(r"^case.*(?:\n|$)", "", state, flags=re.MULTILINE)
    split_states = processed_state.split("\n\n")
    proofstates = []
    for s in split_states:
        s = s.strip()
        if s == "no goals": # if there are no goals, we append an empty context and "no goals"
            proofstates.append({"context": [], "goal": "no goals"})
            continue
        if "⊢" not in s: # if there is no goal, we skip
            continue
        try: # otherwise we split the context and goal based on "⊢"
            context_str, goal = s.split("⊢")
            # remove new lines and extra spaces in the context and goal
            context_str = re.sub(r"\n\s+", " ", context_str).strip()
            goal = re.sub(r"\n\s+", " ", goal).strip()
            # split the context into a list of strings based on \n (every line is a context element)
            context = list(filter(lambda v: ":" in v, context_str.split("\n")))
            proofstates.append({"context": context, "goal": goal})
        except:
            print(state, s)
    return proofstates


# let's look at what the state_before and state_after look like after process_state
all_state_before = process_state(state_before)
all_state_after = process_state(state_after)
print(f"state_before: {state_before}\n")
print(f"all_state_before: {all_state_before}\n")

print(f"state_after: {state_after}\n")
print(f"all_state_after: {all_state_after}\n")

state_before: X : Type u
Y : Type v
ι : Type u_1
inst✝¹ : TopologicalSpace X
inst✝ : TopologicalSpace Y
s t : Set X
l : Filter X
hs : IsCompact s
⊢ Disjoint (𝓝ˢ s) l ↔ ∀ x ∈ s, Disjoint (𝓝 x) l

all_state_before: [{'context': ['X : Type u', 'Y : Type v', 'ι : Type u_1', 'inst✝¹ : TopologicalSpace X', 'inst✝ : TopologicalSpace Y', 's t : Set X', 'l : Filter X', 'hs : IsCompact s'], 'goal': 'Disjoint (𝓝ˢ s) l ↔ ∀ x ∈ s, Disjoint (𝓝 x) l'}]

state_after: X : Type u
Y : Type v
ι : Type u_1
inst✝¹ : TopologicalSpace X
inst✝ : TopologicalSpace Y
s t : Set X
l : Filter X
hs : IsCompact s
H : ∀ x ∈ s, Disjoint (𝓝 x) l
⊢ Disjoint (𝓝ˢ s) l

all_state_after: [{'context': ['X : Type u', 'Y : Type v', 'ι : Type u_1', 'inst✝¹ : TopologicalSpace X', 'inst✝ : TopologicalSpace Y', 's t : Set X', 'l : Filter X', 'hs : IsCompact s', 'H : ∀ x ∈ s, Disjoint (𝓝 x) l'], 'goal': 'Disjoint (𝓝ˢ s) l'}]



In [45]:
# i don't run this one as i didn't compute full moudel_id
premises_list = []
for j in range(len(annotated_tactic)): # iterate over all used premises
    now_premise = annotated_tactic[j]
    # make into a premise object
    premise = Premise(now_premise['def_path'], now_premise['full_name'],
                        Pos(now_premise['def_pos'][0], now_premise['def_pos'][1]),
                        Pos(now_premise['def_end_pos'][0], now_premise['def_end_pos'][1]))
    # save if we hadn't extracted this guy from the corpus
    if premise not in moudel_id:
        not_in.add(premise)
        continue
    # otherwise append to list the id of the premise; so premises_list is a list of id's used in the proof
    premises_list.append(moudel_id[premise])

This coming code, is the piece that 'destroys' the relation between state before, tactic, state after, as if just appends everything into the `data_list`. 

Specifically, the `data_list` looks like: 

`data_list[state]: {'context': [...], 'goal': '...'}`

`data_list[premise]: [1,2,3]`

where we thus have the context (variables), the goal of this premise and the list of premises applied. This facilitates the learning of: which premises can we apply given a particular context and a particular goal. 

In [52]:
premises_list = list(set(premises_list))
premises_list = sorted(premises_list)

# this iteration seems to only happen if we had inside one state_before multiple goals
for k in range(len(all_state_before)):
    now_state = all_state_before[k]
    if now_state['goal'] != "no goals":
        temp_dict = {}
        temp_dict["state"] = now_state
        temp_dict['premise'] = premises_list
        temp_dict['module'] = [def_path]
        temp_dict['state_str'] = state_before
        data_list.append(temp_dict)
    print(f"data_list[state]: {temp_dict['state']}\n")
    print(f"data_list[premise]: {temp_dict['premise']}\n")
    print(f"data_list[module]: {temp_dict['module']}\n")
    print(f"data_list[state_str]: {temp_dict['state_str']}\n")

# same here but for state after
for k in range(len(all_state_after)):
    now_state = all_state_after[k]
    if now_state['goal'] != "no goals":
        temp_dict = {}
        temp_dict["state"] = now_state
        temp_dict['premise'] = premises_list
        temp_dict['module'] = [def_path]
        temp_dict['state_str'] = state_after
        data_list.append(temp_dict)
    print(f"data_list[state]: {temp_dict['state']}\n")
    print(f"data_list[premise]: {temp_dict['premise']}\n")
    print(f"data_list[module]: {temp_dict['module']}\n")
    print(f"data_list[state_str]: {temp_dict['state_str']}\n")

# at the end of these steps, we have data_list which is a list of dictionaries, each dictionary has a state, premise, module, and state_str
# what is the module? it is def_path, but why do we need it? it refers to the mathlib file from which we extracted this. I guess we want it for bookkeeping purposes. 


data_list[state]: {'context': ['X : Type u', 'Y : Type v', 'ι : Type u_1', 'inst✝¹ : TopologicalSpace X', 'inst✝ : TopologicalSpace Y', 's t : Set X', 'l : Filter X', 'hs : IsCompact s'], 'goal': 'Disjoint (𝓝ˢ s) l ↔ ∀ x ∈ s, Disjoint (𝓝 x) l'}

data_list[premise]: []

data_list[module]: ['Mathlib/Topology/Compactness/Compact.lean']

data_list[state_str]: X : Type u
Y : Type v
ι : Type u_1
inst✝¹ : TopologicalSpace X
inst✝ : TopologicalSpace Y
s t : Set X
l : Filter X
hs : IsCompact s
⊢ Disjoint (𝓝ˢ s) l ↔ ∀ x ∈ s, Disjoint (𝓝 x) l

data_list[state]: {'context': ['X : Type u', 'Y : Type v', 'ι : Type u_1', 'inst✝¹ : TopologicalSpace X', 'inst✝ : TopologicalSpace Y', 's t : Set X', 'l : Filter X', 'hs : IsCompact s', 'H : ∀ x ∈ s, Disjoint (𝓝 x) l'], 'goal': 'Disjoint (𝓝ˢ s) l'}

data_list[premise]: []

data_list[module]: ['Mathlib/Topology/Compactness/Compact.lean']

data_list[state_str]: X : Type u
Y : Type v
ι : Type u_1
inst✝¹ : TopologicalSpace X
inst✝ : TopologicalSpace Y
s t : Se

In [None]:
def make_hashable(data):
    if isinstance(data, dict):
        return frozenset((key, make_hashable(value)) for key, value in data.items())
    elif isinstance(data, list):
        return tuple(make_hashable(item) for item in data)
    else:
        return data

state_hashes = {}
for temp_dict in data_list:
    state = temp_dict["state"]
    premise = temp_dict["premise"]
    module = temp_dict["module"]

    state_hash = hash(make_hashable(state))
    # hash(...) computes a unique integer hash for "state", allowing it to be used as a dictionary key.
    
    if state_hash in state_hashes:
        matched_dict = state_hashes[state_hash]
        matched_dict["premise"].extend(premise)
        matched_dict["module"].extend(module)
        matched_dict["premise"] = list(set(matched_dict["premise"]))
        matched_dict["module"] = list(set(matched_dict["module"]))
    else:
        state_hashes[state_hash] = {
            "state": state,
            "premise": premise,
            "module": module
        }

merged_data_list = list(state_hashes.values())
# writes merged_data_list to random_our train file

I don't believe we use this expand logic. We only use: 

`train_file_path = "./mathlib_handler_benchmark_410/random/random_our/pretrain_train_data.jsonl"`

`eval_file_path = "./mathlib_handler_benchmark_410/random/random_our/pretrain_eval_data.jsonl"`

In [None]:
# replicates logic of expand_data; uses the train file we just created
expanded_data = []

record = merged_data_list[0]

# remember that record['premise'] is a list of premise id's
# this logic iterates over the premise id's and creates a new record for each premise id of the form {"premise": premise_id, "all_premises": record['premise']}
for premise_id in record["premise"]:
    new_record = record.copy()
    new_record["premise"] = premise_id
    new_record['all_premises'] = record['premise']
    expanded_data.append(new_record)
# writes into train_expand_premise.jsonl

{'state': {'context': ['X : Type u', 'Y : Type v', 'ι : Type u_1', 'inst✝¹ : TopologicalSpace X', 'inst✝ : TopologicalSpace Y', 's t : Set X', 'l : Filter X', 'hs : IsCompact s'], 'goal': 'Disjoint (𝓝ˢ s) l ↔ ∀ x ∈ s, Disjoint (𝓝 x) l'}, 'premise': [], 'module': ['Mathlib/Topology/Compactness/Compact.lean']}


In [None]:
# i skip this logic for now as its more or less the same as the _get_dataset_path part 
def _get_dataset_path_test(path, save_path, is_train, modeul_id):
    data_list = []
    not_in = set()
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            traced_tactics = data['tactics']
            for i in range(len(traced_tactics)):
                now_tactic = traced_tactics[i]
                annotated_tactic = now_tactic['premises']
                state_str = now_tactic['state_before']
                state_before = process_state(now_tactic['state_before'])
                premises_list = []
                for j in range(len(annotated_tactic)):
                    now_premise = annotated_tactic[j]
                    premise = Premise(now_premise['def_path'], now_premise['full_name'],
                                      Pos(now_premise['def_pos'][0], now_premise['def_pos'][1]),
                                      Pos(now_premise['def_end_pos'][0], now_premise['def_end_pos'][1]))
                    if premise not in modeul_id:
                        not_in.add(premise)
                        continue
                    premises_list.append(modeul_id[premise])
                if len(premises_list) == 0:
                    continue
                premises_list = list(set(premises_list))
                premises_list = sorted(premises_list)
                if state_before[0]['goal'] != "no goals":

                    temp_dict = {}
                    temp_dict['state'] = state_before
                    temp_dict['premise'] = premises_list
                    temp_dict['state_str'] = state_str
                    data_list.append(temp_dict)
                else:
                    if len(state_before) != 0:
                        raise ValueError
    if is_train:
        state_hashes = {}

        for temp_dict in data_list:
            state = temp_dict["state"]
            premise = temp_dict["premise"]
            state_str = temp_dict['state_str']
            state_hash = hash(make_hashable(state))

            # Flag to indicate if we've found a matching state in merged_data_list
            if state_hash in state_hashes:
                matched_dict = state_hashes[state_hash]
                matched_dict["premise"].extend(premise)
            else:
                state_hashes[state_hash] = {
                    "state": state,
                    "premise": premise,
                    'state_str': state_str
                }

        merged_data_list = list(state_hashes.values())
    else:
        merged_data_list = data_list

    with open(save_path, 'w', encoding='utf-8') as f:
        for item in merged_data_list:
            json.dump(item, f, ensure_ascii=False)
            f.write('\n')
    print(len(merged_data_list))
    print(f"Data written to {save_path}")

### The final strings 

So we ended with a data_list with: 

`data_list[state]: {'context': [...], 'goal': '...'}`

`data_list[premise]: [1,2,3]`

For the first pre-train phase, we will simply translate `data_list[state]` into: 

`<VAR>c_1<VAR>c_2<VAR>...<GOAL>goal`
and this is the string we will tokenise. 

Hence: in the pretrain phase, we do not use `data_list[premise]`. I presume we will use this in the RAG style finetuning. 

In [60]:
import datasets
import random
import re
from transformers import BertConfig, BertTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling

In [61]:
class data_transform(object):
    def __init__(self, shuffle_prob, remove_prob):
        self.shuffle_prob = shuffle_prob
        self.remove_prob = remove_prob

    def __call__(self, sample):
        augmented_list = sample.copy()
        if random.random() < self.shuffle_prob:
            random.shuffle(augmented_list)
        if len(sample) >= 15 and random.random() < self.remove_prob:
            augmented_list = [elem for elem in augmented_list if random.random() > 0.2]

        return augmented_list
    

def process_strings(string_list):
    processed_list = []
    for s in string_list:
        s = "<VAR>" + s
        s = re.sub(r'\n\s+', ' ', s)
        processed_list.append(s)
    result = ''.join(processed_list)
    return result

transform = data_transform(shuffle_prob=0, remove_prob=0)

In [62]:
tokenizer_dir = '/home/ex-anastasia/Premise-Retrieval/mathlib_handler_benchmark_410/random/random_our/tokenizer/'
tokenizer = BertTokenizer.from_pretrained(os.path.join(tokenizer_dir, 'vocab.txt'), do_lower_case=False)
special_tokens = ["<VAR>", "<GOAL>"]
tokenizer.add_tokens(special_tokens)

print(len(tokenizer))

30521




In [63]:
file_path_train = "/home/ex-anastasia/Premise-Retrieval/mathlib_handler_benchmark_410/random/random_our/pretrain_train_data.jsonl"
data = datasets.load_dataset('json', data_files=file_path_train, split='train')

item = 0 

# here we iterate over the state context and goal
context = data[item]['state']['context']
goal = data[item]['state']['goal']

is_train = True
print(f'Context raw: {context}\n')
print(f'Goal raw: {goal}\n')
if is_train:
    context = transform(context)
context = process_strings(context)
print(f'Context cleaned: {context}\n')
goal = re.sub(r'\n\s+', ' ', goal)
print(f'Goal cleaned: {goal}\n')
combine = context + '<GOAL>' + goal
print(f'Combine: {combine}\n')
input = tokenizer(combine, truncation=True, padding='max_length', max_length=512, return_special_tokens_mask=True)
print(f'Tokenised: {input}\n')

Context raw: ['C : Type u', 'inst✝⁴ : Category.{v, u} C', 'inst✝³ : ConcreteCategory C', 'inst✝² : HasLimits C', 'inst✝¹ : ConcreteCategory.forget.ReflectsIsomorphisms', 'inst✝ : PreservesLimits ConcreteCategory.forget', 'X : TopCat', 'F : Sheaf C X', 'ι : Type v', 'U : ι → Opens ↑X', 's t : (CategoryTheory.forget C).obj (F.val.obj (op (iSup U)))', 'h : ∀ (i : ι), (F.val.map (leSupr U i).op) s = (F.val.map (leSupr U i).op) t', 'sf : (i : ι) → (CategoryTheory.forget C).obj (F.val.obj (op (U i))) := fun i => (F.val.map (leSupr U i).op) s', 'sf_compatible : IsCompatible F.val U sf']

Goal raw: s = t

Context cleaned: <VAR>C : Type u<VAR>inst✝⁴ : Category.{v, u} C<VAR>inst✝³ : ConcreteCategory C<VAR>inst✝² : HasLimits C<VAR>inst✝¹ : ConcreteCategory.forget.ReflectsIsomorphisms<VAR>inst✝ : PreservesLimits ConcreteCategory.forget<VAR>X : TopCat<VAR>F : Sheaf C X<VAR>ι : Type v<VAR>U : ι → Opens ↑X<VAR>s t : (CategoryTheory.forget C).obj (F.val.obj (op (iSup U)))<VAR>h : ∀ (i : ι), (F.val.map

In [64]:
# Detokenize (convert back to text)
detokenized_text = tokenizer.decode(input.input_ids)
print("Detokenized Text:", detokenized_text)

Detokenized Text: [CLS] <VAR> C : Type u <VAR> inst✝⁴ : Category. { v, u } C <VAR> inst✝³ : ConcreteCategory C <VAR> inst✝² : HasLimits C <VAR> inst✝¹ : ConcreteCategory. forget. ReflectsIsomorphisms <VAR> inst✝ : PreservesLimits ConcreteCategory. forget <VAR> X : TopCat <VAR> F : Sheaf C X <VAR> ι : Type v <VAR> U : ι → Opens ↑X <VAR> s t : ( CategoryTheory. forget C ). obj ( F. val. obj ( op ( iSup U ) ) ) <VAR> h : ∀ ( i : ι ), ( F. val. map ( leSupr U i ). op ) s = ( F. val. map ( leSupr U i ). op ) t <VAR> sf : ( i : ι ) → ( CategoryTheory. forget C ). obj ( F. val. obj ( op ( U i ) ) ) : = fun i = > ( F. val. map ( leSupr U i ). op ) s <VAR> sf _ compatible : IsCompatible F. val U sf <GOAL> s = t [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] 

### Exploring the masking

When we do the input into the model, we use the standard BertMaskedLM masking: 

`data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)`

which masks as follows: 

In [None]:
# Apply the collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
batch = [input]
masked_batch = data_collator(batch)

# note here: 30520 is the id of the <MASK>
print("Original Input IDs:", batch[0]["input_ids"])
print("Masked Input IDs:", masked_batch["input_ids"])
print("Labels (original tokens for masked positions):", masked_batch["labels"])
# -100 gets ignored by the loss function; so we only care about the loss for the masked positions

Original Input IDs: [0, 4, 40, 31, 614, 89, 4, 743, 31, 652, 19, 95, 90, 17, 89, 97, 40, 4, 720, 31, 3119, 40, 4, 695, 31, 5571, 40, 4, 660, 31, 3119, 19, 1456, 19, 6180, 4, 619, 31, 5257, 3119, 19, 1456, 4, 61, 31, 1956, 4, 43, 31, 2422, 40, 61, 4, 129, 31, 614, 90, 4, 58, 31, 129, 229, 1321, 2291, 4, 87, 88, 31, 13, 677, 19, 1456, 40, 14, 19, 783, 13, 43, 19, 978, 19, 783, 13, 914, 13, 3432, 58, 14, 14, 14, 4, 76, 31, 241, 13, 77, 31, 129, 14, 17, 13, 43, 19, 978, 19, 712, 13, 13210, 58, 77, 14, 19, 914, 14, 87, 34, 13, 43, 19, 978, 19, 712, 13, 13210, 58, 77, 14, 19, 914, 14, 88, 4, 4641, 31, 13, 77, 31, 129, 14, 229, 13, 677, 19, 1456, 40, 14, 19, 783, 13, 43, 19, 978, 19, 783, 13, 914, 13, 58, 77, 14, 14, 14, 31, 34, 678, 77, 34, 35, 13, 43, 19, 978, 19, 712, 13, 13210, 58, 77, 14, 19, 914, 14, 87, 4, 4641, 68, 22480, 31, 7830, 43, 19, 978, 58, 4641, 5, 87, 34, 88, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 