In [1]:
import bisect
import os
import rich
import sys
import torch
import transformers
import itertools
import more_itertools
import random
import re

import solid

sys.path.append("../")
import lib_data
import lib_utils

In [2]:
#
# Anything that can be changed:
#

MODEL = "mistralai/Mistral-7B-v0.1"

ANSWER_ONLY = False
ANSWER_ONLY_PATH = "/network/scratch/g/gagnonju/saved_scratchpad_gen_outputs/chatgpt-3.5-commonsenseqa-scratchpads/not-cond-on-answers/commonsenseqa.chatgpt"
ARITHMETIC_DATASET_ROOT_FOLDER_DIR = "/home/mila/g/gagnonju/Marg-Li-CoT/with_trl/libs_data/arithmetic/"
BATCH_SIZE = 2

DATASET_NAME = lib_data.DatasetChoices.ARITHMETIC
INPUT_MAX_LENGTH = 115
SPLIT = lib_utils.CVSets.TRAIN
USE_CURRICULUM = True
USE_FEW_SHOTS = False
CURRICULUM_PROPORTIONS = {3: 1.}

In [3]:
#
# Load the tokenizer & the model
#

t = transformers.AutoTokenizer.from_pretrained(MODEL)
m = transformers.AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, device_map={"": 0})

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
#
# Prepare the dataset and the dataloader and set the curriculum proportions
#

ds = lib_data.prep_dataset_rl(
        split=SPLIT,
        any_tokenizer=t,
        answer_only=ANSWER_ONLY,
        answer_only_path=ANSWER_ONLY_PATH,
        dataset_name=DATASET_NAME,
        input_max_length=INPUT_MAX_LENGTH,
        use_few_shots=USE_FEW_SHOTS,
        use_curriculum=USE_CURRICULUM,
        
        question_prefix=None,
        question_suffix=None,
        extr_arith_ignore_one_line=True,
        arithmetic_dataset_root_folder_dir=ARITHMETIC_DATASET_ROOT_FOLDER_DIR,
)
assert USE_CURRICULUM
dataloader = torch.utils.data.DataLoader(
    batch_size  = BATCH_SIZE,
    collate_fn  = lib_data.data_item_collator,
    dataset     = ds,
    num_workers = 0,
)
dataloader.dataset.set_proportion_difficulties(
    CURRICULUM_PROPORTIONS
)

1.jsonl: Building DataListContainer, including Tok-detok.: 100%|██████████| 2000/2000 [00:00<00:00, 220150.33it/s]


2.jsonl: Building DataListContainer, including Tok-detok.: 100%|██████████| 2000/2000 [00:00<00:00, 220909.80it/s]
3.jsonl: Building DataListContainer, including Tok-detok.: 100%|██████████| 2000/2000 [00:00<00:00, 215468.20it/s]
4.jsonl: Building DataListContainer, including Tok-detok.: 100%|██████████| 2000/2000 [00:00<00:00, 193562.42it/s]
5.jsonl: Building DataListContainer, including Tok-detok.: 100%|██████████| 2000/2000 [00:00<00:00, 193996.62it/s]
Loading files: 100%|██████████| 5/5 [00:00<00:00, 12.04it/s]


In [5]:
it = iter(dataloader)

In [6]:
raw_sample = vars(next(it))

In [7]:
# Tokenize with offset mappings

prepped_sample = {}

RAW_KEY = "raw"
TOKENIZED_KEY = "tokenized"
UNCHANGED_KEY = "unchanged"

for k, v in raw_sample.items():
    prepped_sample[k] = {}
    
    if isinstance(v[0], str):
        prepped_sample[k][RAW_KEY] = v
        prepped_sample[k][TOKENIZED_KEY] = t(v, return_offsets_mapping=True)
    else:
        prepped_sample[k][UNCHANGED_KEY] = v

for k, v in sorted(prepped_sample.items()):
    for kk in sorted(v):
        print(f"{k} - {kk}")


detok_ref_answer - raw
detok_ref_answer - tokenized
detok_ref_query - raw
detok_ref_query - tokenized
detok_ref_scratchpad - raw
detok_ref_scratchpad - tokenized
difficulty_level - unchanged
extra_information - unchanged


In [8]:

"""

Overall noising idea:

1. Find the substring
2. Find the token that contains the substring
    2.a) Stack that shit in a batch list
    2.b) loop for the whole batch
3. Get the model logits for the position of the token
4. Exclude the correct token 
5. Find a new candidate according to a criteria
6. Make sure that the correct answer is not still 
   contained in the new selection just slightly differently somehow

"""


GOODS = "0123456789▁"
GOODS_IDS = set([v for k, v in t.vocab.items() if k in GOODS and len(k) == 1])
assert len(GOODS) == len(GOODS_IDS), (len(GOODS), len(GOODS_IDS))



# Compose bads
# def compose_bads():
#     BADS_IDS = set(t.vocab.values()) - GOODS_IDS
#     EVERY_IDS_BUT_NUMBERS_AND_SPACES = [[x] for x in BADS_IDS]
#     MULTI_NUMBERS = []
#     for i in range(2, 10):
#         MULTI_NUMBERS += [list(x) for x in itertools.combinations_with_replacement(NUMBERS_IDS, i)]

#     print(f"{MULTI_NUMBERS = }")
#     EVERY_IDS_BUT_NUMBERS_AND_SPACES += MULTI_NUMBERS



NUMBERS = "0123456789"
NUMBERS_IDS = set([v for k, v in t.vocab.items() if k in NUMBERS and len(k) == 1])
COMMA_ID = t.vocab["▁,"]
SPACE_ID = t.vocab["▁"]
NUMBERS_OR_COMMA = NUMBERS_IDS | {COMMA_ID}

def prefix_allowed_tokens_fn(batch_id, idx_so_far):
    index = idx_so_far[-1].item()
    print(f"Prev Index: {index}")
    print(f"Prev Value: \"{t.decode([index])}\"")
    
    if index in NUMBERS_OR_COMMA:
        ret_val = [SPACE_ID]
        print("Generating space.")
    else:
        ret_val = list(NUMBERS_IDS)
        print("Generating Number.")
    
    print()
    return ret_val

replace_fn = solid.ReplaceWithMostLikely(prefix_allowed_tokens_fn)

In [13]:


batch_idx  = 0

# Prepare the scratchpad
scratchpad         = prepped_sample["detok_ref_scratchpad"]
full_offsetmapping = scratchpad[TOKENIZED_KEY]["offset_mapping"][batch_idx]
input_input_ids    = scratchpad[TOKENIZED_KEY]["input_ids"][batch_idx].copy()
output_input_ids   = input_input_ids.copy()
scratchpad_text    = scratchpad["raw"][batch_idx]
offset_mapping     = [x[1] for x in full_offsetmapping]


table = rich.table.Table("Reference", highlight=True)
table.add_row(t.decode(input_input_ids))
rich.print(table)

# Prepare the replacement & the search criteria. 
# Should be the sub ansers, so whatever is after ","
for i, (start, end) in enumerate(solid.find_intermediate_answers(scratchpad_text)):
    print("#" * 80)

    bisect_idx_start, bisect_idx_end = solid.string_span_to_input_ids_span(
        span_str           = (start, end), 
        idx_offset_mapping = offset_mapping, 
    )
    
    text_first_token = t.decode(input_input_ids[bisect_idx_start:bisect_idx_start + 1], skip_special_tokens=True)
    text_last_token  = t.decode(input_input_ids[bisect_idx_end  :bisect_idx_end   + 1], skip_special_tokens=True)

    correct_answer = t.decode(
        input_input_ids[bisect_idx_start:bisect_idx_end], 
        skip_special_tokens=True,
    ).replace("\n", "").strip()
    
    start_len = len(output_input_ids)
    
    output_input_ids, replacement_answer = replace_fn(
        output_input_ids = output_input_ids,
        bisect_idx_start = bisect_idx_start,
        bisect_idx_end   = bisect_idx_end,
        tokenizer        = t,
        model            = m,
    )

    assert len(output_input_ids) == start_len, (len(output_input_ids), start_len)

    table = rich.table.Table(f"Replacement #{i}", highlight=True, show_lines=True)
    table.add_row(f"[bold]correct_answer:"    , correct_answer)
    table.add_row(f"[bold]replacement_answer:", replacement_answer)
    table.add_row(f"[bold]Replacement so far" , t.decode(output_input_ids, skip_special_tokens=True))
    rich.print(table)
    
    print("#" * 80)



################################################################################


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Prev Index: 1200
Prev Value: ","
Generating space.

Prev Index: 1200
Prev Value: ","
Generating space.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28783
Prev Value: "8"
Generating space.

Prev Index: 28740
Prev Value: "1"
Generating space.



################################################################################
################################################################################


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Prev Index: 1200
Prev Value: ","
Generating space.

Prev Index: 1200
Prev Value: ","
Generating space.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28734
Prev Value: "0"
Generating space.

Prev Index: 28734
Prev Value: "0"
Generating space.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28740
Prev Value: "1"
Generating space.

Prev Index: 28740
Prev Value: "1"
Generating space.



################################################################################
################################################################################


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Prev Index: 1200
Prev Value: ","
Generating space.

Prev Index: 1200
Prev Value: ","
Generating space.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28770
Prev Value: "3"
Generating space.

Prev Index: 28734
Prev Value: "0"
Generating space.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28740
Prev Value: "1"
Generating space.

Prev Index: 28770
Prev Value: "3"
Generating space.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28740
Prev Value: "1"
Generating space.

Prev Index: 28734
Prev Value: "0"
Generating space.



################################################################################
################################################################################


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Prev Index: 1200
Prev Value: ","
Generating space.

Prev Index: 1200
Prev Value: ","
Generating space.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28734
Prev Value: "0"
Generating space.

Prev Index: 28734
Prev Value: "0"
Generating space.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28740
Prev Value: "1"
Generating space.

Prev Index: 28740
Prev Value: "1"
Generating space.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28740
Prev Value: "1"
Generating space.

Prev Index: 28740
Prev Value: "1"
Generating space.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28705
Prev Value: ""
Generating Number.

Prev Index: 28740
Prev Value: "1"
Generating space.

Prev Index: 28734
Prev Value: "0"
Generating space.



################################################################################
