In [27]:
from langgraph.graph import StateGraph,START, END
from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate

import os
from dotenv import load_dotenv
load_dotenv()

True

In [28]:
def sanitize_ascii(s: str) -> str:
    # Remove any non-ASCII characters from the string
    return ''.join(c for c in s if ord(c) < 128)

In [29]:
# api_key = os.environ["GROQ_API_KEY"]
api_key = os.environ["GEMINI_API_KEY"]
# api_key = os.environ["MISTRAL_API_KEY"]
# api_key = os.environ["OPENAI_API_KEY"]

# model = "llama3-70b-8192"
model = "gemini-2.0-flash"
# model = "codestral-latest"
# model = "gpt-4.1-nano"

sanitized_api_key = sanitize_ascii(api_key)
# llm = ChatGroq(api_key=sanitized_api_key,model=model)
llm = ChatGoogleGenerativeAI(api_key=api_key,model=model)
# llm = ChatMistralAI(api_key=api_key,model=model)
# llm = ChatOpenAI(api_key=api_key,model=model)



In [30]:
result=llm.invoke("Hello")
result

AIMessage(content='Hello! How can I help you today?', additional_kwargs={}, response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'model_name': 'gemini-2.0-flash', 'safety_ratings': []}, id='run-3143183c-dd59-4644-a0e0-91a9192d1b71-0', usage_metadata={'input_tokens': 1, 'output_tokens': 10, 'total_tokens': 11, 'input_token_details': {'cache_read': 0}})

In [31]:
from pydantic import BaseModel, Field
from typing import Any, Optional, List, Dict

class TerraformFile(BaseModel):
    path: str
    content: str
    
class TerraformComponent(BaseModel):
    
    """Represents a component in Terraform.
    
    Attributes:
        name: The name of the component.
        main_tf: The main.tf file content.
        output_tf: The output.tf file content.
        variables_tf: The variables.tf file content.
        
    """
    
    name: str = Field(..., description="The name of the component.")
    main_tf: str = Field(..., description="The main.tf file content.")
    output_tf: str = Field(..., description="The output.tf file content.")
    variables_tf: str = Field(..., description="The variables.tf file content.")
    
class EnvironmentList(BaseModel):
    environments: List[TerraformComponent] = Field(default_factory=list)

class ModuleList(BaseModel):
    modules: List[TerraformComponent] = Field(default_factory=list)

class TerraformState(BaseModel):
    """State for our Terraform code generation agent."""
    modules: ModuleList = Field(default_factory=ModuleList)
    environments: EnvironmentList = Field(default_factory=EnvironmentList)
    user_requirements: str


In [32]:
# Define the template
terraform_template = """
SYSTEM:
You are Terraform expert, an AI agent that generates Terraform code for AWS Cloud infrastructure based on the user's requirements:

Make sure to separte the environments in below formats for dev, stage and prod:
---
**Expected Output Format (for each environment):**

name: [Name of the environment, either "dev", "stage", or "prod"],
main.tf: [Main Terraform code for the environment],
output.tf: [Output Terraform code for the environment],
variables.tf: [Variables Terraform code for the environment]


Make sure to separte the modules in below formats:
---
**Expected Output Format (for each module):**

name: [Name of the module like vpc, security-group, etc],
main.tf: [Main Terraform code for the module],
output.tf: [Output Terraform code for the module],
variables.tf: [Variables Terraform code for the module]

Constraints:
1. Output must be valid and complete.
2. Provide actual code only.
3. Do not include any extra commentary or markdown formatting.

User requirements: {requirements}
"""

In [33]:
terraform_template1 = """
SYSTEM:
You are Terraform expert, an AI agent that generates Terraform code for AWS Cloud infrastructure. Based on the user's requirements provided below, generate a JSON object that matches the following schema exactly. Do not include any extra text, markdown formatting, or function call wrappers. The JSON object must have exactly these keys:

{{
  "environments": [
      {{
          "name": string,           // Must be either "dev", "stage", or "prod"
          "main.tf": string,        // Terraform code for main.tf
          "output.tf": string,      // Terraform code for output.tf
          "variables.tf": string    // Terraform code for variables.tf
      }}
  ],
  "modules": [
      {{
          "name": string,           // Example: "vpc", "security-group", etc.
          "main.tf": string,        // Terraform code for main.tf
          "output.tf": string,      // Terraform code for output.tf
          "variables.tf": string    // Terraform code for variables.tf
      }}
  ],
  "user_requirements": string    // This should echo the input user requirements
}}

Constraints:
1. Output must be valid JSON that follows the schema above.
2. Provide actual Terraform code only.
3. Do not include any extra commentary or markdown formatting.

User requirements: {requirements}
"""

In [34]:
import json

def process_request(state: TerraformState):
    """Process the user's request and update the state."""
    return state

def generate_terraform_code(state: TerraformState):
    """Generate Terraform code based on requirements in state."""

    prompt = PromptTemplate.from_template(terraform_template)
    llm_with_structured_output = llm.with_structured_output(TerraformState)
    
    chain = prompt | llm_with_structured_output
    response = chain.invoke({"requirements": state.user_requirements})

    state.environments = response.environments
    state.modules = response.modules
    return state

# Define the graph
graph = StateGraph(TerraformState)

# Add nodes
graph.add_node("process_request", process_request)
graph.add_node("generate_terraform_code", generate_terraform_code)


# Add edges
graph.add_edge(START, "process_request")
graph.add_edge("process_request", "generate_terraform_code")
graph.add_edge("generate_terraform_code", END)

# Compile the graph
terraform_app = graph.compile()

In [35]:
result = terraform_app.invoke({
    "user_requirements": "Create a VPC with public and private subnets in AWS with appropriate security groups"
})
print(result)

# Save the files to disk
# save_terraform_files(result, "output")

{'modules': ModuleList(modules=[]), 'environments': EnvironmentList(environments=[TerraformComponent(name='dev', main_tf='resource "aws_vpc" "main" {\n  cidr_block = var.vpc_cidr\n  tags = {\n    Name = "main-vpc"\n  }\n}\n\nresource "aws_subnet" "public" {\n  count = length(var.public_subnet_cidrs)\n  vpc_id     = aws_vpc.main.id\n  cidr_block = var.public_subnet_cidrs[count.index]\n  map_public_ip_on_launch = true\n  availability_zone = data.aws_availability_zones.available.names[count.index]\n\n  tags = {\n    Name = "public-subnet-${count.index + 1}"\n  }\n}\n\nresource "aws_subnet" "private" {\n  count = length(var.private_subnet_cidrs)\n  vpc_id     = aws_vpc.main.id\n  cidr_block = var.private_subnet_cidrs[count.index]\n  availability_zone = data.aws_availability_zones.available.names[count.index]\n\n  tags = {\n    Name = "private-subnet-${count.index + 1}"\n  }\n}\n\nresource "aws_internet_gateway" "gw" {\n  vpc_id = aws_vpc.main.id\n\n  tags = {\n    Name = "main-igw"\n  }\n\

In [36]:
from langchain_core.runnables.graph import MermaidDrawMethod

img_data = terraform_app.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API
            )

# Save the image to a file
graph_path = "workflow_graph.png"
with open(graph_path, "wb") as f:
    f.write(img_data) 