In [1]:
#!/usr/bin/env python3


Chain-of-Thought (CoT) Generation Script
=======================================

This script generates Chain-of-Thought reasoning for medical chatbot training data.
It supports multiple AI models and includes robust error handling, batching, and retry logic.

Supported Models:
- OpenAI GPT-4, GPT-3.5-turbo
- Anthropic Claude models (via API)
- Hugging Face models (local or API)
- Ollama models (local)

In [3]:
%pip install backoff

Collecting backoff
  Downloading backoff-2.2.1-py3-none-any.whl.metadata (14 kB)
Downloading backoff-2.2.1-py3-none-any.whl (15 kB)
Installing collected packages: backoff
Successfully installed backoff-2.2.1


In [4]:
import json
import os
import time
import asyncio
import aiohttp
import argparse
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, field
import logging
from pathlib import Path
import backoff
from tqdm.asyncio import tqdm
import openai
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch

In [5]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Configuration class for CoT generation.
@dataclass
class CoTConfig:
    model_name: str = "gpt-3.5-turbo"
    api_key: Optional[str] = None
    batch_size: int = 10
    max_retries: int = 3
    retry_delay: float = 1.0
    max_tokens: int = 512
    temperature: float = 0.7
    concurrent_requests: int = 5
    rate_limit_delay: float = 0.1
    local_model_path: Optional[str] = None
    use_gpu: bool = True

    # API endpoints.
    openai_base_url: str = "https://api.openai.com/v1"
    anthropic_base_url: str = "https://api.anthropic.com/v1"
    huggingface_base_url: str = "https://api-inference.huggingface.co/models"
    ollama_base_url: str = "http://localhost:11434/api"

In [7]:
# Generating Chain-of-Thought reasoning.
class CoTGenerator:

    def __init__(self, config: CoTConfig):
        self.config = config
        self.session = None
        self.local_model = None
        self.local_tokenizer = None
        self.semaphore = asyncio.Semaphore(config.concurrent_requests)

        # Set up API.
        self.setup_apis()

# API clients and local models.
    def setup_apis(self):
        if self.config.model_name.startswith(('gpt-', 'text-davinci')):
            if not self.config.api_key:
                self.config.api_key = os.getenv('OPENAI_API_KEY')
            if not self.config.api_key:
                logger.warning("OpenAI API key not found.")
            else:
                openai.api_key = self.config.api_key

        elif self.config.model_name.startswith('claude'):
            if not self.config.api_key:
                self.config.api_key = os.getenv('ANTHROPIC_API_KEY')
            if not self.config.api_key:
                logger.warning("Anthropic API key not found. Set ANTHROPIC_API_KEY environment variable.")

        elif self.config.local_model_path or self.config.model_name in ['phi-2', 'phi-3']:
            self.setup_local_model()

    # Set up local model for inference.
    def setup_local_model(self):
        try:
            model_path = self.config.local_model_path or self.config.model_name

            logger.info(f"Loading local model: {model_path}")

            device = "cuda" if self.config.use_gpu and torch.cuda.is_available() else "cpu"

            self.local_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
            self.local_model = AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.float16 if device == "cuda" else torch.float32,
                device_map="auto" if device == "cuda" else None,
                trust_remote_code=True
            )

            if device == "cpu":
                self.local_model = self.local_model.to(device)

            logger.info(f"Local model loaded successfully on {device}")

        except Exception as e:
            logger.error(f"Failed to load local model: {e}")
            raise

# Prompt for generating CoT reasoning. Instruction: The original instruction/question. Answer: The expected answer. Returns: Formatted prompt for CoT generation.
    def create_cot_prompt(self, instruction: str, answer: str) -> str:
        prompt = f"""You are a medical expert assistant. For the given medical question and answer, provide detailed Chain-of-Thought reasoning that explains the step-by-step thinking process leading to the answer.

Your reasoning should:
1. Break down the problem systematically
2. Identify key medical concepts and symptoms
3. Consider differential diagnoses where applicable
4. Explain the logical progression of thoughts
5. Be clear, educational, and medically accurate

Question: {instruction}

Expected Answer: {answer}

Please provide the Chain-of-Thought reasoning that leads to this answer. Start your response with "Chain-of-Thought:" and then provide the detailed reasoning.

Chain-of-Thought:"""

        return prompt

    # Generate CoT using OpenAI API.
    @backoff.on_exception(
        backoff.expo,
        (aiohttp.ClientError, asyncio.TimeoutError, openai.RateLimitError),
        max_tries=3
    )
    async def generate_cot_openai(self, prompt: str) -> str:
        try:
            response = await openai.ChatCompletion.acreate(
                model=self.config.model_name,
                messages=[{"role": "user", "content": prompt}],
                max_tokens=self.config.max_tokens,
                temperature=self.config.temperature,
                timeout=30
            )

            return response.choices[0].message.content.strip()

        except Exception as e:
            logger.error(f"OpenAI API error: {e}")
            raise

    @backoff.on_exception(
        backoff.expo,
        (aiohttp.ClientError, asyncio.TimeoutError),
        max_tries=3
    )
    async def generate_cot_anthropic(self, prompt: str) -> str:
       # Anthropic API.
        headers = {
            "Content-Type": "application/json",
            "x-api-key": self.config.api_key,
            "anthropic-version": "2023-06-01"
        }

        payload = {
            "model": self.config.model_name,
            "max_tokens": self.config.max_tokens,
            "temperature": self.config.temperature,
            "messages": [{"role": "user", "content": prompt}]
        }

        async with self.session.post(
            f"{self.config.anthropic_base_url}/messages",
            headers=headers,
            json=payload,
            timeout=30
        ) as response:
            if response.status == 200:
                result = await response.json()
                return result["content"][0]["text"].strip()
            else:
                error_text = await response.text()
                raise aiohttp.ClientError(f"Anthropic API error: {response.status} - {error_text}")

    @backoff.on_exception(
        backoff.expo,
        (aiohttp.ClientError, asyncio.TimeoutError),
        max_tries=3
    )
    async def generate_cot_huggingface(self, prompt: str) -> str:
        # Hugging Face API.
        headers = {
            "Authorization": f"Bearer {self.config.api_key}",
            "Content-Type": "application/json"
        }

        payload = {
            "inputs": prompt,
            "parameters": {
                "max_new_tokens": self.config.max_tokens,
                "temperature": self.config.temperature,
                "return_full_text": False
            }
        }

        async with self.session.post(
            f"{self.config.huggingface_base_url}/{self.config.model_name}",
            headers=headers,
            json=payload,
            timeout=30
        ) as response:
            if response.status == 200:
                result = await response.json()
                return result[0]["generated_text"].strip()
            else:
                error_text = await response.text()
                raise aiohttp.ClientError(f"Hugging Face API error: {response.status} - {error_text}")

    @backoff.on_exception(
        backoff.expo,
        (aiohttp.ClientError, asyncio.TimeoutError),
        max_tries=3
    )
    async def generate_cot_ollama(self, prompt: str) -> str:
        # Ollama API.
        payload = {
            "model": self.config.model_name,
            "prompt": prompt,
            "stream": False,
            "options": {
                "temperature": self.config.temperature,
                "num_predict": self.config.max_tokens
            }
        }

        async with self.session.post(
            f"{self.config.ollama_base_url}/generate",
            json=payload,
            timeout=60
        ) as response:
            if response.status == 200:
                result = await response.json()
                return result["response"].strip()
            else:
                error_text = await response.text()
                raise aiohttp.ClientError(f"Ollama API error: {response.status} - {error_text}")

    def generate_cot_local(self, prompt: str) -> str:
        #Local model.
        try:
            inputs = self.local_tokenizer.encode(prompt, return_tensors="pt")

            if self.config.use_gpu and torch.cuda.is_available():
                inputs = inputs.to("cuda")

            with torch.no_grad():
                outputs = self.local_model.generate(
                    inputs,
                    max_new_tokens=self.config.max_tokens,
                    temperature=self.config.temperature,
                    do_sample=True,
                    pad_token_id=self.local_tokenizer.eos_token_id
                )

            # Decode only the generated part.
            generated_tokens = outputs[0][inputs.shape[1]:]
            response = self.local_tokenizer.decode(generated_tokens, skip_special_tokens=True)

            return response.strip()

        except Exception as e:
            logger.error(f"Local model error: {e}")
            raise

    # CoT for a single sample.
    async def generate_single_cot(self, sample: Dict[str, Any]) -> Dict[str, Any]:
        async with self.semaphore:
            try:
                instruction = sample.get("instruction", "")
                answer = sample.get("output", "")

                if not instruction or not answer:
                    logger.warning(f"Skipping sample with missing instruction or answer")
                    return {**sample, "cot_reasoning": ""}

                prompt = self.create_cot_prompt(instruction, answer)

                # Route to the model.
                if self.config.model_name.startswith(('gpt-', 'text-davinci')):
                    cot_reasoning = await self.generate_cot_openai(prompt)
                elif self.config.model_name.startswith('claude'):
                    cot_reasoning = await self.generate_cot_anthropic(prompt)
                elif self.config.model_name.startswith('hf-'):
                    cot_reasoning = await self.generate_cot_huggingface(prompt)
                elif self.config.model_name.startswith('ollama-'):
                    cot_reasoning = await self.generate_cot_ollama(prompt)
                else:
                    # Local model.
                    cot_reasoning = self.generate_cot_local(prompt)

                # Cleaning up the reasoning.
                if cot_reasoning.startswith("Chain-of-Thought:"):
                    cot_reasoning = cot_reasoning[len("Chain-of-Thought:"):].strip()

                # Add rate limiting delay.
                await asyncio.sleep(self.config.rate_limit_delay)

                return {**sample, "cot_reasoning": cot_reasoning}

            except Exception as e:
                logger.error(f"Error generating CoT for sample: {e}")
                return {**sample, "cot_reasoning": f"Error: {str(e)}"}

    # Process a batch of samples.
    async def process_batch(self, batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        if self.local_model is None:
            # Use async for API calls
            tasks = [self.generate_single_cot(sample) for sample in batch]
            return await asyncio.gather(*tasks, return_exceptions=True)
        else:
            # Use sync for local model
            results = []
            for sample in batch:
                try:
                    result = await self.generate_single_cot(sample)
                    results.append(result)
                except Exception as e:
                    logger.error(f"Error processing sample: {e}")
                    results.append({**sample, "cot_reasoning": f"Error: {str(e)}"})
            return results

    async def generate_cot_for_dataset(self, input_file: str, output_file: str):
        """Generate CoT for entire dataset."""
        logger.info(f"Loading dataset from {input_file}")

        # Load the dataset.
        with open(input_file, 'r', encoding='utf-8') as f:
            dataset = json.load(f)

        logger.info(f"Loaded {len(dataset)} samples")

        # Create batches.
        batches = [
            dataset[i:i + self.config.batch_size]
            for i in range(0, len(dataset), self.config.batch_size)
        ]

        logger.info(f"Processing {len(batches)} batches of size {self.config.batch_size}")

        # Set up session for API calls.
        if self.local_model is None:
            connector = aiohttp.TCPConnector(limit=self.config.concurrent_requests)
            timeout = aiohttp.ClientTimeout(total=60)
            self.session = aiohttp.ClientSession(connector=connector, timeout=timeout)

        try:
            results = []

            # Process batches with progress bar.
            with tqdm(total=len(dataset), desc="Generating CoT") as pbar:
                for batch in batches:
                    batch_results = await self.process_batch(batch)

                    # Handle exceptions in results.
                    for result in batch_results:
                        if isinstance(result, Exception):
                            logger.error(f"Batch processing error: {result}")
                            results.append({"error": str(result)})
                        else:
                            results.append(result)

                    pbar.update(len(batch))

                    if len(results) % (self.config.batch_size * 10) == 0:
                        self.save_intermediate_results(results, output_file)

            # Final results.
            logger.info(f"Saving final results to {output_file}")
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(results, f, indent=2, ensure_ascii=False)

            # Statistics.
            self.generate_statistics(results, input_file, output_file)

        finally:
            if self.session:
                await self.session.close()

    # Intermediate results.
    def save_intermediate_results(self, results: List[Dict[str, Any]], output_file: str):
        intermediate_file = output_file.replace('.json', '_intermediate.json')
        with open(intermediate_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, indent=2, ensure_ascii=False)
        logger.info(f"Saved intermediate results: {len(results)} samples")

    # Save statistics about CoT generation.
    def generate_statistics(self, results: List[Dict[str, Any]], input_file: str, output_file: str):
        stats = {
            "total_samples": len(results),
            "successful_generations": sum(1 for r in results if r.get("cot_reasoning") and not r.get("cot_reasoning", "").startswith("Error:")),
            "failed_generations": sum(1 for r in results if r.get("cot_reasoning", "").startswith("Error:") or not r.get("cot_reasoning")),
            "average_cot_length": 0,
            "model_config": {
                "model_name": self.config.model_name,
                "max_tokens": self.config.max_tokens,
                "temperature": self.config.temperature,
                "batch_size": self.config.batch_size
            }
        }

        # Calculate the average of CoT length.
        valid_cots = [r.get("cot_reasoning", "") for r in results if r.get("cot_reasoning") and not r.get("cot_reasoning", "").startswith("Error:")]
        if valid_cots:
            stats["average_cot_length"] = sum(len(cot.split()) for cot in valid_cots) / len(valid_cots)

        # Statistics.
        stats_file = output_file.replace('.json', '_cot_stats.json')
        with open(stats_file, 'w', encoding='utf-8') as f:
            json.dump(stats, f, indent=2, ensure_ascii=False)

        logger.info(f"CoT generation statistics saved to {stats_file}")
        logger.info(f"Success rate: {stats['successful_generations']}/{stats['total_samples']} ({stats['successful_generations']/stats['total_samples']*100:.1f}%)")

In [10]:
async def main():
    parser = argparse.ArgumentParser(description='Generate Chain-of-Thought reasoning for medical chatbot training data')

    parser.add_argument('--input_file', type=str, required=True,
                       help='Path to input JSON file (e.g., train.json)')
    parser.add_argument('--output_file', type=str, required=True,
                       help='Path to output JSON file (e.g., train_with_cot.json)')
    parser.add_argument('--model', type=str, default='gpt-3.5-turbo',
                       help='Model to use for CoT generation')
    parser.add_argument('--api_key', type=str,
                       help='API key for the model (can also use environment variables)')
    parser.add_argument('--batch_size', type=int, default=10,
                       help='Batch size for processing')
    parser.add_argument('--max_tokens', type=int, default=512,
                       help='Maximum tokens for CoT generation')
    parser.add_argument('--temperature', type=float, default=0.7,
                       help='Temperature for text generation')
    parser.add_argument('--concurrent_requests', type=int, default=5,
                       help='Number of concurrent API requests')
    parser.add_argument('--local_model_path', type=str,
                       help='Path to local model for inference')
    parser.add_argument('--use_gpu', action='store_true', default=True,
                       help='Use GPU for local model inference')

    # In a Colab environment, we need to pass arguments differently
    # as argparse expects command-line arguments.
    # We will parse known arguments and ignore the rest.
    args, unknown = parser.parse_known_args()


    if not os.path.exists(args.input_file):
        logger.error(f"Input file not found: {args.input_file}")
        return

    os.makedirs(os.path.dirname(args.output_file), exist_ok=True)

    config = CoTConfig(
        model_name=args.model,
        api_key=args.api_key,
        batch_size=args.batch_size,
        max_tokens=args.max_tokens,
        temperature=args.temperature,
        concurrent_requests=args.concurrent_requests,
        local_model_path=args.local_model_path,
        use_gpu=args.use_gpu
    )

    generator = CoTGenerator(config)

    try:
        # Generate CoT for dataset.
        await generator.generate_cot_for_dataset(args.input_file, args.output_file)

        logger.info("CoT generation completed successfully!")

    except Exception as e:
        logger.error(f"Error during CoT generation: {e}")
        raise


if __name__ == "__main__":
    # Check if there's an existing event loop (like in Colab)
    try:
        loop = asyncio.get_running_loop()
        if loop.is_running():
            # If in an environment with a running loop, use await
            import nest_asyncio
            nest_asyncio.apply()
            await main()
        else:
            # Otherwise, use asyncio.run()
            asyncio.run(main())
    except RuntimeError:
        # If get_running_loop() raises RuntimeError, it means no loop is running
        asyncio.run(main())

TypeError: main() got an unexpected keyword argument 'input_file'