In [1]:
from google.colab import drive
import os
drive.mount('/content/gdrive')
os.chdir("/content/gdrive/My Drive/Colab Notebooks/Linking")

Mounted at /content/gdrive


In [10]:
import json
import time
import requests
import google.generativeai as genai
from typing import Dict, List, Optional
import logging
from ratelimit import limits, sleep_and_retry
import os

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Constants
MITRE_ENTERPRISE_URL = "https://raw.githubusercontent.com/mitre/cti/master/enterprise-attack/enterprise-attack.json"
CALLS_PER_MINUTE = 60  # Adjust based on Gemini API limits

class TechniqueMapper:
    def __init__(self, api_key: str):
        """Initialize the mapper with API credentials and load MITRE data."""
        self.api_key = AIzaSyCZQwJzz8saJ-cFuHtYetq3oWrZHTMKV8o
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel('gemini-pro')
        self.technique_cache = {}
        self._load_mitre_data()

    def _load_mitre_data(self) -> None:
        """Load MITRE ATT&CK data from official source."""
        try:
            response = requests.get(MITRE_ENTERPRISE_URL)
            response.raise_for_status()
            data = response.json()

            # Create mapping of technique IDs to names
            for obj in data['objects']:
                if obj['type'] == 'attack-pattern':
                    technique_id = obj.get('external_references', [{}])[0].get('external_id')
                    if technique_id:
                        self.technique_cache[technique_id] = obj['name']

            logger.info(f"Loaded {len(self.technique_cache)} techniques from MITRE ATT&CK")
        except Exception as e:
            logger.error(f"Failed to load MITRE data: {str(e)}")
            raise

    @sleep_and_retry
    @limits(calls=CALLS_PER_MINUTE, period=60)
    def get_technique_name(self, technique_id: str) -> str:
        """Get technique name from cache or Gemini API with rate limiting."""
        # First check cache
        if technique_id in self.technique_cache:
            return self.technique_cache[technique_id]

        # If not in cache, try Gemini API as backup
        try:
            prompt = f"""You are a cybersecurity expert. What is the name of the MITRE ATT&CK technique for ID {technique_id}?
            Provide only the technique name, nothing else."""

            response = self.model.generate_content(prompt)
            technique_name = response.text.strip()

            # Cache the result
            self.technique_cache[technique_id] = technique_name
            return technique_name

        except Exception as e:
            logger.error(f"Error getting technique name for {technique_id}: {str(e)}")
            return f"Unknown Technique ({technique_id})"

    def map_cve_to_techniques(self, cve_data: List[Dict]) -> List[Dict]:
        """Map CVE predictions to technique names with error handling."""
        mapped_results = []
        total = len(cve_data)

        for i, cve_entry in enumerate(cve_data, 1):
            try:
                cve_id = cve_entry.get('CVE_ID', 'Unknown CVE')
                predictions = cve_entry.get('Predictions', {})

                mapped_techniques = {}
                for technique_id, score in predictions.items():
                    technique_name = self.get_technique_name(technique_id)
                    mapped_techniques[technique_name] = float(score)  # Ensure score is float

                mapped_results.append({
                    'CVE_ID': cve_id,
                    'Techniques': mapped_techniques
                })

                if i % 10 == 0:  # Log progress every 10 entries
                    logger.info(f"Processed {i}/{total} CVEs")

            except Exception as e:
                logger.error(f"Error processing CVE {cve_id}: {str(e)}")
                continue

        return mapped_results

def main():
    """Main function to run the mapping process."""
    try:
        # Load API key from environment or config
        api_key = os.getenv("GOOGLE_API_KEY")
        if not api_key:
            raise ValueError("Google API key not found in environment variables")

        # Initialize mapper
        mapper = TechniqueMapper(api_key)

        # Load CVE data
        with open('output.json', 'r') as f:
            cve_data = json.load(f)

        # Process the data
        logger.info("Starting CVE to technique mapping...")
        mapped_data = mapper.map_cve_to_techniques(cve_data)

        # Save results
        output_file = 'mapped_cve_techniques.json'
        with open(output_file, 'w') as f:
            json.dump(mapped_data, f, indent=4)

        logger.info(f"Mapping complete. Results saved to '{output_file}'")

    except Exception as e:
        logger.error(f"Fatal error: {str(e)}")
        raise

if __name__ == "__main__":
    main()


ModuleNotFoundError: No module named 'ratelimit'