# CrossProbe Code Transfer

This notebook converts PyTorch/TensorFlow code based on the knowledge constructed in the *alignment* process.

**Workflow:**
1. Parse source code files
2. Extract API calls with context
3. Find matched APIs from documentation DB
4. Generate target framework code with GPT-4o
5. Save translated code

In [None]:
# Setup and Imports
%pip install openai

import os
import openai
import pandas as pd

# Load previous documentation DB
DOC_DB = pd.read_csv("api_documentation_db.csv")

# Set OpenAI API key from environment variables, or use a local Ollama server
if api_key := os.getenv("OPENAI_API_KEY"):
    client = openai.OpenAI(api_key=api_key)
    model = "gpt-4o"
else:
    client = openai.OpenAI(
        base_url="http://localhost:11434/v1",
        api_key="ollama",
    )
    model = "gemma3:27b"

## 2. Code Analysis Utilities

In [2]:
import ast
from typing import List


class APIExtractor(ast.NodeVisitor):
    """AST visitor for extracting framework API calls"""

    def __init__(self, framework: str):
        self.framework = framework
        self.api_calls = []

    def visit_Call(self, node):
        if isinstance(node.func, ast.Attribute):
            api_path = self._get_full_path(node.func)
            if api_path.startswith(f"{self.framework}."):
                self.api_calls.append(api_path)
        self.generic_visit(node)

    def _get_full_path(self, node):
        if isinstance(node.value, ast.Name):
            return f"{node.value.id}.{node.attr}"
        elif isinstance(node.value, ast.Attribute):
            return f"{self._get_full_path(node.value)}.{node.attr}"
        return node.attr


def extract_apis(code: str, framework: str) -> List[str]:
    """Extract framework-specific API calls from code"""
    tree = ast.parse(code)
    extractor = APIExtractor(framework)
    extractor.visit(tree)
    return list(set(extractor.api_calls))

## 3. Documentation Context Builder

In [3]:
def build_context(source_apis: List[str], target_framework: str) -> str:
    """Build prompt context from matched API documentation"""
    context = []

    for api in source_apis:
        # Get source API docs
        source_doc = DOC_DB.loc[api].to_dict() if api in DOC_DB.index else ""

        # Find best matching target API
        target_api = find_top_match(api, target_framework)
        target_doc = DOC_DB.loc[target_api].to_dict() if target_api else ""

        context.append(
            f"Source API ({'PyTorch' if 'torch' in api else 'TensorFlow'}): {api}\n"
            f"Documentation: {source_doc.get('processed_text', '')[:500]}\n\n"
            f"Target API ({target_framework}): {target_api}\n"
            f"Documentation: {target_doc.get('processed_text', '')[:500]}\n"
            "----------------------------------------"
        )

    return "\n".join(context)


def find_top_match(api: str, target_framework: str) -> str:
    """Find best matching API from previous similarity analysis"""
    if "torch" in api:
        return (
            DOC_DB[(DOC_DB.framework == target_framework) & (DOC_DB.similar_api == api)]
            .iloc[0]
            .name
        )
    else:
        return (
            DOC_DB[(DOC_DB.framework == target_framework) & (DOC_DB.similar_api == api)]
            .iloc[0]
            .name
        )

## 4. GPT-4o Translation Engine
You can replace the endpoint / model with other OpenAI-compatible API (e.g. ollama).

In [7]:
def translate_code(code: str, source_framework: str, target_framework: str) -> str:
    """Translate code between frameworks using GPT-4o"""
    apis = extract_apis(code, source_framework)
    context = build_context(apis, target_framework)

    response = client.chat.completions.create(
        model=model,
        messages=[
            {
                "role": "system",
                "content": f"""
                You are an expert AI code translator specializing in {source_framework} to {target_framework} 
                conversions. Use the provided API documentation context to make accurate translations.
                Maintain original functionality and code structure.
            """,
            },
            {
                "role": "user",
                "content": f"""
                Documentation Context:
                {context}

                Source Code to Translate:
                ```python
                {code}
                ```

                Requirements:
                1. Output only valid {target_framework} code
                2. Preserve comments and code structure
                3. Add conversion comments where non-trivial
                4. Include necessary imports
            """,
            },
        ],
        temperature=0.2,
        max_tokens=2000,
    )

    content = response.choices[0].message.content

    if content.startswith("```python") and content.endswith("```"):
        content = content[10:-3].strip()

    return content

## 5. Batch Translation Pipeline

In [8]:
from tqdm import tqdm


def process_directory(source_dir: str, target_dir: str, source_framework: str):
    """Batch process directory for code translation"""
    os.makedirs(target_dir, exist_ok=True)

    for root, _, files in os.walk(source_dir):
        for file in tqdm(files, desc="Transferring from " + source_framework):
            if file.endswith(".py"):
                source_path = os.path.join(root, file)
                rel_path = os.path.relpath(source_path, source_dir)
                target_path = os.path.join(target_dir, rel_path)

                with open(source_path, "r") as f:
                    code = f.read()

                translated = translate_code(
                    code,
                    source_framework,
                    "TensorFlow" if "torch" in source_framework else "PyTorch",
                )

                os.makedirs(os.path.dirname(target_path), exist_ok=True)
                with open(target_path, "w") as f:
                    f.write(translated)


# Example usage
process_directory(
    source_dir="pytorch",
    target_dir="tensorflow-test",
    source_framework="PyTorch",
)

process_directory(
    source_dir="tensorflow",
    target_dir="pytorch-test",
    source_framework="TensorFlow",
)

Transferring from PyTorch: 100%|██████████| 18/18 [03:18<00:00, 11.00s/it]
Transferring from TensorFlow: 100%|██████████| 8/8 [02:12<00:00, 16.59s/it]
