In [2]:
from llama_cpp import Llama
import os
import json
import time
from concurrent.futures import ThreadPoolExecutor, as_completed  # Use ThreadPool instead of ProcessPool


In [3]:

def load_model_with_mps():
    # Enable MPS and configure for Apple Silicon
    llm = Llama.from_pretrained(
        repo_id="mradermacher/TxAgent-T1-Llama-3.1-8B-GGUF",
        filename="TxAgent-T1-Llama-3.1-8B.Q8_0.gguf",
        # n_ctx=2048,
        n_gpu_layers=16,  # Offload all possible layers to MPS (auto-detect)
        n_threads=2,      # Use CPU threads for remaining work
        n_batch=32,      # Larger batch size for MPS efficiency
        verbose=False,     # Show MPS initialization logs
        device="cpu"      # Explicitly use Metal Performance Shaders
    )
    
    # Verify MPS usage (check logs for confirmation)
    # Look for lines like: "llama_metal_init: loaded kernel" or "offloaded X layers to GPU"
    return llm

In [None]:
llm = load_model_with_mps()

In [5]:
def process_chunk(chunk, llm):
    """Process a single chunk of samples"""
    chunk_results = []
    for sample in chunk:
        try:
            prompt = f"""Please answer the question: {sample['question']}. These are the options {sample['options']}."""
            
            messages = [
                {"role": "system", "content": "You are a reasoning clinical assistant."},
                {"role": "user", "content": prompt}
            ]
            
            # Small delay to prevent MPS overload
            time.sleep(0.3)
            
            output = llm.create_chat_completion(
                messages=messages,
                # max_tokens=64,
                # temperature=0.0,
                stream=False
            )
            
            chunk_results.append({
                'id': sample['id'],
                'llm_answer': output['choices'][0]['message']['content']
            })
            print(f"Processed sample {sample['id']}")
            
        except Exception as e:
            print(f"Error in sample {sample['id']}: {str(e)}")
            chunk_results.append({'id': sample['id'], 'error': str(e)})
    
    return chunk_results

In [6]:
def split_into_chunks(data, num_chunks):
    """Split data into N equal chunks"""
    chunk_size = len(data) // num_chunks
    chunks = [data[i:i+chunk_size] for i in range(0, len(data), chunk_size)]
    # Handle any remaining samples in the last chunk
    if len(chunks) > num_chunks:
        chunks[num_chunks-1].extend(chunks[num_chunks:])
        chunks = chunks[:num_chunks]
    return chunks

In [7]:
def process_sample(sample, llm):
    """Process a single sample with the loaded model"""
    prompt = f"""Please answer the question: {sample['question']}. These are the options {sample['options']}."""
    
    messages = [
        {"role": "system", "content": "You are a reasoning clinical assistant."},
        {"role": "user", "content": prompt}
    ]
    
    # Generate answer with reduced max_tokens if possible (faster)
    output = llm.create_chat_completion(
        messages=messages,
        # max_tokens=512,  # 512 may be excessive for Q&A; test what works
        # temperature=0.0  # Deterministic output (faster than higher temps)
    )
    
    return {
        'id': sample['id'],
        'question_type': sample['question_type'],
        'question': sample['question'],
        'options': sample['options'],
        'llm_answer': output['choices'][0]['message']['content'],
        'llm_output': str(output)
    }

def main(data, output_json_path, num_chunks=2):  # num_chunks = desired parallelism
    # Load existing results
    if os.path.exists(output_json_path):
        with open(output_json_path, 'r', encoding='utf-8') as f:
            results = json.load(f)
    else:
        results = []
    
    # Filter out already processed samples
    processed_ids = {r['id'] for r in results}
    to_process = [s for s in data if s['id'] not in processed_ids]
    print(f"Need to process {len(to_process)} samples split into {num_chunks} chunks")

    # Split data into chunks for parallel processing
    chunks = split_into_chunks(to_process, num_chunks)

    # Load model once (shared across threads)
    print("Loading model...")
    llm = load_model_with_mps()
    time.sleep(2)  # Let MPS initialize

    # Process chunks in parallel
    with ThreadPoolExecutor(max_workers=num_chunks) as executor:
        # Map each chunk to a processing task
        futures = [executor.submit(process_chunk, chunk, llm) for chunk in chunks]
        
        # Collect results from all chunks
        for future in futures:
            chunk_results = future.result()
            results.extend(chunk_results)
            
            # Save progress after each chunk is done
            with open(output_json_path, 'w', encoding='utf-8') as f:
                json.dump(results, f, indent=4, ensure_ascii=False)
            print(f"Completed a chunk. Total processed: {len(results)}")

    # Final save
    with open(output_json_path, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=4, ensure_ascii=False)
    print(f"All done! Total results: {len(results)}")

In [9]:
# Load your test data here (replace with actual data loading)
with open('data/curebench_testset_phase1_mini.jsonl', 'r', encoding='utf-8') as file:
    data = [json.loads(line) for line in file]

output_path = "data/curebench_results_test.json"

In [7]:
# temp = process_sample(data[1], llm)

In [10]:
# main(data, output_path, max_workers=2) # leads to overkill
main(data, output_path, num_chunks=2)

Need to process 9 samples split into 2 chunks
Loading model...


llama_context: n_batch is less than GGML_KQ_MASK_PAD - increasing to 64
llama_context: n_ctx_per_seq (512) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
ggml_metal_init: skipping kernel_get_rows_bf16                     (not supported)
ggml_metal_init: skipping kernel_set_rows_bf16                     (not supported)
ggml_metal_init: skipping kernel_mul_mv_bf16_f32                   (not supported)
ggml_metal_init: skipping kernel_mul_mv_bf16_f32_c4                (not supported)
ggml_metal_init: skipping kernel_mul_mv_bf16_f32_1row              (not supported)
ggml_metal_init: skipping kernel_mul_mv_bf16_f32_l4                (not supported)
ggml_metal_init: skipping kernel_mul_mv_bf16_bf16                  (not supported)
ggml_metal_init: skipping kernel_mul_mv_id_bf16_f32                (not supported)
ggml_metal_init: skipping kernel_mul_mm_bf16_f32                   (not supported)
ggml_metal_init: skipping kernel_mul_mm_id_bf16_f16                (

Error in sample QjaQymRAabrS: llama_decode returned -1


: 