In [None]:
#!/usr/bin/env python3


Chain-of-Thought (CoT) Generation Script
=======================================

This script generates Chain-of-Thought reasoning for medical chatbot training data.
It supports multiple AI models and includes robust error handling, batching, and retry logic.

Supported Models:
- OpenAI GPT-4, GPT-3.5-turbo
- Anthropic Claude models (via API)
- Hugging Face models (local or API)
- Ollama models (local)

In [10]:
import subprocess
import sys

In [11]:
def install_packages():
    packages = [
        'backoff',
        'nest-asyncio',
        'aiohttp',
        'openai',
        'transformers',
        'torch',
        'tqdm'
    ]

    for package in packages:
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install", package])
            print(f"Installed {package}")
        except subprocess.CalledProcessError:
            print(f"Failed to install {package}")

In [12]:
install_packages()

Installed backoff
Installed nest-asyncio
Installed aiohttp
Installed openai
Installed transformers
Installed torch
Installed tqdm


In [13]:
import json
import os
import time
import asyncio
import aiohttp
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, field
import logging
from pathlib import Path
import backoff
from tqdm.asyncio import tqdm
import nest_asyncio
import openai

In [14]:
# For Colab compatibility.
nest_asyncio.apply()

In [15]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [24]:
@dataclass
class CoTConfig:
    model_name: str = "gpt-3.5-turbo"
    api_key: Optional[str] = None
    batch_size: int = 5
    max_retries: int = 3
    retry_delay: float = 1.0
    max_tokens: int = 512
    temperature: float = 0.7
    concurrent_requests: int = 3
    rate_limit_delay: float = 1.0
    local_model_path: Optional[str] = None
    use_gpu: bool = True

    openai_base_url: str = "https://api.openai.com/v1"
    anthropic_base_url: str = "https://api.anthropic.com/v1"

class CoTGenerator:
    def __init__(self, config: CoTConfig):
        self.config = config
        self.session = None
        self.semaphore = asyncio.Semaphore(config.concurrent_requests)
        self.setup_apis()

    def setup_apis(self):

        if self.config.model_name.startswith(('gpt-', 'text-davinci')):
            if not self.config.api_key:
                self.config.api_key = os.getenv('OPENAI_API_KEY')
            if not self.config.api_key:
                logger.warning("OpenAI API key not found. Please set it manually.")
            else:
                openai.api_key = self.config.api_key

        elif self.config.model_name.startswith('claude'):
            if not self.config.api_key:
                self.config.api_key = os.getenv('ANTHROPIC_API_KEY')
            if not self.config.api_key:
                logger.warning("Anthropic API key not found.")

    def create_cot_prompt(self, instruction: str, answer: str) -> str:
        """Create prompt for generating CoT reasoning"""
        prompt = f"""You are a medical expert assistant. For the given medical question and answer, provide detailed Chain-of-Thought reasoning that explains the step-by-step thinking process leading to the answer.

Your reasoning should:
1. Break down the problem systematically
2. Identify key medical concepts and symptoms
3. Consider differential diagnoses where applicable
4. Explain the logical progression of thoughts
5. Be clear, educational, and medically accurate

Question: {instruction}

Expected Answer: {answer}

Please provide the Chain-of-Thought reasoning that leads to this answer. Start your response with "Chain-of-Thought:" and then provide the detailed reasoning.

Chain-of-Thought:"""
        return prompt

    @backoff.on_exception(
        backoff.expo,
        (Exception,),
        max_tries=3
    )
    async def generate_cot_openai(self, prompt: str) -> str:
        try:
            client = openai.AsyncOpenAI(api_key=self.config.api_key)

            response = await client.chat.completions.create(
                model=self.config.model_name,
                messages=[{"role": "user", "content": prompt}],
                max_tokens=self.config.max_tokens,
                temperature=self.config.temperature,
                timeout=30
            )

            return response.choices[0].message.content.strip()

        except Exception as e:
            logger.error(f"OpenAI API error: {e}")
            raise

    @backoff.on_exception(
        backoff.expo,
        (aiohttp.ClientError, asyncio.TimeoutError),
        max_tries=3
    )
    async def generate_cot_anthropic(self, prompt: str) -> str:
        headers = {
            "Content-Type": "application/json",
            "x-api-key": self.config.api_key,
            "anthropic-version": "2023-06-01"
        }

        payload = {
            "model": self.config.model_name,
            "max_tokens": self.config.max_tokens,
            "temperature": self.config.temperature,
            "messages": [{"role": "user", "content": prompt}]
        }

        async with self.session.post(
            f"{self.config.anthropic_base_url}/messages",
            headers=headers,
            json=payload,
            timeout=30
        ) as response:
            if response.status == 200:
                result = await response.json()
                return result["content"][0]["text"].strip()
            else:
                error_text = await response.text()
                raise aiohttp.ClientError(f"Anthropic API error: {response.status} - {error_text}")

    async def generate_single_cot(self, sample: Dict[str, Any]) -> Dict[str, Any]:
        async with self.semaphore:
            try:
                instruction = sample.get("instruction", "")
                answer = sample.get("output", "")

                if not instruction or not answer:
                    logger.warning("Skipping sample with missing instruction or answer.")
                    return {**sample, "cot_reasoning": ""}

                prompt = self.create_cot_prompt(instruction, answer)

                # Route to appropriate model for CoT generation.
                if self.config.model_name.startswith(('gpt-', 'text-davinci')):
                    cot_reasoning = await self.generate_cot_openai(prompt)
                elif self.config.model_name.startswith('claude'):
                    cot_reasoning = await self.generate_cot_anthropic(prompt)
                else:
                    raise ValueError(f"Unsupported model: {self.config.model_name}")

                # Cleaning up the reasoning.
                if cot_reasoning.startswith("Chain-of-Thought:"):
                    cot_reasoning = cot_reasoning[len("Chain-of-Thought:"):].strip()

                # Add rate limit delay.
                await asyncio.sleep(self.config.rate_limit_delay)

                return {**sample, "cot_reasoning": cot_reasoning}

            except Exception as e:
                logger.error(f"Error generating CoT for sample: {e}")
                return {**sample, "cot_reasoning": f"Error: {str(e)}"}

    async def process_batch(self, batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        tasks = [self.generate_single_cot(sample) for sample in batch]
        return await asyncio.gather(*tasks, return_exceptions=True)

    async def generate_cot_for_dataset(self, input_file: str, output_file: str, limit: Optional[int] = None):
        logger.info(f"Loading dataset from {input_file}")

        # Loading the dataset.
        with open(input_file, 'r', encoding='utf-8') as f:
            dataset = json.load(f)

        if limit is not None:
            dataset = dataset[:limit]
            logger.info(f"Limiting dataset to {len(dataset)} samples")


        logger.info(f"Loaded {len(dataset)} samples")

        # Creating batches.
        batches = [
            dataset[i:i + self.config.batch_size]
            for i in range(0, len(dataset), self.config.batch_size)
        ]

        logger.info(f"Processing {len(batches)} batches of size {self.config.batch_size}")

        # Setting up session for API calls.
        connector = aiohttp.TCPConnector(limit=self.config.concurrent_requests)
        timeout = aiohttp.ClientTimeout(total=60)
        self.session = aiohttp.ClientSession(connector=connector, timeout=timeout)

        try:
            results = []

            # Processing batches.
            for i, batch in enumerate(batches):
                print(f"Processing batch {i+1}/{len(batches)}.")

                batch_results = await self.process_batch(batch)

                # Handling exceptions in results.
                for result in batch_results:
                    if isinstance(result, Exception):
                        logger.error(f"Batch processing error: {result}")
                        results.append({"error": str(result)})
                    else:
                        results.append(result)

                # Save results every 10 batches.
                if (i + 1) % 10 == 0:
                    self.save_intermediate_results(results, output_file)

                print(f"Completed {len(results)}/{len(dataset)} samples")

            # Final results.
            logger.info(f"Saving final results to {output_file}")
            os.makedirs(os.path.dirname(output_file) if os.path.dirname(output_file) else '.', exist_ok=True)

            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(results, f, indent=2, ensure_ascii=False)

            # Statistics.
            self.generate_statistics(results, input_file, output_file)

        finally:
            if self.session:
                await self.session.close()

    def save_intermediate_results(self, results: List[Dict[str, Any]], output_file: str):
        intermediate_file = output_file.replace('.json', '_intermediate.json')
        output_dir = os.path.dirname(intermediate_file)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(intermediate_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, indent=2, ensure_ascii=False)
        logger.info(f"Saved intermediate results: {len(results)} samples.")

    def generate_statistics(self, results: List[Dict[str, Any]], input_file: str, output_file: str):
        stats = {
            "total_samples": len(results),
            "successful_generations": sum(1 for r in results if r.get("cot_reasoning") and not r.get("cot_reasoning", "").startswith("Error:")),
            "failed_generations": sum(1 for r in results if r.get("cot_reasoning", "").startswith("Error:") or not r.get("cot_reasoning")),
            "average_cot_length": 0,
            "model_config": {
                "model_name": self.config.model_name,
                "max_tokens": self.config.max_tokens,
                "temperature": self.config.temperature,
                "batch_size": self.config.batch_size
            }
        }

        # Average CoT length.
        valid_cots = [r.get("cot_reasoning", "") for r in results if r.get("cot_reasoning") and not r.get("cot_reasoning", "").startswith("Error:")]
        if valid_cots:
            stats["average_cot_length"] = sum(len(cot.split()) for cot in valid_cots) / len(valid_cots)

        # Statistics.
        stats_file = output_file.replace('.json', '_cot_stats.json')
        output_dir = os.path.dirname(stats_file)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(stats_file, 'w', encoding='utf-8') as f:
            json.dump(stats, f, indent=2, ensure_ascii=False)

        logger.info(f"CoT generation statistics saved to {stats_file}")
        logger.info(f"Success rate: {stats['successful_generations']}/{stats['total_samples']} ({stats['successful_generations']/stats['total_samples']*100:.1f}%)")


def setup_colab_environment():
    try:
        import google.colab
        IN_COLAB = True
        print("Running in Google Colab environment.")

        # Google Drive.
        from google.colab import drive
        print("Mounting Google Drive.")
        drive.mount('/content/drive')
        print("Google Drive successful.")

    except ImportError:
        IN_COLAB = False
        print("Running in local environment.")

    return IN_COLAB


def find_train_json():
    search_paths = [
        '/content/drive/MyDrive/train.json',
        '/content/drive/My Drive/train.json',
        '/content/drive/MyDrive/data/train.json',
        '/content/drive/My Drive/data/train.json',
        '/content/train.json',
        './train.json',
        './data/train.json'
    ]

    for path in search_paths:
        if os.path.exists(path):
            print(f"Found train.json at: {path}")
            return path

    return None


# Function for CoT generation.
async def run_cot_generation():
    print("Chain-of-Thought Generation for our Medical Chatbot")
    print("=" * 50)

    # Environment.
    IN_COLAB = setup_colab_environment()

    train_file = find_train_json()

    if not train_file:
        print("Train.json file not found.")
        print("Please provide the path to the train.json file:")
        train_file = input("Train file path: ").strip()

        if not os.path.exists(train_file):
            print(f"File not found: {train_file}")
            return

    # Configuration
    print("\n📋 Configuration:")
    print("Available models:")
    print("  1. gpt-3.5-turbo (recommended, cheaper)")
    print("  2. gpt-4 (more expensive, better quality)")
    print("  3. claude-3-sonnet-20240229")

    model_choice = input("Choose model (1/2/3) or enter custom name: ").strip()

    if model_choice == "1":
        model_name = "gpt-3.5-turbo"
    elif model_choice == "2":
        model_name = "gpt-4"
    elif model_choice == "3":
        model_name = "claude-3-sonnet-20240229"
    elif model_choice:
        model_name = model_choice
    else:
        model_name = "gpt-3.5-turbo"

    print(f"✓ Selected model: {model_name}")

    # API Key
    api_key = None
    if model_name.startswith('gpt'):
        print("\n🔑 API Key Setup:")
        print("You need an OpenAI API key for GPT models.")
        print("Get one at: https://platform.openai.com/api-keys")

        api_key = input("Enter your OpenAI API Key (starts with sk-): ").strip()
        if not api_key:
            api_key = os.getenv('OPENAI_API_KEY')

        if not api_key:
            print("❌ Error: OpenAI API key is required for GPT models!")
            return
        elif not api_key.startswith('sk-'):
            print("⚠️ Warning: OpenAI API keys usually start with 'sk-'")
            confirm = input("Continue anyway? (y/n): ").strip().lower()
            if confirm != 'y':
                return

    elif model_name.startswith('claude'):
        print("\n🔑 API Key Setup:")
        print("You need an Anthropic API key for Claude models.")
        print("Get one at: https://console.anthropic.com/")

        api_key = input("Enter your Anthropic API Key: ").strip()
        if not api_key:
            api_key = os.getenv('ANTHROPIC_API_KEY')

        if not api_key:
            print("❌ Error: Anthropic API key is required for Claude models!")
            return

    print("✅ API key configured successfully!")

    # Batch configuration
    print("\n⚙️ Processing Configuration:")

    while True:
        try:
            batch_input = input("Batch size (5 recommended): ").strip()
            batch_size = int(batch_input) if batch_input else 5
            if batch_size > 0:
                break
            else:
                print("❌ Batch size must be positive!")
        except ValueError:
            print("❌ Please enter a valid number for batch size!")

    while True:
        try:
            tokens_input = input("Max tokens for CoT (512 recommended): ").strip()
            max_tokens = int(tokens_input) if tokens_input else 512
            if max_tokens > 0:
                break
            else:
                print("❌ Max tokens must be positive!")
        except ValueError:
            print("❌ Please enter a valid number for max tokens!")

    # Get limit from user
    limit_input = input("Number of samples to process (enter for all): ").strip()
    limit = int(limit_input) if limit_input else None


    # Output file
    output_dir = input("Output directory (./data): ").strip() or "./data"
    output_file = os.path.join(output_dir, "train_with_cot.json")

    # Preview the dataset
    print(f"\n👀 Previewing dataset: {train_file}")
    try:
        with open(train_file, 'r', encoding='utf-8') as f:
            dataset = json.load(f)

        print(f"📊 Dataset size: {len(dataset)} samples")
        if dataset:
            sample = dataset[0]
            print(f"📝 Sample keys: {list(sample.keys())}")
            print(f"📏 Sample instruction length: {len(sample.get('instruction', '').split())} words")
            print(f"📏 Sample output length: {len(sample.get('output', '').split())} words")
    except Exception as e:
        print(f"❌ Error reading dataset: {e}")
        return

    # Estimate costs (rough)
    total_samples_to_process = limit if limit is not None else len(dataset)
    estimated_tokens_per_sample = 300  # Rough estimate
    total_tokens = total_samples_to_process * estimated_tokens_per_sample

    if model_name == "gpt-3.5-turbo":
        cost_per_1k = 0.0015  # $0.0015 per 1K tokens
    elif model_name == "gpt-4":
        cost_per_1k = 0.03    # $0.03 per 1K tokens
    else:
        cost_per_1k = 0.01    # Rough estimate for other models

    estimated_cost = (total_tokens / 1000) * cost_per_1k

    print(f"\n💰 Estimated cost: ${estimated_cost:.2f}")
    print(f"⏱️ Estimated time: {(total_samples_to_process * 2) / 60:.1f} minutes")

    confirm = input("\nProceed with CoT generation? (y/n): ").strip().lower()
    if confirm != 'y':
        print("❌ Generation cancelled.")
        return

    # Create configuration
    config = CoTConfig(
        model_name=model_name,
        api_key=api_key,
        batch_size=batch_size,
        max_tokens=max_tokens,
        temperature=0.7,
        concurrent_requests=3,
        rate_limit_delay=1.0
    )

    # Initialize generator
    generator = CoTGenerator(config)

    try:
        print(f"\n🚀 Starting CoT generation...")
        await generator.generate_cot_for_dataset(train_file, output_file, limit=limit)

        print("\n🎉 CoT generation completed successfully!")
        print(f"📁 Output file: {output_file}")

        # Download in Colab
        if IN_COLAB:
            try:
                from google.colab import files
                files.download(output_file)

                # Also download stats file
                stats_file = output_file.replace('.json', '_cot_stats.json')
                if os.path.exists(stats_file):
                    files.download(stats_file)

                print("✅ Files downloaded successfully!")
            except Exception as e:
                print(f"⚠️ Could not download files: {e}")
                print(f"Files are available at: {output_file}")

    except Exception as e:
        print(f"❌ Error during CoT generation: {e}")
        import traceback
        traceback.print_exc()

In [25]:
# Main execution
async def main():
    await run_cot_generation()

# For Colab execution
if __name__ == "__main__":
    asyncio.run(main())

Chain-of-Thought Generation for our Medical Chatbot
Running in Google Colab environment.
Mounting Google Drive.
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive successful.
Found train.json at: /content/train.json

📋 Configuration:
Available models:
  1. gpt-3.5-turbo (recommended, cheaper)
  2. gpt-4 (more expensive, better quality)
  3. claude-3-sonnet-20240229
Choose model (1/2/3) or enter custom name: 1
✓ Selected model: gpt-3.5-turbo

🔑 API Key Setup:
You need an OpenAI API key for GPT models.
Get one at: https://platform.openai.com/api-keys
Enter your OpenAI API Key (starts with sk-): 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
✅ API key configured successfully!

⚙️ Processing Configuration:
Batch size (5 recommended): 5
Max tokens for CoT (512 recommended): 
Number of 

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

✅ Files downloaded successfully!
