In [None]:
import cv2
import numpy as np
import pandas as pd
import torch
from geoclip import GeoCLIP
from geopy.geocoders import Nominatim
import matplotlib.pyplot as plt
from collections import Counter
from tqdm import tqdm
# from io import BytesIO
# from PIL import Image

In [None]:
model = GeoCLIP()

In [None]:
locator = Nominatim(user_agent="abcd")

In [None]:
def auto_crop_black_borders(img, threshold=10):
    """
    Crop black borders from the right and bottom of an image.
    
    Parameters:
        img: Input image (NumPy array)
        threshold: Pixel intensity threshold to consider a pixel as "non-black"
    
    Returns:
        Cropped image (without black borders)
    """
    if len(img.shape) == 3:
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    else:
        gray = img

    # Create a binary mask of non-black pixels
    mask = gray > threshold

    # Find the bounding box of the non-black area
    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)

    if not np.any(rows) or not np.any(cols):
        return img  # nothing to crop

    y_min, y_max = np.where(rows)[0][[0, -1]]
    x_min, x_max = np.where(cols)[0][[0, -1]]

    cropped = img[y_min:y_max+1, x_min:x_max+1]
    return cropped

In [None]:
def equirectangular_to_perspective(equi_img, fov, theta, height, width):
    """
    Simplified conversion from equirectangular to perspective.
    Only horizontal rotation (theta). No vertical tilt or stretch correction.

    Parameters:
        equi_img: Equirectangular input image (OpenCV format)
        fov: Horizontal field of view in degrees
        theta: Yaw angle in degrees (0 = front, 90 = right, etc.)
        height, width: Output dimensions

    Returns:
        Perspective view image
    """
    equ_h, equ_w = equi_img.shape[:2]

    # Convert angles to radians
    fov_rad = np.deg2rad(fov)
    theta_rad = np.deg2rad(theta)

    # Grid of x, y in normalized view space
    x = np.linspace(-np.tan(fov_rad / 2), np.tan(fov_rad / 2), width)
    y = np.linspace(-1, 1, height)  # keep vertical stretch simple
    x, y = np.meshgrid(x, -y)  # flip y for image orientation
    z = np.ones_like(x)

    # Normalize direction vectors
    norm = np.sqrt(x**2 + y**2 + z**2)
    x /= norm
    y /= norm
    z /= norm

    # Rotate around Y axis (theta)
    x_rot = np.cos(theta_rad) * x + np.sin(theta_rad) * z
    z_rot = -np.sin(theta_rad) * x + np.cos(theta_rad) * z

    # Convert to spherical coordinates
    lon = np.arctan2(x_rot, z_rot)
    lat = np.arcsin(y)

    # Map to image coordinates
    u = (lon + np.pi) / (2 * np.pi) * equ_w
    v = (np.pi / 2 - lat) / np.pi * equ_h

    # Remap
    u = u.astype(np.float32)
    v = v.astype(np.float32)
    perspective = cv2.remap(equi_img, u, v, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP)

    return perspective

In [None]:
from geopy.extra.rate_limiter import RateLimiter

# Initialize geolocator and rate limiter
reverse_geocode = RateLimiter(locator.reverse, min_delay_seconds=1)  # limit the frequency of requests

# Cache dictionary
cache = {}

def get_country_with_cache(lat, lon):
    key = (round(lat, 1), round(lon, 1))  # round lat/lon to handle small variations
    if key in cache:
        return cache[key]
    
    try:
        # Perform the reverse geocoding query
        location = reverse_geocode(f"{lat}, {lon}", language="en", addressdetails=True, zoom=3)
        country = location.raw.get('address', {}).get('country', None)
        cache[key] = country  # store result in cache
        return country
    except Exception as e:
        return None

In [None]:
df_train_cl = pd.read_pickle("intermediate/train_df_clean.pkl") # replace by cleaned df
display(df_train_cl)

In [None]:
img = cv2.imread(df_train_cl['path'][2111])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert from BGR to RGB
img = auto_crop_black_borders(img)

plt.imshow(img)
plt.axis('off')
plt.show()

In [None]:
views = []
img = auto_crop_black_borders(img)

for i in range(0, 4):
    # print(i)     
    view = equirectangular_to_perspective(img, fov=90, theta=i*90, height=480, width=600)
    cv2.imwrite('intermediate/current_geo_clip_view.jpg', view)

    top_pred_gps, top_pred_prob = model.predict('intermediate/current_geo_clip_view.jpg', top_k=3)
    
    views.append(view)

In [None]:
img = auto_crop_black_borders(img)
views = []
df = pd.DataFrame(columns=["lat", "lon", "view", "rank"])

for i in range(4):  
    view = equirectangular_to_perspective(img, fov=90, theta=i*90, height=480, width=600)
    views.append(view)
    cv2.imwrite('intermediate/current_geo_clip_view.jpg', view)
    top_pred_gps, _ = model.predict('intermediate/current_geo_clip_view.jpg', top_k=1)

    for rank, coords in enumerate(top_pred_gps.tolist()):
        df.loc[len(df)] = [coords[0], coords[1], i, rank]  

In [None]:
plt.imshow(views[0])
plt.axis('off')
plt.show()

In [None]:
top_pred_gps, top_pred_prob = model.predict('intermediate/current_geo_clip_view.jpg', top_k=5)

In [None]:
top_pred_gps

In [None]:
df['country'] = df.apply(
    lambda row: locator.reverse(f"{row['lat']}, {row['lon']}", language="en", addressdetails=False, zoom=0),    
    axis=1
)

In [None]:
display(df)

In [None]:
def process_image(image_path):
    """Process one image: split into 4 slices, predict countries, and return majority vote."""
    img = cv2.imread(image_path)
    img = auto_crop_black_borders(img)  # Optional, remove black borders if needed
    
    countries = []
    
    # Split the image into 4 slices and predict for each slice
    for i in range(4):
        view = equirectangular_to_perspective(img, fov=90, theta=i*90, height=480, width=600)
        cv2.imwrite('intermediate/current_geo_clip_view.jpg', view)
        
        # Predict the coordinates for the slice
        top_pred_gps, _ = model.predict('intermediate/current_geo_clip_view.jpg', top_k=3)
        
        # Get the country from the first prediction (or use your cache function)
        lat, lon = top_pred_gps[0].tolist()
        country = get_country_with_cache(lat, lon)  # using your cache function
        
        countries.append(country)
    
    # Majority vote (most frequent country)
    country_counts = Counter(countries)
    
    # Check if the most common country appears at least twice (i.e., 2 out of 4 slices)
    most_common_country, count = country_counts.most_common(1)[0]
    
    if count >= 2:  # Majority vote: country needs at least 2 votes
        majority_country = most_common_country
    else:
        majority_country = countries[0]  # No majority, keep the first slice's country
    
    return majority_country


In [None]:
tqdm.pandas() # to see a progress bar

df_train_cl['majority_country'] = df_train_cl['path'].progress_apply(process_image)