In [248]:
import pandas as pd
import math
import numpy as np

STATIONS_BANDS = ["SO2","C6H6","NO2","O3","PM10","PM25","CO"]
R = 6373.0

def create_data(data_path: str, legend_path: str):
    data = {}
    data_path_df = pd.read_csv(data_path)
    legend_path_df = pd.read_csv(legend_path, sep=";")
    legend_path_df['Location'] = legend_path_df['Location'].str.split(', ')
    legend_dict = dict(zip(legend_path_df['id_amat'], legend_path_df['Location']))
    
    for _, row in data_path_df.iterrows():
        date = row['date']
        if date not in data:
            data[date] = {}
        latlon = legend_dict.get(row['station_id'])
        if latlon:
            latlon = latlon[1][:-1] + ' ' + latlon[0][1:]
            data[date][latlon] = dict(row[4:])
            
    return data
    
def get_closest_dist_per_band(data, date, latlon):
    """
    Given latlon of a pixel find for each pollutant the distance of the closest GoldenStation
    NPArray aligned with STATION_BANDS
    """
    data_single_date = data[date]
    distances = {}
    for i, band in enumerate(STATIONS_BANDS):
        for j, latlon_data in enumerate(list(data_single_date.keys())):
            latlon_data_list = latlon_data.split()
            if not np.isnan(data_single_date[latlon_data][band]):
                lat1 = math.radians(latlon[0])
                lat2 = math.radians(float(latlon_data_list[0]))
                lon1 = math.radians(latlon[1])
                lon2 = math.radians(float(latlon_data_list[1]))
                diff_lon = lon2 - lon1
                diff_lat = lat2 -lat1 
                a = (math.sin(diff_lat/2))**2 + math.cos(lon1) * math.cos(float(lat2)) * (math.sin(diff_lon/2))**2
                c = 2 * math.atan2(math.sqrt(a), math.sqrt(1-a))
                dist = R * c
                if band in distances.keys():
                    if dist < distances[band]:
                        distances[band] = dist
                else:
                    distances[band] = dist
    return list(distances.values())

In [225]:
data = create_data("/Users/luca/Downloads/weak_labels_demo/stations_data_2023-12-18.csv", "/Users/luca/Downloads/weak_labels_demo/qaria_stazione.csv")

In [226]:
print(data['2016-01-07'])

{'9.18218994140625 45.432300567627': {'SO2': nan, 'C6H6': nan, 'NO2': 49.0, 'O3': nan, 'PM10': nan, 'PM25': nan, 'CO': nan}, '9.23478031158447 45.4740982055664': {'SO2': 9.0, 'C6H6': 2.3, 'NO2': 68.0, 'O3': 26.0, 'PM10': 41.0, 'PM25': 38.0, 'CO': nan}, '9.16944026947021 45.4441986083984': {'SO2': nan, 'C6H6': nan, 'NO2': 73.0, 'O3': nan, 'PM10': nan, 'PM25': nan, 'CO': 2.1}, '9.19083976745605 45.4962997436523': {'SO2': nan, 'C6H6': 1.1, 'NO2': 70.0, 'O3': nan, 'PM10': nan, 'PM25': nan, 'CO': 2.3}, '9.24730014801025 45.4995994567871': {'SO2': nan, 'C6H6': nan, 'NO2': nan, 'O3': 20.0, 'PM10': nan, 'PM25': nan, 'CO': nan}, '9.19791984558105 45.4705009460449': {'SO2': nan, 'C6H6': nan, 'NO2': 75.0, 'O3': nan, 'PM10': 38.0, 'PM25': 32.0, 'CO': 1.4}, '9.19534015655518 45.4635009765625': {'SO2': nan, 'C6H6': nan, 'NO2': 49.0, 'O3': 20.0, 'PM10': 33.0, 'PM25': nan, 'CO': nan}, '9.141770362854 45.4761009216309': {'SO2': nan, 'C6H6': 1.9, 'NO2': 75.0, 'O3': nan, 'PM10': nan, 'PM25': nan, 'CO': 1

In [249]:
get_closest_dist_per_band(data, '2016-01-07', [9.18218994140625, 45.432300567627])

[7.013521220762758,
 6.002515540487363,
 0.0,
 3.2376869330026183,
 3.2376869330026183,
 3.9456139626043534,
 1.7956884291618183]

In [250]:
def get_loss_factor(date, latlon):
    closest_dist_per_band = get_closest_dist_per_band(data, date, latlon)
    loss_factors = np.ndarray(len(STATIONS_BANDS))
    for i in range(len(closest_dist_per_band)):
        if np.isclose( closest_dist_per_band[i] , 0 ):
            loss_factors[i] = 1
        else:
            loss_factors[i] = 0.3

    return loss_factors

get_loss_factor('2016-01-07', [9.18218994140625, 45.432300567627])

array([0.3, 0.3, 1. , 0.3, 0.3, 0.3, 0.3])