## TokenShift Memory benchmarking

As current TokenShift models, have started to drastically surpass the raven model, this is a varient that focuses on such models

## Setup

In [None]:
# Due to the size of the CSV data, we did not include it in the repository. You can download our current CSV data from hugging face
!mkdir -p ./logs

# Experimental V5 models
!cd ./logs && wget -nc "https://huggingface.co/rwkv-x-dev/rwkv-x-playground/raw/main/experiment/memory-bench/logs/BaseV5-C-Tune5-1k.csv"
!cd ./logs && wget -nc "https://huggingface.co/rwkv-x-dev/rwkv-x-playground/resolve/main/experiment/memory-bench/logs/BaseV5-C-Tune5-4k.csv"
!cd ./logs && wget -nc "https://huggingface.co/rwkv-x-dev/rwkv-x-playground/resolve/main/experiment/memory-bench/logs/BaseV5-C-Tune5-16k.csv"

!cd ./logs && wget -nc "https://huggingface.co/rwkv-x-dev/rwkv-x-playground/raw/main/experiment/memory-bench/logs/v5-L6-D1024-E0_1-1k.csv"
!cd ./logs && wget -nc "https://huggingface.co/rwkv-x-dev/rwkv-x-playground/resolve/main/experiment/memory-bench/logs/v5-L6-D1024-E0_1-4k.csv"
!cd ./logs && wget -nc "https://huggingface.co/rwkv-x-dev/rwkv-x-playground/resolve/main/experiment/memory-bench/logs/v5-L6-D1024-E0_1-16k.csv"

!cd ./logs && wget -nc "https://huggingface.co/rwkv-x-dev/rwkv-x-playground/raw/main/experiment/memory-bench/logs/v5-L6-D2048-E0_1-1k.csv"
!cd ./logs && wget -nc "https://huggingface.co/rwkv-x-dev/rwkv-x-playground/resolve/main/experiment/memory-bench/logs/v5-L6-D2048-E0_1-4k.csv"
!cd ./logs && wget -nc "https://huggingface.co/rwkv-x-dev/rwkv-x-playground/resolve/main/experiment/memory-bench/logs/v5-L6-D2048-E0_1-16k.csv"

!cd ./logs && wget -nc "https://huggingface.co/rwkv-x-dev/rwkv-x-playground/raw/main/experiment/memory-bench/logs/v5-L6-D4096-E0_1-1k.csv"
!cd ./logs && wget -nc "https://huggingface.co/rwkv-x-dev/rwkv-x-playground/resolve/main/experiment/memory-bench/logs/v5-L6-D4096-E0_1-4k.csv"
!cd ./logs && wget -nc "https://huggingface.co/rwkv-x-dev/rwkv-x-playground/resolve/main/experiment/memory-bench/logs/v5-L6-D4096-E0_1-16k.csv"

!cd ./logs && wget -nc "https://huggingface.co/rwkv-x-dev/rwkv-x-playground/raw/main/experiment/memory-bench/logs/v5-L96-D1024-E0_1-mem-ctx-8k-1k.csv"
!cd ./logs && wget -nc "https://huggingface.co/rwkv-x-dev/rwkv-x-playground/resolve/main/experiment/memory-bench/logs/v5-L96-D1024-E0_1-mem-ctx-8k-4k.csv"
!cd ./logs && wget -nc "https://huggingface.co/rwkv-x-dev/rwkv-x-playground/resolve/main/experiment/memory-bench/logs/v5-L96-D1024-E0_1-mem-ctx-8k-16k.csv"


In [None]:
# Install required pip libraries
!python -m pip install matplotlib

## Loading of CSV data

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
pd.set_option('display.max_rows', 100)

# Define a custom function to calculate averages for the first N elements
def calculate_first_n_avg(n, s):
    if(n == -1):
        return s.iloc[:].mean()
    return s.iloc[:n].mean()

# Groupby the 'eval_token_count' and calculate the average for the first 5, 10, and 20 'eval_token_pos'
def group_csv_data(inCSV, modelName):
    grouped_data = inCSV.groupby(['eval_token_count', 'is_random_baseline']).apply(lambda x: pd.Series({
        'First 1 tokens average': calculate_first_n_avg(1, x['eval_token_pos']),
        'First 2 tokens average': calculate_first_n_avg(2, x['eval_token_pos']),
        'First 5 tokens average': calculate_first_n_avg(5, x['eval_token_pos']),
        'First 10 tokens average': calculate_first_n_avg(10, x['eval_token_pos']),
        'First 25 tokens average': calculate_first_n_avg(25, x['eval_token_pos']),
        'First 50 tokens average': calculate_first_n_avg(50, x['eval_token_pos']),
        'First 100 tokens average': calculate_first_n_avg(100, x['eval_token_pos']),
        'First 250 tokens average': calculate_first_n_avg(250, x['eval_token_pos']),
        'First 500 tokens average': calculate_first_n_avg(500, x['eval_token_pos']),
        'First 750 tokens average': calculate_first_n_avg(750, x['eval_token_pos']),
        'First 1000 tokens average': calculate_first_n_avg(1000, x['eval_token_pos']),
        'tokens average position': calculate_first_n_avg(-1, x['eval_token_pos']),
        "match_count": x["matched"].sum(),
        "match_percentage": x["matched"].sum() * 100.0 / x["matched"].count() ,
    })).reset_index()
    grouped_data['model'] = modelName
    return grouped_data

# Read a CSV file, and group the data
def group_csv_file(filepath, modelName):
    return group_csv_data(pd.read_csv(filepath), modelName)

# Merge the DataFrames 
full_grouped_data = pd.concat([

    group_csv_file("./logs/BaseV5-C-Tune5-1k.csv", "V5r2-Baseline 1B5 (L24-D2048)"),
    group_csv_file("./logs/BaseV5-C-Tune5-4k.csv", "V5r2-Baseline 1B5 (L24-D2048)"),
    group_csv_file("./logs/BaseV5-C-Tune5-16k.csv", "V5r2-Baseline 1B5 (L24-D2048)"),

    group_csv_file("./logs/v5-L6-D1024-E0_1-1k.csv", "V5r2-L6-D1024"),
    group_csv_file("./logs/v5-L6-D1024-E0_1-4k.csv", "V5r2-L6-D1024"),
    group_csv_file("./logs/v5-L6-D1024-E0_1-16k.csv", "V5r2-L6-D1024"),

    group_csv_file("./logs/v5-L6-D2048-E0_1-1k.csv", "V5r2-L6-D2048"),
    group_csv_file("./logs/v5-L6-D2048-E0_1-4k.csv", "V5r2-L6-D2048"),
    group_csv_file("./logs/v5-L6-D2048-E0_1-16k.csv", "V5r2-L6-D2048"),

    group_csv_file("./logs/v5-L6-D4096-E0_1-1k.csv", "V5r2-L6-D4096"),
    group_csv_file("./logs/v5-L6-D4096-E0_1-4k.csv", "V5r2-L6-D4096"),
    group_csv_file("./logs/v5-L6-D4096-E0_1-16k.csv", "V5r2-L6-D4096"),

    group_csv_file("./logs/v5-L96-D1024-E0_1-mem-ctx-8k-1k.csv", "V5r2-L96-D1024"),
    group_csv_file("./logs/v5-L96-D1024-E0_1-mem-ctx-8k-4k.csv", "V5r2-L96-D1024"),
    group_csv_file("./logs/v5-L96-D1024-E0_1-mem-ctx-8k-16k.csv", "V5r2-L96-D1024"),
    
])

# Plot the data
full_grouped_data

In [None]:
# We get the position values, of the average token prediction, withou ONLY the output.
# this helps get the approximate "random" score baseline, while accounting for the fact that the model may eventually notice patterns that makes it not truely random
# (eg. no special characters, etc), as the sample grows.

# # Filter out for noise baseline
random_baseline = full_grouped_data[full_grouped_data['is_random_baseline'] == True]

# # Geet the average first 1000 tokens for all models
# random_baseline_pos = random_baseline.groupby(['model']).mean()['First 1000 tokens average']["Raven 1B5"]
# half_random_base_line_pos = random_baseline_pos / 2

# print("random_baseline_pos", random_baseline_pos)
# print("half_random_base_line_pos", half_random_base_line_pos)

# Get the key from one of the models
key = random_baseline.groupby(['model']).mean()['First 1000 tokens average'].keys()[0]

# Get the specialize model randomized baseline
special_random_baseline_pos = random_baseline.groupby(['model']).mean()['First 1000 tokens average'][key]
special_half_random_base_line_pos = special_random_baseline_pos / 2

print("special_random_baseline_pos", special_random_baseline_pos)
print("special_half_random_base_line_pos", special_half_random_base_line_pos)

# Give the randomized baseline values
random_baseline

In [None]:
# Graphing function
def plotGrapData(sizeArr=[2], redline=False):
    # Lets join the size array, to a single str for logging
    sizeStr = ", ".join([str(x) for x in sizeArr])

    # Plot the axis
    fig, ax = plt.subplots(figsize=(15,7)) #

    # Get the highest value for sizeArr
    max_size = max(sizeArr)

    # Filter the data accordingly
    filtered_data = full_grouped_data[full_grouped_data['is_random_baseline'] == False]
    if max_size > 0:
        filtered_data = filtered_data[filtered_data['eval_token_count'] <= max_size]

    # Plot the data, excluding is random baseline = true
    for key, grp in filtered_data.groupby('model'):
        if len(sizeArr) == 1:
            if(sizeArr[0] == -1):
                ax = grp.plot(
                    ax=ax, kind='line', x='eval_token_count', 
                    y=f'tokens average position',
                    
                    label=f'{key}'
                )
            else:
                ax = grp.plot(
                    ax=ax, kind='line', x='eval_token_count', 
                    y=f'First {sizeArr[0]} tokens average', 
                    label=f'{key}'
                )
        else:
            ax = grp.plot(
                ax=ax, kind='line', x='eval_token_count', 
                y=f'First {sizeArr[0]} tokens average', 
                label=f'{key} - First {sizeArr[0]} tokens average'
            )

        if len(sizeArr) >= 2:
            ax = grp.plot(
                ax=ax, kind='line', x='eval_token_count',
                y=f'First {sizeArr[1]} tokens average', 
                label=f'{key} - First {sizeArr[1]} tokens average', linestyle='dashed'
            )

        if len(sizeArr) >= 3:
            ax = grp.plot(ax=ax, kind='line', x='eval_token_count', 
                          y=f'First {sizeArr[2]} tokens average', 
                          label=f'{key} - First {sizeArr[2]} tokens average', linestyle='dotted'
            )

    # # Limit the X axis to max_size
    # if max_size > 0:
    #     ax.set_xlim(0, max_size)

    # Add redline if set
    if redline != False:
        # ax.axhline(y=half_random_base_line_pos, color='r', linestyle='-.', label='50 percent of Raven Randomized baseline')
        ax.axhline(y=special_half_random_base_line_pos, color='r', linestyle=':', label='50 percent of Specialized Model Randomized baseline')

    # Title overwrites?
    if( sizeArr[0] == -1):
        ax.set_title(f'Average position of all tokens in long sequence')
        ax.set_ylabel(f'Average position of all the correct tokens, in sorted probability order')
    else:
        ax.set_title(f'Recall of the first {sizeStr} tokens in long sequence')
        ax.set_ylabel(f'Average position of first {sizeStr} correct tokens, in sorted probability order')
    ax.set_xlabel(f'Prompt Length (tokens, used in long sequence)')

    # Include grid lines - with major, and minor grid
    ax.minorticks_on()
    ax.grid(which='major', linestyle='-', linewidth='0.5')
    ax.grid(which='minor', linestyle=':', linewidth='0.5')

## Highscores for each models

In [None]:
# Lets extract all the "high score" for each model
df = full_grouped_data

models = df.model.unique()

results = []
for model in models:
    model_df = df[df.model == model]
    
    # Last highest match_percentage and associated metrics
    match_percs = model_df.match_percentage.sort_values(ascending=False)
    max_match_perc = match_percs.values[0]
    max_match_perc_row = model_df[model_df.match_percentage == max_match_perc]

    # Last highest match_count and associated metrics
    match_counts = model_df.match_count.sort_values(ascending=False)
    max_match_count = match_counts.values[0]
    max_match_count_row = model_df[model_df.match_count == max_match_count]

    # Last match_percentage >= 90 and associated metrics
    flipped_model_df = model_df.sort_values(by=['eval_token_count'], ascending=False)
    for idx, row in flipped_model_df.iterrows():
        if row['match_percentage'] >= 90.0:
            match_90_row = row
            break
            
    results.append({
        'model': model,

        'max%': max_match_perc,
        'max% : input': int(max_match_perc_row.eval_token_count.values[-1]),

        '90% match: input': match_90_row.eval_token_count,
        '90% match: match': match_90_row.match_count,
        '90% match: %': match_90_row.match_percentage,

        'matched: input': int(max_match_count_row.eval_token_count.values[0]),
        'matched: count': int(max_match_count),
        'matched: %': max_match_count_row.match_percentage.values[0]

    }) 
    
results_df = pd.DataFrame(results)
results_df

## (optional) Plotting of high level data

In [None]:
# # Lets chart too much data at too many points, so we can get a better idea of the trend
# # before narrowing it down, commented out, unless you really want it
# plotGrapData([1,2])
# plotGrapData([5,10,25])
# plotGrapData([50,100,250])
# plotGrapData([500,750,1000], redline=True)

## Plotting the data points

In [None]:
# Lets focus on first 1k
plotGrapData([1000]) #, redline=True

In [None]:
# Lets chart overall average position
plotGrapData([-1]) #, redline=True