In [None]:
import numpy as np
import pandas as pd
from geoclip import GeoCLIP
from geopy.geocoders import Nominatim
from collections import Counter, defaultdict
import cv2
import random

In [None]:
pd.set_option("display.max_colwidth", None)

In [None]:
df_train = pd.read_pickle("intermediate/df_train.pkl")
df_train.head()

Use training data to determine lat/lon boundaries for region clusters

In [None]:
region_bounds = df_train.groupby("region_cluster").agg(
    min_latitude=("latitude", "min"),
    max_latitude=("latitude", "max"),
    min_longitude=("longitude", "min"),
    max_longitude=("longitude", "max")
).reset_index()

region_bounds

In [None]:
valid_countries = df_train['country'].unique().tolist()
valid_countries

In [None]:
country_bounds = df_train.groupby(["country", "region_cluster"]).agg(
    min_latitude=("latitude", "min"),
    max_latitude=("latitude", "max"),
    min_longitude=("longitude", "min"),
    max_longitude=("longitude", "max")
).reset_index()

country_bounds

In [None]:
# to use for images that cannot be classified based on coordinate ranges

locator = Nominatim(user_agent="abcd")

In [None]:
model = GeoCLIP()

In [None]:
def auto_crop_black_borders(img, threshold=10):
    """
    Crop black borders from the right and bottom of an image.
    
    Parameters:
        img: Input image (NumPy array)
        threshold: Pixel intensity threshold to consider a pixel as "non-black"
    
    Returns:
        Cropped image (without black borders)
    """
    if len(img.shape) == 3:
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    else:
        gray = img

    # Create a binary mask of non-black pixels
    mask = gray > threshold

    # Find the bounding box of the non-black area
    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)

    if not np.any(rows) or not np.any(cols):
        return img  # nothing to crop

    y_min, y_max = np.where(rows)[0][[0, -1]]
    x_min, x_max = np.where(cols)[0][[0, -1]]

    cropped = img[y_min:y_max+1, x_min:x_max+1]
    return cropped

In [None]:
def equirectangular_to_perspective(equi_img, fov, theta, height, width):
    """
    Simplified conversion from equirectangular to perspective.
    Only horizontal rotation (theta). No vertical tilt or stretch correction.

    Parameters:
        equi_img: Equirectangular input image (OpenCV format)
        fov: Horizontal field of view in degrees
        theta: Yaw angle in degrees (0 = front, 90 = right, etc.)
        height, width: Output dimensions

    Returns:
        Perspective view image
    """
    equ_h, equ_w = equi_img.shape[:2]

    # Convert angles to radians
    fov_rad = np.deg2rad(fov)
    theta_rad = np.deg2rad(theta)

    # Grid of x, y in normalized view space
    x = np.linspace(-np.tan(fov_rad / 2), np.tan(fov_rad / 2), width)
    y = np.linspace(-1, 1, height)  # keep vertical stretch simple
    x, y = np.meshgrid(x, -y)  # flip y for image orientation
    z = np.ones_like(x)

    # Normalize direction vectors
    norm = np.sqrt(x**2 + y**2 + z**2)
    x /= norm
    y /= norm
    z /= norm

    # Rotate around Y axis (theta)
    x_rot = np.cos(theta_rad) * x + np.sin(theta_rad) * z
    z_rot = -np.sin(theta_rad) * x + np.cos(theta_rad) * z

    # Convert to spherical coordinates
    lon = np.arctan2(x_rot, z_rot)
    lat = np.arcsin(y)

    # Map to image coordinates
    u = (lon + np.pi) / (2 * np.pi) * equ_w
    v = (np.pi / 2 - lat) / np.pi * equ_h

    # Remap
    u = u.astype(np.float32)
    v = v.astype(np.float32)
    perspective = cv2.remap(equi_img, u, v, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP)

    return perspective

In [None]:
from geopy.extra.rate_limiter import RateLimiter

# Initialize geolocator and rate limiter
reverse_geocode = RateLimiter(locator.reverse, min_delay_seconds=1)  # limit the frequency of requests

# Cache dictionary
cache = {}

def get_country_with_cache(lat, lon):
    key = (round(lat, 1), round(lon, 1))  # round lat/lon to handle small variations
    if key in cache:
        return cache[key]
    
    try:
        # Perform the reverse geocoding query
        location = reverse_geocode(f"{lat}, {lon}", language="en", addressdetails=True, zoom=3)
        country = location.raw.get('address', {}).get('country', None)
        cache[key] = country  # store result in cache
        return country
    except Exception as e:
        return None

In [None]:
def get_candidate_areas(lat, lon, bounds_df):
    candidates = bounds_df[
        (bounds_df["min_latitude"] <= lat) & (lat <= bounds_df["max_latitude"]) &
        (bounds_df["min_longitude"] <= lon) & (lon <= bounds_df["max_longitude"])
    ]
    return candidates["country"].tolist(), candidates["region_cluster"].tolist()

def assign_location(lat, lon, bounds_df = country_bounds, geocode_fallback=None):
    country_candidates, region_candidates = get_candidate_areas(lat, lon, bounds_df)

    # Determine country
    if len(country_candidates) == 1:
        country = country_candidates[0]
    elif len(country_candidates) > 1 and geocode_fallback:
        country = geocode_fallback(lat, lon)
    else:
        country = "unknown"

    # Determine region_cluster
    region_cluster = region_candidates[0] if len(region_candidates) >= 1 else "ambiguous"

    return country, region_cluster

In [None]:

def process_image_adjusted(image_path):
    """Process one image: split into 4 slices, predict countries, and return majority vote."""
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert from BGR to RGB
    img = auto_crop_black_borders(img)  # Remove black borders if needed
    
    all_countries = []
    all_coords_by_country = defaultdict(list)
    valid_countries_list = []
    valid_coords_by_country = defaultdict(list)
    
    # Split the image into 4 slices and predict for each slice
    for i in range(4):
        view = equirectangular_to_perspective(img, fov=90, theta=i*90, height=480, width=600)
        cv2.imwrite('intermediate/current_geo_clip_view.jpg', view)
        
        # Predict the coordinates for the slice
        top_pred_coords, _ = model.predict('intermediate/current_geo_clip_view.jpg', top_k=3)

        for lat, lon in top_pred_coords:
            lat_rounded = round(float(lat), 2)
            lon_rounded = round(float(lon), 2)

            country_from_bounds, region_from_bounds = assign_location(lat_rounded, lon_rounded)

            if country_from_bounds == "unknown":
            
                country = get_country_with_cache(lat_rounded, lon_rounded)

            else: country = country_from_bounds

            if country is None:
                continue  # skip if no country was returned

            country = country.lower().replace(' ', '_')
            all_countries.append(country)
            all_coords_by_country[country].append((lat_rounded, lon_rounded))

            print(country)

            if country in valid_countries:
                valid_countries_list.append(country)
                valid_coords_by_country[country].append((lat_rounded, lon_rounded))

    if valid_countries_list:
        majority_country = Counter(valid_countries_list).most_common(1)[0][0]
        majority_coords = random.choice(valid_coords_by_country[majority_country])
    elif all_countries:
        majority_country = Counter(all_countries).most_common(1)[0][0]
        majority_coords = random.choice(all_coords_by_country[majority_country])
    else:
        return None  # no valid or fallback predictions

    country_from_bounds, region_from_bounds = assign_location(majority_coords[0], majority_coords[1])

    return [majority_coords[0], majority_coords[1], majority_country]

In [None]:
process_image_adjusted('images/japan/1741695546_36.823432_139.5921591.jpg')

In [None]:
assign_location(42.55, 1.8, country_bounds)

In [None]:
df_test = pd.read_pickle("intermediate/df_test.pkl")
df_test.head()

In [None]:
df_test_pred = df_test.copy()

In [None]:
df_test_pred[["pred_lat", "pred_lon", "pred_country", "pred_region_cluster"]] = (
    df_test_pred["full_path"]
    .apply(process_image_adjusted)
    .apply(pd.Series)
)

In [None]:
df_test_pred

In [None]:
df_test_pred.to_pickle("intermediate/df_test_pred_GeoCLIP.pkl") 

In [None]:
df_test_pred = pd.read_pickle("intermediate/df_test_pred_GeoCLIP.pkl") 

In [None]:
df_test_pred_only = df_test_pred[['country', 'pred_country']].copy()
df_test_pred_only.loc[~df_test_pred_only['pred_country'].isin(valid_countries), 'pred_country'] = 'other'

df_test_pred_only

In [None]:
from sklearn.metrics import classification_report

In [None]:
report_dict = classification_report(df_test_pred_only['country'], df_test_pred_only['pred_country'], output_dict=True)

In [None]:
pd.DataFrame(report_dict).transpose()

In [None]:
pd.crosstab(df_test_pred_only['country'], df_test_pred_only['pred_country'])