## Importing Libraries and tokens

In [1]:
import os
import sys
import time
import threading
import signal
import queue

import wandb
import torch
from torchvision import transforms
from mac_notifications import client
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler

# load .env file
from dotenv import load_dotenv

sys.path.insert(0, "../6_deployment")
from geo_model_deployer import GeoModelDeployer
from image_data_handler_deploy import DeployImageDataHandler

sys.path.insert(0, "../5_evaluation")
from wandb_downloader import WandbDownloader

sys.path.insert(0, "../")
from data_loader import resolve_env_variable

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [2]:
WANDB_TOKEN = os.getenv("WANDB_TOKEN")
# Define where to run
env_path = "../../.env"
if not WANDB_TOKEN and os.path.exists(env_path):
    load_dotenv(env_path)
    WANDB_TOKEN = os.getenv("WANDB_TOKEN")

In [3]:
# Check if GPU is available
if torch.cuda.is_available():
    print("GPU is available.")

    # Print the name of the GPU
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")

    # Print the total and available memory
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9  # Convert bytes to GB
    print(f"Total Memory: {total_memory:.2f} GB")

    allocated_memory = torch.cuda.memory_allocated(0) / 1e9  # Convert bytes to GB
    print(f"Allocated Memory: {allocated_memory:.2f} GB")

    cached_memory = torch.cuda.memory_reserved(0) / 1e9  # Convert bytes to GB
    print(f"Cached Memory: {cached_memory:.2f} GB")

    # Print other properties
    device_properties = torch.cuda.get_device_properties(0)
    print(f"CUDA Capability: {device_properties.major}.{device_properties.minor}")
    print(f"Multi-Processor Count: {device_properties.multi_processor_count}")
else:
    print("No GPU found.")

No GPU found.


## Loading files from wandb

In [4]:
wandb.login(key=WANDB_TOKEN) if WANDB_TOKEN else wandb.login()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkillusions[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
entity = resolve_env_variable("nlp_ls", "WANDB_ENTITY", True)  # Please provide your own entity
if entity == "nlp_ls":
    print("Please provide your own wandb entity if you are not part of our project, add WANDB_ENTITY to your .env.")

predict_coordinates = False
predict_regions = False
run_id = None  # Only get specific run if needed
project = "dspro2-predicting-region" if predict_regions else ("dspro2-predicting-coordinates" if predict_coordinates else "dspro2-predicting-country")
metric_name = "Best Validation Accuracy Top 1" if not predict_coordinates else "Best Validation Distance (km)"
data_augmentation = "base_augmentation"  # or "full_augmentation_v2"
datasize = 332786  # Replace with the desired datasize
file_names_to_download = [".pth", ".json"]
image_size = [80, 130]

downloader = WandbDownloader(entity, project, data_augmentation, datasize, image_size)
try:
    run_data = downloader.get_and_collect_best_runs(metric_name, file_names_to_download, run_id=run_id)
except Exception as e:
    if entity == "nlp_ls":
        print("Using our wandb entity publicly is not supported, please either provide your own entity, download the files manually or correctly authenticate.")
        raise ConnectionError("Using our wandb entity publicly is not supported, please either provide your own entity (add WANDB_ENTITY to your .env), download the files manually or correctly authenticate.")
    raise e

print(run_data["Best Run 1"]["parameters"])

Please provide your own wandb entity if you are not part of our project, add WANDB_ENTITY to your .env.
dspro2-predicting-country: Found 3 matching runs for datasize 332786 and base_augmentation.
{'seed': 42, 'epochs': 50, 'optimizer': 'adamW', 'batch_size': 400, 'model_name': 'efficientnet_b1', 'mapped_data': False, 'dataset_size': 332786, 'weight_decay': 0.01, 'learning_rate': 0.01, 'predict_regions': False, 'input_image_size': [80, 130], 'data_augmentation': 'base_augmentation', 'different_regions': 4596, 'dataset_identifier': '22a493044dbe99c1d431b9ee4656792efbb09ece4182274670ba5faec505d9cf', 'different_countries': 138, 'predict_coordinates': False}


## Loading data and creating data loader

In [6]:
run = None

for i in range(min(len(run_data.keys()), 5)):
    run = run_data[f"Best Run {i+1}"]
    if run["files"].get("best_model", None):
        break
    else:
        run = None
        print(f"Run {i+1} does not contain the necessary files. Trying the next run...")

if run is None:
    raise Exception("No run with the necessary files found.")

augmented_transform = None  # Never used for test data
base_transform = transforms.Compose([transforms.Resize((image_size[0], image_size[1])), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Creating Dataloaders with the classes
files = run["files"]
country_to_index = files.get("country_to_index.json", None)
region_to_index = files.get("region_to_index.json", None)
region_index_to_middle_point = files.get("region_index_to_middle_point.json", None)
region_index_to_country_index = files.get("region_index_to_country_index.json", None)

data_handler = DeployImageDataHandler(country_to_index, region_to_index, region_index_to_middle_point, region_index_to_country_index, base_transform, join_to_current_dir="../7_demo")
country_to_index = data_handler.country_to_index
region_to_index = data_handler.region_to_index
region_index_to_middle_point = data_handler.region_index_to_middle_point
region_index_to_country_index = data_handler.region_index_to_country_index

num_regions = data_handler.num_regions
num_countries = data_handler.num_countries

Loaded 138 countries.
Loaded 4596 regions.
Loaded 4596 region middle points.
Loaded 3595 region to country index mappings.


## Demo the model

In [7]:
num_classes = 3 if predict_coordinates else (num_regions if predict_regions else num_countries)

if num_classes == 0:
    raise ValueError("No classes detected. Please check the data.")

geo_model = GeoModelDeployer(num_classes=num_classes, predict_coordinates=predict_coordinates, country_to_index=country_to_index, region_to_index=region_to_index, region_index_to_middle_point=region_index_to_middle_point, region_index_to_country_index=region_index_to_country_index, predict_regions=predict_regions if not predict_coordinates else None)

In [8]:
model_name = run["parameters"]["model_name"]
pretrained_weights = run["files"]["best_model"]

geo_model.prepare(model_type=model_name, model_path=pretrained_weights)

In [9]:
shutdown_event = threading.Event()
notification_queue = queue.Queue()


def setup_signal_handling():
    def signal_handler(_1, _2):
        print("Shutdown signal received...")
        shutdown_event.set()

    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)


def file_event_handler(event):
    if not event.is_directory and event.event_type == "created":
        # Wait for a second to ensure the file is not still being written
        time.sleep(1)
        threading.Thread(target=predict_from_file, args=(event.src_path,)).start()


def monitor_directory(path):
    observer = Observer()
    handler = FileSystemEventHandler()
    handler.on_created = file_event_handler
    observer.schedule(handler, path, recursive=False)
    observer.start()
    try:
        while not shutdown_event.is_set():
            time.sleep(1)
    finally:
        observer.stop()
        observer.join()
        print("File observer has been stopped.")


def notify_or_print(message):
    if not shutdown_event.is_set():
        notification_queue.put(message)


def handle_notifications():
    while not shutdown_event.is_set() or not notification_queue.empty():
        try:
            messages = notification_queue.get(timeout=1)
            try:
                if isinstance(messages, str):
                    message = messages
                else:
                    message = "\n".join(messages)
                client.create_notification(title="Geoguessr location found", subtitle=message, snooze_button_str="Hide")
            except:
                print()
                if isinstance(messages, str):
                    print(message)
                else:
                    for message in messages:
                        print(message)
        except queue.Empty:
            continue


def predict_from_file(image_file_path):
    try:

        if predict_coordinates:
            coordinates, cartesian = geo_model.predict_single(data_handler.load_single_image(image_file_path), top_n=5)

            notify_or_print(f"Predicted Coordinates: {coordinates} (Cartesian: {cartesian})")
        else:
            if predict_regions:
                regions, region_indices, region_probabilities, countries, country_indices, country_probabilities, corresponding_countries, corresponding_country_indices = geo_model.predict_single(data_handler.load_single_image(image_file_path), top_n=3)

                # Print the top 3 regions
                messages = []
                for i, (region, _1, region_probability, corresponding_country, _2) in enumerate(zip(regions, region_indices, region_probabilities, corresponding_countries, corresponding_country_indices)):
                    messages.append(f"Region {i+1}: {region} with prob.: {region_probability*100:.3f}, in: {corresponding_country}")
                notify_or_print(messages)
            else:
                countries, country_indices, country_probabilities = geo_model.predict_single(data_handler.load_single_image(image_file_path), top_n=3)

            # Print the top 3 countries
            messages = []
            for i, (country, _, country_probability) in enumerate(zip(countries, country_indices, country_probabilities)):
                messages.append(f"Country {i+1}: {country} with prob.: {country_probability*100:.2f}")
            notify_or_print(messages)

    except Exception as e:
        print(f"An error occurred while predicting the location of the image: {e}")

In [10]:
# Watch this directory for new files
path = "./.data"

path = data_handler.path_from_current_dir(path)

# Create the directory if it does not exist
if not os.path.exists(path):
    os.makedirs(path)

print(f"Monitoring directory: {os.path.basename(path)}")
threading.Thread(target=monitor_directory, args=(path,)).start()
handle_notifications()

Monitoring directory: .data
