# RWKV World Memory Finetune (Memory Finetune)

This takes an existing RWKV world model, and finetune them specifically for the memory repeat task of various sizes.
This test is used as an approximation of testing the model token memory size in the "worse case scenerio"

- Using randomized data, so prior learning does not help, nor is it possible to compress the data
- Using a variety of token lengths, to avoid overfitting to a single length
- Based on the pretrained model (rwkv world)
- This process does "destroy the model" but it helps quantify the model limits

In practise however, the model may show "attention range" longer then what is benchmarked, as natural text is highly compressible. Unlike the pure randomized data that was being tested here.

This runner has been optimized to run on 8 x 24GB vram nodes, you should allocate atleast 500GB disk space.

> This project assumes you have the rwkv-infctx conda env setup, and you are executing in that environment - see the main README.md for the conda env setup steps

## Configure your environment settings
(!Important: you will need to rerun the below cell, if you restart your kernel)

In [2]:
DEEPSPEED_STRAT="deepspeed_stage_1"
GPU_DEVICES="auto"
ENABLE_WANDB=True
WANDB_PREFIX="[8x4090] RWKV-v5-1B5-World"

print("DEEPSPEED_STRAT:", DEEPSPEED_STRAT)
print("ENABLE_WANDB:", ENABLE_WANDB)
print("GPU_DEVICES:", GPU_DEVICES)

if ENABLE_WANDB:
    WANDB_MODE="online"
else:
    WANDB_MODE="disabled"

# The model sizing
MODEL_NAME="RWKV-v5-1B5-world.pth"
MODEL_URL="https://huggingface.co/BlinkDL/rwkv-5-world/resolve/main/RWKV-5-World-1B5-v2-20231025-ctx4096.pth?download=true"

# Computing the notebook, and various paths
import os
NOTEBOOK_DIR=os.path.dirname(os.path.abspath("__file__"))
PROJECT_DIR=os.path.abspath(os.path.join(NOTEBOOK_DIR, "../../../../"))
TRAINER_DIR=os.path.abspath(os.path.join(PROJECT_DIR, "./RWKV-v5/"))
MEMORY_SCRIPT_DIR=os.path.abspath(os.path.join(PROJECT_DIR, "./notebook/util-scripts/memory_script"))

print("NOTEBOOK_DIR:", NOTEBOOK_DIR)
print("TRAINER_DIR:", TRAINER_DIR)
print("PROJECT_DIR:", PROJECT_DIR)

DEEPSPEED_STRAT: deepspeed_stage_1
ENABLE_WANDB: True
GPU_DEVICES: auto
NOTEBOOK_DIR: /home/recursal/RWKV-infctx-trainer/notebook/rwkv-x-exp/v5-exp/memory-test
TRAINER_DIR: /home/recursal/RWKV-infctx-trainer/RWKV-v5
PROJECT_DIR: /home/recursal/RWKV-infctx-trainer


## Download the pretrained model
(if you want to skip the the basemodel train + instruct tune)


In [5]:
# Lets wget the model files
!cd "{PROJECT_DIR}" && mkdir -p "{PROJECT_DIR}/model"
!cd "{PROJECT_DIR}/model" && \
    wget -O "{MODEL_NAME}" -nc "{MODEL_URL}"

File ‘RWKV-v5-1B5-world.pth’ already there; not retrieving.


## Finetune 1 (0 -> 2k) : Dataset preperation

Stage 1, handles total context size of 2048. Meaning it will be tuned for memory task of 1 to approximately 1024 tokens of size.

In [6]:
# Folder and eval pip setup
!cp -r "{MEMORY_SCRIPT_DIR}/" "{NOTEBOOK_DIR}/"
!python3 -m pip install rwkv asyncio aiocsv aiofiles



In [7]:
%%script bash

########################################
# Generate the required jsonl dataset
########################################

# Reset the dataset dir
mkdir -p ./dataset
rm -rf ./dataset/*.jsonl

# Generate the various datasets
echo "## Generating word reptition dataset ##"

#
# Training set for < 50 words
# This is used to fill up as much blanks as possible
#
python ./memory_script/gen_limited_prompt_completion_jsonl.py ./dataset/word-2-count.jsonl 2 300 &
python ./memory_script/gen_limited_prompt_completion_jsonl.py ./dataset/word-2-count.jsonl 4 1000 &
for i in {5..100..5} 
do
    python ./memory_script/gen_limited_prompt_completion_jsonl.py ./dataset/gen-word-$i-count.jsonl $i 500 & 
    python ./memory_script/shuffle_limited_prompt_completion_jsonl.py ./dataset/shuffle-word-$i-count.jsonl $i 100 & 
done

#
# Ramping up the 50+ - 400 words dataset
# 
for i in {105..200..5} 
do
    python ./memory_script/gen_limited_prompt_completion_jsonl.py ./dataset/gen-word-$i-count.jsonl $i 125 & 
    python ./memory_script/shuffle_limited_prompt_completion_jsonl.py ./dataset/shuffle-word-$i-count.jsonl $i 100 & 
done

#
# Ramping up the 50+ - 400 words dataset
# 
for i in {205..1500..5} 
do
    python ./memory_script/gen_limited_prompt_completion_jsonl.py ./dataset/gen-word-$i-count.jsonl $i 100 & 
    python ./memory_script/shuffle_limited_prompt_completion_jsonl.py ./dataset/shuffle-word-$i-count.jsonl $i 100 & 
done

wait
echo "## Done ##"

ls -alh ./dataset/

## Generating word reptition dataset ##
Generated JSONL file with - 5 max words, 500 samples - at ./dataset/gen-word-5-count.jsonl
Generated JSONL file with - 15 max words, 500 samples - at ./dataset/gen-word-15-count.jsonl
Generated JSONL file with - 2 max words, 300 samples - at ./dataset/word-2-count.jsonl
Generated JSONL file with - 10 max words, 500 samples - at ./dataset/gen-word-10-count.jsonl
Generated JSONL file with - 25 max words, 500 samples - at ./dataset/gen-word-25-count.jsonl
Generated JSONL file with - 35 max words, 500 samples - at ./dataset/gen-word-35-count.jsonl
Generated JSONL file with - 50 max words, 500 samples - at ./dataset/gen-word-50-count.jsonl
Generated JSONL file with - 125 max words, 125 samples - at ./dataset/gen-word-125-count.jsonl
Generated JSONL file with - 140 max words, 125 samples - at ./dataset/gen-word-140-count.jsonl
Generated JSONL file with - 65 max words, 500 samples - at ./dataset/gen-word-65-count.jsonl
Generated JSONL file with - 20 max

In [9]:
# Lets pre tokenize the requried dataset
!cd "{TRAINER_DIR}" && \
    python3 preload_datapath.py "{NOTEBOOK_DIR}/stage-1-tune.yaml"

# Ensure the checkpoint directory exists
!cd "{TRAINER_DIR}" && mkdir -p "../checkpoint/stage-1-memory-finetune/"

Resolving data files: 100%|███████████████| 601/601 [00:00<00:00, 370647.95it/s]
Filter (num_proc=160): 100%|██| 372801/372801 [00:03<00:00, 93704.48 examples/s]
Map (num_proc=160): 100%|████| 363015/363015 [00:02<00:00, 127526.30 examples/s]
Map (num_proc=160): 100%|█████| 363015/363015 [00:07<00:00, 46066.68 examples/s]
Map (num_proc=160): 100%|███████| 87900/87900 [00:03<00:00, 27106.01 examples/s]
Saving the dataset (2/2 shards): 100%|█| 87900/87900 [00:01<00:00, 82134.79 exam
Saving the dataset (1/1 shards): 100%|█| 364/364 [00:00<00:00, 13312.35 examples


## Finetune 1 (0 -> 2k) : The actual tune!

In [10]:
# Start the finetune model training
!cd "{TRAINER_DIR}" && \
    export WANDB_MODE="{WANDB_MODE}" && \
    python3 lightning_trainer.py fit \
        -c "{NOTEBOOK_DIR}/stage-1-tune.yaml" \
        --model.load_model="../model/{MODEL_NAME}" \
        --trainer.callbacks.init_args.dirpath="../checkpoint/stage-1-memory-finetune/{MODEL_NAME}/" \
        --trainer.logger.init_args.name="{WANDB_PREFIX} - Mem-Finetune-1 (bs=256, train-ctx=2048, {DEEPSPEED_STRAT})" \
        --trainer.strategy="{DEEPSPEED_STRAT}" \
        --trainer.devices="{GPU_DEVICES}"  \
        --trainer.microbatch_size=8 \
        --model.ctx_len=2048

[2024-01-22 20:30:14,781] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[RWKV.model] Running RWKV infctx using 'torch-jit' with torch '2.1.2'
/home/recursal/miniconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/cli.py:518: LightningCLI's args parameter is intended to run from within Python like if it were from the command line. To prevent mistakes it is not recommended to provide both args and command line arguments, got: sys.argv[1:]=['fit', '-c', '/home/recursal/RWKV-infctx-trainer/notebook/rwkv-x-exp/v5-exp/memory-test/stage-1-tune.yaml', '--model.load_model=../model/RWKV-v5-1B5-world.pth', '--trainer.callbacks.init_args.dirpath=../checkpoint/stage-1-memory-finetune/RWKV-v5-1B5-world.pth/', '--trainer.logger.init_args.name=[8x4090] RWKV-v5-1B5-World - Mem-Finetune-1 (bs=256, train-ctx=2048, deepspeed_stage_1)', '--trainer.strategy=deepspeed_stage_1', '--trainer.devices=auto', '--trainer.microbatch_size=8', '--model

In [11]:
# Lets export the model from the checkpoint
!cd "{TRAINER_DIR}" && \
    python export_checkpoint.py \
        "../checkpoint/stage-1-memory-finetune/{MODEL_NAME}/last.ckpt" \
        "../model/Memory-Tune-Stage-1-{MODEL_NAME}"
!cd "{TRAINER_DIR}" && ls -alh "../model/Memory-Tune-Stage-1-{MODEL_NAME}"

[2024-01-22 21:33:02,316] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Processing zero checkpoint '../checkpoint/stage-1-memory-finetune/RWKV-v5-1B5-world.pth/last.ckpt/checkpoint'
Detected checkpoint of type zero stage ZeroStageEnum.optimizer_states, world_size: 8
Parsing checkpoint created by deepspeed==0.12.6
Reconstructed fp32 state dict with 534 params 1577754624 elements
Saving bf16 state dict to ../model/Memory-Tune-Stage-1-RWKV-v5-1B5-world.pth
-rw-rw-r-- 1 recursal recursal 3.0G Jan 22 21:33 ../model/Memory-Tune-Stage-1-RWKV-v5-1B5-world.pth


In [12]:
# # Lets do a quick dragon prompt validation
!cd "{TRAINER_DIR}" && \
    python3 dragon_test.py "../model/Memory-Tune-Stage-1-{MODEL_NAME}" "cuda fp32"

[2024-01-22 21:33:21,418] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[RWKV.model] Running RWKV infctx using 'torch-jit' with torch '2.1.2'
  return self.fget.__get__(instance, owner)()
---
[RWKV.TimeMix] Compiling CUDA kernel with HEAD_SIZE=64
Using /home/recursal/.cache/torch_extensions/py311_cu121 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/recursal/.cache/torch_extensions/py311_cu121/wkv5/build.ninja...
Building extension module wkv5...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module wkv5...
[RWKV.TimeMix] CUDA kernel compiled & loaded globally
---
  batch_tokens = torch.tensor(
--- DRAGON PROMPT ---
In a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the

In [15]:
# Lets do a memory eval!
!python3 ./memory_script/eval_v5_memory_guided.py "{PROJECT_DIR}/model/Memory-Tune-Stage-1-{MODEL_NAME}"
!python3 ./memory_script/eval_v5_memory_guided.py "{PROJECT_DIR}/model/Memory-Tune-Stage-1-{MODEL_NAME}" "none" 1000 3000

SCRIPT_DIR:  /home/recursal/RWKV-infctx-trainer/notebook/rwkv-x-exp/v5-exp/memory-test/memory_script
PROJECT_DIR:  /home/recursal/RWKV-infctx-trainer
MODEL_CODE_DIR:  /home/recursal/RWKV-infctx-trainer/RWKV-v5
[2024-01-22 22:55:37,758] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[RWKV.model] Running RWKV infctx using 'torch-jit' with torch '2.1.2'
  return self.fget.__get__(instance, owner)()
---
[RWKV.TimeMix] Compiling CUDA kernel with HEAD_SIZE=64
Using /home/recursal/.cache/torch_extensions/py311_cu121 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/recursal/.cache/torch_extensions/py311_cu121/wkv5/build.ninja...
Building extension module wkv5...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module wkv5...
[RWKV.TimeMix] CUDA kernel compiled & loaded globally
---
  batch_toke

## Finetune 2 (2k -> 4k) - More data

In [4]:
%%script bash

########################################
# Generate the required jsonl dataset
########################################

# Reset the dataset dir
mkdir -p ./dataset
rm -rf ./dataset/*.jsonl

# Generate the various datasets
echo "## Generating word reptition dataset ##"

#
# Training set for < 100 words
# We bump this aggressively, as its used to fill in packing
#
for i in {5..100..5} 
do
    python ./memory_script/gen_limited_prompt_completion_jsonl.py ./dataset/gen-word-$i-count.jsonl $i 500 & 
    python ./memory_script/shuffle_limited_prompt_completion_jsonl.py ./dataset/shuffle-word-$i-count.jsonl $i 500 & 
done

#
# Ramping up the 50+ - 1500 words dataset
# This is to ensure there is ramp from the previous models
# 
for i in {105..1500..5} 
do
    python ./memory_script/gen_limited_prompt_completion_jsonl.py ./dataset/gen-word-$i-count.jsonl $i 50 & 
    python ./memory_script/shuffle_limited_prompt_completion_jsonl.py ./dataset/shuffle-word-$i-count.jsonl $i 50 & 
done

#
# Ramping up the 1500+ - 6000 words dataset
# 
for i in {1505..6000..5} 
do
    python ./memory_script/gen_limited_prompt_completion_jsonl.py ./dataset/gen-word-$i-count.jsonl $i 100 & 
    python ./memory_script/shuffle_limited_prompt_completion_jsonl.py ./dataset/shuffle-word-$i-count.jsonl $i 100 & 
done

wait
echo "## Done ##"

ls -alh ./dataset/

## Generating word reptition dataset ##
Generated JSONL file with - 5 max words, 500 samples - at ./dataset/gen-word-5-count.jsonl
Generated JSONL file with - 20 max words, 500 samples - at ./dataset/gen-word-20-count.jsonl
Generated JSONL file with - 35 max words, 500 samples - at ./dataset/gen-word-35-count.jsonl
Generated JSONL file with - 190 max words, 50 samples - at ./dataset/gen-word-190-count.jsonl
Generated JSONL file with - 60 max words, 500 samples - at ./dataset/gen-word-60-count.jsonl
Generated JSONL file with - 205 max words, 50 samples - at ./dataset/gen-word-205-count.jsonl
Generated JSONL file with - 10 max words, 500 samples - at ./dataset/gen-word-10-count.jsonl
Generated JSONL file with - 200 max words, 50 samples - at ./dataset/gen-word-200-count.jsonl
Generated JSONL file with - 230 max words, 50 samples - at ./dataset/gen-word-230-count.jsonl
Generated JSONL file with - 15 max words, 500 samples - at ./dataset/gen-word-15-count.jsonl
Generated JSONL file with - 

In [5]:
# Lets pre tokenize the requried dataset
!cd "{TRAINER_DIR}" && \
    python3 preload_datapath.py "{NOTEBOOK_DIR}/stage-2-tune.yaml"

# Ensure the checkpoint directory exists
!cd "{TRAINER_DIR}" && mkdir -p "../checkpoint/stage-2-memory-finetune/"

Resolving data files: 100%|█████████████| 2400/2400 [00:00<00:00, 365795.62it/s]
Generating train split: 1258813 examples [00:16, 78115.78 examples/s] 
Map (num_proc=160): 100%|███| 1258813/1258813 [01:30<00:00, 13918.46 examples/s]
Filter (num_proc=160): 100%|█| 1258813/1258813 [00:49<00:00, 25674.24 examples/s
Map (num_proc=160): 100%|██| 1238639/1238639 [00:06<00:00, 189798.96 examples/s]
Map (num_proc=160): 100%|███| 1238639/1238639 [01:09<00:00, 17869.03 examples/s]
Map (num_proc=160): 100%|██████| 127252/127252 [00:15<00:00, 8260.92 examples/s]
Saving the dataset (11/11 shards): 100%|█| 127252/127252 [00:05<00:00, 21357.11 
Saving the dataset (1/1 shards): 100%|█| 6225/6225 [00:00<00:00, 62482.00 exampl


In [8]:
## Finetune 2 (2k -> 8k) : The actual tune!
# Start the finetune model training
!cd "{TRAINER_DIR}" && \
    export WANDB_MODE="{WANDB_MODE}" && \
    python3 lightning_trainer.py fit \
        -c "{NOTEBOOK_DIR}/stage-2-tune.yaml" \
        --model.load_model="../model/Memory-Tune-Stage-1-{MODEL_NAME}" \
        --trainer.callbacks.init_args.dirpath="../checkpoint/stage-2-memory-finetune/{MODEL_NAME}/" \
        --trainer.logger.init_args.name="{WANDB_PREFIX} - Mem-Finetune-2 (bs=256, train-ctx=2048, {DEEPSPEED_STRAT})" \
        --trainer.strategy="{DEEPSPEED_STRAT}" \
        --trainer.devices="{GPU_DEVICES}"  \
        --trainer.microbatch_size=8 \
        --model.ctx_len=2048

[2024-01-23 00:03:23,061] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[RWKV.model] Running RWKV infctx using 'torch-jit' with torch '2.1.2'
/home/recursal/miniconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/cli.py:518: LightningCLI's args parameter is intended to run from within Python like if it were from the command line. To prevent mistakes it is not recommended to provide both args and command line arguments, got: sys.argv[1:]=['fit', '-c', '/home/recursal/RWKV-infctx-trainer/notebook/rwkv-x-exp/v5-exp/memory-test/stage-2-tune.yaml', '--model.load_model=../model/Memory-Tune-Stage-1-RWKV-v5-1B5-world.pth', '--trainer.callbacks.init_args.dirpath=../checkpoint/stage-2-memory-finetune/RWKV-v5-1B5-world.pth/', '--trainer.logger.init_args.name=[8x4090] RWKV-v5-1B5-World - Mem-Finetune-2 (bs=256, train-ctx=2048, deepspeed_stage_1)', '--trainer.strategy=deepspeed_stage_1', '--trainer.devices=auto', '--trainer.microbat