In [12]:
from PIL import Image, UnidentifiedImageError
import os
import re
import sys
import concurrent
import torch
import json
import requests
from torchvision import transforms
from image_data_handler_test import TestImageDataHandler
from geo_model_tester import GeoModelTester
from wandb_downloader import WandbDownloader
from custom_image_dataset_test import CustomImageDatasetTest

sys.path.insert(0, '../')
from data_loader import resolve_env_variable, load_image_file_raw, get_image_files

sys.path.insert(0, './4_modeling')
from region_handler import RegionHandler

In [13]:
test_data_path = "C:\\Users\\yutar\\Documents\\HSLU\\Paper_data\\yfcc4k"
print("Test data path: ", test_data_path)
image_path = os.path.join(test_data_path, "image_processed")
json_path = os.path.join(test_data_path, "jsons")

Test data path:  C:\Users\yutar\Documents\HSLU\Paper_data\yfcc4k


In [14]:
# Check if GPU is available
if torch.cuda.is_available():
    print("GPU is available.")
    
    # Print the name of the GPU
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    
    # Print the total and available memory
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9  # Convert bytes to GB
    print(f"Total Memory: {total_memory:.2f} GB")

    allocated_memory = torch.cuda.memory_allocated(0) / 1e9  # Convert bytes to GB
    print(f"Allocated Memory: {allocated_memory:.2f} GB")

    cached_memory = torch.cuda.memory_reserved(0) / 1e9  # Convert bytes to GB
    print(f"Cached Memory: {cached_memory:.2f} GB")

    # Print other properties
    device_properties = torch.cuda.get_device_properties(0)
    print(f"CUDA Capability: {device_properties.major}.{device_properties.minor}")
    print(f"Multi-Processor Count: {device_properties.multi_processor_count}")
else:
    print("No GPU found.")

No GPU found.


In [15]:
# get all image files
image_files = [os.path.join(image_path, f) for f in os.listdir(image_path) if os.path.isfile(os.path.join(image_path, f))]
print("Number of image files: ", len(image_files))
json_files = [os.path.join(json_path, f) for f in os.listdir(json_path) if os.path.isfile(os.path.join(json_path, f))]
print("Number of json files: ", len(json_files))

Number of image files:  4536
Number of json files:  4536


In [16]:
# check if all image files have a corresponding json file and vice versa by comparing the file names without the extension
image_file_names = [os.path.splitext(os.path.basename(f))[0] for f in image_files]
json_file_names = [os.path.splitext(os.path.basename(f))[0] for f in json_files]
image_file_names = set(image_file_names)
json_file_names = set(json_file_names)

# check if all image files have a corresponding json file
missing_json_files = image_file_names - json_file_names
if len(missing_json_files) > 0:
    print("Missing json files: ", missing_json_files)

# check if all json files have a corresponding image file
missing_image_files = json_file_names - image_file_names
if len(missing_image_files) > 0:
    print("Missing image files: ", missing_image_files)

# remove image files without a corresponding json file
image_files = [f for f in image_files if os.path.splitext(os.path.basename(f))[0] not in missing_json_files]

# remove json files without a corresponding image file
json_files = [f for f in json_files if os.path.splitext(os.path.basename(f))[0] not in missing_image_files]

# create a list of tuples with the image and the data inside corresponding json file "coordinates":, "country_name", "country_code", "regions", "is_in_region" as keys
data = []
image_list = []
coordinates_list = []
country_name_list = []
country_code_list = []
regions_list = []
is_in_region_list = []

for image_file in image_files:
    image_file_name = os.path.splitext(os.path.basename(image_file))[0]
    json_file = [f for f in json_files if os.path.splitext(os.path.basename(f))[0] == image_file_name]
    if len(json_file) == 1:
        json_file = json_file[0]
        with open(json_file, "r", encoding="utf-8") as f:
            json_data = json.load(f)
            data.append((image_file, json_data))
            image_list.append(image_file)
            coordinates_list.append(json_data["coordinates"])
            country_name_list.append(json_data["country_name"])
            country_code_list.append(json_data["country_code"])
            regions_list.append(json_data["regions"])
            is_in_region_list.append(json_data["is_in_region"])

print("Number of image files with json data: ", len(data), len(image_list), len(coordinates_list), len(country_name_list), len(country_code_list), len(regions_list), len(is_in_region_list)  )

# print first 5 entries
for i in range(5):
    print(data[i])


Number of image files with json data:  4536 4536 4536 4536 4536 4536 4536
('C:\\Users\\yutar\\Documents\\HSLU\\Paper_data\\yfcc4k\\image_processed\\10003206806.jpg', {'coordinates': [5.1962, 44.3857], 'country_name': 'France', 'country_code': 'FR', 'regions': [['France_Drôme_FRA-5288', 'POINT (5.173945 44.683623)'], ['France_Vaucluse_FRA-5352', 'POINT (5.180045 44.011459)'], ['France_Bouches-du-Rhône_FRA-5274', 'POINT (5.090141 43.549418)'], ['France_Ardèche_FRA-5267', 'POINT (4.42979 44.755635)'], ['France_Isère_FRA-5311', 'POINT (5.577282 45.265943)']], 'is_in_region': True})
('C:\\Users\\yutar\\Documents\\HSLU\\Paper_data\\yfcc4k\\image_processed\\10008911015.jpg', {'coordinates': [-73.266, 44.3543], 'country_name': 'United States', 'country_code': 'US', 'regions': [['United States_Vermont_USA-3540', 'POINT (-72.665394 44.075698)'], ['United States_New Hampshire_USA-3538', 'POINT (-71.57834 43.68971)'], ['United States_Massachusetts_USA-3513', 'POINT (-71.805476 42.257527)'], ['Unit

In [17]:
# create a list of region_names for each image
region_names = []
for regions in regions_list:
    region_names.append([region[0] for region in regions])

print("Number of region names: ", len(region_names))

Number of region names:  4536


In [18]:

running_device = "colab_L4"
image_size = [80, 130]
data_augmentation = "base_augmentation" # or "base_augmentation", "full_augmentation_v2"
predict_coordinates=False
predict_regions=True

if running_device == "colab_L4":
    # Run unmapped images with low image resolution
    USE_MAPPED = False

elif running_device == "colab_A100":
    # Run mapped images with high image resolution
    image_size = [180, 320]
    NUMBER_OF_FILES = 79000

In [19]:
prediction_type = "regions" if predict_regions else ("coordinates" if predict_coordinates else "countries")

train_ratio = 0.7
val_ratio = 0.2
test_ratio = 0.1

preprocessing_config = { 'data_augmentation': data_augmentation, 'height': image_size[0], 'width': image_size[1], 'train_ratio': train_ratio, 'val_ratio': val_ratio, 'test_ratio': test_ratio }

base_transform = transforms.Compose([
          transforms.Resize((image_size[0], image_size[1])),
        ])
augmented_transform = None
final_transform = transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

if data_augmentation == "full_augmentation_v2":
    base_transform = transforms.Compose([])
    augmented_transform = transforms.Compose([
        # Disabled because black bars really hurt the performance at this size (only for v2)
        # transforms.RandomPerspective(distortion_scale=0.75, p=0.5),  # Randomly apply perspective transformation
        transforms.RandomResizedCrop((image_size[0], image_size[1]), scale=(0.75, 1.0)),  # Randomly crop the image and resize it to the original size
        transforms.RandomRotation(10),          # Randomly rotate the image by up to 10 degrees, sadly also causes black borders
        transforms.ColorJitter(
            brightness=(0.5, 1.5),  # Randomly change brightness (lower limit to simulate night, upper limit for bright daylight)
            contrast=(0.5, 1.5),    # Randomly change contrast
            saturation=(0.5, 1.5),  # Randomly change saturation
            hue=(-0.1, 0.1)         # Randomly change hue
        )
    ])

In [20]:
test_transform=transforms.Compose([base_transform, final_transform])

def load_image_file(file):
  # channels, height, width is the pytorch convention
  with Image.open(file) as img:
    img = img.convert("RGB")
    img = test_transform(img)
    return img

# load mutliple .png files parallelized
def load_image_files(files, num_workers=16):
  with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
    results = list(executor.map(load_image_file, files))
  return results


# load all images
images_loaded_list = load_image_files(image_list)
print("Number of images loaded: ", len(images_loaded_list))

Number of images loaded:  4536


In [21]:
project_name = "predicting-coordinates" if predict_coordinates else ("predicting-region" if predict_regions else "predicting-country")
sweep_metric_name = "Validation Distance (km)" if predict_coordinates or predict_regions else "Validation Accuracy Top 1"


In [22]:
wandb_downloader = WandbDownloader(entity="nlp_ls", project=f"dspro2-{project_name}", data_augmentation=data_augmentation, input_image_size=image_size)

file_names_to_download = [".pth", ".json"]

run_data = wandb_downloader.get_and_collect_best_runs(sweep_metric_name, file_names_to_download)

Found 5 matching runs.


In [23]:
print(run_data)

{'Best Run 1': {'id': 'a46wv6x3', 'parameters': {'seed': 42, 'epochs': 50, 'optimizer': 'adamW', 'model_name': 'resnet50', 'mapped_data': True, 'dataset_size': 81505, 'weight_decay': 0.1, 'learning_rate': 0.001, 'predict_regions': True, 'input_image_size': [80, 130], 'data_augmentation': 'base_augmentation', 'different_regions': 4596, 'dataset_identifier': '63289b51067a4c6ede4c44c23a329d82ab4964ed43942794430a9b71ec685b5c', 'different_countries': 75, 'predict_coordinates': False}, 'metrics': {'Validation Distance (km)': 3724.32276090458, 'test_data_run_id': '3sq5pqyq', '_step': 13, 'Train Loss': 29.596859680230697, 'Train Accuracy Top 5': 0.974567507405395, 'Validation Accuracy Top 5': 0.16906938224648796, '_timestamp': 1719231807.0481946, 'Validation Loss': 68.43571333425159, 'Train Accuracy Top 3 Country': 0.957618354863022, 'Train Accuracy Top 5 Country': 0.9669254903335496, 'Validation Accuracy Top 1 Country': 0.2610882767928348, 'Validation Accuracy Top 3 Country': 0.34353720630636

In [24]:
best_run = run_data["Best Run 1"]

print(best_run)

best_run_files = best_run["files"]
print(best_run_files)

{'id': 'a46wv6x3', 'parameters': {'seed': 42, 'epochs': 50, 'optimizer': 'adamW', 'model_name': 'resnet50', 'mapped_data': True, 'dataset_size': 81505, 'weight_decay': 0.1, 'learning_rate': 0.001, 'predict_regions': True, 'input_image_size': [80, 130], 'data_augmentation': 'base_augmentation', 'different_regions': 4596, 'dataset_identifier': '63289b51067a4c6ede4c44c23a329d82ab4964ed43942794430a9b71ec685b5c', 'different_countries': 75, 'predict_coordinates': False}, 'metrics': {'Validation Distance (km)': 3724.32276090458, 'test_data_run_id': '3sq5pqyq', '_step': 13, 'Train Loss': 29.596859680230697, 'Train Accuracy Top 5': 0.974567507405395, 'Validation Accuracy Top 5': 0.16906938224648796, '_timestamp': 1719231807.0481946, 'Validation Loss': 68.43571333425159, 'Train Accuracy Top 3 Country': 0.957618354863022, 'Train Accuracy Top 5 Country': 0.9669254903335496, 'Validation Accuracy Top 1 Country': 0.2610882767928348, 'Validation Accuracy Top 3 Country': 0.3435372063063616, 'Validation

In [25]:
# Download the json file using the url and return the json data
def download_json_file(url):
    print("Downloading json file from: ", url)
    response = requests.get(url)
    response.raise_for_status()
    respone_data = response.json()
    return respone_data

In [26]:

country_to_index_json = download_json_file(best_run_files["country_to_index.json"])

print(len(country_to_index_json))
print(country_to_index_json)

region_index_to_country_index_json = None
region_to_index_json = None
region_index_to_middle_point_json = None
if predict_regions:
  region_index_to_country_index_json = download_json_file(best_run_files["region_index_to_country_index.json"])

  print(len(region_index_to_country_index_json))
  print(region_index_to_country_index_json)

  region_to_index_json = download_json_file(best_run_files["region_to_index.json"])

  print(len(region_to_index_json))
  print(region_to_index_json)

  region_index_to_middle_point_json = download_json_file(best_run_files["region_index_to_middle_point.json"])

  print(len(region_index_to_middle_point_json))
  print(region_index_to_middle_point_json)


Downloading json file from:  https://api.wandb.ai/files/nlp_ls/dspro2-predicting-region/a46wv6x3/run-20240624_114339-a46wv6x3/country_to_index.json
75
{'Albania': 0, 'Argentina': 1, 'Australia': 2, 'Austria': 3, 'Bangladesh': 4, 'Belgium': 5, 'Bolivia, Plurinational State of': 6, 'Botswana': 7, 'Brazil': 8, 'Bulgaria': 9, 'Cambodia': 10, 'Canada': 11, 'Chile': 12, 'Colombia': 13, 'Croatia': 14, 'Czechia': 15, 'Denmark': 16, 'Dominican Republic': 17, 'Ecuador': 18, 'Estonia': 19, 'Eswatini': 20, 'Finland': 21, 'France': 22, 'Germany': 23, 'Ghana': 24, 'Greece': 25, 'Guatemala': 26, 'Hungary': 27, 'India': 28, 'Indonesia': 29, 'Ireland': 30, 'Israel': 31, 'Italy': 32, 'Japan': 33, 'Kenya': 34, 'Korea, Republic of': 35, 'Kyrgyzstan': 36, "Lao People's Democratic Republic": 37, 'Latvia': 38, 'Lesotho': 39, 'Lithuania': 40, 'Malaysia': 41, 'Malta': 42, 'Mexico': 43, 'Montenegro': 44, 'Netherlands': 45, 'New Zealand': 46, 'Nigeria': 47, 'North Macedonia': 48, 'Norway': 49, 'Peru': 50, 'Phili

In [27]:
dataset = CustomImageDatasetTest(images=images_loaded_list, coordinates=coordinates_list, countries=country_name_list, regions=region_names, region_to_index=region_to_index_json, country_to_index=country_to_index_json)

print("Dataset length: ", len(dataset))

Removing image at index 13 with country 'Viet Nam' not in the country_to_index mapping.
Removing image at index 27 with country 'Taiwan, Province of China' not in the country_to_index mapping.
Removing image at index 46 with country 'Palestine, State of' not in the country_to_index mapping.
Removing image at index 56 with country 'Morocco' not in the country_to_index mapping.
Removing image at index 58 with country 'Taiwan, Province of China' not in the country_to_index mapping.
Removing image at index 68 with country 'Taiwan, Province of China' not in the country_to_index mapping.
Removing image at index 89 with country 'Zimbabwe' not in the country_to_index mapping.
Removing image at index 93 with country 'Taiwan, Province of China' not in the country_to_index mapping.
Removing image at index 97 with country 'China' not in the country_to_index mapping.
Removing image at index 99 with country 'Kuwait' not in the country_to_index mapping.
Removing image at index 100 with country 'Taiwa

In [28]:
test_dataloader = torch.utils.data.DataLoader(dataset, batch_size=400, shuffle=False)

print("Test dataloader length: ", len(test_dataloader))

num_classes = len(region_to_index_json) if predict_regions else 3 if predict_coordinates else len(country_to_index_json)

geo_model_tester = GeoModelTester(datasize=len(dataset), test_dataloader=test_dataloader, num_classes=num_classes, predict_coordinates=predict_coordinates, predict_regions=predict_regions, country_to_index=country_to_index_json, region_to_index=region_to_index_json, region_index_to_middle_point=region_index_to_middle_point_json, region_index_to_country_index=region_index_to_country_index_json)

Test dataloader length:  11


In [29]:
geo_model_tester.test(model_type=best_run["parameters"]["model_name"], model_path=best_run_files["best_model"])

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to C:\Users\yutar/.cache\torch\hub\checkpoints\resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 58.7MB/s]


Test Loss: 82.0763, Test Distance: 8436.3847, Test Top 1 Accuracy: 0.0022, Test Top 3 Accuracy: 0.0063, Test Top 5 Accuracy: 0.0094
Test Top 1 Accuracy (Country): 0.0000, Test Top 3 Accuracy (Country): 0.0000, Test Top 5 Accuracy (Country): 0.0000
Test Top 1 Balanced Accuracy: 0.0008, Test Top 1 Balanced Accuracy (Country): 0.0000
