In [1]:
continuous_series_names = [
                           'brownian_motion', 
                           # 'geometric_brownian_motion',
                           # 'noisy_logistic_map',
                           # 'logistic_map',
                           # 'lorenz_system',
                           # 'uncorrelated_gaussian',
                           # 'uncorrelated_uniform'
                           ]
markov_chain_names = ['markov_chain']

import numpy as np

### Set up directory
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3,4,5"
from pathlib import Path
parent_dir = os.path.dirname(os.getcwd())
sys.path.append(parent_dir)

from tqdm import tqdm
import pickle
import torch
from llama import get_model_and_tokenizer
from ICL import MultiResolutionPDF, recursive_refiner, trim_kv_cache, recursive_refiner_preprompt

# Check if directory exists, if not create it
save_path = Path(parent_dir) / 'processed_series_v2'
if not os.path.exists(save_path):
    os.makedirs(save_path)
    
# Define the directory where the generated series are stored
generated_series_dir = Path(parent_dir) / 'generated_series'

In [2]:
def calculate_Markov(full_series, llama_size = '13b'):
    '''
     This function calculates the multi-resolution probability density function (PDF) for a given series.

     Parameters:
     full_series (str): The series for which the PDF is to be calculated.
     llama_size (str, optional): The size of the llama model. Defaults to '13b'.

     Returns:

    '''
    model, tokenizer = get_model_and_tokenizer(llama_size)
    states = sorted(set(full_series))
    good_tokens = [tokenizer.convert_tokens_to_ids(state) for state in states]
    batch = tokenizer(
        [full_series], 
        return_tensors="pt",
        add_special_tokens=True,        
    )
    torch.cuda.empty_cache()
    with torch.no_grad():
        out = model(batch['input_ids'].cpu())
    logit_mat = out['logits']
    logit_mat_good = logit_mat[:,:,good_tokens].cpu()

    return logit_mat_good

def calculate_multiPDF(
    full_series, prec, mode = 'neighbor', refine_depth = 1, llama_size = '13b', size_preprompt = 0,
    number_of_tokens_original = None
):
    '''
     This function calculates the multi-resolution probability density function (PDF) for a given series.

     Parameters:
     full_series (str): The series for which the PDF is to be calculated.
     prec (int): The precision of the PDF.
     mode (str, optional): The mode of calculation. Defaults to 'neighbor'.
     refine_depth (int, optional): The depth of refinement for the PDF. Defaults to 1.
     llama_size (str, optional): The size of the llama model. Defaults to '13b'.

     Returns:
     list: A list of PDFs for the series.
    '''
    # if llama_size != '13b':
    #     assert False, "Llama size must be '13b'"
    good_tokens_str = list("0123456789")
    print(f"good_tokens_str: {good_tokens_str}")
    good_tokens = [tokenizer.convert_tokens_to_ids(token) for token in good_tokens_str]
    print(f"good_tokens: {good_tokens}")
    assert refine_depth < prec, "Refine depth must be less than precision"
    refine_depth = refine_depth - prec
    curr = -prec
    batch = tokenizer(
        [full_series], 
        return_tensors="pt",
        add_special_tokens=True        
    )
    print(f"batch['input_ids']: shape | {batch['input_ids'].shape}, sample | {batch['input_ids'][0,:10]}")
    torch.cuda.empty_cache()
    with torch.no_grad():
        out = model(batch['input_ids'].cuda(), use_cache=True)
    print(f"out: {list(out.keys())}")
    logit_mat = out['logits']
    print(f"logit_mat: shape | {logit_mat.shape}, sample | {logit_mat[:10]}")
    kv_cache_main = out['past_key_values']
    logit_mat_good = logit_mat[:,:,good_tokens].clone()
    print(f"logit_mat_good: shape | {logit_mat_good.shape}, sample | {logit_mat_good[:10]}")
    probs = torch.nn.functional.softmax(logit_mat_good[:,1:,:], dim=-1)
    
    PDF_list = []
    comma_locations = np.sort(np.where(np.array(list(full_series[size_preprompt:])) == ',')[0])
    
    print(f"len coma locations: {comma_locations.shape} | sample: {comma_locations[:10]}")
    print(f"probs: {probs.shape}, type: {type(probs)}")
    # start_loop_from = 1 if use_instruct else 0
    for i in tqdm(range(len(comma_locations))):
        PDF = MultiResolutionPDF()
        # slice out the number before ith comma
        if i == 0:
            start_idx = 0
        else:
            start_idx = comma_locations[i-1]+1
        end_idx = comma_locations[i]
        # print(f"start_idx:end_idx {start_idx}:{end_idx}")
        # if end_idx <= probs.shape[1]:
        num_slice = full_series[size_preprompt:][start_idx:end_idx]
        if number_of_tokens_original:
            prob_slice = probs[:,-(number_of_tokens_original-1):][0,start_idx:end_idx].cpu().numpy()
        else:
            prob_slice = probs[0,start_idx:end_idx].cpu().numpy()
        ### Load hierarchical PDF 
        # print(f"prob_slice: {prob_slice.shape}, type: {type(prob_slice)}, sample: {prob_slice[:10]}")
        # print(f"num_slice: {num_slice}, type: {type(num_slice)}")
        PDF.load_from_num_prob(num_slice, prob_slice)

        # raise ValueError('test')
        
        ### Refine hierarchical PDF
        seq = full_series[:size_preprompt+end_idx]
        # cache and full_series are shifted from beginning, not end
        end_idx_neg = end_idx - len(full_series[size_preprompt:])
        ### kv cache contains seq[0:-1]
        kv_cache = trim_kv_cache(kv_cache_main, end_idx_neg-1)
        # recursive_refiner_preprompt(
        #     PDF, seq, curr=curr, main=True, refine_depth=refine_depth, mode=mode, 
        #     kv_cache=kv_cache, model=model, tokenizer=tokenizer, good_tokens=good_tokens,
        #     size_preprompt=size_preprompt
        # )
        recursive_refiner(
            PDF, seq, curr=curr, main=True, refine_depth=refine_depth, mode=mode, 
            kv_cache=kv_cache, model=model, tokenizer=tokenizer, good_tokens=good_tokens,
        )

        PDF_list += [PDF]

        raise ValueError("test")

        if i==10:
            print(f"start_idx: {start_idx}")
            print(f"end_idx: {end_idx}")
            print(f"num_slice: {num_slice}")
            print(f"prob_slice: {prob_slice}")
            print(f"PDF_list: shape | {len(PDF_list)}, sample | {PDF_list[:10]}")
    
    # release memory
    del logit_mat, kv_cache_main
    return PDF_list

In [3]:
model, tokenizer = get_model_and_tokenizer('7b')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
# Initialize dictionaries to store the data for continuous series and Markov chains
continuous_series_task = {}
markov_chain_task = {}

# Loop through each file in the directory
for file in generated_series_dir.iterdir():
    # Check if a series is already processed
    # if not (save_path / file.name).exists():\
    # Extract the series name from the file name
    series_name = file.stem.rsplit('_', 1)[0]
    # If the series is a continuous series, load the data into the continuous_series_data dictionary
    if series_name in continuous_series_names:
        continuous_series_task[file.name] = pickle.load(file.open('rb'))
    # If the series is a Markov chain, load the data into the markov_chain_data dictionary
    elif series_name in markov_chain_names:
        markov_chain_task[file.name] = pickle.load(file.open('rb'))
    # If the series name is not recognized, raise an exception
    # else:
    #     raise Exception(f"Unrecognized series name: {series_name}")

In [5]:
print(continuous_series_task.keys())
print(markov_chain_task.keys())

dict_keys(['brownian_motion_4.pkl', 'brownian_motion_14.pkl', 'brownian_motion_13.pkl', 'brownian_motion_17.pkl', 'brownian_motion_7.pkl', 'brownian_motion_11.pkl', 'brownian_motion_10.pkl', 'brownian_motion_6.pkl', 'brownian_motion_2.pkl', 'brownian_motion_5.pkl', 'brownian_motion_9.pkl', 'brownian_motion_16.pkl', 'brownian_motion_15.pkl', 'brownian_motion_3.pkl', 'brownian_motion_1.pkl', 'brownian_motion_18.pkl', 'brownian_motion_8.pkl', 'brownian_motion_19.pkl', 'brownian_motion_0.pkl', 'brownian_motion_12.pkl'])
dict_keys(['markov_chain_8.pkl', 'markov_chain_7.pkl', 'markov_chain_4.pkl', 'markov_chain_3.pkl', 'markov_chain_10.pkl', 'markov_chain_0.pkl', 'markov_chain_14.pkl', 'markov_chain_5.pkl', 'markov_chain_11.pkl', 'markov_chain_6.pkl', 'markov_chain_12.pkl', 'markov_chain_13.pkl', 'markov_chain_9.pkl', 'markov_chain_17.pkl', 'markov_chain_15.pkl', 'markov_chain_1.pkl', 'markov_chain_16.pkl', 'markov_chain_2.pkl'])


### Analyze Multi Digit series

In [10]:
# pre_prompt = "Brownian Motion,"
pre_prompt = ""
number_of_tokens_original = None
for series_name, series_dict in sorted(continuous_series_task.items()):
    print("Processing ", series_name)
    if 'brownian_motion' in series_name:
        full_series = series_dict['full_series']
        print(f"full_series: {full_series[:10]}")
        prec = series_dict['prec']
        refine_depth = series_dict['refine_depth']
        llama_size = series_dict['llama_size']
        mode = series_dict['mode']
        if len(pre_prompt) > 1:
            number_of_tokens_original = len(tokenizer(full_series)['input_ids'])
        print(f"number_of_tokens_original: {number_of_tokens_original}")
        print(f"comma token: {tokenizer(',')['input_ids']}")
        PDF_list = calculate_multiPDF(
            pre_prompt+full_series, prec, mode = mode, refine_depth = refine_depth, llama_size = llama_size, 
            size_preprompt=len(pre_prompt), number_of_tokens_original=number_of_tokens_original
        )
        series_dict['PDF_list'] = PDF_list
        save_name = os.path.join(save_path, f"{series_name.split('.')[0]}_llama3_raw.pkl")
        # save_name = os.path.join(save_path, series_name)
        with open(save_name, 'wb') as f:
            pickle.dump(series_dict, f)

Processing  brownian_motion_0.pkl
full_series: 214,223,21
number_of_tokens_original: None
comma token: [1, 1919]
good_tokens_str: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
good_tokens: [29900, 29896, 29906, 29941, 29946, 29945, 29953, 29955, 29947, 29929]
batch['input_ids']: shape | torch.Size([1, 4002]), sample | tensor([    1, 29871, 29906, 29896, 29946, 29892, 29906, 29906, 29941, 29892])
out: ['logits', 'past_key_values']
logit_mat: shape | torch.Size([1, 4002, 32000]), sample | tensor([[[ 0.1040, -0.2216,  0.3127,  ...,  1.3271,  1.8799,  0.6436],
         [-2.3770,  7.4023,  9.1172,  ...,  4.3398, -0.2791,  3.2949],
         [-9.1719, -5.5156,  0.6323,  ..., -4.8477, -8.5547, -3.3789],
         ...,
         [-5.7305, -6.2188,  6.3438,  ..., -3.5723, -4.5273, -2.7539],
         [-2.0762, -0.5918, 10.5625,  ...,  0.8237, -1.3496, -1.8525],
         [-3.5195, -1.6318,  7.4961,  ..., -0.3696, -2.3652, -1.4277]]],
       device='cuda:0')
logit_mat_good: shape | torch.Size([1

  0%|                                                                                                                  | 0/1000 [00:02<?, ?it/s]


ValueError: test

In [11]:
%debug

> [0;32m/tmp/ipykernel_1477006/569063534.py[0m(117)[0;36mcalculate_multiPDF[0;34m()[0m
[0;32m    115 [0;31m        [0mPDF_list[0m [0;34m+=[0m [0;34m[[0m[0mPDF[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    116 [0;31m[0;34m[0m[0m
[0m[0;32m--> 117 [0;31m        [0;32mraise[0m [0mValueError[0m[0;34m([0m[0;34m"test"[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    118 [0;31m[0;34m[0m[0m
[0m[0;32m    119 [0;31m        [0;32mif[0m [0mi[0m[0;34m==[0m[0;36m10[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  PDF.bin_height_arr


array([6.94085172e-02, 6.03808161e+00, 1.28056978e+00, 4.17367299e-01,
       4.87951789e-01, 3.16278424e-01, 6.85439545e-01, 4.04526239e-01,
       3.13817108e-01, 3.11374967e-01, 2.27806796e-01, 4.96871468e-01,
       4.43656301e-01, 2.84216888e-01, 2.75472458e-01, 1.79954925e-01,
       4.10314969e-01, 2.17067074e-01, 2.00754210e-01, 1.23198901e-01,
       1.24651135e-01, 6.72958998e-01, 3.65881426e-01, 2.44683522e-01,
       1.08806169e+00, 2.03643686e-01, 5.64478968e-01, 2.70839970e-01,
       2.24534227e-01, 1.98926298e-01, 1.44970200e-01, 3.63286818e-01,
       2.53615063e-01, 1.96745568e-01, 2.26452765e-01, 1.56859246e-01,
       2.79631832e-01, 1.95978539e-01, 1.45071090e-01, 1.11229862e-01,
       1.19332227e-01, 3.63272254e-01, 2.45802300e-01, 1.89941305e-01,
       2.22063872e-01, 1.79832808e-01, 2.60635100e-01, 2.02191466e-01,
       1.47926437e-01, 1.24080911e-01, 1.28520713e-01, 1.16857122e+00,
       2.66926834e-01, 2.13645560e-01, 2.60743492e-01, 1.87806488e-01,
      

ipdb>  exit()


### Analyze Markov Series

In [22]:
for series_name, series_dict in sorted(markov_chain_task.items()):
    print("Processing ", series_name)
    full_series = series_dict['full_series']
    llama_size = series_dict['llama_size']
    logit_mat_good = calculate_Markov(full_series, llama_size = llama_size)    
    series_dict['logit_mat_good'] = logit_mat_good
    save_name = os.path.join(save_path, series_name)
    with open(save_name, 'wb') as f:
        pickle.dump(series_dict, f)
    break

Processing  markov_chain_0.pkl


Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████| 2/2 [00:05<00:00,  2.89s/it]


In [12]:
np.arange(0,1000) / 100

array([0.  , 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1 ,
       0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2 , 0.21,
       0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3 , 0.31, 0.32,
       0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4 , 0.41, 0.42, 0.43,
       0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5 , 0.51, 0.52, 0.53, 0.54,
       0.55, 0.56, 0.57, 0.58, 0.59, 0.6 , 0.61, 0.62, 0.63, 0.64, 0.65,
       0.66, 0.67, 0.68, 0.69, 0.7 , 0.71, 0.72, 0.73, 0.74, 0.75, 0.76,
       0.77, 0.78, 0.79, 0.8 , 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87,
       0.88, 0.89, 0.9 , 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98,
       0.99, 1.  , 1.01, 1.02, 1.03, 1.04, 1.05, 1.06, 1.07, 1.08, 1.09,
       1.1 , 1.11, 1.12, 1.13, 1.14, 1.15, 1.16, 1.17, 1.18, 1.19, 1.2 ,
       1.21, 1.22, 1.23, 1.24, 1.25, 1.26, 1.27, 1.28, 1.29, 1.3 , 1.31,
       1.32, 1.33, 1.34, 1.35, 1.36, 1.37, 1.38, 1.39, 1.4 , 1.41, 1.42,
       1.43, 1.44, 1.45, 1.46, 1.47, 1.48, 1.49, 1.