In [15]:
import os
import math
import numpy as np
import pandas as pd
import geopandas as gpd
from shapely.prepared import prep
from shapely.geometry import Point
from concurrent.futures import ThreadPoolExecutor

from wsi.mapping.iso_name import ISO_NAME
from wsi.mapping.iso_gw import ISO_GW
from wsi.mapping.iso_iso2 import ISO_ISO2
from wsi.utils import raw_data_path, processed_data_path

# Constants
EARTH_RADIUS_KM = 6371
FILE_PATTERN = "gpw_v4_population_count_adjusted_to_2015_unwpp_country_totals_rev11_2020_30_sec_{tile}.asc"


In [16]:
import logging

logging.basicConfig(
    level=logging.INFO,
    filename=processed_data_path("shocks","proximity_conflict", 'conflict_logs.log'),   # Output file path
    filemode='a',                          # Append mode
    format='%(asctime)s - %(levelname)s - %(message)s'
)

logger = logging.getLogger(__name__)


In [17]:

def read_population_count(file_path):
    if not os.path.exists(file_path):
        print(f"File not found: {file_path}")
        return None
    with open(file_path, 'r') as f:
        metadata = {}
        for _ in range(6):
            key, value = f.readline().strip().split()
            metadata[key.lower()] = float(value)
    data = np.loadtxt(file_path, skiprows=6)
    gt = (
        metadata['xllcorner'],
        metadata['cellsize'],
        0,
        metadata['yllcorner'] + metadata['nrows'] * metadata['cellsize'],
        0,
        -metadata['cellsize']
    )
    return {
        "file": file_path,
        "data": data,
        "geotransform": gt,
        "no_data_value": metadata['nodata_value']
    }


def prepare_pixel_grid(geotransform, shape):
    origin_x, pixel_w, _, origin_y, _, pixel_h = geotransform
    rows, cols = shape
    row_grid, col_grid = np.ogrid[0:rows, 0:cols]
    lat_grid = origin_y + row_grid * pixel_h
    lon_grid = origin_x + col_grid * pixel_w
    lat_grid = np.broadcast_to(lat_grid, (rows, cols))
    lon_grid = np.broadcast_to(lon_grid, (rows, cols))
    return lat_grid, lon_grid

def haversine_distance_vector(lat1, lon1, lat2, lon2):
    lat1, lon1 = math.radians(lat1), math.radians(lon1)
    lat2, lon2 = np.radians(lat2), np.radians(lon2)
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = np.sin(dlat / 2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2)**2
    return EARTH_RADIUS_KM * 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a))

def get_population_in_conflict_area(all_data, conflict_coords, radius_km=50):
    total_population = 0
    union_grid_points = []
    radius_deg_lat = radius_km / 111.0

    for dataset in all_data:
        data = dataset["data"]
        gt = dataset["geotransform"]
        nodata = dataset["no_data_value"]
        lat_grid, lon_grid = prepare_pixel_grid(gt, data.shape)
        mask = np.zeros(data.shape, dtype=bool)

        for lat, lon in conflict_coords:
            radius_deg_lon = radius_km / (111.0 * math.cos(math.radians(lat)))
            lat_min, lat_max = lat - radius_deg_lat, lat + radius_deg_lat
            lon_min, lon_max = lon - radius_deg_lon, lon + radius_deg_lon
            conflict_mask = (
                (lat_grid >= lat_min) & (lat_grid <= lat_max) &
                (lon_grid >= lon_min) & (lon_grid <= lon_max)
            )
            dists = haversine_distance_vector(lat, lon, lat_grid[conflict_mask], lon_grid[conflict_mask])
            tmp_mask = np.zeros_like(mask)
            tmp_mask[conflict_mask] = dists <= radius_km
            mask |= tmp_mask

        valid_mask = mask & (data != nodata)
        total_population += data[valid_mask].sum()
        if np.any(valid_mask):
            union_grid_points += np.column_stack((lat_grid[valid_mask], lon_grid[valid_mask], data[valid_mask])).tolist()

    return total_population, union_grid_points

def clip_grid_points_to_country(grid_points, country_polygon):
    prepped = prep(country_polygon)
    return [pt for pt in grid_points if prepped.contains(Point(pt[1], pt[0]))]

def filter_conflicts(df, country_code, year):
    return df[(df['year'] == year) & df['country_id'].astype(str).str.contains(str(country_code))]

def get_conflict_coordinates(df):
    return df[['latitude', 'longitude']].dropna().values.tolist()

def process_country_code(country_code, years, countries, event_csv, df_pop, all_data):
    summary_rows = []
    heatmap_points = []

    iso3 = next((iso for iso, code in ISO_GW.items() if str(code) == country_code), None)
    if not iso3:
        logger.warning(f"ISO3 code not found for country_code: {country_code}")
        return None

    iso2 = ISO_ISO2.get(iso3)
    if not iso2:
        logger.warning(f"ISO2 code not found for ISO3: {iso3} (country_code: {country_code})")
        return None
    
    country_gdf = countries[countries['ISO'] == iso2]
    if country_gdf.empty:
        logger.warning(f"Country geometry not found for ISO3: {iso3}/ ISO2: {iso2})")
        return None

    polygon = country_gdf.geometry.iloc[0]

    for yr in years:
        conflict_df = filter_conflicts(event_csv, country_code, yr)
        coords = get_conflict_coordinates(conflict_df)

        if not coords:
            pop_in_conflict = 0
            union_grid_points = []
        else:
            pop_in_conflict, union_grid_points = get_population_in_conflict_area(all_data, coords)
            union_grid_points = clip_grid_points_to_country(union_grid_points, polygon)
            pop_in_conflict = sum(pt[2] for pt in union_grid_points)
            # store grid with year tag
            # for pt in union_grid_points:
            #     heatmap_points.append({
            #         'year': yr,
            #         'latitude': pt[0],
            #         'longitude': pt[1],
            #         'population': pt[2]
            #     })

        national_pop = df_pop[(df_pop['ISO_code'] == iso3) & (df_pop['Year'] == yr)]['Population']
        if not national_pop.empty and national_pop.iloc[0] > 0:
            pct = (pop_in_conflict / national_pop.iloc[0]) * 100
        else:
            pct = None

        summary_rows.append({
            'gw_code': country_code,
            'iso3': iso3,
            'year': yr,
            'pop_in_conflict': pop_in_conflict,
            'national_pop': national_pop.iloc[0] if not national_pop.empty else None,
            'percent': pct
        })

        # Save individual files
        pd.DataFrame(summary_rows).to_csv(processed_data_path("shocks", "proximity_conflict", f"conflict_summary_{iso3}_{yr}.csv"),index=False)
    # pd.DataFrame(heatmap_points).to_csv(processed_data_path("shocks", "proximity_conflict", f"heatmap_grid_{iso3}.csv"),index=False)

    return iso3

In [18]:
# Load shared data (outside parallel scope)

## POPULATION DNESITY
all_data = []
for tile in range(1, 9):
    fp = raw_data_path("shocks", "gpw-v4", FILE_PATTERN.format(tile=tile))
    result = read_population_count(fp)
    if result:
        all_data.append(result)

## SHAPEFILE
# TODO: make secondary shapefile dataset when country not availbale in first
countries = gpd.read_file(raw_data_path("shocks", "country_shapefiles", "World_Countries_Generalized.shp")).to_crs("EPSG:4326")

## CONFLICT EVENTS
UcdpPrioConflict_csv = pd.read_csv(raw_data_path("shocks", "UcdpPrioConflict_v25_1.csv"))
event_csv = pd.read_csv(raw_data_path("shocks", "GEDEvent_v25_1.csv"))
event_csv = event_csv[event_csv['conflict_new_id'].isin(UcdpPrioConflict_csv['conflict_id'].unique())]

# fitler events, at least one fatality, also more than one death at event per country per year per dyad (i.e. exclude small conflicts)
event_csv = event_csv[event_csv['best'] > 0]

# total deaths per dyad-country-year
death_sums = (
    event_csv.groupby(['dyad_new_id', 'country_id', 'year'])['best']
    .sum()
    .reset_index(name='group_best_sum')
)

# Keep only groups where total deaths > 1
valid_groups = death_sums[death_sums['group_best_sum'] > 1]

# Merge back to filter the original event-level data
event_csv = event_csv.merge(
    valid_groups[['dyad_new_id', 'country_id', 'year']],
    on=['dyad_new_id', 'country_id', 'year'],
    how='inner'
)

## TOTAL POPULATION
from wsi.shocks.population import build_population_df
df_pop = build_population_df()

# Save all lat/long of relevant events
# Invert ISO_GW: {GW_code → ISO3}
GW_ISO = {str(v): k for k, v in ISO_GW.items()}
event_csv['ISO3'] = event_csv['country_id'].astype(str).map(GW_ISO)
all_events = event_csv[['year', 'country_id', 'conflict_name', 'dyad_name', 'best','latitude', 'longitude']].copy()
all_events.to_csv(processed_data_path("shocks", "proximity_conflict", f"event_level_coords.csv"),index=False)

  event_csv = pd.read_csv(raw_data_path("shocks", "GEDEvent_v25_1.csv"))


In [None]:
import os
import re

# Define the directory path
directory = processed_data_path("shocks", "proximity_conflict")

# Pattern to match filenames like conflict_summary_ABC.csv
#pattern = re.compile(r"conflict_summary_([A-Z]{3})\.csv")

# pattern to match filenames like conflict_summary_ABC.csv and conflict_summary_ABC_2016.csv
pattern = re.compile(r"conflict_summary_([A-Z]{3})(?:_\d{4})?\.csv")

# List all files and extract matching ISO codes
iso_codes = []
for filename in os.listdir(directory):
    match = pattern.match(filename)
    if match:
        iso_codes.append(match.group(1))

# Sort and get unique ISO codes
iso_codes_completed = set(iso_codes)
print(iso_codes_completed)


{'CIV', 'AGO', 'GRD', 'HRV', 'URY', 'DJI', 'PRT', 'PER', 'NER', 'QAT', 'GBR', 'KOR', 'UGA', 'PRY', 'TJK', 'LIE', 'MNG', 'TUN', 'EGY', 'SVN', 'KIR', 'AFG', 'GIN', 'BHS', 'MDG', 'DNK', 'VUT', 'KHM', 'SYC', 'UZB', 'DMA', 'KWT', 'MMR', 'NPL', 'KAZ', 'PAK', 'MRT', 'THA', 'ARM', 'BFA', 'NIC', 'PNG', 'SWZ', 'MLT', 'ECU', 'LBN', 'MOZ', 'ROU', 'MDA', 'IRQ', 'FRA', 'KEN', 'BWA', 'ARG', 'MYS', 'ZMB', 'SLB', 'CMR', 'BHR', 'NOR', 'IRN', 'STP', 'ALB', 'LTU', 'VEN', 'BGR', 'BRA', 'COM', 'JOR', 'CUB', 'VNM', 'CYP', 'LBY', 'LCA', 'PSE', 'RWA', 'AZE', 'SOM', 'KGZ', 'BEN', 'ESP', 'KNA', 'WSM', 'BDI', 'POL', 'GRC', 'TGO', 'FIN', 'IDN', 'SGP', 'SEN', 'ERI', 'HND', 'CZE', 'TCD', 'IRL', 'ISL', 'LAO', 'COG', 'DEU', 'PRK', 'BLZ', 'NAM', 'ITA', 'NRU', 'RUS', 'CHL', 'SRB', 'SSD', 'GTM', 'PAN', 'ZAF', 'PHL', 'SLE', 'ZWE', 'JAM', 'TON', 'GNB', 'SLV', 'DOM', 'HTI', 'TTO', 'FSM', 'ISR', 'VCT', 'LUX', 'MWI', 'BRB', 'AUS', 'BOL', 'OMN', 'GEO', 'MDV', 'NZL', 'FJI', 'HUN', 'TKM', 'GUY', 'BGD', 'MHL', 'TLS', 'BIH', 'COL'

In [20]:
gw_completed = [str(ISO_GW[iso]) for iso in iso_codes_completed if iso in ISO_GW]
all_gw_codes = GW_ISO.keys()

valid_gw_codes = list(set(all_gw_codes) - set(gw_completed)) + ['700']  
len(valid_gw_codes)

38

In [21]:
GW_ISO.keys()

dict_keys(['2', '20', '31', '40', '41', '42', '51', '52', '53', '70', '80', '89', '90', '91', '349', '93', '94', '95', '99', '100', '101', '110', '115', '130', '135', '140', '145', '150', '155', '160', '165', '200', '205', '210', '211', '212', '220', '572', '230', '235', '240', '245', '265', '267', '269', '271', '273', '275', '280', '290', '300', '900', '310', '316', '317', '325', '327', '329', '332', '335', '337', '338', '339', '712', '343', '344', '340', '345', '346', '347', '350', '352', '355', '359', '360', '365', '366', '367', '368', '369', '370', '371', '372', '373', '375', '380', '385', '390', '395', '402', '404', '411', '420', '432', '433', '434', '435', '475', '437', '438', '439', '450', '451', '452', '461', '471', '481', '482', '483', '484', '490', '500', '501', '510', '511', '516', '517', '520', '522', '530', '531', '540', '541', '551', '552', '553', '560', '563', '564', '565', '570', '571', '580', '581', '590', '600', '615', '616', '620', '625', '626', '630', '640', '645', 

In [None]:
import time

years = list(range(1995, 2025))  # 30 years

total = len(valid_gw_codes)
start_all = time.time()

for i, code in enumerate(valid_gw_codes):
    print(f"\nProcessing country {i+1}/{total}: GW Code {code} ({valid_gw_codes.index(code)+1}/{total})")
    start_country = time.time()

    iso3 = next((iso for iso, c in ISO_GW.items() if str(c) == str(code)), None)
    if not iso3:
        print(f"  Skipping: No ISO3 for {code}", flush=True)
        continue

    for yr in years:
        print(f"  Processing {iso3} / {code} for year {yr}", flush=True)
        try:
            process_country_code(code, [yr], countries, event_csv, df_pop, all_data)
        except Exception as e:
            print(f"    Error for {code}, {yr}: {e}", flush=True)

    elapsed_country = time.time() - start_country
    print(f"Finished {iso3} ({code}) in {elapsed_country/60:.2f} min", flush=True)

elapsed_total = time.time() - start_all
print(f"\nAll countries processed in {elapsed_total/60:.1f} minutes.")



Processing country 1/38: GW Code 329 (1/38)
  Processing SIC / 329 for year 1995
  Processing SIC / 329 for year 1996
  Processing SIC / 329 for year 1997
  Processing SIC / 329 for year 1998
  Processing SIC / 329 for year 1999
  Processing SIC / 329 for year 2000
  Processing SIC / 329 for year 2001
  Processing SIC / 329 for year 2002
  Processing SIC / 329 for year 2003
  Processing SIC / 329 for year 2004
  Processing SIC / 329 for year 2005
  Processing SIC / 329 for year 2006
  Processing SIC / 329 for year 2007
  Processing SIC / 329 for year 2008
  Processing SIC / 329 for year 2009
  Processing SIC / 329 for year 2010
  Processing SIC / 329 for year 2011
  Processing SIC / 329 for year 2012
  Processing SIC / 329 for year 2013
  Processing SIC / 329 for year 2014
  Processing SIC / 329 for year 2015
  Processing SIC / 329 for year 2016
  Processing SIC / 329 for year 2017
  Processing SIC / 329 for year 2018
  Processing SIC / 329 for year 2019
  Processing SIC / 329 for yea

In [None]:
# from tqdm import tqdm
# import time

# years = list(range(1995, 2025))  # 30 years
# total_country = len(valid_gw_codes)
# average_times = []

# with tqdm(total=total_country, desc="Processing", unit="country-year") as pbar:
#     for code in valid_gw_codes:
#         start = time.time()

#         process_country_code(code, years, countries, event_csv, df_pop, all_data)
#         elapsed = time.time() - start
#         average_times.append(elapsed)

#         avg_time = sum(average_times) / len(average_times)
#         remaining = avg_time * (total_country - pbar.n)

#         pbar.set_postfix({
#             "Last (s)": f"{elapsed:.1f}",
#             "Avg (s)": f"{avg_time:.1f}",
#             "ETA": f"{remaining/60:.1f} min"
#         })
#         pbar.update(1)



In [None]:
# # Parallel execution
# years = list(range(1995,2025))

# valid_gw_codes = list(valid_gw_codes - set(iso_to_gw)) + [700] #Afghanistan
# #["811", "840", "850", "900"]  # Cambodia, Phillipines, Indonesia, Australia
# #valid_gw_codes = ["900"] 
# #valid_gw_codes = GW_ISO.keys()

# from concurrent.futures import ThreadPoolExecutor, as_completed

# completed = []

# with ThreadPoolExecutor() as executor:
#     futures = {executor.submit(process_country_code, code, years, countries, event_csv, df_pop, all_data): code for code in valid_gw_codes}
#     for future in as_completed(futures):
#         result = future.result()
#         if result:
#             completed.append(result)
#             print(f"✅ Saved results for {result}")