In [1]:
import os
 # Change to ProjectDirectory (One Directory above)
os.chdir('/Coding/Spectrum4Geo/')

import time
import torch

from dataclasses import dataclass
from geopy.distance import geodesic
from scipy.spatial import distance

from torch.utils.data import DataLoader
from spectrum4geo.dataset.soundingearth import SoundingEarthDatasetEval
from spectrum4geo.transforms import get_transforms_val_sat, get_transforms_val_spectro 
from spectrum4geo.model import TimmModel
from spectrum4geo.trainer import predict

from sklearn.metrics import DistanceMetric

import numpy as np
import pandas as pd
from tqdm import tqdm
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
@dataclass
class Configuration:
    
    # Model
    model: str = 'convnext_base.fb_in22k_ft_in1k_384'
    
    # Override model image size
    img_size: int = 384                  # for satallite images
    patch_time_steps: int = 1024*4       # Image size for spectrograms (Width)
    n_mels: int = 128                    # image size for spectrograms (Height)
    sr_kHz: float = 48
    
    # Evaluation
    batch_size_eval: int = 128
    verbose: bool = True
    gpu_ids: tuple =  (0,1,2,3)          # GPU ids for evaluating
    normalize_features: bool = True
    
    # Savepath for model eval logs
    model_path: str = './soundingearth/testing'

    # Dataset
    data_folder = 'data'        
    split_csv = 'test_df.csv' 

    # Checkpoint to start from
    checkpoint_start = 'soundingearth/training/convnext_base.fb_in22k_ft_in1k_384/145835/weights_end.pth'   
  
    # set num_workers to 0 if on Windows
    num_workers: int = 0 if os.name == 'nt' else 4 
    
    # train on GPU if available
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu' 

config = Configuration()

In [3]:
#-----------------------------------------------------------------------------#
# Model                                                                       #
#-----------------------------------------------------------------------------#
    
model_path = f'{config.model_path}/{config.model}/{time.strftime('%H%M%S')}'

if not os.path.exists(model_path):
    os.makedirs(model_path)


print(f'\nModel: {config.model}')

print(f'Used .csv file for evaluating: {config.split_csv}')

model = TimmModel(config.model,
                    pretrained=True,
                    img_size=config.img_size)
                        
data_config = model.get_config()
print(data_config)
mean = data_config['mean']
std = data_config['std']
img_size = config.img_size

img_size_sat = (img_size, img_size)
img_size_spectro = (config.patch_time_steps, config.n_mels)
    
# load pretrained Checkpoint    
if config.checkpoint_start is not None:  
    print('Start from:', config.checkpoint_start)
    model_state_dict = torch.load(config.checkpoint_start)  
    model.load_state_dict(model_state_dict, strict=False)     

# Data parallel
print('GPUs available:', torch.cuda.device_count())  
if torch.cuda.device_count() > 1 and len(config.gpu_ids) > 1:
    model = torch.nn.DataParallel(model, device_ids=config.gpu_ids)
        
# Model to device   
model = model.to(config.device)

print(f'\nSpectrogram details:\n'
        f'\tSample rate: {config.sr_kHz} kHz\n'
        f'\tn_mels: {config.n_mels}\n'
        f'\tPatch width (time steps): {config.patch_time_steps}')     

print('\nImage Size Sat:', img_size_sat)
print('Image Size Spectro:', img_size_spectro)
print(f'Mean: {mean}')
print(f'Std:  {std}\n') 

#-----------------------------------------------------------------------------#
# DataLoader                                                                  #
#-----------------------------------------------------------------------------#
    

# Eval
sat_transforms_val = get_transforms_val_sat(img_size_sat,
                                            mean=mean,
                                            std=std,
                                            )

spectro_transforms_val = get_transforms_val_spectro(mean=mean,       
                                                    std=std
                                                    )        

# Satalite Satellite Images
sat_dataset_test = SoundingEarthDatasetEval(data_folder=config.data_folder ,
                                    split_csv=config.split_csv, 
                                    query_type = 'sat',
                                    transforms=sat_transforms_val,
                                    patch_time_steps=config.patch_time_steps,
                                    sr_kHz=config.sr_kHz,
                                    n_mels=config.n_mels,
                                    )

sat_dataloader_test = DataLoader(sat_dataset_test,
                                    batch_size=config.batch_size_eval,
                                    num_workers=config.num_workers,
                                    shuffle=False,
                                    pin_memory=True)

# Spectrogram Ground Images Test
spectro_dataset_test = SoundingEarthDatasetEval(data_folder=config.data_folder ,
                                    split_csv=config.split_csv, 
                                    query_type='spectro',
                                    transforms=spectro_transforms_val,
                                    patch_time_steps=config.patch_time_steps,
                                    sr_kHz=config.sr_kHz,
                                    n_mels=config.n_mels,
                                    )

spectro_dataloader_test = DataLoader(spectro_dataset_test,
                                    batch_size=config.batch_size_eval,
                                    num_workers=config.num_workers,
                                    shuffle=False,
                                    pin_memory=True)

print('Satalite Images Test:', len(sat_dataset_test))
print('Spectrogram Images Test:', len(spectro_dataset_test))


Model: convnext_base.fb_in22k_ft_in1k_384
Used .csv file for evaluating: test_df.csv
{'input_size': (3, 384, 384), 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'crop_pct': 1.0, 'crop_mode': 'squash'}
Start from: soundingearth/training/convnext_base.fb_in22k_ft_in1k_384/145835/weights_end.pth
GPUs available: 4

Spectrogram details:
	Sample rate: 48 kHz
	n_mels: 128
	Patch width (time steps): 4096

Image Size Sat: (384, 384)
Image Size Spectro: (4096, 128)
Mean: (0.485, 0.456, 0.406)
Std:  (0.229, 0.224, 0.225)

Satalite Images Test: 10179
Spectrogram Images Test: 10179


In [4]:
reference_features, reference_labels, reference_chords = predict(config, model, sat_dataloader_test) 
query_features, query_labels, query_chords = predict(config, model, spectro_dataloader_test)

100%|██████████| 80/80 [00:57<00:00,  1.38it/s]
100%|██████████| 80/80 [01:18<00:00,  1.02it/s]


In [5]:
def calculate_label_ids_until_hit(query_features, reference_features, labels, step_size=1000):
    '''returns an dict with query item label IDs as keys. Each key maps to a list of label IDs, 
       sorted by descending probability, up until (but not including) the hit (query item label ID).'''
    Q = len(query_features)    
    steps = Q // step_size + 1
    labels_np = labels.cpu().numpy()
    ref2index = {idx: i for i, idx in enumerate(labels_np)}
    similarity = []
    
    for i in range(steps):
        start = step_size * i
        end = start + step_size
        sim_tmp = query_features[start:end] @ reference_features.T    
        similarity.append(sim_tmp.cpu())
     
    # matrix Q x R
    similarity = torch.cat(similarity, dim=0)
    label_ids_until_hit = {}
    bar = tqdm(range(Q), desc='Generate lists of label_ids until Hit')

    for i in bar:
        # similiarity value of gt reference
        gt_sim = similarity[i, ref2index[labels_np[i]]]
        # number of references with higher similiarity as gt
        higher_sim = (similarity[i, :] > gt_sim).numpy()
        # creating list of label_ids until hit
        hit_indices = np.where(higher_sim)[0]
        # sorting in descending order
        sorted_hit_indices = hit_indices[np.argsort(-similarity[i, hit_indices].numpy())]
        # label_ids_until_hit[label_id] = [], [label_id1], [label_id1, label_id2, ...]
        label_ids_until_hit[labels_np[i]] = labels_np[sorted_hit_indices].tolist()

    return label_ids_until_hit

In [6]:
labels = reference_labels

# copy meta_df, reduce meta_df and and set index to 'short_key' 
meta_df = copy.deepcopy(spectro_dataset_test.meta)
columns_to_drop = meta_df.columns.difference(['short_key', 'longitude', 'latitude', 'continent'])
meta_df = meta_df.drop(columns=columns_to_drop)
meta_df.set_index('short_key', inplace=True)

label_ids_until_hit = calculate_label_ids_until_hit(query_features, reference_features, labels, step_size=1000)

Generate lists of label_ids until Hit: 100%|██████████| 10179/10179 [00:01<00:00, 5096.24it/s]


In [7]:
def calculate_scores(label_ids_until_hit, metadata_df, recall_ranks, topk_recall=True, verbose=False):
    '''returns an tuple containing an dict of recall results with the ranks as keys, the median rank, the mean error distance
       and top_str, which is the string used as key for the Recall@~1%.'''
    count_until_hit = [len(value) for value in label_ids_until_hit.values()]
    id_count = len(label_ids_until_hit)
    topk = id_count//100
    topk_str = f'{topk}/{topk/id_count*100:0.2f}%'

    #### Set up headers for display
    if not verbose:
        header_format = ' | '.join(['{:<13}' for _ in recall_ranks]) 
        headers = [f'Recall@{rank}' for rank in recall_ranks]
        if topk_recall:
            header_format += ' | {:<16}'
            headers += [f'Recall@{topk_str}']
        header_format += ' | {:<13}' + ' | {:<24}'
        headers += ['Median Rank'] + ['Mean Error Distance [km]']
        header_formated = (header_format).format(*headers)

        print('Calculate Recalls, Median Rank and Mean Error Distance!')
        print(header_formated)
        print('-' * len(header_formated))

    #### Calculating Recalls
    recall_results = {rank: np.mean([int(count < rank) for count in count_until_hit]) * 100 for rank in recall_ranks}
    if topk_recall:
        recall_results[topk_str] = np.mean([int(count < topk) for count in count_until_hit]) * 100

    #### Calculating Median Rank
    median_rank = np.median(count_until_hit)

    #### Calculating Mean Error Distance
    coordinates = metadata_df.loc[:, ['latitude', 'longitude']]
    error_distances = []
    dist = DistanceMetric.get_metric('haversine')
    
    for true_label_id, wrong_label_ids in label_ids_until_hit.items():
        true_coords = coordinates.loc[true_label_id].to_numpy()
        if len(wrong_label_ids) > 0:
            wrong_coords = coordinates.loc[wrong_label_ids].to_numpy()
            # Calculate Haversine distances
            distances = dist.pairwise(np.radians([true_coords]), np.radians(wrong_coords)).flatten()
            error_distances.append(np.mean(distances))
        else:
            error_distances.append(0)
    
    mean_distance_error = np.mean(error_distances) * 6371  # Convert to kilometer

    #### Output the calculated metrics
    if not verbose:
        result_format = ' | '.join(['{:<13.4f}' for _ in recall_ranks]) 
        result_values = [recall_results[rank] for rank in recall_ranks] 
        if topk_recall:
            result_format += ' | {:<16.4f}'
            result_values += [recall_results[topk_str]]
        result_format += ' | {:<13.0f}' + ' | {:<24.4f}'
        result_values += [median_rank, mean_distance_error]
        print(result_format.format(*result_values))
        print()
    
    return recall_results, median_rank, mean_distance_error, topk_str

In [8]:
recall_results, median_rank, mean_distance_error, topk = calculate_scores(label_ids_until_hit, meta_df, recall_ranks=[1,5,10,50,100], topk_recall=True)
recall_results, median_rank, mean_distance_error, topk = calculate_scores(label_ids_until_hit, meta_df, recall_ranks=[1,5,10], topk_recall=False)

Calculate Recalls, Median Rank and Mean Error Distance!
Recall@1      | Recall@5      | Recall@10     | Recall@50     | Recall@100    | Recall@101/0.99% | Median Rank   | Mean Error Distance [km]
-------------------------------------------------------------------------------------------------------------------------------------------
19.8939       | 33.5200       | 40.4755       | 56.8327       | 64.1615       | 64.2794          | 25            | 2067.6814               

Calculate Recalls, Median Rank and Mean Error Distance!
Recall@1      | Recall@5      | Recall@10     | Median Rank   | Mean Error Distance [km]
----------------------------------------------------------------------------------------
19.8939       | 33.5200       | 40.4755       | 25            | 2067.6814               



Sollte sein:

Recall@1: 19.8939 - Recall@5: 33.5200 - Recall@10: 40.4755 - Recall@50: 56.8327 - Recall@100: 64.1615 - Recall@101/Recall@0.99: 64.2794
Median Rank: 25.0
Mean Distance: 2067.681 km

In [9]:
def calculate_scores_continentwise(label_ids_until_hit, metadata_df, recall_ranks=[1,5,10,50,100], topk_recall=True):
    '''returns an dict with continent as keys and tuples as values. 
       These tuples contain an dict of recall results with the ranks as keys, the median rank, the mean error distance
       and top_str, which is the string used as key for the Recall@~1%.'''
    continents = sorted(set(metadata_df['continent']))
    header_format = '{:<13} | ' + '{:<13} | ' + ' | '.join(['{:<13}' for _ in recall_ranks]) 
    headers = ['Continent', 'used Samples'] + [f'Recall@{rank}' for rank in recall_ranks]
    if topk_recall:
        header_format += ' | {:<13}' + ' | {:<8}' 
        headers += [f'Recall@~1% ->'] + [f'~1%']
    header_format += ' | {:<13}' + ' | {:<24}'
    headers += ['Median Rank'] + ['Mean Error Distance [km]']
    header_formated = header_format.format(*headers)

    print('Calculate Recalls, Median Rank and Mean Error Distance within Continents: ' + ', '.join(continents) + '!')
    print(header_formated)
    print('-' * len(header_formated))

    continent_scores = {}
    for continent in continents:
        allowed_label_ids = set(metadata_df[metadata_df['continent'] == continent].index)
        continent_label_ids_until_hit = {
            key: value for key, value in label_ids_until_hit.items() if key in allowed_label_ids
        }

        for key in continent_label_ids_until_hit:
            continent_label_ids_until_hit[key] = [
                label_id for label_id in continent_label_ids_until_hit[key] if label_id in allowed_label_ids
            ]
        
        # tuple containing an dict of recall results with the ranks as keys, the median rank, the mean error distanceand top_str
        continent_scores[continent] = calculate_scores(continent_label_ids_until_hit, metadata_df, recall_ranks, topk_recall, verbose=True)
        recall_results, median_rank, mean_distance_error, topk_str = continent_scores[continent]

        result_format = '{:<13} | ' + '{:<13.0f} | ' + ' | '.join(['{:<13.4f}' for _ in recall_ranks]) 
        result_values = [continent, len(allowed_label_ids)] + [recall_results[rank] for rank in recall_ranks] 
        if topk_recall:
            result_format += ' | {:<13.4f}' + ' | {:<8}' 
            result_values += [recall_results[topk_str], topk_str]
        result_format += ' | {:<13.0f}' + ' | {:<24.4f}'
        result_values += [median_rank, mean_distance_error]

        print(result_format.format(*result_values))

    print()

    return continent_scores


In [10]:
continent_scores = calculate_scores_continentwise(label_ids_until_hit, meta_df, recall_ranks=[1,5,10,50,100], topk_recall=True)

continent_scores = calculate_scores_continentwise(label_ids_until_hit, meta_df, recall_ranks=[1,5,10,50,100], topk_recall=False)

Calculate Recalls, Median Rank and Mean Error Distance for Continents: Africa, Asia, Australia, Europe, North America, Oceania, South America!
Continent     | used Samples  | Recall@1      | Recall@5      | Recall@10     | Recall@50     | Recall@100    | Recall@~1% -> | ~1%      | Median Rank   | Mean Error Distance [km]
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Africa        | 195           | 26.6667       | 51.2821       | 61.0256       | 84.1026       | 91.7949       | 26.6667       | 1/0.51%  | 4             | 1740.9821               
Asia          | 3104          | 17.2680       | 36.2758       | 46.9072       | 72.2938       | 80.8956       | 65.7539       | 31/1.00% | 12            | 378.2622                
Australia     | 157           | 30.5732       | 47.1338       | 62.4204       | 85.3503       | 94.9045       | 30.5732       | 1/0.64%  

In [15]:
def calculate_region_wise_recalls(label_ids_until_hit, metadata_df, calc_ranks=[1,5,10], print_ranks=[1,5,10]):
    '''returns an dict where each key is a rank (from calc_ranks) and the value is another dict.
       The nested dict has continents as keys and region-wise recall scores as values.'''
    continents = sorted(set(metadata_df['continent']))

    print('Calculate RegionWiseRecalls!')
    if print_ranks:
        header_format = '{:<15} | {:<14} | ' + ' | '.join(['{:<19}' for _ in print_ranks])
        headers = ['Continent', 'valid Samples'] + [f'RegionWiseRecall@{rank}' for rank in print_ranks]
        header_formated = (header_format).format(*headers)
        print(header_formated)
        print('-' * (len(header_formated)))

    region_wise_recalls = {}
    for continent in continents:
        allowed_label_ids = set(metadata_df[metadata_df['continent'] == continent].index)

        label_ids_until_continent_hit = {}
        for key, wrong_label_ids in label_ids_until_hit.items():
            if key in allowed_label_ids:
                continental_wrong_label_ids = next((wrong_label_ids[:i] for i, id in enumerate(wrong_label_ids) if id in allowed_label_ids), wrong_label_ids)
                label_ids_until_continent_hit[key] = continental_wrong_label_ids

        count_until_continent_hit = [len(value) for value in label_ids_until_continent_hit.values()]

        recall_results = {}
        for rank in calc_ranks:
            region_wise_recall = np.mean([int(count < rank) for count in count_until_continent_hit])*100
            recall_results[rank] = region_wise_recall

        region_wise_recalls[continent] = recall_results

        if print_ranks:
            result_format = '{:<15} | {:<14} | ' + ' | '.join(['{:<19.4f}' for _ in print_ranks])
            result_values = [continent, len(allowed_label_ids)] + [region_wise_recalls[continent].get(rank, 0.0) for rank in print_ranks]
            print(result_format.format(*result_values))
    print()
    
    return region_wise_recalls

In [16]:
region_wise_recalls = calculate_region_wise_recalls(label_ids_until_hit, meta_df, calc_ranks=[1,5,10,25], print_ranks=[1,5,10,25])


Calculate RegionWiseRecalls!
Continent       | valid Samples  | RegionWiseRecall@1  | RegionWiseRecall@5  | RegionWiseRecall@10 | RegionWiseRecall@25
------------------------------------------------------------------------------------------------------------------------
Africa          | 195            | 46.6667             | 62.5641             | 71.7949             | 82.5641            
Asia            | 3104           | 91.9137             | 94.5554             | 95.8441             | 97.4871            
Australia       | 157            | 45.2229             | 63.6943             | 69.4268             | 76.4331            
Europe          | 5634           | 89.9894             | 98.1186             | 98.9883             | 99.6273            
North America   | 880            | 56.7045             | 76.7045             | 84.3182             | 93.4091            
Oceania         | 45             | 42.2222             | 48.8889             | 57.7778             | 64.4444            
Sou

In [13]:
def calculate_blanced_continental_recalls(region_wise_recalls):
    '''returns an dict with continents as keys and balanced recall scores as values.'''
    print('Calculate BalancedContinentalRecalls!')

    # inverts region_wise_recalls dictionary from {continent: {rank: value}} to {rank: {continent: value}}
    ranks = sorted(set(rank for subdict in region_wise_recalls.values() for rank in subdict.keys()))
    region_wise_recalls = {rank: {continent: region_wise_recalls[continent].get(rank, None) for continent in region_wise_recalls} for rank in ranks}

    header_format = ' | '.join(['{:<29}' for _ in ranks])
    headers = [f'BalancedContinentalRecall@{rank}' for rank in ranks]
    header_formated = (header_format).format(*headers)
    print(header_formated)
    print('-' * (len(header_formated)))

    balanced_continental_recalls = {}
    for rank in ranks:
        recall_values = [value for value in region_wise_recalls[rank].values()]
        balanced_continental_recalls[rank] = np.mean(recall_values)

    result_format = ' | '.join(['{:<29.4f}' for _ in ranks])
    result_values = [balanced_continental_recalls[rank] for rank in ranks]
    print(result_format.format(*result_values))
    print()

    return balanced_continental_recalls

In [14]:
blanced_continental_recalls = calculate_blanced_continental_recalls(region_wise_recalls)

Calculate BalancedContinentalRecalls!
BalancedContinentalRecall@1   | BalancedContinentalRecall@5   | BalancedContinentalRecall@10  | BalancedContinentalRecall@25 
-----------------------------------------------------------------------------------------------------------------------------
58.9948                       | 71.9532                       | 77.9761                       | 84.1379                      

