In [17]:
from langgraph.graph import StateGraph, START, END
from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import PromptTemplate
import os
from dotenv import load_dotenv
import json
import re
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Union, Literal
from pydantic import BaseModel, Field, validator

load_dotenv()

def sanitize_ascii(s: str) -> str:
    # Remove any non-ASCII characters from the string
    return ''.join(c for c in s if ord(c) < 128)

api_key = os.environ["GROQ_API_KEY"]
model = "llama3-70b-8192"

sanitized_api_key = sanitize_ascii(api_key)
llm = ChatGroq(api_key=sanitized_api_key, model=model)

class TerraformFile(BaseModel):
    path: str
    content: str
    
class TerraformComponent(BaseModel):
    name: str = Field(..., description="The name of the component.")
    main_tf: str = Field(..., description="The main.tf file content.")
    output_tf: str = Field(..., description="The output.tf file content.")
    variables_tf: str = Field(..., description="The variables.tf file content.")
    
class EnvironmentList(BaseModel):
    environments: List[TerraformComponent] = []

class ModuleList(BaseModel):
    modules: List[TerraformComponent] = []


class UserInput(BaseModel):
    """User input for the Terraform code generation agent."""
    # Core infrastructure settings
    services: List[str] = Field(
        ..., 
        description="List of AWS services to deploy (e.g., ['ec2', 's3', 'rds', 'lambda'])."
    )
    region: str = Field(
        ..., 
        description="AWS region where services will be deployed (e.g., 'us-west-2')."
    )
    
    # Networking configuration
    vpc_cidr: str = Field(
        ..., 
        description="CIDR block for the VPC (e.g., '10.0.0.0/16')."
    )
    subnet_configuration: Dict[str, List[str]] = Field(
        default_factory=lambda: {"public": [], "private": [], "database": []},
        description="CIDR blocks for subnets by type (public, private, database)."
    )
    availability_zones: List[str] = Field(
        ...,
        description="List of availability zones to use (e.g., ['us-west-2a', 'us-west-2b'])."
    )
    
    # Compute settings
    compute_type: str = Field(
        ..., 
        description="Type of compute to use (e.g., 'ec2', 'ecs', 'lambda')."
    )
    compute_instance_type: Optional[str] = Field(
        None, 
        description="EC2/RDS instance type if applicable (e.g., 't3.micro')."
    )
    
    # Database settings
    database_type: Optional[str] = Field(
        None, 
        description="Type of database to use if needed (e.g., 'mysql', 'postgres', 'dynamodb')."
    )
    database_settings: Optional[Dict[str, Union[str, int, bool]]] = Field(
        None,
        description="Additional database settings like version, size, etc."
    )
    
    # Architecture choices
    is_multi_az: bool = Field(
        ..., 
        description="Whether to deploy across multiple availability zones for high availability."
    )
    is_serverless: bool = Field(
        ..., 
        description="Whether to use serverless architecture where applicable."
    )
    
    # Additional configuration
    enable_logging: bool = Field(
        True, 
        description="Whether to enable CloudWatch logging for services."
    )
    enable_monitoring: bool = Field(
        True, 
        description="Whether to enable CloudWatch monitoring for services."
    )
    load_balancer_type: Optional[Literal["ALB", "NLB", "CLB"]] = Field(
        None,
        description="Type of load balancer to deploy if needed."
    )
    
    # Security settings
    enable_waf: bool = Field(
        False, 
        description="Whether to enable AWS WAF for web applications."
    )
    security_groups: Dict[str, List[Dict[str, Union[str, int, List[str]]]]] = Field(
        default_factory=dict,
        description="Security group rules by group name."
    )
    
    # Tagging and metadata
    tags: Dict[str, str] = Field(
        default_factory=lambda: {
            "Environment": "dev",
            "ManagedBy": "Terraform",
            "Owner": "DevOps"
        },
        description="Resource tags."
    )
    
    # Free-form requirements
    requirements: str = Field(
        ..., 
        description="Additional requirements in natural language."
    )
    
    # Advanced configuration
    custom_parameters: Dict[str, Union[str, int, bool, List, Dict]] = Field(
        default_factory=dict,
        description="Additional custom parameters for advanced configurations."
    )
    
    # Validators
    @validator('vpc_cidr')
    def validate_cidr(cls, v):
        import ipaddress
        try:
            ipaddress.IPv4Network(v)
            return v
        except ValueError:
            raise ValueError(f"Invalid CIDR block format: {v}")
    
    @validator('region')
    def validate_region(cls, v):
        valid_regions = [
            'us-east-1', 'us-east-2', 'us-west-1', 'us-west-2',
            'eu-west-1', 'eu-west-2', 'eu-west-3', 'eu-central-1',
            'ap-northeast-1', 'ap-northeast-2', 'ap-southeast-1', 'ap-southeast-2',
            # Add more valid regions as needed
        ]
        if v not in valid_regions:
            raise ValueError(f"Invalid AWS region: {v}. Must be one of {valid_regions}")
        return v
    

class TerraformState(BaseModel):
    """State for our Terraform code generation agent."""
    modules: ModuleList = Field(default_factory=ModuleList)
    environments: EnvironmentList = Field(default_factory=EnvironmentList)
    user_input: Optional[UserInput] = None

# Prompt template
terraform_template = """
You are a senior AWS Solutions Architect and Terraform Expert responsible for creating enterprise-grade, 
production-ready infrastructure as code according to AWS best practices and the Well-Architected Framework.

Requirements: {requirements}

INFRASTRUCTURE SPECIFICATIONS:
- AWS Region: {region}
- VPC CIDR: {vpc_cidr}
- Subnet Configuration: {subnet_configuration}
- Availability Zones: {availability_zones}
- Services to deploy: {services}
- Compute type: {compute_type}
- Compute instance type: {compute_instance_type}
- Database type: {database_type}
- Database settings: {database_settings}
- Multi-AZ deployment: {is_multi_az}
- Serverless architecture: {is_serverless}
- Load balancer type: {load_balancer_type}
- Logging enabled: {enable_logging}
- Monitoring enabled: {enable_monitoring}
- WAF enabled: {enable_waf}
- Security groups: {security_groups}
- Resource tags: {tags}
- Additional parameters: {custom_parameters}

INSTRUCTIONS:
1. For each service (e.g., VPC, EC2, RDS, S3, DynamoDB, Lambda, etc.) create a dedicated module following Terraform best practices.
2. Each module should be self-contained, reusable, and include appropriate variables, outputs, and resources.
3. Create three different environment configurations (dev, stage, prod) that reference these modules.
4. Follow the principle of "infrastructure as code" where configuration is parameterized and DRY.
5. Ensure proper dependencies are established between modules (e.g., compute depends on network).

TERRAFORM ARCHITECTURE BEST PRACTICES:
- Use a modular approach where each AWS service or logical component is a separate module
- Follow Terraform recommended naming conventions and file structure
- Implement proper state management
- Use locals for computed values and code readability
- Use data sources for dynamic lookups
- Implement proper error handling with lifecycle blocks
- Ensure all resources have appropriate tags
- Use count or for_each for resource repetition

MODULE STRUCTURE:
For each AWS service or component requested in the requirements, create a dedicated module with:
- main.tf (containing resources)
- variables.tf (parameters with appropriate type constraints and validations)
- outputs.tf (useful outputs needed by other modules)

RESPONSE FORMAT INSTRUCTIONS:
Return ONLY a JSON object with this structure:

{{
  "environments": [
    {{
      "name": "dev",
      "main_tf": "# Terraform code for dev environment",
      "output_tf": "# Output variables for dev environment",
      "variables_tf": "# Input variables for dev environment"
    }},
    {{
      "name": "stage",
      "main_tf": "# Terraform code for stage environment",
      "output_tf": "# Output variables for stage environment",
      "variables_tf": "# Input variables for stage environment"
    }},
    {{
      "name": "prod",
      "main_tf": "# Terraform code for prod environment",
      "output_tf": "# Output variables for prod environment",
      "variables_tf": "# Input variables for prod environment"
    }}
  ],
  "modules": [
    /* DYNAMIC: Create a module entry for EACH AWS service or component needed */
    /* Example format for each module: */
    {{
      "name": "service-name-module",
      "main_tf": "# Terraform resources",
      "output_tf": "# Output variables",
      "variables_tf": "# Input variables"
    }}
    /* ADD ALL REQUIRED MODULES based on the services in the requirements */
  ]
}}

IMPORTANT: 
1. The modules you create should be based on the services requested in the requirements and services list.
2. Example modules: vpc, security-groups, iam, ec2, asg, alb, rds, s3, dynamodb, lambda, api-gateway, etc.
3. Create all necessary modules, not just the ones mentioned as examples.
4. Your response must be syntactically valid JSON.
"""

def process_request(state: TerraformState):
    """Process the user's input and update the state."""
    # state.user_input = user_input
    print("Processing user input...")
    print(state.user_input)
    return state


def extract_json_from_text(text):
    """Extract JSON from text using multiple methods."""
    # Try to find JSON in code blocks
    json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', text)
    if json_match:
        return json_match.group(1).strip()
    
    # Try to find JSON between curly braces
    json_match = re.search(r'(\{[\s\S]*\})', text)
    if json_match:
        return json_match.group(1).strip()
    
    # Return the whole text as a last resort
    return text.strip()

def generate_terraform_code(state: TerraformState):
    """Generate Terraform code based on the structured user input."""
    if not state.user_input:
        raise ValueError("User input is required to generate Terraform code")
    
    # Create variables dictionary for prompt template
    variables = {
        "requirements": state.user_input.requirements,
        "region": state.user_input.region,
        "vpc_cidr": state.user_input.vpc_cidr,
        "subnet_configuration": json.dumps(state.user_input.subnet_configuration),
        "availability_zones": ", ".join(state.user_input.availability_zones),
        "services": ", ".join(state.user_input.services),
        "compute_type": state.user_input.compute_type,
        "compute_instance_type": state.user_input.compute_instance_type or "Not specified",
        "database_type": state.user_input.database_type or "None",
        "database_settings": json.dumps(state.user_input.database_settings) if state.user_input.database_settings else "None",
        "is_multi_az": str(state.user_input.is_multi_az),
        "is_serverless": str(state.user_input.is_serverless),
        "load_balancer_type": state.user_input.load_balancer_type or "None",
        "enable_logging": str(state.user_input.enable_logging),
        "enable_monitoring": str(state.user_input.enable_monitoring),
        "enable_waf": str(state.user_input.enable_waf),
        "security_groups": json.dumps(state.user_input.security_groups),
        "tags": json.dumps(state.user_input.tags),
        "custom_parameters": json.dumps(state.user_input.custom_parameters)
    }
    
    prompt = PromptTemplate.from_template(terraform_template)
    
    chain = prompt | llm
    response = chain.invoke(variables)
    
    try:
        # Extract JSON from the response
        json_str = extract_json_from_text(response.content)
        
        # Parse the JSON response
        data = json.loads(json_str)
        
        # Update environments
        for env_data in data.get("environments", []):
            component = TerraformComponent(
                name=env_data.get("name", ""),
                main_tf=env_data.get("main_tf", ""),
                output_tf=env_data.get("output_tf", ""),
                variables_tf=env_data.get("variables_tf", "")
            )
            state.environments.environments.append(component)
        
        # Update modules
        for module_data in data.get("modules", []):
            component = TerraformComponent(
                name=module_data.get("name", ""),
                main_tf=module_data.get("main_tf", ""),
                output_tf=module_data.get("output_tf", ""),
                variables_tf=module_data.get("variables_tf", "")
            )
            state.modules.modules.append(component)
            
    except Exception as e:
        print(f"Error parsing LLM response: {e}")
        print(f"Response content: {response.content}")
    
    return state

# Function to save generated Terraform files
def save_terraform_files(state: TerraformState):
    """Save the generated Terraform files to disk."""
    
    base_dir = "output/src"
     
    os.makedirs(base_dir, exist_ok=True)
    
    for env in state.environments.environments:
        
        env_dir = os.path.join(base_dir, "environments", env.name)
        os.makedirs(env_dir, exist_ok=True)
        
        with open(os.path.join(env_dir, "main.tf"), "w") as f:
            f.write(env.main_tf)
        
        with open(os.path.join(env_dir, "output.tf"), "w") as f:
            f.write(env.output_tf)
        
        with open(os.path.join(env_dir, "variables.tf"), "w") as f:
            f.write(env.variables_tf)
    
    for module in state.modules.modules:
       
        module_dir = os.path.join(base_dir, "modules", module.name)
        os.makedirs(module_dir, exist_ok=True)
        
        with open(os.path.join(module_dir, "main.tf"), "w") as f:
            f.write(module.main_tf)
        
        with open(os.path.join(module_dir, "output.tf"), "w") as f:
            f.write(module.output_tf)
        
        with open(os.path.join(module_dir, "variables.tf"), "w") as f:
            f.write(module.variables_tf)
    
    print(f"Terraform files have been saved to {base_dir}")
    
    return state

# Define the graph
graph = StateGraph(TerraformState)

# Add nodes
graph.add_node("process_request", process_request)
graph.add_node("generate_terraform_code", generate_terraform_code)
graph.add_node("save_terraform_files", save_terraform_files)

# Add edges
graph.add_edge(START, "process_request")
graph.add_edge("process_request", "generate_terraform_code")
graph.add_edge("generate_terraform_code", "save_terraform_files")
graph.add_edge("save_terraform_files", END)

# Compile the graph
terraform_app = graph.compile()


/var/folders/07/8j6bcwpn5_qfb0_tjmt2qg100000gn/T/ipykernel_13490/2367243744.py:145: PydanticDeprecatedSince20: Pydantic V1 style `@validator` validators are deprecated. You should migrate to Pydantic V2 style `@field_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  @validator('vpc_cidr')
/var/folders/07/8j6bcwpn5_qfb0_tjmt2qg100000gn/T/ipykernel_13490/2367243744.py:154: PydanticDeprecatedSince20: Pydantic V1 style `@validator` validators are deprecated. You should migrate to Pydantic V2 style `@field_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  @validator('region')


In [18]:
def generate_terraform(
    services: List[str],
    region: str,
    vpc_cidr: str,
    availability_zones: List[str],
    compute_type: str,
    compute_instance_type: Optional[str] = None,
    database_type: Optional[str] = None,
    database_settings: Optional[Dict[str, Union[str, int, bool]]] = None,
    subnet_configuration: Optional[Dict[str, List[str]]] = None,
    is_multi_az: bool = True,
    is_serverless: bool = False,
    enable_logging: bool = True,
    enable_monitoring: bool = True,
    enable_waf: bool = False,
    load_balancer_type: Optional[str] = None,
    security_groups: Optional[Dict[str, List[Dict[str, Union[str, int, List[str]]]]]] = None,
    tags: Optional[Dict[str, str]] = None,
    custom_parameters: Optional[Dict] = None,
    requirements: str = ""
):
    """
    Generate Terraform code based on user input parameters.
    
    Args:
        services: List of AWS services to deploy
        region: AWS region
        vpc_cidr: CIDR block for VPC
        availability_zones: List of availability zones
        compute_type: Type of compute (ec2, ecs, lambda)
        compute_instance_type: EC2/ECS instance type if applicable
        database_type: Type of database if needed
        database_settings: Additional database configuration
        subnet_configuration: CIDR blocks for different subnet types
        is_multi_az: Whether to use multiple AZs
        is_serverless: Whether to use serverless architecture
        enable_logging: Whether to enable logging
        enable_monitoring: Whether to enable monitoring
        enable_waf: Whether to enable AWS WAF
        load_balancer_type: Type of load balancer
        security_groups: Security group rules configuration
        tags: Resource tags
        custom_parameters: Additional configuration parameters
        requirements: Free-form requirements
        
    Returns:
        Generated Terraform code as a TerraformState object
    """
    if subnet_configuration is None:
        subnet_configuration = {
            "public": [],
            "private": [],
            "database": []
        }
    
    if security_groups is None:
        security_groups = {}
    
    if tags is None:
        tags = {"Environment": "dev", "ManagedBy": "Terraform", "Owner": "DevOps"}
    
    if custom_parameters is None:
        custom_parameters = {}
    
    user_input = UserInput(
        services=services,
        region=region,
        vpc_cidr=vpc_cidr,
        subnet_configuration=subnet_configuration,
        availability_zones=availability_zones,
        compute_type=compute_type,
        compute_instance_type=compute_instance_type,
        database_type=database_type,
        database_settings=database_settings,
        is_multi_az=is_multi_az,
        is_serverless=is_serverless,
        enable_logging=enable_logging,
        enable_monitoring=enable_monitoring,
        enable_waf=enable_waf,
        load_balancer_type=load_balancer_type,
        security_groups=security_groups,
        tags=tags,
        custom_parameters=custom_parameters,
        requirements=requirements
    )
    
    result = terraform_app.invoke({"user_input": user_input})
    return result

In [20]:
# Example usage with comprehensive structured input
result = generate_terraform(
    services=["ec2", "rds", "alb"],
    region="us-west-2",
    vpc_cidr="10.0.0.0/16",
    subnet_configuration={
        "public": ["10.0.1.0/24", "10.0.2.0/24"],
        "private": ["10.0.3.0/24", "10.0.4.0/24"],
        "database": ["10.0.5.0/24", "10.0.6.0/24"]
    },
    availability_zones=["us-west-2a", "us-west-2b"],
    compute_type="ec2",
    compute_instance_type="t3.medium",
    database_type="postgres",
    database_settings={
        "version": "13.4",
        "instance_type": "db.t3.medium",
        "storage_gb": 20,
        "multi_az": True
    },
    is_multi_az=True,
    is_serverless=False,
    enable_logging=True,
    enable_monitoring=True,
    enable_waf=True,
    load_balancer_type="ALB",
    security_groups={
        "web_sg": [
            {
                "type": "ingress",
                "from_port": 443,
                "to_port": 443,
                "protocol": "tcp",
                "cidr_blocks": ["0.0.0.0/0"]
            },
            {
                "type": "ingress",
                "from_port": 80,
                "to_port": 80,
                "protocol": "tcp",
                "cidr_blocks": ["0.0.0.0/0"]
            }
        ],
        "app_sg": [
            {
                "type": "ingress",
                "from_port": 8080,
                "to_port": 8080,
                "protocol": "tcp",
                "source_security_group": "web_sg"
            }
        ]
    },
    tags={
        "Environment": "dev",
        "Project": "WebApp",
        "Owner": "DevOps",
        "CostCenter": "IT-123"
    },
    custom_parameters={
        "enable_auto_scaling": True,
        "min_capacity": 2,
        "max_capacity": 10,
        "desired_capacity": 2,
        "backup_retention_period": 7
    },
    requirements="Create a highly available web application with a PostgreSQL database. Include proper security groups and implement auto-scaling for the EC2 instances."
)



Processing user input...
services=['ec2', 'rds', 'alb'] region='us-west-2' vpc_cidr='10.0.0.0/16' subnet_configuration={'public': ['10.0.1.0/24', '10.0.2.0/24'], 'private': ['10.0.3.0/24', '10.0.4.0/24'], 'database': ['10.0.5.0/24', '10.0.6.0/24']} availability_zones=['us-west-2a', 'us-west-2b'] compute_type='ec2' compute_instance_type='t3.medium' database_type='postgres' database_settings={'version': '13.4', 'instance_type': 'db.t3.medium', 'storage_gb': 20, 'multi_az': True} is_multi_az=True is_serverless=False enable_logging=True enable_monitoring=True load_balancer_type='ALB' enable_waf=True security_groups={'web_sg': [{'type': 'ingress', 'from_port': 443, 'to_port': 443, 'protocol': 'tcp', 'cidr_blocks': ['0.0.0.0/0']}, {'type': 'ingress', 'from_port': 80, 'to_port': 80, 'protocol': 'tcp', 'cidr_blocks': ['0.0.0.0/0']}], 'app_sg': [{'type': 'ingress', 'from_port': 8080, 'to_port': 8080, 'protocol': 'tcp', 'source_security_group': 'web_sg'}]} tags={'Environment': 'dev', 'Project': 