In [1]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=5

env: CUDA_VISIBLE_DEVICES=5


In [36]:
from tqdm import tqdm
from copy import copy
import itertools
import os, sys
import yaml
import json
import gc
import argparse
import functools
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data.distributed import DistributedSampler
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from datasets import Dataset

sys.path.append(os.path.abspath('/homes/80/anya/Documents/llm_tiny_ideas/coconut-outer/coconut'))
from utils import Config, set_seed
from coconut import Coconut
from dataset import get_dataset, get_question_latent_dataset, get_cot_latent_dataset, MyCollator


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# rank = int(os.environ["RANK"])
# print(rank)
# world_size = int(os.environ["WORLD_SIZE"])
# print(world_size)

cuda


# Transitions dataset

In [None]:
min_transitions_size = 4
max_transitions_size = 26
min_num_steps = 4
max_num_steps = 100
all_symbols = [s for s in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ']
dataset_size = int(1e5)

dataset_transition_dicts = []
dataset_num_steps = []
dataset_start_symbols = []
dataset_answers = []
for i in tqdm(range(dataset_size), desc='Generating dataset', total=dataset_size):
    transitions_size = np.random.randint(min_transitions_size, max_transitions_size)
    num_steps = np.random.randint(min_num_steps, max_num_steps)
    from_symbols = np.random.choice(all_symbols, size=np.random.randint(min_transitions_size, max_transitions_size), replace=False)
    to_symbols = from_symbols.copy()
    np.random.shuffle(to_symbols)
    transition_dict = dict(zip(from_symbols, to_symbols))
    start_symbol = np.random.choice(from_symbols)
    end_symbol = start_symbol.copy()
    for _ in range(num_steps):
        end_symbol = transition_dict[end_symbol]
    dataset_transition_dicts.append(transition_dict)
    dataset_num_steps.append(num_steps)
    dataset_start_symbols.append(start_symbol)
    dataset_answers.append(end_symbol)

template = "The final answer will be given after \"####\". The transitions are: <TRANSITIONS>. Let's think step-by-step and work out the symbol reached if we start at <START> and take <NUM_STEPS> steps."
dataset_questions = []
for i in tqdm(range(dataset_size), desc='Formatting questions', total=dataset_size):
    dataset_transition_string = '{' + ', '.join([f'{k}->{v}' for k, v in dataset_transition_dicts[i].items()]) + '}'
    question = template.replace('<TRANSITIONS>', dataset_transition_string).replace('<START>', dataset_start_symbols[i]).replace('<NUM_STEPS>', str(dataset_num_steps[i]))
    dataset_questions.append(question)

transitions_dataset = Dataset.from_dict({'question': dataset_questions, 'answer': dataset_answers})
print(transitions_dataset)
print(transitions_dataset[0], "\n")


Generating dataset: 100%|██████████| 100000/100000 [00:03<00:00, 29632.21it/s]
Formatting questions: 100%|██████████| 100000/100000 [00:00<00:00, 225843.23it/s]


Dataset({
    features: ['question', 'answer'],
    num_rows: 100000
})
{'question': 'The final answer will be given after "####". The transitions are: {h->Z, I->h, e->I, K->V, V->Y, F->C, v->v, C->F, W->W, Z->e, Y->K}. Let\'s think step-by-step and work out the symbol reached if we start at K and take 93 steps.', 'answer': 'K'} 



# n-ary addition dataset

In [4]:
num_digits = 3
all_number_pairs = np.arange(10**(2*num_digits-1), 10**(2*num_digits))
np.random.shuffle(all_number_pairs)

number1s = all_number_pairs // 10**num_digits
number2s = all_number_pairs % 10**num_digits
answers = number1s + number2s

questions = [
    f"Let's work out the answer to {number1} + {number2} and give the answer after \"####\"."
    for number1, number2 in zip(number1s, number2s)
]
questions_reversed = [
    f"Let's work out the answer to {number1} + {number2} and give the answer after \"####\" starting from the rightmost digit."
    for number1, number2 in zip(number1s, number2s)
]
answers = [str(answer) for answer in answers]
answers_reversed = [str(answer)[::-1] for answer in answers]

addition_dataset = Dataset.from_dict({'question': questions, 'answer': answers})
addition_reversed_dataset = Dataset.from_dict({'question': questions_reversed, 'answer': answers_reversed})
print(addition_dataset)
print(addition_dataset[0], "\n")
print(addition_reversed_dataset)
print(addition_reversed_dataset[0], "\n")

Dataset({
    features: ['question', 'answer'],
    num_rows: 900000
})
{'question': 'Let\'s work out the answer to 907 + 741 and give the answer after "####".', 'answer': '1648'} 

Dataset({
    features: ['question', 'answer'],
    num_rows: 900000
})
{'question': 'Let\'s work out the answer to 907 + 741 and give the answer after "####" starting from the rightmost digit.', 'answer': '8461'} 



# p-hop

In [25]:
min_num_hops = 1
max_num_hops = 4
sequence_length = 20
alphabet_size = 6
all_symbols = [s for s in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ']
all_symbols = [s for s in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ']
dataset_size = int(1e5)

dataset_sequences = []
dataset_letters = []
dataset_num_hops = []
dataset_answers = []
i = 0
for _ in tqdm(range(dataset_size*2), desc='Generating dataset', total=dataset_size*2):
    if i >= dataset_size:
        print("Dataset size reached.")
        break
    alphabet = np.random.choice(all_symbols, size=alphabet_size, replace=False)
    sequence = np.random.choice(alphabet, size=sequence_length, replace=True)
    letter = np.random.choice(alphabet)
    num_hops = np.random.randint(min_num_hops, max_num_hops)
    letter_idxs = np.where(sequence == letter)[0]
    if len(letter_idxs) <= num_hops or letter_idxs[-num_hops] == sequence_length-1:
        continue
    answer = sequence[letter_idxs[-num_hops]+1]
    dataset_sequences.append(sequence)
    dataset_letters.append(letter)
    dataset_num_hops.append(num_hops)
    dataset_answers.append(answer)
    i += 1

assert len(dataset_sequences) == dataset_size, f"Dataset size is {len(dataset_sequences)}"

int_to_pos = {1: '', 2: ' 2nd', 3: ' 3rd', 4: ' 4th', 5: ' 5th', 6: ' 6th', 7: ' 7th', 8: ' 8th', 9: ' 9th', 10: ' 10th'}
template = "The final answer will be given after \"####\". The sequence is: <SEQUENCE>. Let's think step-by-step and find what letter comes after the<POS> last \"<LETTER>\"."
dataset_questions = []
for i in tqdm(range(dataset_size), desc='Formatting questions', total=dataset_size):
    sequence_as_string = "[" + ', '.join(dataset_sequences[i]) + "]"
    question = template.replace('<SEQUENCE>', sequence_as_string).replace('<POS>', int_to_pos[dataset_num_hops[i]]).replace('<LETTER>', dataset_letters[i])
    dataset_questions.append(question)

phop_dataset = Dataset.from_dict({'question': dataset_questions, 'answer': dataset_answers})
print(phop_dataset)
print(phop_dataset[0], "\n")


Generating dataset:  83%|████████▎ | 165085/200000 [00:04<00:00, 38806.01it/s]


Dataset size reached.


Formatting questions: 100%|██████████| 100000/100000 [00:00<00:00, 282249.51it/s]


Dataset({
    features: ['question', 'answer'],
    num_rows: 100000
})
{'question': 'The final answer will be given after "####". The sequence is: [B, E, C, E, C, B, B, B, C, B, B, G, B, B, V, G, Y, E, V, V]. Let\'s think step-by-step and find what letter comes after the 3rd last "B".', 'answer': 'G'} 



# i-gsm

In [38]:
all_symbols = [s for s in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ']
def generate_problem(depth=4, modulo=7):
    """
    Generate one i-GSM style problem.
    The problem is represented as a series of assignments (a DAG) with the following design:
    - We start with a few constant assignments.
    - Each subsequent assignment computes a new variable as a function (sum, subtraction, or multiplication)
      of one to three previously defined variables and/or constants.
    - All arithmetic is done modulo 7.
    - Finally, a target variable is chosen and the problem is printed with its answer.
    
    Returns:
        A tuple (problem_str, answer) where problem_str is the full problem as a string,
        and answer is the computed value (an integer between 0 and 6) for the target variable.
    """
    # List to hold (var_name, expression, value) tuples.
    assignments = []
    
    # For reproducibility, you could seed random here if desired.
    
    # Step 1. Create a few initial assignments with constant values.
    num_initial = 2  # you can vary this number
    for _ in range(num_initial):
        var = np.random.choice(all_symbols)
        # choose a constant between 0 and 6
        const_val = np.random.randint(0, 6)
        const_val_mod = const_val % modulo if modulo is not None else const_val
        assignments.append((var, f"{const_val}", const_val_mod))
    
    # Step 2. Create additional assignments up to the desired depth.
    # We ensure each new assignment only uses variables that were defined earlier.
    num_assignments = depth * 2  # arbitrarily, total nodes ~ 2*depth
    for i in range(num_assignments):
        # Select a new variable name that is not already used.
        var = np.random.choice(all_symbols)
        while any(var == a[0] for a in assignments):
            var = np.random.choice(all_symbols)
            
        # Choose an operation type: addition, subtraction, or multiplication.
        op = np.random.choice(["+", "-", "*"])
        
        # Choose 1-3 operands randomly from previous assignments or a constant.
        num_operands = np.random.randint(1, 3)
        operands = []
        operand_values = []
        for _ in range(num_operands):
            if assignments and np.random.random() < 0.7:
                # choose an existing variable as operand
                rand_idx = np.random.randint(0, len(assignments))
                prev_var, _, prev_val = assignments[rand_idx]
                operands.append(prev_var)
                operand_values.append(prev_val)
            else:
                # or use a random constant
                const_val = np.random.randint(0, 6)
                operands.append(str(const_val))
                operand_values.append(const_val)
        
        # Build the expression string.
        expr_str = f" {op} ".join(operands)
        # Evaluate the expression modulo 7.
        # Note: for subtraction, we apply left-to-right evaluation.
        result = operand_values[0]
        for val in operand_values[1:]:
            if op == "+":
                result = (result + val)
            elif op == "-":
                result = (result - val)
            elif op == "*":
                result = (result * val)
            result = result % modulo if modulo is not None else result
                
        assignments.append((var, expr_str, result))
    
    # Step 3. Choose a target variable from the assignments (for instance, the last one)
    target_var, _, target_val = assignments[-1]
    
    # Build the problem string:
    # Each assignment is printed in the form "var := expression."
    # The final line is "target_var?" asking for its value.
    problem_lines = []
    for var, expr, _ in assignments:
        problem_lines.append(f"{var} := {expr}")
    problem_lines.append(f"{target_var}?")
    
    problem_str = ", ".join(problem_lines)
    
    return problem_str, target_val

dataset_size = int(1e3)

dataset_questions = []
dataset_answers = []
for i in tqdm(range(dataset_size), desc='Generating dataset', total=dataset_size):
    q, a = generate_problem(depth=4, modulo=None)
    dataset_questions.append(q)
    dataset_answers.append(a)

igsm_dataset = Dataset.from_dict({'question': dataset_questions, 'answer': dataset_answers})
print(igsm_dataset)
print(igsm_dataset[0], "\n")

Generating dataset: 100%|██████████| 1000/1000 [00:00<00:00, 5240.16it/s]

Dataset({
    features: ['question', 'answer'],
    num_rows: 1000
})
{'question': 'J := 1, F := 3, U := J + J, L := U * 3, I := F + J, X := L, G := J - X, B := U, Z := B - L, W := B + I, W?', 'answer': 6} 






# countdown

In [None]:
min_num = 1
max_num = 100
num_operands = 2
max_num_extras = 1
dataset_size = int(1e5)

dataset_numbers = []
dataset_target = []
dataset_answer = []
i = 0
for _ in tqdm(range(dataset_size), desc='Generating dataset', total=dataset_size):
    num_extras = np.random.randint(0, max_num_extras+1)
    numbers = np.random.randint(min_num, max_num+1, size=num_operands+1)
    extras = np.random.randint(min_num, max_num+1, size=num_extras)
    operands = np.random.choice(['+', '-', '*', '/'], size=num_operands)
    if all(op in ['+', '-'] for op in operands) or all(op in ['*', '/'] for op in operands):
        pass
    raise NotImplementedError

assert len(dataset_sequences) == dataset_size, f"Dataset size is {len(dataset_sequences)}"

template = f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
User: Using the numbers <NUMBERS>, create an equation that equals <TARGET>. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work and then return the final answer after \"####\", for example \"#### (1 + 2) / 3 ".
Assistant: Let me solve this step by step.
<think>"""

dataset_questions = []
for i in tqdm(range(dataset_size), desc='Formatting questions', total=dataset_size):
    numbers_as_string = "[" + ', '.join(dataset_numbers[i]) + "]"
    question = template.replace('<NUMBERS>', numbers_as_string).replace('<TARGET>', dataset_target[i])
    dataset_questions.append(question)


In [5]:
def careful_repeat(tensor, num_repeats):
    assert isinstance(tensor, torch.Tensor)
    batch_size = tensor.shape[0]
    if tensor.ndim == 1:
        tensor = tensor.unsqueeze(1).repeat(1, num_repeats).reshape(batch_size*num_repeats, *tensor.shape[1:])
    elif tensor.ndim == 2:
        tensor = tensor.unsqueeze(1).repeat(1, num_repeats, 1).reshape(batch_size*num_repeats, *tensor.shape[1:])
    else:
        raise ValueError(f"Invalid ndim: {tensor.ndim}")
    return tensor

In [10]:
x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.float32)
y = careful_repeat(x, 3)
z = y.mean(-1).reshape(2, 3)
print(x)
print(y)
print(z)

tensor([[1., 2., 3., 4.],
        [5., 6., 7., 8.]])
tensor([[1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [5., 6., 7., 8.],
        [5., 6., 7., 8.],
        [5., 6., 7., 8.]])
tensor([[2.5000, 2.5000, 2.5000],
        [6.5000, 6.5000, 6.5000]])
