In [None]:
from glob import glob
import requests
import os
import json
import re
import subprocess
import signal
import time
from tqdm import tqdm
from threading import Thread
from queue import Queue
from multiprocess import Pool
import itertools

def chunks(l, n):
    for i in range(0, len(l), n):
        yield (l[i: i + n], i // n)

def multiprocessing(strings, function, cores=6, returned=True):
    df_split = chunks(strings, len(strings) // cores)
    pool = Pool(cores)
    pooled = pool.map(function, df_split)
    pool.close()
    pool.join()

    if returned:
        return list(itertools.chain(*pooled))

In [None]:
def loop(rows):
    rows, index = rows

    os.environ['CUDA_VISIBLE_DEVICES'] = str(index)
    
    for row in rows:
        model, port, folder = row
        os.makedirs(folder, exist_ok = True)
        cmd = f'/root/.venv/bin/vllm serve {model} --gpu-memory-utilization 0.95 --port {port}'
        process = subprocess.Popen(cmd.split(), stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
        try:
            url = f'http://localhost:{port}/docs'
            pbar = tqdm(range(1000))
            for i in pbar:
                try:
                    r = requests.get(url, timeout=5.0)
                    if r.status_code == 200:
                        break
                except Exception:
                    pass
                pbar.set_description(f"checking health {url} for {i} times")
                time.sleep(5.0)

            with open('MalayMMLU_0shot.json') as fopen:
                malaymmlu = json.load(fopen)
            
            questions = []
            for i in range(len(malaymmlu)):
                q = malaymmlu[i]['prompt']
                questions.append((i, q))

            def generate_answer(row, repeat = 5):
                no, q = row
                for k in range(repeat):
                    filename = os.path.join(folder, f'{no}-{k}.json')
                    try:
                        with open(filename) as fopen:
                            json.load(fopen)
                        continue
                    except:
                        pass
            
                    system = 'First, you try to think step-by-step in {{lang}}, after that, put your final answer within $\\boxed{}$.'
                    messages = [
                        {"role": "system", "content": system.replace('{{lang}}', 'malay')},
                        {"role": "user", "content": q},
                    ]
            
                    json_data = {
                        'model': model,
                        'messages': messages,
                        'max_tokens': 24000,
                    }
            
                    while True:
                        try:
                            response = requests.post(f'http://localhost:{port}/v1/chat/completions', json=json_data)
                            r = response.json()['choices'][0]['message']['reasoning_content'].strip()
                            answers = re.findall(r"\$boxed\{(.*?)\}\$", r)
                            if len(answers) == 1:
                                a = answers[0]
                                with open(filename, 'w') as fopen:
                                    json.dump(a, fopen)
                                    break
                        except:
                            pass

            def consumer(queue, name):
                while True:
                    if queue.qsize() == 0:
                        break
                    item = queue.get()
                    generate_answer(item)
                print(f'consumer {name} done')
            
            generate_answer(questions[0])
            queue = Queue()
            for u in questions:
                queue.put(u)
                
            ori_size = queue.qsize()
            max_worker = 30
            consumers = [Thread(target=consumer, args=(queue,i)) for i in range(max_worker)]
            for i in range(len(consumers)):
                consumers[i].start()
                
            pbar = tqdm(total=ori_size)
            last_size = 0
            while True:
                size = queue.qsize()
                if size == 0:
                    break
                left = ori_size - size
                minus = left - last_size
                if minus > 0:
                    pbar.update(minus)
                    last_size += minus
            
            pbar.close()
            
        finally:
            if process.poll() is None:
                print("Cleaning up vLLM process...")
                process.send_signal(signal.SIGINT)
                try:
                    process.wait(timeout=10)
                except subprocess.TimeoutExpired:
                    print("Force killing vLLM process...")
                    process.kill()

In [None]:
rows = [
    ['malaysian-reasoning-20b-lora-r16-merged', 8000, 'malaymmlu-20b-r16'],
    ['malaysian-reasoning-20b-lora-r32-merged', 8001, 'malaymmlu-20b-r32'],
    ['malaysian-reasoning-20b-lora-r64-merged', 8002, 'malaymmlu-20b-r64'],
    ['malaysian-reasoning-20b-lora-r128-merged', 8003, 'malaymmlu-20b-r128'],
    ['malaysian-reasoning-20b-lora-r256-merged', 8004, 'malaymmlu-20b-r256'],
    ['malaysian-reasoning-20b-lora-r512-merged', 8005, 'malaymmlu-20b-r512'],
]

In [None]:
multiprocessing(rows, loop, len(rows), returned = False)