# Ash RL Training (Colab/PyTorch Version)

This notebook implements the Reinforcement Learning loop for Bash command generation using PyTorch and Hugging Face Transformers, compatible with Google Colab (Linux/CUDA) or local NVIDIA GPUs.

Original Code: [ash-rl-mlx](https://github.com/Krabbens/ash-rl-mlx)

In [None]:
# 1. Setup Environment
import os
if os.path.exists("ash-rl-mlx"):
    %cd ash-rl-mlx
    !git pull
else:
    !git clone https://github.com/Krabbens/ash-rl-mlx.git
    %cd ash-rl-mlx

!pip install -q transformers peft bitsandbytes trl accelerate datasets

In [None]:
# 2. Imports
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig, get_peft_model, PeftModel
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
import os
import json
import re
import random
import shutil
import glob
from tqdm import tqdm
import subprocess

# Import local modules from the cloned repo
# Ensure we are in the right directory
if not os.path.exists("terminal_bench.py"):
    raise FileNotFoundError("Please run the git clone cell first!")

from terminal_bench import TerminalBench

In [None]:
# 3. Configuration & Helpers
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
OUTPUT_DIR = "./adapters_pt"
METRICS_FILE = "metrics_pt.jsonl"
NUM_ITERATIONS = 50

# Reward Function (Ported from main.py)
def check_bash_syntax(cmd):
    if not cmd: return False
    try:
        res = subprocess.run(["bash", "-n"], input=cmd, text=True, capture_output=True, timeout=1)
        return res.returncode == 0
    except:
        return False

def calculate_reward(success, eval_meta, cmd, prompt):
    if not cmd: return 0.0
    reward = 0.0
    
    # 1. Base success
    if success: reward += 2.0
    
    # 2. Syntax
    if check_bash_syntax(cmd): reward += 0.3
    
    # 3. Execution status
    if eval_meta.get("exit_code") == 0: reward += 0.2
    elif eval_meta.get("exit_code") == 124: reward -= 0.5
        
    # 4. Keyword matching
    potential_files = re.findall(r'/?app/([\w\.\-]+)', prompt)
    for f in potential_files:
        if f in cmd: reward += 0.1
            
    # 5. Hallucinations/Bad patterns
    if "<<" in cmd and "jq" in cmd: reward -= 0.5
    if cmd.strip().endswith("|"): reward -= 0.5
    
    # 7. Python/Non-bash penalty (Refined)
    bad_python = ["import os", "import sys", "import json", "def main(", "if __name__ ==", "class "]
    if any(k in cmd for k in bad_python):
        reward -= 2.0
    
    # 6. Brevity penalty
    len_penalty = max(0, (len(cmd) - 100) * 0.0001)
    reward -= len_penalty
    
    return max(0.0, reward)

def extract_command(response):
    # 1. Code blocks
    code_blocks = re.findall(r'```(?:bash|sh)?\n(.*?)```', response, re.DOTALL)
    if code_blocks:
        return "\n".join([b.strip() for b in code_blocks])
    
    # 2. Extract after thinking tags
    if "</thinking>" in response:
        after = response.split("</thinking>")[-1].strip()
        # Simple heuristic: take lines that look like commands
        lines = [l.strip() for l in after.split('\n') if l.strip()]
        # Just return the whole block if unsure, specifically stripped
        return "\n".join(lines)

    # 3. Fallback: Last line
    lines = response.strip().split('\n')
    for line in reversed(lines):
        cl = line.strip()
        if cl and not cl.startswith(('#', 'Here', 'To', 'Step', 'Note')):
            return cl
    return response.strip()

def clean_command(cmd):
    if not cmd: return cmd
    lines = cmd.split('\n')
    cleaned = []
    for line in lines:
        if '#' in line:
            line = line.split('#')[0]
        if line.strip():
            cleaned.append(line.strip())
    return '\n'.join(cleaned)


In [None]:
# 4. Initialize Model & Tokenizer
print(f"Loading {MODEL_NAME}...")

# Load in bfloat16 for efficiency on T4/A100
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, 
    torch_dtype=torch.bfloat16, 
    device_map="auto",
    trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

# Apply LoRA Adapter
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

In [None]:
# 5. Main Training Loop
# Setup Bench
FULL_DATASET_PATH = "terminal_bench_full.json"
if not os.path.exists(FULL_DATASET_PATH):
    print("Warning: Full dataset not found, using internal local tasks only.")
    bench = TerminalBench()
else:
    bench = TerminalBench(task_file=FULL_DATASET_PATH)

all_tasks = bench.get_tasks()
print(f"Loaded {len(all_tasks)} tasks.")

data_dir = "data_bash_pt"
os.makedirs(data_dir, exist_ok=True)

for iteration in range(NUM_ITERATIONS):
    print(f"\n=== Iteration {iteration+1}/{NUM_ITERATIONS} ===")
    
    # Curriculum - Strictly enforce easy tasks first
    if iteration < 5:
        # ONLY allow known easy categories from local tasks (basic, files, text)
        # Exclude 'easy' because external tasks use it for hard things
        filtered_tasks = [t for t in all_tasks if t.get("category", "").lower() in ["basic", "files", "text"]]
    elif iteration < 15:
        filtered_tasks = [t for t in all_tasks if t.get("category", "").lower() in ["basic", "files", "text", "medium", "search", "script"]]
    else:
        filtered_tasks = all_tasks
    
    # Fallback if filtering removes everything (shouldn't happen with local tasks merged)
    if not filtered_tasks:
        filtered_tasks = all_tasks
        
    # Sample tasks to keep iteration speed high
    random.shuffle(filtered_tasks)
    current_tasks = filtered_tasks[:16] # Micro-batch size
    
    candidates = []
    solved_count = 0
    train_file = os.path.join(data_dir, "train.jsonl")
    open(train_file, "w").close() # Clear file
    
    # --- GENERATION PHASE ---
    model.eval()
    # Prepare batch
    prompts_text = []
    sys_prompt = "You are a Linux terminal. CWD: /app. Plan in <thinking>. Output ONLY raw bash commands in a ```bash block. No comments."
    
    for task in current_tasks:
        messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": task["prompt"]}]
        text_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        prompts_text.append(text_prompt)
        
    # Batch generation
    inputs = tokenizer(prompts_text, return_tensors="pt", padding=True).to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs, 
            max_new_tokens=1024, 
            do_sample=True, 
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id
        )
    
    decoded_responses = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
    
    for i, response in enumerate(decoded_responses):
        task = current_tasks[i]
        prompt = task["prompt"]
        cmd = clean_command(extract_command(response))
        
        # Reconstruct messages for training data
        messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": prompt}]
        
        # Evaluate
        success, meta = bench.evaluate_task(task, {'command': cmd})
        reward = calculate_reward(success, meta, cmd, prompt)
        
        is_solved = (reward >= 1.0)
        status = "✅" if is_solved else "❌"
        print(f"  {status} [{task['id']}] {prompt[:40]}... -> {cmd[:50]}")
        
        if is_solved:
            solved_count += 1
            # Create training sample
            # We train on the FULL conversation including the thinking and correct code
            full_messages = messages + [{"role": "assistant", "content": response}]
            full_text = tokenizer.apply_chat_template(full_messages, tokenize=False)
            
            with open(train_file, "a") as f:
                json.dump({"text": full_text}, f)
                f.write("\n")
            candidates.append(1)

    print(f"Solved {solved_count}/{len(current_tasks)} tasks.")
    
    # --- TRAINING PHASE ---
    if candidates:
        print(f"Training on {len(candidates)} samples...")
        model.train()
        
        dataset = load_dataset("json", data_files=train_file, split="train")
        
        training_args = SFTConfig(
            output_dir=OUTPUT_DIR,
            num_train_epochs=1,
            per_device_train_batch_size=16,
            gradient_accumulation_steps=4,
            learning_rate=1e-5,
            logging_steps=1,
            save_strategy="no",
            report_to="none",
            dataset_text_field="text",
            max_seq_length=1024,
            packing=False # Simple for now
        )
        
        trainer = SFTTrainer(
            model=model,
            train_dataset=dataset,
            args=training_args,
            processing_class=tokenizer,
        )
        
        trainer.train()
        
        # Save adapter occasionally
        if (iteration + 1) % 10 == 0:
            trainer.save_model(os.path.join(OUTPUT_DIR, f"ckpt-{iteration+1}"))
    else:
        print("No successful samples to train on.")
