In [1]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=5

env: CUDA_VISIBLE_DEVICES=5


In [3]:
from tqdm import tqdm
from copy import copy
import itertools
import os, sys
import yaml
import json
import gc
import argparse
import functools
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data.distributed import DistributedSampler
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from datasets import Dataset

sys.path.append(os.path.abspath('/homes/80/anya/Documents/llm_tiny_ideas/coconut-outer/coconut'))
from utils import Config, set_seed
from coconut import Coconut
from dataset import get_dataset, get_question_latent_dataset, get_cot_latent_dataset, MyCollator


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# rank = int(os.environ["RANK"])
# print(rank)
# world_size = int(os.environ["WORLD_SIZE"])
# print(world_size)

  from .autonotebook import tqdm as notebook_tqdm


cuda


# Transitions dataset

In [7]:
min_transitions_size = 4
max_transitions_size = 26
min_num_steps = 4
max_num_steps = 100
all_symbols = [s for s in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ']
dataset_size = int(1e5)

In [None]:
dataset_transition_dicts = []
dataset_num_steps = []
dataset_start_symbols = []
dataset_answers = []
for i in tqdm(range(dataset_size), desc='Generating dataset', total=dataset_size):
    transitions_size = np.random.randint(min_transitions_size, max_transitions_size)
    num_steps = np.random.randint(min_num_steps, max_num_steps)
    from_symbols = np.random.choice(all_symbols, size=np.random.randint(min_transitions_size, max_transitions_size), replace=False)
    to_symbols = from_symbols.copy()
    np.random.shuffle(to_symbols)
    transition_dict = dict(zip(from_symbols, to_symbols))
    start_symbol = np.random.choice(from_symbols)
    end_symbol = start_symbol.copy()
    for _ in range(num_steps):
        end_symbol = transition_dict[end_symbol]
    dataset_transition_dicts.append(transition_dict)
    dataset_num_steps.append(num_steps)
    dataset_start_symbols.append(start_symbol)
    dataset_answers.append(end_symbol)

template = "The final answer will be given after \"####\". The transitions are: <TRANSITIONS>. Let's think step-by-step and work out the symbol reached if we start at <START> and take <NUM_STEPS> steps."
dataset_questions = []
for i in tqdm(range(dataset_size), desc='Formatting questions', total=dataset_size):
    dataset_transition_string = '{' + ', '.join([f'{k}->{v}' for k, v in dataset_transition_dicts[i].items()]) + '}'
    question = template.replace('<TRANSITIONS>', dataset_transition_string).replace('<START>', dataset_start_symbols[i]).replace('<NUM_STEPS>', str(dataset_num_steps[i]))
    dataset_questions.append(question)

transitions_dataset = Dataset.from_dict({'question': dataset_questions, 'answer': dataset_answers})
print(dataset)
print(dataset[0], "\n")


Generating dataset: 100%|██████████| 100000/100000 [00:03<00:00, 30594.74it/s]
Formatting questions: 100%|██████████| 100000/100000 [00:00<00:00, 221440.12it/s]


Dataset({
    features: ['question', 'answer'],
    num_rows: 100000
})
{'question': 'The final answer will be given after "####". The transitions are: {n->b, b->n, J->J, y->y}. Let\'s think step-by-step and work out the symbol reached if we start at n and take 76 steps.', 'answer': 'n'} 



# n-ary addition dataset

In [6]:
num_digits = 3
all_number_pairs = np.arange(10**(2*num_digits-1), 10**(2*num_digits))
np.random.shuffle(all_number_pairs)

number1s = all_number_pairs // 10**num_digits
number2s = all_number_pairs % 10**num_digits
answers = number1s + number2s

questions = [
    f"Let's work out the answer to {number1} + {number2} and give the answer after \"####\"."
    for number1, number2 in zip(number1s, number2s)
]
questions_reversed = [
    f"Let's work out the answer to {number1} + {number2} and give the answer after \"####\" starting from the rightmost digit."
    for number1, number2 in zip(number1s, number2s)
]
answers = [str(answer) for answer in answers]
answers_reversed = [str(answer)[::-1] for answer in answers]

addition_dataset = Dataset.from_dict({'question': questions, 'answer': answers})
addition_reversed_dataset = Dataset.from_dict({'question': questions_reversed, 'answer': answers_reversed})
print(addition_dataset)
print(addition_dataset[0], "\n")
print(addition_reversed_dataset)
print(addition_reversed_dataset[0], "\n")

Dataset({
    features: ['question', 'answer'],
    num_rows: 900000
})
{'question': 'Let\'s work out the answer to 355 + 484 and give the answer after "####".', 'answer': '839'} 

Dataset({
    features: ['question', 'answer'],
    num_rows: 900000
})
{'question': 'Let\'s work out the answer to 355 + 484 and give the answer after "####" starting from the rightmost digit.', 'answer': '938'} 

