In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

import nest_asyncio


sys.path.insert(0, os.path.abspath('..'))
nest_asyncio.apply()

In [2]:
import logging


logging.basicConfig(
    level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s'
)

### Prepare input file

In [None]:
from decouple import config
from huggingface_hub import AsyncInferenceClient


HUGGINGFACE_TOKEN = config('HUGGINGFACE_TOKEN')

MODEL_HUB_ID = 'microsoft/Phi-3-mini-4k-instruct'

client = AsyncInferenceClient(
    model=MODEL_HUB_ID,
    provider='hf-inference',
    timeout=None,
    api_key=HUGGINGFACE_TOKEN,
)

In [10]:
from math_rag.application.base.assistants import BaseAssistantInput, BaseAssistantOutput


class SomeInput(BaseAssistantInput):
    pass


class SomeOutput(BaseAssistantOutput):
    result: int

In [11]:
import json

from pathlib import Path

from math_rag.application.models.inference import (
    LLMBatchRequest,
    LLMBatchResult,
    LLMConversation,
    LLMMessage,
    LLMParams,
    LLMRequest,
)
from math_rag.infrastructure.mappings.inference.huggingface import (
    LLMRequestMapping,
    LLMResponseListMapping,
)


MODEL_HUB_ID = 'microsoft/Phi-3-mini-4k-instruct'
some_input = SomeInput()

# request = LLMRequest(
#     conversation=LLMConversation(
#         messages=[
#             LLMMessage(role='system', content='You are a helpful assistant.'),
#             LLMMessage(role='user', content='what is 2+2'),
#         ]
#     ),
#     params=LLMParams(
#         model=MODEL_HUB_ID,
#         temperature=0,
#         response_type=SomeOutput,
#         max_completion_tokens=10,
#         metadata={'input_id': str(some_input.id)},
#     ),
# )

# batch_request: LLMBatchRequest = LLMBatchRequest(requests=[request])

_requests = [
    LLMRequest(
        conversation=LLMConversation(
            messages=[
                LLMMessage(role='system', content='You are a helpful assistant.'),
                LLMMessage(role='user', content=f'what is {i}+2'),
            ]
        ),
        params=LLMParams(
            model=MODEL_HUB_ID,
            temperature=0,
            response_type=SomeOutput,
            max_completion_tokens=10,
            metadata={'input_id': str(some_input.id)},
        ),
    )
    for i in range(50)
]

batch_request: LLMBatchRequest = LLMBatchRequest(requests=_requests)

In [5]:
requests = [LLMRequestMapping.to_target(request) for request in batch_request.requests]
lines = [json.dumps(request, separators=(',', ':')) for request in requests]
jsonl_str = '\n'.join(lines)
jsonl_bytes = jsonl_str.encode('utf-8')

In [7]:
request_dict = json.loads(lines[0])

In [8]:
result = await client.chat_completion(**request_dict)
result

ChatCompletionOutput(choices=[ChatCompletionOutputComplete(finish_reason='stop', index=0, message=ChatCompletionOutputMessage(role='assistant', content='{ "result": 4 }', tool_call_id=None, tool_calls=None), logprobs=None)], created=1744125157, id='', model='microsoft/Phi-3-mini-4k-instruct', system_fingerprint='3.2.1-native', usage=ChatCompletionOutputUsage(completion_tokens=9, prompt_tokens=17, total_tokens=26), object='chat.completion')

In [40]:
response_list = LLMResponseListMapping.to_source(
    result,
    request_id=request.id,
    input_id=request_dict['extra_body']['input_id'],
    response_type=SomeOutput,
)
response_list

LLMResponseList(id=UUID('f7fd3b29-2772-468d-831f-b160bf1dfaa0'), request_id=UUID('cb02d5a3-bed8-4109-b06c-04d585081ef3'), responses=[LLMResponse(id=UUID('3976268b-adbf-494a-b88b-fda24781d477'), content=BoundAssistantOutput(id=UUID('bd7fdedd-f32a-4d7f-845d-0a21413828b6'), input_id=UUID('c882baa3-2047-4f65-b8df-bd176bba3b2c'), result=4), logprobs=None)])

In [None]:
input_file_path = Path(f'.tmp/input_{batch_request.id}.jsonl')

with open(input_file_path, 'w') as input_file:
    for line in lines:
        input_file.write(line + '\n')

### Apptainer

In [3]:
from math_rag.infrastructure.containers import InfrastructureContainer


infrastructure_container = InfrastructureContainer()
infrastructure_container.init_resources()

tgi_batch_llm = infrastructure_container.tgi_batch_llm()

* 'allow_population_by_field_name' has been renamed to 'validate_by_name'


In [7]:
await tgi_batch_llm.init_resources()

2025-05-08 14:51:17,631 - INFO - Host canonicalization disabled
2025-05-08 14:51:17,631 - INFO - Opening SSH connection to login-gpu.hpc.srce.hr, port 22


2025-05-08 14:51:17,646 - INFO - [conn=7] Connected to SSH server at login-gpu.hpc.srce.hr, port 22
2025-05-08 14:51:17,646 - INFO - [conn=7]   Local address: 172.18.0.6, port 44886
2025-05-08 14:51:17,646 - INFO - [conn=7]   Peer address: 161.53.2.37, port 22
2025-05-08 14:51:17,679 - INFO - [conn=7] Beginning auth for user lpanic
2025-05-08 14:51:17,949 - INFO - [conn=7] Auth for user lpanic succeeded
2025-05-08 14:51:17,951 - INFO - [conn=7, chan=0] Requesting new SSH session
2025-05-08 14:51:17,990 - INFO - [conn=7, chan=0]   Command: mkdir -p tgi_default_root
2025-05-08 14:51:18,137 - INFO - [conn=7, chan=0] Received exit status 0
2025-05-08 14:51:18,139 - INFO - [conn=7, chan=0] Received channel close
2025-05-08 14:51:18,140 - INFO - [conn=7, chan=0] Channel closed
2025-05-08 14:51:18,141 - INFO - Command `mkdir -p tgi_default_root` in `run` returned stdout: 
2025-05-08 14:51:18,142 - INFO - [conn=7] Closing connection
2025-05-08 14:51:18,143 - INFO - [conn=7] Sending disconnect:

In [12]:
res = await tgi_batch_llm.batch_generate(
    batch_request=batch_request,
    response_type=SomeOutput,
    poll_interval=3 * 60,
    max_tokens_per_day=None,
    max_num_retries=0,
)
res

2025-05-08 15:00:25,922 - INFO - Host canonicalization disabled
2025-05-08 15:00:25,923 - INFO - Opening SSH connection to login-gpu.hpc.srce.hr, port 22
2025-05-08 15:00:25,936 - INFO - [conn=47] Connected to SSH server at login-gpu.hpc.srce.hr, port 22
2025-05-08 15:00:25,936 - INFO - [conn=47]   Local address: 172.18.0.6, port 52258
2025-05-08 15:00:25,937 - INFO - [conn=47]   Peer address: 161.53.2.37, port 22
2025-05-08 15:00:25,972 - INFO - [conn=47] Beginning auth for user lpanic
2025-05-08 15:00:26,241 - INFO - [conn=47] Auth for user lpanic succeeded
2025-05-08 15:00:26,244 - INFO - [conn=47, chan=0] Requesting new SSH session
2025-05-08 15:00:26,279 - INFO - [conn=47, chan=0]   Subsystem: sftp
2025-05-08 15:00:26,287 - INFO - [conn=47, chan=0] Starting SFTP client
2025-05-08 15:00:26,477 - INFO - [conn=47, chan=0] Received exit status 0
2025-05-08 15:00:26,478 - INFO - [conn=47, chan=0] Received channel close
2025-05-08 15:00:26,478 - INFO - [conn=47, chan=0] SFTP client exit

LLMBatchResult(id=UUID('375028c7-7120-4f0f-ba49-12245ebf6485'), batch_request_id=UUID('2189fcf7-0b5d-4d18-b9ee-14824b3d18b1'), response_lists=[LLMResponseList(id=UUID('a78e5a38-a166-4ace-8952-dd933c0b474b'), request_id=UUID('34b13a03-02a6-46da-94f3-28f544c99f4f'), responses=[LLMResponse(id=UUID('32ea832a-eb99-4cf1-8a35-1d885f452966'), content=BoundAssistantOutput(id=UUID('0863a0b2-4403-4325-bff9-5bbd4a59dfbc'), input_id=UUID('b6b2029d-c3ae-444b-abbf-812e615f134e'), result=5), logprobs=None)]), LLMResponseList(id=UUID('f74d1919-f994-418c-8537-79a05043af83'), request_id=UUID('53a2b23e-dfff-4e93-9545-5a504210b53f'), responses=[LLMResponse(id=UUID('549177e0-95a5-401a-b5b5-a651e33dc335'), content=BoundAssistantOutput(id=UUID('6614e76c-2e98-41db-96b2-b575d8ee7590'), input_id=UUID('b6b2029d-c3ae-444b-abbf-812e615f134e'), result=3), logprobs=None)]), LLMResponseList(id=UUID('6918f6a7-b59a-47f8-9126-dcab5e4c87ba'), request_id=UUID('0f1e97e9-1440-43c6-bc93-4da7cd5651b5'), responses=[LLMResponse(

In [None]:
from os import environ

from decouple import config


environ['PBS_O_WORKDIR'] = '../.tmp'
environ['TGI_API_KEY'] = config('HUGGINGFACE_TOKEN')
environ['MODEL_HUB_ID'] = 'microsoft/Phi-3-mini-4k-instruct'

%run ../assets/hpc/hf/tgi/tgi_client.py

In [13]:
import asyncio
import threading


# NOTE: prometheus_snapshot_loader_service.load will be run as a background task in fast api
def run_async_in_thread(coro):
    def runner():
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        loop.run_until_complete(coro)
        loop.close()

    thread = threading.Thread(target=runner)
    thread.start()

    return thread


prometheus_snapshot_loader_service = (
    infrastructure_container.prometheus_snapshot_loader_service()
)
run_async_in_thread(prometheus_snapshot_loader_service.load())

<Thread(Thread-6 (runner), started 281472322564480)>

2025-05-08 15:08:12,898 - INFO - Host canonicalization disabled
2025-05-08 15:08:12,898 - INFO - Opening SSH connection to login-gpu.hpc.srce.hr, port 22


2025-05-08 15:08:12,913 - INFO - [conn=62] Connected to SSH server at login-gpu.hpc.srce.hr, port 22
2025-05-08 15:08:12,913 - INFO - [conn=62]   Local address: 172.18.0.6, port 47236
2025-05-08 15:08:12,913 - INFO - [conn=62]   Peer address: 161.53.2.37, port 22
2025-05-08 15:08:12,946 - INFO - [conn=62] Beginning auth for user lpanic
2025-05-08 15:08:13,213 - INFO - [conn=62] Auth for user lpanic succeeded
2025-05-08 15:08:13,217 - INFO - [conn=62, chan=0] Requesting new SSH session
2025-05-08 15:08:13,249 - INFO - [conn=62, chan=0]   Command: find tgi_default_root -name "snapshot_*.json" -maxdepth 1 -print -quit
2025-05-08 15:08:13,398 - INFO - [conn=62, chan=0] Received exit status 0
2025-05-08 15:08:13,400 - INFO - [conn=62, chan=0] Received channel close
2025-05-08 15:08:13,402 - INFO - [conn=62, chan=0] Channel closed
2025-05-08 15:08:13,403 - INFO - Command `find tgi_default_root -name "snapshot_*.json" -maxdepth 1 -print -quit` in `run` returned stdout: tgi_default_root/snapsh