# RWKV v5 WaveNet C - memory finetune

We continue, with additional stages (beyond tune 4), to push the model further into its limits

> This project assumes you have the rwkv-infctx conda env setup, and you are executing in that environment - see the main README.md for the conda env setup steps

## Optional: Download the pretrained model
(if you want to skip the the basemodel train + instruct tune)


In [None]:
# Init required dirs
!mkdir -p ../../../../model/
!mkdir -p ../../../../datapath/
!mkdir -p ../../../../checkpoint/

# # Download the pretrained file
# !cd ../../../../model/ && wget -nc https://huggingface.co/picocreator/memory-size-experiment-for-rwkv/resolve/main/v5-Wave/WaveV5-C-Tune4.pth
# !ls alh ../../../../model/

## Configure your environment settings
(!Important: you will need to rerun the below cell, if you restart your kernel)

In [None]:
DEEPSPEED_STRAT="deepspeed_stage_1"
GPU_DEVICES="auto"
ENABLE_WANDB=True
WANDB_PREFIX="WaveV5-C"

print("DEEPSPEED_STRAT:", DEEPSPEED_STRAT)
print("ENABLE_WANDB:", ENABLE_WANDB)
print("GPU_DEVICES:", GPU_DEVICES)

if ENABLE_WANDB:
    WANDB_MODE="online"
else:
    WANDB_MODE="disabled"

# Computing the notebook, and various paths
import os
NOTEBOOK_DIR=os.path.dirname(os.path.abspath("__file__"))
PROJECT_DIR=os.path.abspath(os.path.join(NOTEBOOK_DIR, "../../../../"))
TRAINER_DIR=os.path.abspath(os.path.join(PROJECT_DIR, "./RWKV-v5wavenet/"))
INFERENCE_DIR=os.path.abspath(os.path.join(PROJECT_DIR, "./RWKV-v5wavenet/"))

print("NOTEBOOK_DIR:", NOTEBOOK_DIR)
print("INFERENCE_DIR:", INFERENCE_DIR)
print("TRAINER_DIR:", TRAINER_DIR)
print("PROJECT_DIR:", PROJECT_DIR)

## Tune 5 : Ramping up the ctx size (8192), memory training

- Tune 5: We are now training into the 8k range (4k inputs)

Lets go !

In [None]:
%%script bash

########################################
# Generate the required jsonl dataset
########################################

# Reset the dataset dir
mkdir -p ../dataset
rm -rf ../dataset/*.jsonl

# Generate the various datasets
echo "## Generating word reptition dataset ##"

#
# Some data sample for low word count
#
python ../memory_script/gen_limited_prompt_completion_jsonl.py ../dataset/word-2-count.jsonl 2 500 &
python ../memory_script/gen_limited_prompt_completion_jsonl.py ../dataset/word-5-count.jsonl 5 500 &

#
# Distributed dataset from 10 - 2000
# 
for i in {10..2000..10} 
do
    python ../memory_script/gen_limited_prompt_completion_jsonl.py ../dataset/gen-word-$i-count.jsonl $i 500 & 
    python ../memory_script/shuffle_limited_prompt_completion_jsonl.py ../dataset/shuffle-word-$i-count.jsonl $i 5 & 
done

#
# Distributed dataset from 2000 - 4000
# 
for i in {2010..4000..10} 
do
    python ../memory_script/gen_limited_prompt_completion_jsonl.py ../dataset/gen-word-$i-count.jsonl $i 1000 & 
    python ../memory_script/shuffle_limited_prompt_completion_jsonl.py ../dataset/shuffle-word-$i-count.jsonl $i 10 & 
done

wait
echo "## Done ##"

ls -alh ../dataset/

In [None]:
# Lets pre tokenize the requried dataset
!cd "{TRAINER_DIR}" && \
    python3 preload_datapath.py "{NOTEBOOK_DIR}/WaveV5-C-mem-finetune-5.yaml"

# Ensure the checkpoint directory exists
!cd "{TRAINER_DIR}" && mkdir -p "../checkpoint/WaveV5-C-mem-finetune-5/"

In [None]:
# Start the finetune model training
!cd "{TRAINER_DIR}" && \
    export WANDB_MODE="{WANDB_MODE}" && \
    python lightning_trainer.py fit \
        -c "{NOTEBOOK_DIR}/WaveV5-C-mem-finetune-5.yaml" \
        --trainer.logger.init_args.name="{WANDB_PREFIX} - Mem-Finetune-5 (bs=256, train-ctx=8192, {DEEPSPEED_STRAT})" \
        --trainer.strategy="{DEEPSPEED_STRAT}" \
        --trainer.devices="{GPU_DEVICES}"  \
        --model.ctx_len=4096 \
        --model.bptt_learning_range=2

In [None]:
# Lets export the model from the checkpoint
!cd "{TRAINER_DIR}" && \
    python export_checkpoint.py \
        "../checkpoint/WaveV5-C-mem-finetune-5/last.ckpt" \
        "../model/WaveV5-C-Tune5.pth"
!cd "{TRAINER_DIR}" && ls -alh ../model/WaveV5-C-Tune5.pth

In [None]:
# # Lets do a memory eval 
# !python3 ../memory_script/eval_memory_guided.py "{PROJECT_DIR}/model/WaveV5-C-Tune4.pth"