In [1]:
!pip install -q azure-ai-inference

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m124.9/124.9 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.7/210.7 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# Author: -
# 2025-06-01

import os
import time
import logging
import asyncio
import concurrent.futures
from typing import List, Dict, Tuple, Any, Optional
from functools import partial
from concurrent.futures import ProcessPoolExecutor
from logging import FileHandler, StreamHandler, Logger


import re
import pandas as pd
from tqdm import tqdm
from pydantic import BaseModel
from openai import AzureOpenAI, OpenAI, AsyncAzureOpenAI
from azure.ai.inference import ChatCompletionsClient
from azure.core.credentials import AzureKeyCredential



class RawData(BaseModel):
    query: str
    prompt: List[Dict[str, str]]
    context: Optional[str] = None
    g_t: Any  # ground_truth
    t_p: Optional[float]=1.0

class RP(BaseModel):
    query: str
    context: Optional[str] = None
    g_t: Any  # ground_truth
    r_s: Optional[str] = None
    answer: Any
    temperature: float = 1.0
    tokens: float = 0
    score: Optional[Any] = None

class Doraemon:
    """
    Author: -
    Date: 2025-06-01
    Unified client for AzureOpenAI and OpenAI.
    """
    _client: Any = None
    _async_client: Any = None
    _provider: str = 'gpt3'  # or 'llama3'

    # Cache for secrets
    _secrets_cache: Dict[str, str] = {
        # Kaggle secret web service is not good performance for large skill tasks.
        "KSUG_AZURE_ENDPOINT_URL":"",
        "KSUG_AZURE_OPENAI_API_KEY":"",
        "KSUG_AZURE_API_VERSION":"",
        "KSUG_GPT35_TURBO":"",
        "AZURE_ENDPOINT": "",
        "AZURE_API_KEY": "",
        "AZURE_API_VERSION": "",
        "LLAMA3": "",
        "MINISTRAL3B_ENDPOINT": "",
        "MINISTRAL3B_KEY": "",
        "MINISTRAL3B_API_VERSION": "",
        "MINISTRAL3B_MODEL": "",
    }

    _user_secrets = None  # class-level singleton for UserSecretsClient

    @classmethod
    def set_provider(cls, provider: str):
        provider = provider.lower()
        cls._provider = provider
        cls._client = None  # Reset client to force re-init
        cls._async_client = None

    @classmethod
    def _get_secret(cls, secret_label: str) -> str:
        """
        Retrieve secret only once and cache it.
        Uses singleton-style UserSecretsClient.
        """
        if cls._user_secrets is None:
            from kaggle_secrets import UserSecretsClient
            cls._user_secrets = UserSecretsClient()
        try:
            if secret_label not in cls._secrets_cache:
                cls._secrets_cache[secret_label] = cls._user_secrets.get_secret(secret_label)
        except Exception as e:
            pass
        return cls._secrets_cache[secret_label]

    @classmethod
    def _initialize_client(cls, logger: Logger) -> Any:
        if cls._client is not None:
            return cls._client
        if cls._provider == 'gpt3':
            cls._client = AzureOpenAI(
                azure_endpoint=cls._get_secret("KSUG_AZURE_ENDPOINT_URL"),
                api_key=cls._get_secret("KSUG_AZURE_OPENAI_API_KEY"),
                api_version=cls._get_secret("KSUG_AZURE_API_VERSION")
            )
        elif cls._provider == 'llama3':
            cls._client = AzureOpenAI(
                azure_endpoint=cls._get_secret("AZURE_ENDPOINT"),
                api_key=cls._get_secret("AZURE_API_KEY"),
                api_version=cls._get_secret("AZURE_API_VERSION")                
            )
        elif cls._provider == 'ministral3b':
            from azure.ai.inference import ChatCompletionsClient
            from azure.core.credentials import AzureKeyCredential

            cls._client = ChatCompletionsClient(
                endpoint=cls._get_secret("MINISTRAL3B_ENDPOINT"),
                credential=AzureKeyCredential(cls._get_secret("MINISTRAL3B_KEY")),
                api_version=cls._get_secret("MINISTRAL3B_API_VERSION")
            )
        elif cls._provider == 'cloud_flare':
            cls._client = OpenAI(
                api_key=cls._get_secret("WORKERS_AI_API_TOKEN"),
                base_url=cls._get_secret("CLOUD_FLARE_CLIENT_URL")
            )
        else:
            raise RuntimeError("Provider must be set to either 'azure' or 'cloud_flare'")
        return cls._client

    @classmethod
    def _initialize_aync_client(cls, logger=None) -> Any:
        """
        Initialize and return the async client for the configured provider/model.
        Sets cls._client if not already initialized.
        """
        if cls._async_client is not None:
            return cls._async_client
    
        if cls._provider == 'gpt3':
            from openai import AsyncAzureOpenAI
            cls._async_client = AsyncAzureOpenAI(
                azure_endpoint=cls._get_secret("KSUG_AZURE_ENDPOINT_URL"),
                api_key=cls._get_secret("KSUG_AZURE_OPENAI_API_KEY"),
                api_version=cls._get_secret("KSUG_AZURE_API_VERSION")
            )
        elif cls._provider == 'llama3':
            from openai import AsyncAzureOpenAI
            cls._async_client = AsyncAzureOpenAI(
                azure_endpoint=cls._get_secret("AZURE_ENDPOINT"),
                api_key=cls._get_secret("AZURE_API_KEY"),
                api_version=cls._get_secret("AZURE_API_VERSION")
            )
        elif cls._provider == 'cloud_flare':
            from openai import AsyncOpenAI
            cls._async_client = AsyncOpenAI(
                api_key=cls._get_secret("WORKERS_AI_API_TOKEN"),
                base_url=cls._get_secret("CLOUD_FLARE_CLIENT_URL")
            )
        else:
            raise RuntimeError("Provider must be set to a known value")
    
        if logger:
            logger.info(f"Initialized client for provider {cls._provider}")
        return cls._async_client


    @classmethod
    def inference(
        cls,
        logger: Logger,
        messages: List[Dict] = [{"role": "system", "content": 'You are a helpful AI assistant'}, {"role": "user", "content": 'hi'}],
        model: str = None,
        temperature: float = 1.0,
        max_tokens: int = 512,
        top_p: float = 1,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        stop: str = None,
        stream: bool = False
    ) -> Tuple[str, int]:
        """
        """
        client = cls._initialize_client(logger)

        if cls._provider == 'ministral3b':
            use_model = model or cls._get_secret("MINISTRAL3B_MODEL")
            response = client.complete(
                stream=stream,
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                model=use_model
            )
            content = response.choices[0].message.content
            tokens = getattr(response.usage, 'completion_tokens', 0)
            return content, tokens

        if cls._provider == 'gpt3':
            deployment = cls._get_secret("KSUG_GPT35_TURBO")
        elif cls._provider == 'llama3':
            deployment = cls._get_secret("LLAMA3")
        else:
            deployment = model  # Model name should be explicit for OpenAI
            if deployment is None:
                raise ValueError("Model name must be provided.")

        response = client.chat.completions.create(
            model=deployment,
            messages=messages,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            stop=stop,
            stream=stream
        )
        content = response.choices[0].message.content
        tokens = response.usage.completion_tokens
        return content, tokens

    @classmethod
    def token_log_prob(
        cls,
        logger: Logger,
        messages: List[Dict] = [{"role": "system", "content": 'You are a helpful AI assistant'}, {"role": "user", "content": 'hi'}],
        model: str = None,
        temperature: float = 1.0,
        max_tokens: int = 512,
        top_p: float = 0.9,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        stop: str = None,
        stream: bool = False,
        logprobs: bool = True
    ) -> Tuple[Dict[str, int], List[Dict[str, float]]]:
        """
        Get output tokens and logprobs for each generated token.
        Returns:
            - Dict[content, tokens_used]
            - List[{'token': ..., 'logprob': ...}]
        """
        client = cls._initialize_client(logger)
        
        if cls._provider == 'gpt3':
            deployment = cls._get_secret("KSUG_GPT35_TURBO")
        elif cls._provider == 'llama3':
            deployment = cls._get_secret("LLAMA3")
        else:
            deployment = model
            if deployment is None:
                raise ValueError("Model name must be provided.")
    
        response = client.chat.completions.create(
            model=deployment,
            messages=messages,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            stop=stop,
            stream=stream,
            logprobs=logprobs
        )
    
        choice = response.choices[0]
        content = choice.message.content
        tokens = response.usage.completion_tokens
    
        token_info = []
        # Defensive: check logprobs exists and content is not None
        if hasattr(choice, 'logprobs') and getattr(choice.logprobs, 'content', None) is not None:
            for t in choice.logprobs.content:
                token_info.append({
                    "token": t.token,
                    "logprob": float(getattr(t, "logprob", float('nan')))
                })
    
        return {content: tokens}, token_info

    @classmethod
    def get_logger(
        cls,
        name: str = __name__,
        level: int = logging.INFO,
        fmt: str = "%(asctime)s %(levelname)s %(message)s",
        logfile: str = "doraemon.log",
        file_mode: str = "a"
    ) -> Logger:
        logger = logging.getLogger(name)
        logger.setLevel(level)
        if not logger.handlers:
            console_h = StreamHandler()
            console_h.setFormatter(logging.Formatter(fmt))
            console_h.setLevel(level)
            logger.addHandler(console_h)
            file_h = FileHandler(logfile, mode=file_mode, encoding="utf-8", delay=False)
            file_h.setFormatter(logging.Formatter(fmt))
            file_h.setLevel(level)
            logger.addHandler(file_h)
        return logger

    @classmethod
    def get_answer(cls, raw_answer: str) -> str:
        match = re.search(r'\\boxed{([^}]+)}', raw_answer, flags=re.IGNORECASE)
        if match:
            return match.group(1).strip()
        text = str(raw_answer).strip()
        return text

    @classmethod
    def process_data(cls, args: Dict[str, Any], logger:Logger) -> Optional[Dict[str, Any]]:
        """
        Processes a single data item using Doraemon inference and builds the result.
        """
        try:
            r_s, tokens = cls.inference(
                logger=logger,
                messages=args.get('prompt'),
                temperature=args.get('t_p')
            )
            result = RP(
                query=args.get('query'),
                context=args.get('context'),
                r_s=r_s,
                answer=cls.get_answer(r_s),
                g_t=args.get('g_t'),
                temperature=float(args.get('t_p')),
                tokens=int(tokens),
                score=0.0
            )
            return result.model_dump()
        except Exception as e:
            logger.exception(f"Error processing question {args.get('query')} at temperature {args.get('t_p')} with exception {e}")
            return None

    @classmethod
    def run_parallel_inference(cls, tasks: List[Dict[str, Any]], logger:Logger, max_workers: int = 1) -> List[Optional[Dict[str, Any]]]:
        """
        Runs process_data in parallel over a list of tasks using ProcessPoolExecutor.
        """
        process_func = partial(cls.process_data, logger=logger)
        with ProcessPoolExecutor(max_workers=max_workers) as executor:
            results = list(tqdm(executor.map(process_func, tasks), total=len(tasks)))
        return results

    @classmethod
    def build_tasks(cls, df: pd.DataFrame) -> List[Dict[str, Any]]:
        """
        Convert expanded DataFrame to a list of model-valid dicts for RawData.
        """
        return [RawData.model_validate(item).model_dump() for item in df.to_dict(orient='records')]

    @classmethod
    async def async_inference(
        cls,
        logger: Logger,
        prompts: List[List[Dict]],
        temperatures: List[float] = None,
        max_tokens: int = 512,
        top_p: float = 1.0,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        stop: str = None,
        stream: bool = False,
        max_concurrent: int = 20,
        max_retries: int = 5,  # Set a max retry limit
        backoff_factor: float = 2  # Exponential backoff factor
    ) -> List[Tuple[str, int]]:
        """
        Concurrent async inference using AsyncAzureOpenAI with retry logic.
        Each prompt is a list of messages (List[Dict]).
        Accepts a temperatures list, one per prompt.
        """
        client = cls._initialize_aync_client(logger)
        if cls._provider == 'gpt3':
            deployment = cls._get_secret("KSUG_GPT35_TURBO")
        elif cls._provider == 'llama3':
            deployment = cls._get_secret("LLAMA3")
        else:
            deployment = model  # Model name should be explicit for OpenAI
            if deployment is None:
                raise ValueError("Model name must be provided.")
        
        semaphore = asyncio.Semaphore(max_concurrent)
    
        n = len(prompts)
        results: List[Tuple[str, int]] = [("", 0)] * n
    
        if temperatures is None:
            temperatures = [1.0] * n
    
        async def infer_single(idx: int, messages: List[Dict], temp: float) -> None:
            async with semaphore:
                retries = 0
                while retries < max_retries:
                    try:
                        resp = await client.chat.completions.create(
                            model=deployment,
                            messages=messages,
                            max_tokens=max_tokens,
                            temperature=temp,
                            top_p=top_p,
                            frequency_penalty=frequency_penalty,
                            presence_penalty=presence_penalty,
                            stop=stop,
                            stream=stream
                        )
                        content = resp.choices[0].message.content
                        tokens = getattr(resp.usage, "completion_tokens", 0)
                        results[idx] = (content, tokens)
                        break  # Exit retry loop on success
                    except Exception as e:
                        if "429" in str(e):  # Check if error is due to rate limit
                            retries += 1
                            backoff_time = backoff_factor ** retries
                            logger.error(f"Inference failed for index {idx}, retrying in {backoff_time}s: {e}")
                            await asyncio.sleep(backoff_time)  # Wait before retrying
                        else:
                            logger.error(f"Inference failed for index {idx}: {e}")
                            results[idx] = (None, 0)
                            break  # Exit retry loop on non-429 error
    
        tasks = [infer_single(i, msg, temperatures[i]) for i, msg in enumerate(prompts)]
        
        await asyncio.gather(*tasks)
        return results


    @classmethod
    async def run_parallel_inference_async(
        cls,
        tasks: List[Dict[str, Any]],
        logger: Logger,
        max_concurrent: int = 50,
        max_tokens: int = 512,
        top_p: float = 1.0,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        stop: str = None,
        stream: bool = False,
    ) -> List[Dict[str, Any]]:
        """
        Run async_inference over all tasks in parallel (batch async inference).
        Each task must have a 'prompt' key containing a List[Dict],
        and can have a 'temperature' key.
        Returns a list of model_dump() results from RP class, one per task.
        """
        prompts = [task['prompt'] for task in tasks]
        temperatures = [task.get('temperature', 1.0) for task in tasks]
        results = await cls.async_inference(
            logger=logger,
            prompts=prompts,
            temperatures=temperatures,
            max_tokens=max_tokens,
            top_p=top_p,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            stop=stop,
            stream=stream,
            max_concurrent=max_concurrent,
        )

        output = []
        for task, (r_s, tokens) in zip(tasks, results):
            result = RP(
                query=task.get('query'),
                context=task.get('context'),
                r_s=r_s,
                answer=cls.get_answer(r_s),
                g_t=task.get('g_t'),
                temperature=float(task.get('temperature', 1.0)),
                tokens=int(tokens),
                score=0.0
            )
            output.append(result.model_dump())
        return output