In [49]:
# important for gpuhub
# !pip install -r ../../requirements.txt --upgrade

## Importing Libraries and tokens

In [50]:
import os

import wandb
import torch

# load .env file
from dotenv import load_dotenv

from geo_model_tester import GeoModelTester
from image_data_handler_test import TestImageDataHandler
from best_run_loader import BestRunLoader
from wandb_downloader import WandbDownloader

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [51]:
WANDB_TOKEN = os.getenv("WANDB_TOKEN")
# Define where to run
env_path = "../../.env"
if not WANDB_TOKEN and os.path.exists(env_path):
    load_dotenv(env_path)
    WANDB_TOKEN = os.getenv("WANDB_TOKEN")

In [52]:
# Check if GPU is available
if torch.cuda.is_available():
    print("GPU is available.")

    # Print the name of the GPU
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")

    # Print the total and available memory
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9  # Convert bytes to GB
    print(f"Total Memory: {total_memory:.2f} GB")

    allocated_memory = torch.cuda.memory_allocated(0) / 1e9  # Convert bytes to GB
    print(f"Allocated Memory: {allocated_memory:.2f} GB")

    cached_memory = torch.cuda.memory_reserved(0) / 1e9  # Convert bytes to GB
    print(f"Cached Memory: {cached_memory:.2f} GB")

    # Print other properties
    device_properties = torch.cuda.get_device_properties(0)
    print(f"CUDA Capability: {device_properties.major}.{device_properties.minor}")
    print(f"Multi-Processor Count: {device_properties.multi_processor_count}")
else:
    print("No GPU found.")

No GPU found.


## Loading files from wandb

In [53]:
wandb.login(key=WANDB_TOKEN) if WANDB_TOKEN else wandb.login()

True

In [54]:
# Setting for the right models to test
entity = "nlp_ls"
predict_coordinates = False
predict_regions = False
datasize = 332786  # 79000, 81505, 332786
data_augmentation = "full_augmentation_v2"  # or "base_augmentation", "full_augmentation_v2"

# Automatic settings
project = "dspro2-predicting-region" if predict_regions else ("dspro2-predicting-coordinates" if predict_coordinates else "dspro2-predicting-country")
metric_name = "Best Validation Accuracy Top 1" if not predict_coordinates else "Best Validation Distance (km)"
metric_ascending = False if not predict_coordinates else True
file_names_to_download = [".pth", ".json"]
image_size = [180, 320] if datasize == 79000 else [80, 130]

downloader = WandbDownloader(entity, project, data_augmentation, datasize, image_size)
run_data = downloader.get_and_collect_best_runs(metric_name, file_names_to_download, metric_ascending=metric_ascending)

dspro2-predicting-country: Found 5 matching runs for datasize 332786 and full_augmentation_v2.


In [55]:
# Print the validation accuracy for the top 1, 3, and 5 predictions
for j in range(1, min(len(run_data), 6)):
    print(f"\nBest Run {j}: {run_data[f'Best Run {j}']['id']}")
    if predict_coordinates:
        print(f"Best Validation Distance (km): ", run_data[f"Best Run {j}"]["metrics"]["Best Validation Distance (km)"])
    else:
        for i in [1, 3, 5]:
            try:
                print(f"Best Validation Accuracy Top {i}: ", run_data[f"Best Run {j}"]["metrics"][f"Best Validation Accuracy Top {i}"])
            except KeyError:
                print("No validation accuracy found.")


Best Run 1: bc69qzqh
Best Validation Accuracy Top 1:  0.6489625433838665
Best Validation Accuracy Top 3:  0.8262391634238322
Best Validation Accuracy Top 5:  0.8834382559310066

Best Run 2: kth52fnv
Best Validation Accuracy Top 1:  0.6403684060279159
Best Validation Accuracy Top 3:  0.8207401174932765
Best Validation Accuracy Top 5:  0.8790510389590876

Best Run 3: 8pqe6jmh
Best Validation Accuracy Top 1:  0.6072839821506378
Best Validation Accuracy Top 3:  0.8052646603663026
Best Validation Accuracy Top 5:  0.8728608561082981

Best Run 4: elpusqol
Best Validation Accuracy Top 1:  0.605360818546509
Best Validation Accuracy Top 3:  0.8047237706026413
Best Validation Accuracy Top 5:  0.8710578902294274


In [56]:
run_data["Best Run 1"]["parameters"]

{'seed': 42,
 'epochs': 25,
 'optimizer': 'adamW',
 'batch_size': 200,
 'model_name': 'efficientnet_b1',
 'mapped_data': False,
 'dataset_size': 332786,
 'weight_decay': 0.01,
 'learning_rate': 0.01,
 'predict_regions': False,
 'input_image_size': [80, 130],
 'data_augmentation': 'full_augmentation_v2',
 'different_regions': 4596,
 'dataset_identifier': '22a493044dbe99c1d431b9ee4656792efbb09ece4182274670ba5faec505d9cf',
 'different_countries': 138,
 'predict_coordinates': False}

In [57]:
run_data["Best Run 1"]["files"]

{'wandb_manifest.json': 'https://api.wandb.ai/files/nlp_ls/dspro2-predicting-country/bc69qzqh/artifact/931421173/wandb_manifest.json',
 'best_model': 'https://api.wandb.ai/files/nlp_ls/dspro2-predicting-country/bc69qzqh/best_model_checkpointmodel_efficientnet_b1_lr_0.01_opt_adamW_weightDecay_0.01_imgSize_[80, 130]_predict_coordinates_False.pth',
 'country_to_index.json': 'https://api.wandb.ai/files/nlp_ls/dspro2-predicting-country/bc69qzqh/run-20240627_133143-bc69qzqh/country_to_index.json',
 'region_index_to_country_index.json': 'https://api.wandb.ai/files/nlp_ls/dspro2-predicting-country/bc69qzqh/run-20240627_133143-bc69qzqh/region_index_to_country_index.json',
 'region_index_to_middle_point.json': 'https://api.wandb.ai/files/nlp_ls/dspro2-predicting-country/bc69qzqh/run-20240627_133143-bc69qzqh/region_index_to_middle_point.json',
 'region_to_index.json': 'https://api.wandb.ai/files/nlp_ls/dspro2-predicting-country/bc69qzqh/run-20240627_133143-bc69qzqh/region_to_index.json',
 'test_d

## Loading data and creating data loader

In [58]:
cache = True

run = None

for i in range(min(len(run_data), 5)):
    run = run_data[f"Best Run {i+1}"]
    if run["files"].get("test_data", None) and run["files"].get("best_model", None):
        break
    else:
        run = None
        print(f"Run {i+1} does not contain the necessary files. Trying the next run...")

if run is None:
    raise Exception("No run with the necessary files found.")

# Creating Dataloaders with the classes
test_dataset = run["files"]["test_data"]
files = run["files"]
country_to_index = files.get("country_to_index.json", None)
region_to_index = files.get("region_to_index.json", None)
region_index_to_middle_point = files.get("region_index_to_middle_point.json", None)
region_index_to_country_index = files.get("region_index_to_country_index.json", None)

data_handler = TestImageDataHandler(test_dataset, country_to_index, region_to_index, region_index_to_middle_point, region_index_to_country_index, cache=cache)
test_dataloader = data_handler.test_loader
country_to_index = data_handler.country_to_index
region_to_index = data_handler.region_to_index
region_index_to_middle_point = data_handler.region_index_to_middle_point
region_index_to_country_index = data_handler.region_index_to_country_index

num_regions = data_handler.num_regions
num_countries = data_handler.num_countries

Loaded 138 countries.
Loaded 4596 regions.
Loaded 4596 region middle points.
Loaded 3595 region to country index mappings.
Loading test data from test_data.pth


Downloading...: 100%|██████████| 4.17G/4.17G [06:47<00:00, 10.2MB/s] 


Test data loaded.
Caching test data to run-20240627_133143-bc69qzqh/test_data.pth
Test data cached.


## Evaluating the model

In [59]:
num_classes = 3 if predict_coordinates else (num_regions if predict_regions else num_countries)

if num_classes == 0:
    raise ValueError("No classes detected. Please check the data.")

geo_model_tester = GeoModelTester(test_dataloader=test_dataloader, num_classes=num_classes, predict_coordinates=predict_coordinates, country_to_index=country_to_index, region_to_index=region_to_index, region_index_to_middle_point=region_index_to_middle_point, region_index_to_country_index=region_index_to_country_index, predict_regions=predict_regions if not predict_coordinates else None)

In [60]:
# TODO: Test the model from best runs
# TODO: Show the different models with the best results (also do it for different data sizes and mapped/non-mapped data)
model_name = run["parameters"]["model_name"]
pretrained_weights = run["files"]["best_model"]

# Countries from 81k more mapped dataset, keep in sync with evaluating_models.ipynb
countries_only = ["Albania", "Argentina", "Australia", "Austria", "Bangladesh", "Belgium", "Bolivia, Plurinational State of", "Botswana", "Brazil", "Bulgaria", "Cambodia", "Canada", "Chile", "Colombia", "Croatia", "Czechia", "Denmark", "Dominican Republic", "Ecuador", "Estonia", "Eswatini", "Finland", "France", "Germany", "Ghana", "Greece", "Guatemala", "Hungary", "India", "Indonesia", "Ireland", "Israel", "Italy", "Japan", "Kenya", "Korea, Republic of", "Kyrgyzstan", "Lao People's Democratic Republic", "Latvia", "Lesotho", "Lithuania", "Malaysia", "Malta", "Mexico", "Montenegro", "Netherlands", "New Zealand", "Nigeria", "North Macedonia", "Norway", "Peru", "Philippines", "Poland", "Portugal", "Romania", "Russian Federation", "Rwanda", "Senegal", "Serbia", "Singapore", "Slovakia", "Slovenia", "South Africa", "Spain", "Sri Lanka", "Sweden", "Switzerland", "Thailand", "T\u00fcrkiye", "Uganda", "Ukraine", "United Arab Emirates", "United Kingdom", "United States", "Uruguay"]

geo_model_tester.test(model_type=model_name, model_path=pretrained_weights, balanced_on_countries_only=countries_only, accuracy_per_country=False)

if not predict_coordinates:
    # And over all countries
    geo_model_tester.test(model_type=model_name, model_path=pretrained_weights, balanced_on_countries_only=None, accuracy_per_country=True)



Network-Model: efficientnet_b1
Project Name: country
Run ID: bc69qzqh
Test Loss: 2.0048, Test Top 1 Accuracy: 0.6496, Test Top 3 Accuracy: 0.8300, Test Top 5 Accuracy: 0.8839
Test Top 1 Balanced Accuracy: 0.5104
Network-Model: efficientnet_b1
Project Name: country
Run ID: bc69qzqh
Test Loss: 2.0048, Test Top 1 Accuracy: 0.6496, Test Top 3 Accuracy: 0.8300, Test Top 5 Accuracy: 0.8839
Test Top 1 Balanced Accuracy: 0.3973
Accuracy per country:
Country India: 0.91007
Country Japan: 0.84041
Country Faroe Islands: 0.83333
Country Qatar: 0.83333
Country United States: 0.83326
Country Bangladesh: 0.80982
Country Guatemala: 0.80208
Country Indonesia: 0.78987
Country Rwanda: 0.78947
Country United Kingdom: 0.78022
Country Senegal: 0.76471
Country Nigeria: 0.76404
Country Ghana: 0.74684
Country Kenya: 0.73958
Country Russian Federation: 0.73866
Country Kyrgyzstan: 0.72414
Country Lesotho: 0.72034
Country Australia: 0.70933
Country Finland: 0.70253
Country Germany: 0.68736
Country Korea, Republic