In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, udf, split
from pyspark.sql import functions as F
from pyspark.sql.types import StringType
import requests, pygeohash, os, configparser

spark = SparkSession.builder.appName("RestaurantDataETL").getOrCreate()

I kept credentials (paths and API keys) in config file:

In [4]:
config = configparser.ConfigParser()
config.read('config.ini')
api_key = config['api_keys']['opencage_api_key']
restaurant_path = config['paths']['restaurant_path']
weather_path = config['paths']['weather_path']

In [27]:
restaurant_df = spark.read.option("header", True).csv(restaurant_path)
restaurant_df.show(3)

+------------+------------+----------------+-----------------------+-------+----------+------+--------+
|          id|franchise_id|  franchise_name|restaurant_franchise_id|country|      city|   lat|     lng|
+------------+------------+----------------+-----------------------+-------+----------+------+--------+
|197568495625|          10|The Golden Spoon|                  24784|     US|   Decatur|34.578| -87.021|
| 17179869242|          59|     Azalea Cafe|                  10902|     FR|     Paris|48.861|   2.368|
|214748364826|          27| The Corner Cafe|                  92040|     US|Rapid City|44.080|-103.250|
+------------+------------+----------------+-----------------------+-------+----------+------+--------+
only showing top 3 rows



Next, I created a udf function to get coordinates using opencage api:

In [None]:
api_key = config['api_keys']['opencage_api_key']
def fetch_coordinates(franchise_name, city, country):
    """Fetch latitude and longitude using OpenCage API."""
    if not franchise_name or not city or not country:
        return None
    query = f"{franchise_name}, {city}, {country}"
    url = f"https://api.opencagedata.com/geocode/v1/json?q={query}&key={api_key}"
    try:
        response = requests.get(url)
        if response.status_code == 200:
            results = response.json().get('results', [])
            if results:
                lat = results[0]['geometry']['lat']
                lng = results[0]['geometry']['lng']
                return f"{lat},{lng}"
    except Exception as e:
        print(f"Error fetching coordinates for {query}: {e}")
    return None

fetch_coordinates_udf = udf(fetch_coordinates, StringType())

Next, I fulfill missing coordinates and convert them into float

In [29]:
# Combine latitude and longitude into a single column
restaurant_df = restaurant_df.withColumn(
    "lat_lng",
    when(col("lat").isNull() | col("lng").isNull(),
         fetch_coordinates_udf(col("franchise_name"), col("city"), col("country")))
    .otherwise(None)
)

In [30]:
# Split lat_lng into latitude and longitude
restaurant_df = restaurant_df.withColumn(
    "lat",
    when(col("lat").isNull(), split(col("lat_lng"), ",").getItem(0)).otherwise(col("lat"))
)

restaurant_df = restaurant_df.withColumn(
    "lng",
    when(col("lng").isNull(), split(col("lat_lng"), ",").getItem(1)).otherwise(col("lng"))
)
# Drop the temporary lat_lng column
restaurant_df = restaurant_df.drop("lat_lng")

In [31]:
restaurant_df = restaurant_df.withColumn("lat", col("lat").cast("float"))
restaurant_df = restaurant_df.withColumn("lng", col("lng").cast("float"))

Then, I create another function to hash coordinates. and added column 'geohash'

In [None]:
def generate_geohash(lat, lng):
    return pygeohash.encode(lat, lng, precision=4) #dr5r

geohash_udf = udf(generate_geohash, StringType())

restaurant_df = restaurant_df.withColumn("geohash", geohash_udf(col("lat"), col("lng")))

In [34]:
restaurant_df.show(3)

+------------+------------+----------------+-----------------------+-------+----------+------+-------+-------+
|          id|franchise_id|  franchise_name|restaurant_franchise_id|country|      city|   lat|    lng|geohash|
+------------+------------+----------------+-----------------------+-------+----------+------+-------+-------+
|197568495625|          10|The Golden Spoon|                  24784|     US|   Decatur|34.578|-87.021|   dn4h|
| 17179869242|          59|     Azalea Cafe|                  10902|     FR|     Paris|48.861|  2.368|   u09t|
|214748364826|          27| The Corner Cafe|                  92040|     US|Rapid City| 44.08|-103.25|   9xyd|
+------------+------------+----------------+-----------------------+-------+----------+------+-------+-------+
only showing top 3 rows



Next step is to create dataframe for weather data

In [None]:
# here i used a function to get paths of all parquet files:
def find_parquet_files(directory):
    files = []
    for root, _, filenames in os.walk(directory):
        for filename in filenames:
            if filename.endswith(".parquet"):
                files.append(os.path.join(root, filename))
    return files

parquet_files = find_parquet_files(weather_path)
parquet_files[0]

'/content/weather/year=2016/month=10/day=29/part-00018-44bd3411-fbe4-4e16-b667-7ec0fc3ad489.c000.snappy.parquet'

In [36]:
len(parquet_files)

93

In [None]:
weather_df = spark.read.format("parquet").load(parquet_files)
weather_df.show(5)

+--------+-------+----------+----------+----------+
|     lng|    lat|avg_tmpr_f|avg_tmpr_c| wthr_date|
+--------+-------+----------+----------+----------+
|-111.202|18.7496|      82.7|      28.2|2016-10-12|
|-111.155| 18.755|      82.7|      28.2|2016-10-12|
|-111.107|18.7604|      82.7|      28.2|2016-10-12|
|-111.059|18.7657|      82.5|      28.1|2016-10-12|
|-111.012|18.7711|      82.5|      28.1|2016-10-12|
+--------+-------+----------+----------+----------+
only showing top 5 rows



In [38]:
weather_df.count()

37333145

In [39]:
weather_df = weather_df.withColumn("geohash", geohash_udf(col("lat"), col("lng")))

It's more efficient to aggregate by geohash before joining with restaurant data, as this reduces the amount of data processed

In [None]:
weather_aggregated_df = (
    weather_df.groupBy("geohash", "wthr_date")
    .agg(
        F.avg("avg_tmpr_f").alias("avg_tmpr_f"),
        F.avg("avg_tmpr_c").alias("avg_tmpr_c"),
        F.first("lng").alias("lng"),
        F.first("lat").alias("lat")
    )
)

In [44]:
weather_aggregated_df.show(5)

+-------+----------+-----------------+------------------+--------+-------+
|geohash| wthr_date|       avg_tmpr_f|        avg_tmpr_c|     lng|    lat|
+-------+----------+-----------------+------------------+--------+-------+
|   9eqz|2016-10-12|75.34444444444445|24.074074074074073|-102.966|19.5366|
|   9gej|2016-10-12|82.63333333333334| 28.14444444444445|-96.9953|20.6083|
|   d7qb|2016-10-12|81.21034482758618|27.341379310344827| -69.236|18.2997|
|   d7d9|2016-10-12|80.71428571428574|27.064285714285713|-75.1868|19.8997|
|   d7mz|2016-10-12|77.85714285714288|25.467857142857135|-70.6554|19.5211|
+-------+----------+-----------------+------------------+--------+-------+
only showing top 5 rows



Since it was written to keep all columns from both dataset, lat and lng of both tables are kept.

In [45]:
joined_df = restaurant_df.join(weather_aggregated_df, on="geohash", how="left")
joined_df.show(5)

+-------+------------+------------+----------------+-----------------------+-------+-------+------+-------+----------+-----------------+------------------+--------+-------+
|geohash|          id|franchise_id|  franchise_name|restaurant_franchise_id|country|   city|   lat|    lng| wthr_date|       avg_tmpr_f|        avg_tmpr_c|     lng|    lat|
+-------+------------+------------+----------------+-----------------------+-------+-------+------+-------+----------+-----------------+------------------+--------+-------+
|   dn4h|197568495625|          10|The Golden Spoon|                  24784|     US|Decatur|34.578|-87.021|2016-10-03|68.58076923076922|20.319230769230767|-87.0083|34.6264|
|   dn4h|197568495625|          10|The Golden Spoon|                  24784|     US|Decatur|34.578|-87.021|2016-10-06|72.58846153846153| 22.54615384615385|-87.0083|34.6264|
|   dn4h|197568495625|          10|The Golden Spoon|                  24784|     US|Decatur|34.578|-87.021|2016-10-13| 68.3423076923077

Finally, unit tests are conducted:

In [1]:
import unittest
from unittest.mock import patch, MagicMock

In [None]:
class TestFetchCoordinates(unittest.TestCase):

    @patch('requests.get')
    def test_fetch_coordinates_success(self, mock_get):
        mock_response = MagicMock()
        mock_response.status_code = 200
        mock_response.json.return_value = {
            'results': [{
                'geometry': {'lat': 43.25667, 'lng': 76.92861}
            }]
        }
        mock_get.return_value = mock_response

        franchise_name = "Pizza Hut"
        city = "Almaty"
        country = "Kazakhstan"
        result = fetch_coordinates(franchise_name, city, country)
        
        self.assertEqual(result, "43.25667,76.92861")
        mock_get.assert_called_once_with(
            "https://api.opencagedata.com/geocode/v1/json?q=Pizza Hut, Almaty, Kazakhstan&key="
            + config['api_keys']['opencage_api_key']
        )

    @patch('requests.get')
    def test_fetch_coordinates_no_results(self, mock_get):
        mock_response = MagicMock()
        mock_response.status_code = 200
        mock_response.json.return_value = {'results': []}
        mock_get.return_value = mock_response

        franchise_name = "Nonexistent Place"
        city = "Unknown City"
        country = "Unknown Country"
        result = fetch_coordinates(franchise_name, city, country)
        
        self.assertIsNone(result)

    @patch('requests.get')
    def test_fetch_coordinates_invalid_status(self, mock_get):
        mock_response = MagicMock()
        mock_response.status_code = 500
        mock_get.return_value = mock_response

        franchise_name = "Pizza Hut"
        city = "Almaty"
        country = "Kazakhstan"
        result = fetch_coordinates(franchise_name, city, country)
        
        self.assertIsNone(result)

    def test_fetch_coordinates_missing_input(self):
        self.assertIsNone(fetch_coordinates("", "City", "Country"))
        self.assertIsNone(fetch_coordinates("Franchise", "", "Country"))
        self.assertIsNone(fetch_coordinates("Franchise", "City", ""))

unittest.main(argv=[''], verbosity=2, exit=False)

test_fetch_coordinates_invalid_status (__main__.TestFetchCoordinates.test_fetch_coordinates_invalid_status) ... ok
test_fetch_coordinates_missing_input (__main__.TestFetchCoordinates.test_fetch_coordinates_missing_input) ... ok
test_fetch_coordinates_no_results (__main__.TestFetchCoordinates.test_fetch_coordinates_no_results) ... ok
test_fetch_coordinates_success (__main__.TestFetchCoordinates.test_fetch_coordinates_success) ... ok

----------------------------------------------------------------------
Ran 4 tests in 0.009s

OK


<unittest.main.TestProgram at 0x2727a6de750>