In [1]:
import os
import csv
import asyncio
import aiohttp
import backoff
import pandas as pd
import json
from tqdm import tqdm
from abc import ABC, abstractmethod
from typing import Dict, Any, List

In [2]:
class GeocodingService(ABC):
    @abstractmethod
    async def geocode_batch(self, session: aiohttp.ClientSession, locations: List[Dict[str, float]]) -> List[Dict[str, Any]]:
        pass

In [3]:
class AzureMapsGeocoding(GeocodingService):
    def __init__(self, api_key: str):
        self.api_key = api_key
        self.endpoint = 'https://atlas.microsoft.com/search/address/reverse/batch/json'

    async def geocode_batch(self, session: aiohttp.ClientSession, locations: List[Dict[str, float]]) -> List[Dict[str, Any]]:
        headers = {
            'Content-Type': 'application/json',
            'subscription-key': self.api_key
        }
        
        body = {
            "batchItems": [{"query": f"{loc['lat']},{loc['lon']}"} for loc in locations]
        }

        @backoff.on_exception(backoff.expo, aiohttp.ClientError, max_tries=5)
        async def make_request():
            async with session.post(self.endpoint, headers=headers, json=body) as response:
                response.raise_for_status()
                return await response.json()

        try:
            data = await make_request()
            results = []
            for i, item in enumerate(data['batchItems']):
                if 'response' in item and 'addresses' in item['response']:
                    address = item['response']['addresses'][0]
                    results.append({
                        'lon': locations[i]['lon'],
                        'lat': locations[i]['lat'],
                        'country': address.get('country', ''),
                        'municipality': address.get('municipality', ''),
                        'postalCode': address.get('postalCode', ''),
                        'freeformAddress': address.get('freeformAddress', ''),
                        'addressComponents': json.dumps(address)
                    })
                else:
                    results.append(self._get_empty_result(locations[i]['lat'], locations[i]['lon']))
            return results
        except Exception as e:
            print(f"Error in batch geocoding: {str(e)}")
            return [self._get_empty_result(loc['lat'], loc['lon']) for loc in locations]

    def _get_empty_result(self, lat: float, lon: float) -> Dict[str, Any]:
        return {
            'lon': lon,
            'lat': lat,
            'country': '',
            'municipality': '',
            'postalCode': '',
            'freeformAddress': '',
            'addressComponents': '{}'
        }

In [4]:
class OpenCageGeocoding(GeocodingService):
    def __init__(self, api_key: str):
        self.api_key = api_key
        self.endpoint = 'https://api.opencagedata.com/geocode/v1/json'

    async def geocode_batch(self, session: aiohttp.ClientSession, locations: List[Dict[str, float]]) -> List[Dict[str, Any]]:
        tasks = [self.geocode_single(session, loc['lat'], loc['lon']) for loc in locations]
        return await asyncio.gather(*tasks)

    async def geocode_single(self, session: aiohttp.ClientSession, lat: float, lon: float) -> Dict[str, Any]:
        params = {
            'q': f'{lat},{lon}',
            'key': self.api_key,
            'no_annotations': 1
        }

        @backoff.on_exception(backoff.expo, aiohttp.ClientError, max_tries=5)
        async def make_request():
            async with session.get(self.endpoint, params=params) as response:
                response.raise_for_status()
                return await response.json()

        try:
            data = await make_request()
            result = data['results'][0]
            components = result['components']
            return {
                'lon': lon,
                'lat': lat,
                'country': components.get('country', ''),
                'municipality': components.get('city', '') or components.get('town', '') or components.get('village', ''),
                'postalCode': components.get('postcode', ''),
                'freeformAddress': result.get('formatted', ''),
                'addressComponents': json.dumps(components)
            }
        except Exception as e:
            print(f"Error geocoding {lat}, {lon}: {str(e)}")
            return self._get_empty_result(lat, lon)

    def _get_empty_result(self, lat: float, lon: float) -> Dict[str, Any]:
        return {
            'lon': lon,
            'lat': lat,
            'country': '',
            'municipality': '',
            'postalCode': '',
            'freeformAddress': '',
            'addressComponents': '{}'
        }

In [5]:
class GoogleMapsGeocoding(GeocodingService):
    def __init__(self, api_key: str):
        self.api_key = api_key
        self.endpoint = 'https://maps.googleapis.com/maps/api/geocode/json'

    async def geocode_batch(self, session: aiohttp.ClientSession, locations: List[Dict[str, float]]) -> List[Dict[str, Any]]:
        tasks = [self.geocode_single(session, loc['lat'], loc['lon']) for loc in locations]
        return await asyncio.gather(*tasks)

    async def geocode_single(self, session: aiohttp.ClientSession, lat: float, lon: float) -> Dict[str, Any]:
        params = {
            'latlng': f'{lat},{lon}',
            'key': self.api_key
        }

        @backoff.on_exception(backoff.expo, aiohttp.ClientError, max_tries=5)
        async def make_request():
            async with session.get(self.endpoint, params=params) as response:
                response.raise_for_status()
                return await response.json()

        try:
            data = await make_request()
            if data['status'] == 'OK':
                result = data['results'][0]
                address_components = {comp['types'][0]: comp['long_name'] for comp in result['address_components']}
                return {
                    'lon': lon,
                    'lat': lat,
                    'country': address_components.get('country', ''),
                    'municipality': address_components.get('locality', '') or address_components.get('administrative_area_level_2', ''),
                    'postalCode': address_components.get('postal_code', ''),
                    'freeformAddress': result.get('formatted_address', ''),
                    'addressComponents': json.dumps(result['address_components'])
                }
            else:
                print(f"Error geocoding {lat}, {lon}: {data['status']}")
                return self._get_empty_result(lat, lon)
        except Exception as e:
            print(f"Error geocoding {lat}, {lon}: {str(e)}")
            return self._get_empty_result(lat, lon)

In [6]:
# Script settings... CHANGE THIS ACCORDINGLY
INPUT_CSV = '../data/prescila/test/input/geocoding_chunk_5_1-2.csv'  # input_locations.csv should have columns 'lat' and 'lon'
OUTPUT_CSV = '../data/prescila/test/result/geocoding_chunk_5_1-2.csv' # output CSV file with geocoded results
BATCH_SIZE = 100 # number of locations to geocode in a single batch
CHECKPOINT_FILE = 'last_processed_index.txt' # file to store the last processed index

In [7]:
async def process_batch(session: aiohttp.ClientSession, geocoding_service: GeocodingService, batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    return await geocoding_service.geocode_batch(session, batch)

In [8]:
async def main(geocoding_service: GeocodingService):
    df = pd.read_csv(INPUT_CSV)
    total_rows = len(df)

    # Resume from last checkpoint if it exists
    if os.path.exists(CHECKPOINT_FILE):
        with open(CHECKPOINT_FILE, 'r') as f:
            start_index = int(f.read().strip())
    else:
        start_index = 0

    progress_bar = tqdm(total=total_rows, initial=start_index, desc="Processing locations")

    async with aiohttp.ClientSession() as session:
        results = []
        for i in range(start_index, total_rows, BATCH_SIZE):
            batch = df.iloc[i:i+BATCH_SIZE].to_dict('records')
            batch_results = await process_batch(session, geocoding_service, batch)
            results.extend(batch_results)
            progress_bar.update(len(batch))

            # Save results every 100 records
            if len(results) >= 100:
                save_results(results)
                results = []

            # Update checkpoint
            with open(CHECKPOINT_FILE, 'w') as f:
                f.write(str(i + BATCH_SIZE))

    # Save any remaining results
    if results:
        save_results(results)
    
    progress_bar.close()
    os.remove(CHECKPOINT_FILE)  # Remove checkpoint file after successful completion

In [9]:
def save_results(results: List[Dict[str, Any]]):
    df = pd.DataFrame(results)
    mode = 'w' if not os.path.exists(OUTPUT_CSV) else 'a'
    header = not os.path.exists(OUTPUT_CSV)
    df.to_csv(OUTPUT_CSV, mode=mode, header=header, index=False)

In [10]:
print(os.getenv('PLATFORM_API_KEY'))

AzureMaps


In [12]:
if __name__ == "__main__":
    # Choose the geocoding service based on available API keys
    #select which Key to use
    if os.getenv('PLATFORM_API_KEY') == 'AzureMaps':
        geocoding_service = AzureMapsGeocoding(os.getenv('AZURE_MAPS_KEY'))
    # elif os.getenv('PLATFORM_API_KEY') == 'GoogleMaps':
    #     geocoding_service = OpenCageGeocoding(os.getenv('GOOGLE_MAPS_KEY'))
    # elif os.getenv('PLATFORM_API_KEY') == 'OpenCage':
    #     geocoding_service = OpenCageGeocoding(os.getenv('OPENCAGE_API_KEY'))
    else:
        raise ValueError("No API key found for any supported geocoding service")

    asyncio.run(main(geocoding_service))

RuntimeError: asyncio.run() cannot be called from a running event loop