In [None]:
import sys
import os
from pathlib import Path

path = Path(os.getcwd())
parent_directory = Path.joinpath(path.parent,'src').as_posix()
sys.path.append(str(parent_directory))

import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, Dataset

import numpy as np
import pandas as pd
from datetime import datetime
import seaborn as sns
import rasterio
import itertools
from pathlib import Path

from typing import Dict
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import utils as ut
from models import get_model

np.random.seed(42)
torch.manual_seed(42)

#### RankerDataset, Resuable Tools

In [None]:
def get_date(image_date):
    year = image_date[:4]
    month = image_date[4:6]
    day = image_date[6:]
    ml_date = datetime.strptime(f"{year}-{month}-{day}", "%Y-%m-%d").date()
    return ml_date

def get_day_of_week(date):
    date = datetime.strptime(date, "%Y-%m-%d")
    day = date.strftime("%A")
    day = day.lower()
    return day

def ranker_inference(pairs_df, anchor_date):
    unique_anchor_dates = pairs_df[f'{anchor_date}'].unique()
    date_pred_scores = {}
    date_pred_scores['date'] = []
    date_pred_scores['probs'] = []
    date_pred_scores['raw_rank'] = []
    date_pred_scores['rank_scaled'] = []        
    for date in unique_anchor_dates:
        date_df = pairs_df[pairs_df[f'{anchor_date}'] == date]
        count_1 = date_df[date_df['pred'] == 1].shape[0]
        date_pred_scores['date'].append(date)
        date_pred_scores['probs'].append(date_df['probs'].mean())
        date_pred_scores['raw_rank'].append(count_1)
        date_pred_scores['rank_scaled'].append(count_1 / date_df.shape[0])

    overall_date_pred_scores_df = pd.DataFrame(date_pred_scores)
    overall_date_pred_scores_df = overall_date_pred_scores_df.sort_values('date')
    
    overall_date_pred_scores_df['rank'] = overall_date_pred_scores_df['raw_rank'].rank(ascending=True, method="dense").astype(int)
    return overall_date_pred_scores_df

In [None]:
class PlanetRankerInfDataset(Dataset):
    def __init__(self, im_data, transform=None, preprocess_type='none', building_polygon_path=None):
        if isinstance(im_data, str):
            image_dataframe = pd.read_csv(im_data)
        elif isinstance(im_data, pd.DataFrame):
            image_dataframe = im_data.reset_index(drop=True)

        self.image_dataframe = image_dataframe
        self.transform = transform
        self.preprocess_type = preprocess_type
        self.image_cache = {}
        self.building_polygon_path = building_polygon_path

    def __len__(self):
        return len(self.image_dataframe)

    def stacking_rgb_images(self, image):
        if image.shape[0] == 8:
            red = image[5, :, :] 
            green = image[3, :, :]
            blue = image[1, :, :]
        elif image.shape[0] == 4:
            red = image[2, :, :]
            green = image[1, :, :]
            blue = image[0, :, :]

        return np.dstack((red, green, blue)).astype('float32')

    def __getitem__(self, index):
        row = self.image_dataframe.iloc[index]
        anchor_image_image_path = row['anchor_image']
        anchor_image_pair_image_path = row['anchor_image_pair']

        with rasterio.open(anchor_image_image_path, 'r') as src:
            anchor_img_data = src.read()
            
        with rasterio.open(anchor_image_pair_image_path, 'r') as src:
            anchor_pair_img_data = src.read()

        anchor_img_data = ut.convert_to_rgb(anchor_img_data)
        anchor_pair_img_data = ut.convert_to_rgb(anchor_pair_img_data)

        if self.transform:
            anchor_img_data_tensor = self.transform(anchor_img_data)
            anchor_pair_img_data_tensor = self.transform(anchor_pair_img_data)
        else:
            anchor_img_data_tensor = torch.as_tensor(anchor_img_data)
            anchor_pair_img_data_tensor = torch.as_tensor(anchor_pair_img_data)

        return anchor_img_data_tensor, anchor_pair_img_data_tensor

def get_inference_dataloader(test_data_path: str, test_augmentations: Dict, preprocess_type: str, batch_size: int, building_polygon_path: str = None):
    test_transform = ut.get_transforms(test_augmentations)

    if isinstance(test_data_path, str):
        test_df = pd.read_csv(test_data_path)
    elif isinstance(test_data_path, pd.DataFrame):
        test_df = test_data_path.copy()
    else:
        raise ValueError("test_data_path should be either a string or a dataframe")

    test_dataset = PlanetRankerInfDataset(test_df, transform=test_transform, preprocess_type=preprocess_type, building_polygon_path=building_polygon_path)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    return test_loader


def load_model(model_config, model_weight):
    model = get_model(model_config)
    print(f"Loading model from {model_weight}")
    model.load_state_dict(torch.load(model_weight))
    return model


#### Sudan Use Case

#### Assign pre and post war labels based on days observed by 
Guo, Zhe, et al. "Monitoring indicators of economic activities in Sudan amidst ongoing conflict using satellite data." Defence and Peace Economics 35.8 (2024): 992-1008.

In [None]:
all_sudan_df = pd.read_csv("C:/Users/Theo/Documents/Projects/small_things/input/sudan/polygon_images/sudan_clipped/all_images.csv")
all_sudan_df['date'] = pd.to_datetime(all_sudan_df['date']).dt.date
pre_war_all_sudan_df = all_sudan_df[(all_sudan_df['date'] >= datetime.strptime('2023-04-01', "%Y-%m-%d").date()) & (all_sudan_df['date'] <= datetime.strptime('2023-04-07', "%Y-%m-%d").date())]
post_war_all_sudan_df = all_sudan_df[(all_sudan_df['date'] >= datetime.strptime('2023-04-14', "%Y-%m-%d").date()) & (all_sudan_df['date'] <= datetime.strptime('2023-04-21', "%Y-%m-%d").date())]

pre_war_all_sudan_df['period'] = 'pre_war'
post_war_all_sudan_df['period'] = 'post_war'

all_sudan_df = pd.concat([pre_war_all_sudan_df, post_war_all_sudan_df])#, current_all_sudan_df])
all_sudan_df = all_sudan_df.reset_index(drop=True)
all_sudan_df['date'] = pd.to_datetime(all_sudan_df['date']).dt.date.astype(str)
all_sudan_df['day_of_week'] = all_sudan_df['date'].apply(get_day_of_week)

#to be removed
all_sudan_df['image_path'] = all_sudan_df['image_path'].apply(lambda x: x.replace("../../", "C:/Users/Theo/Documents/Projects/small_things/"))

all_sudan_df["date"] = pd.to_datetime(all_sudan_df["date"])
all_sudan_df["day_of_week"] = all_sudan_df["date"].dt.day_name()
day_ranked_df_counts = all_sudan_df.groupby(["period", "day_of_week"])["date"].count().unstack(fill_value=0)

all_sudan_df.head()

In [None]:
#Select only days that are presnt in both pre and post war data
all_sudan_df = all_sudan_df[all_sudan_df['day_of_week'].isin(['Saturday', 'Sunday', 'Monday','Tuesday', 'Wednesday'])]

In [None]:
config = ut.load_config("C:/Users/Theo/Documents/Projects/small_things/experiments/20250215-131800_patch_pairwiseranker/inference_config.yaml")
model = load_model(config['model_config'], f"C:/Users/Theo/Documents/Projects/small_things/{config['model_config']['model_weight']}")

all_sudan_df = all_sudan_df.groupby(['date','cluster']).agg({'image_path':'first'}).reset_index()
print(all_sudan_df.shape)
all_sudan_df['date'] = pd.to_datetime(all_sudan_df['date']).dt.date.astype(str)
all_sudan_df['day_of_week'] = all_sudan_df['date'].apply(lambda x: get_day_of_week(x))
# day_of_week = 'wednesday'
# all_sudan_df_saturday = all_sudan_df[all_sudan_df['day_of_week'] == day_of_week]

pairs = list(itertools.product(all_sudan_df['image_path'], repeat=2))

covid_anchor_images = [pair for pair in pairs if pair[0] != pair[1]]

pairs_sudan_df = pd.DataFrame(covid_anchor_images, columns=['anchor_image', 'anchor_image_pair'])

pairs_sudan_df['anchor_date'] = pairs_sudan_df.anchor_image.apply(lambda x: get_date(Path(x).stem.split("_")[0]))
pairs_sudan_df['anchor_day'] = pairs_sudan_df['anchor_date'].apply(lambda x: x.strftime("%A")).astype(str)
pairs_sudan_df['anchor_date'] = pairs_sudan_df['anchor_date'].astype(str)
pairs_sudan_df['anchor_pair_date'] = pairs_sudan_df.anchor_image_pair.apply(lambda x: get_date(Path(x).stem.split("_")[0]))
pairs_sudan_df['anchor_pair_date'] = pairs_sudan_df['anchor_pair_date'].astype(str)

test_loader = get_inference_dataloader(pairs_sudan_df, config['data_config']['test_data']['augmentations'], 'rgb', 1)
pair_wise_test_dict = {}
pair_wise_test_dict['probs'] = []

for anchor_image, anchor_image_pair in test_loader:
    model.eval()
    with torch.no_grad():
        output = model(anchor_image, anchor_image_pair)
        pair_wise_test_dict['probs'].extend(output.cpu().numpy()[0])

pairs_sudan_df['probs'] = pair_wise_test_dict['probs']
pairs_sudan_df['pred'] = np.where(pairs_sudan_df['probs'] <= 0.5, 0, 1)

In [None]:
ranked_sudan_df = ranker_inference(pairs_sudan_df,'anchor_date')
ranked_sudan_df.reset_index(drop=True, inplace=True)
ranked_sudan_df

In [None]:
ranked_sudan_df_agg = ranked_sudan_df.copy()
ranked_sudan_df_agg = ranked_sudan_df_agg.groupby(['date']).agg({'raw_rank': 'mean'}).reset_index()
ranked_sudan_df_agg['rank'] = ranked_sudan_df_agg['raw_rank'].rank(ascending=True).astype(int)
ranked_sudan_df_agg.sort_values('date', inplace=True)
ranked_sudan_df_agg['day_of_week'] = ranked_sudan_df_agg['date'].apply(lambda x: get_day_of_week(x))
ranked_sudan_df_agg['date'] = pd.to_datetime(ranked_sudan_df_agg['date']).dt.date
ranked_sudan_df_agg['period'] = ranked_sudan_df_agg['date'].apply(lambda x: 'pre_war' if x <= datetime.strptime('2023-04-14', "%Y-%m-%d").date() else 'post_war')
ranked_sudan_df_agg['ID'] = ranked_sudan_df_agg['period'] + '_' + ranked_sudan_df_agg['day_of_week'] + "_" + ranked_sudan_df_agg['date'].astype(str)
plt.figure(figsize=(15, 10))
plt.bar(ranked_sudan_df_agg['ID'], ranked_sudan_df_agg['rank']-1)
plt.xticks(rotation=75)
plt.grid(True)
plt.title(f'Image ranking based of all parking lots on dates: Sudan Bus Terminal')
plt.tight_layout()
plt.ylabel('Number of images beaten')
plt.xlabel('Date')
plt.show()

In [None]:
ranked_sudan_df_agg = ranked_sudan_df.copy()
ranked_sudan_df_agg = ranked_sudan_df_agg.groupby(['date']).agg({'raw_rank': 'mean'}).reset_index()
ranked_sudan_df_agg['rank'] = ranked_sudan_df_agg['raw_rank'].rank(ascending=True).astype(int)
ranked_sudan_df_agg.sort_values('date', inplace=True)
ranked_sudan_df_agg['day_of_week'] = ranked_sudan_df_agg['date'].apply(lambda x: get_day_of_week(x))
day_order = ['Monday', 'Tuesday', 'Wednesday', 'Saturday', 'Sunday']
ranked_sudan_df_agg['day_of_week'] = pd.Categorical(
    ranked_sudan_df_agg['day_of_week'].str.capitalize(),
    categories=day_order,
    ordered=True
)

ranked_sudan_df_agg['date'] = pd.to_datetime(ranked_sudan_df_agg['date']).dt.date
ranked_sudan_df_agg['period'] = ranked_sudan_df_agg['date'].apply(lambda x: 'pre_war' if x <= datetime.strptime('2023-04-14', "%Y-%m-%d").date() else 'post_war')
#ranked_sudan_df_agg['ID'] = ranked_sudan_df_agg['period'] + '_' + ranked_sudan_df_agg['day_of_week'] + "_" + ranked_sudan_df_agg['date'].astype(str)
#ranked_sudan_df_agg['day_of_week'] = ranked_sudan_df_agg['day_of_week'].str.capitalize()
plt.figure(figsize=(15, 10))
sns.barplot(data=ranked_sudan_df_agg, x='day_of_week', y='rank', hue='period', order=day_order)
plt.legend(title='Period', loc='upper left', fontsize=20)
plt.tick_params(axis='both', which='major', labelsize=20)
plt.xticks(rotation=75)
plt.tight_layout()
plt.ylabel('Image ranking', fontsize=20)
plt.xlabel('Day of week', fontsize=20)
plt.show()