In [44]:
import numpy as np
import aisuite as ai
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
# from classification.classification_models.vit import ClassificationModel, ModelConfig
# from classification.classification_models.vit import ImageWoofDataset
# from classification.classification_metrics.metrics import ClassificationMetrics
from pathlib import Path
import albumentations as A
from tqdm import tqdm
from albumentations.pytorch import ToTensorV2
import json
import re
from typing import Optional, Tuple

In [54]:
class CodeGenerationAgent:
    def __init__(self, model_name: str, file_type: str = "python"):
        self.client = ai.Client()
        self.model_name = model_name
        self.installed_packages = [line.strip() for line in open("../requirements.txt", "r").readlines()]
        self.system_prompt = f"""
        You are an expert computer vision engineer.
        You are given a user prompt and a file type: {file_type}.
        The file packages you have access to are as follows: 
        {self.installed_packages}
        You need to generate a code based on the instruction, and the packages you have access to.
        The code should be enclosed in <improvement_code> and </improvement_code> tags
        """

    def generate_code(self, improvement_prompt: str, project_context: str) -> str: 
        """
        Generate code on the improvement prompt.
        """
        prompt = f"""
        You are tasked to write a code to implement the following improvement: {improvement_prompt}.
        This code is expected to be used in PyTorch / Lightning collate_fn to process the batch.
        The code should be enclosed in <improvement_code> and </improvement_code> tags.
        
        The project context is as follows:
        {project_context}

        Please provide:
        1. Complete, runnable code with proper imports.
        2. Clear documentation and type hints.
        3. Integration instructions.

        <improvement_code> YOUR CODE HERE </improvement_code>
        <integration_instructions> YOUR INTEGRATION INSTRUCTIONS HERE </integration_instructions>
        """
        file_name = improvement_prompt.replace(" ", "_").lower()
        path = Path(f"../classification/improvements/")
        path.mkdir(parents=True, exist_ok=True)
        messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": prompt}
        ]
        results = self.client.chat.completions.create(
            model=self.model_name,
            messages=messages,
            temperature=0.0,
            max_tokens=4000,
        )

        success, message = self.save_generated_code(results.choices[0].message.content, file_name, path)

    def parse_code(self, content: str, tag: str) -> Optional[str]:
        """
        Parse code from XML-like tags with better error handling
        
        Args:
            content: Full LLM response
            tag: Tag name (e.g., 'improvement_code', 'integration_instructions')
        
        Returns:
            Extracted code or None if not found
        """
        # Use DOTALL flag to match across newlines
        pattern = f"<{tag}>(.*?)</{tag}>"
        match = re.search(pattern, content, re.DOTALL)
        
        if match:
            return match.group(1).strip()
        
        # Fallback: try to find markdown code blocks
        code_block_pattern = r"\n(.*?)\n```"
        match = re.search(code_block_pattern, content, re.DOTALL)
        
        if match:
            return match.group(1).strip()
        
        return None

    def save_generated_code(
        self,
        response_content: str, 
        file_name: str, 
        path: str = ".",
        save_instructions: bool = True
    ) -> Tuple[bool, str]:
        """
        Save generated code and optionally integration instructions
        
        Returns:
            (success, message)
        """
        # Parse code
        improvement_code = self.parse_code(response_content, "improvement_code")
        
        if not improvement_code:
            return False, "‚ùå Failed to parse improvement code from response"
        
        # Create directory if it doesn't exist
        Path(path).mkdir(parents=True, exist_ok=True)
        
        # Save the main code file
        code_file = Path(path) / f"{file_name}.py"
        try:
            with open(code_file, "w") as f:
                f.write(improvement_code)
            print(f"‚úÖ Code saved to: {code_file}")
        except Exception as e:
            return False, f"‚ùå Failed to save code: {e}"
        
        # Save integration instructions if present
        if save_instructions:
            integration_instructions = self.parse_code(response_content, "integration_instructions")
            
            if integration_instructions:
                instructions_file = Path(path) / f"{file_name}_INSTRUCTIONS.md"
                try:
                    with open(instructions_file, "w") as f:
                        f.write(integration_instructions)
                    print(f"üìÑ Instructions saved to: {instructions_file}")
                except Exception as e:
                    print(f"‚ö†Ô∏è Warning: Failed to save instructions: {e}")
        
        # Also save the raw response for debugging
        raw_file = Path(path) / f"{file_name}.py"
        try:
            with open(raw_file, "w") as f:
                f.write(response_content)
            print(f"üíæ Raw response saved to: {raw_file}")
        except Exception as e:
            print(f"‚ö†Ô∏è Warning: Failed to save raw response: {e}")
        
        return True, f"Successfully saved {file_name}.py"










In [55]:
# Test
import re
improvement_string = f"""
<improvement>Mixup</improvement>
<improvement>CutMix</improvement>
<improvement>RandAugment</improvement>
<improvement>AutoAugment</improvement>
<improvement>Canny edge detection</improvement>
<improvement>Sobel filter concatenation</improvement>
<improvement>Gabor filter concatenation</improvement>
<improvement>Histogram equalization</improvement>
<improvement>Test time augmentation</improvement>
<improvement>ImageNet pretraining</improvement>
<improvement>Local binary patterns</improvement>
<improvement>Laplacian filter concatenation</improvement>
<improvement>Color jitter</improvement>
<improvement>Erasing</improvement>
<improvement>HOG feature concatenation</improvement>

"""
improvement_list = re.findall(r'<improvement>(.*?)</improvement>', improvement_string)
print(improvement_list)


['Mixup', 'CutMix', 'RandAugment', 'AutoAugment', 'Canny edge detection', 'Sobel filter concatenation', 'Gabor filter concatenation', 'Histogram equalization', 'Test time augmentation', 'ImageNet pretraining', 'Local binary patterns', 'Laplacian filter concatenation', 'Color jitter', 'Erasing', 'HOG feature concatenation']


# Unit test

In [56]:
coding_agent = CodeGenerationAgent(model_name="ollama:gemini-3-flash-preview", file_type="python")
results = coding_agent.generate_code(improvement_prompt="Mixup", project_context="")

‚úÖ Code saved to: ../classification/improvements/mixup.py
üìÑ Instructions saved to: ../classification/improvements/mixup_INSTRUCTIONS.md
üíæ Raw response saved to: ../classification/improvements/mixup.py


In [24]:
path = Path(f"../code_generation_agent/classification/improvements/mixup.py")
path.mkdir(parents=True, exist_ok=True)