In [1]:
# important for gpuhub
# !pip install -r ../../requirements.txt --upgrade

In [2]:
import wandb
import json
import sys
import os

import torch
import torch.nn as nn
from torchvision.models import resnet50
from torchvision.models import ResNet50_Weights
from torchvision import transforms
from PIL import Image

# load .env file
from dotenv import load_dotenv
from geo_model_trainer import GeoModelTrainer
from image_data_handler import ImageDataHandler

#torch.backends.cudnn.benchmark = False
#torch.backends.cudnn.deterministic = True

sys.path.insert(0, '../')
from data_loader import get_data_to_load, split_json_and_image_files, hash_filenames

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [3]:
WANDB_TOKEN = os.getenv('WANDB_TOKEN')
# Define where to run
env_path = '../../.env'
if not WANDB_TOKEN and os.path.exists(env_path):
  load_dotenv(env_path)
  WANDB_TOKEN = os.getenv('WANDB_TOKEN')

In [4]:
# Check if GPU is available
if torch.cuda.is_available():
    print("GPU is available.")
    
    # Print the name of the GPU
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    
    # Print the total and available memory
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9  # Convert bytes to GB
    print(f"Total Memory: {total_memory:.2f} GB")

    allocated_memory = torch.cuda.memory_allocated(0) / 1e9  # Convert bytes to GB
    print(f"Allocated Memory: {allocated_memory:.2f} GB")

    cached_memory = torch.cuda.memory_reserved(0) / 1e9  # Convert bytes to GB
    print(f"Cached Memory: {cached_memory:.2f} GB")

    # Print other properties
    device_properties = torch.cuda.get_device_properties(0)
    print(f"CUDA Capability: {device_properties.major}.{device_properties.minor}")
    print(f"Multi-Processor Count: {device_properties.multi_processor_count}")
else:
    print("No GPU found.")

No GPU found.


## Loading data

In [5]:
# set number of files to load
NUMBER_OF_FILES = 0 # 100000
# Set to False to use non-mapped data (singleplayer distribution), has more data
USE_MAPPED = True

# get list with local data and file paths
list_files = get_data_to_load(loading_file='../3_data_preparation/04_data_cleaning/updated_data_list_more' if USE_MAPPED else '../3_data_preparation/04_data_cleaning/updated_data_list_non_mapped', 
                              file_location='../3_data_preparation/01_enriching/.data', image_file_location='../1_data_collection/.data', allow_new_file_creation=False, 
                              from_remote_only=True, download_link='default', limit=NUMBER_OF_FILES, shuffle_seed=42, allow_file_location_env=True, allow_json_file_location_env=True, 
                              allow_image_file_location_env=True, allow_download_link_env=True)

json_files, image_files = split_json_and_image_files(list_files)
paired_files = list(zip(json_files, image_files))

Getting files list from remote
Got files list from remote
Parsed files list from remote
All remote files: 705681


PermissionError: [Errno 13] Permission denied: '/Volumes/LinUSB'

In [None]:
def filter_corrupted_pairs(paired_files):
    non_corrupted_pairs = []
    
    for json_path, image_path in paired_files:
      non_corrupted_pairs.append((json_path, image_path))
      
    return non_corrupted_pairs

# Filter the paired_files list to remove any corrupted entries
filtered_paired_files = filter_corrupted_pairs(paired_files)
print(f"Total non-corrupted pairs: {len(filtered_paired_files)}")

def split_json_and_image_files(paired_files):
    json_files = [json_file for json_file, _ in paired_files if json_file.endswith('.json')]
    image_files = [image_file for _, image_file in paired_files if image_file.endswith('.png')]
    return json_files, image_files

json_files, image_files = split_json_and_image_files(filtered_paired_files)
paired_files = filtered_paired_files

Corrupted image found and skipped: /home/jovyan/dspro2/dspro2/.data/geoguessr_location_singleplayer_FjbsHTUCjCEjccK3_3.png
Total non-corrupted pairs: 74999


In [None]:
len(json_files), len(image_files), len(paired_files)

(74999, 74999, 74999)

## Processing and loading data

In [None]:
# Default was 50, 50
data_augmentation = "base_augmentation"
image_size = [80, 130]
# Original size is  pixelHeight: 180, pixelWidth: 320
#image_size = [180, 320]

if data_augmentation == "base_augmentation":
    transform = transforms.Compose([
        transforms.Resize((image_size[0], image_size[1])),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
elif data_augmentation == "full_augmentation":
    transform = transforms.Compose([
        transforms.Resize((image_size[0], image_size[1])),
        transforms.RandomRotation(10),          # Randomly rotate the image by up to 10 degrees
        transforms.ColorJitter(
            brightness=(0.5, 1.5),  # Randomly change brightness (lower limit to simulate night, upper limit for bright daylight)
            contrast=(0.5, 1.5),    # Randomly change contrast
            saturation=(0.5, 1.5),  # Randomly change saturation
            hue=(-0.1, 0.1)         # Randomly change hue
        ),
        transforms.ToTensor(),                  # Convert the image to a tensor
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize the image
    ])

In [None]:
# Creating Dataloasders with the classes

# Hash the files list to get a unique identifier for the data
hash_filenames = hash_filenames(list_files)

data_handler = ImageDataHandler(image_files, json_files, transform, NUMBER_OF_FILES, batch_size=100)
train_dataloader = data_handler.train_loader
val_dataloader = data_handler.val_loader
test_dataloader = data_handler.test_loader
country_to_index = data_handler.country_to_index

# Load the country_to_index mapping and print the count of different countries
print("Dataset size:", NUMBER_OF_FILES)
print("Dataset identifier:", hash_filenames)
print(f"Count of different countries: {len(country_to_index)}")

Loading images and labels:   0%|          | 17/15000 [04:03<59:39:42, 14.34s/it]


KeyboardInterrupt: 

In [None]:
print("Number of train batches:", len(train_dataloader.dataset), "")

PRINT_FIRST = True

# Print first batch as an example, to see the structure
for images, coordinates, country_indices in train_dataloader:
    if PRINT_FIRST:
      print("Images batch shape:", images.shape)
      print("Coordinates batch shape:", coordinates.shape)
      print(coordinates[0])
      print("Country indices:", country_indices.shape)
      print(country_indices[0])
      PRINT_FIRST = False
    #break

## Model

In [None]:
"""# Load the pre-trained ResNet50 model with updated approach
model = resnet50(weights=ResNet50_Weights.DEFAULT)

# Change the output features of the last layer to 2 for binary classification
model.fc = nn.Linear(model.fc.in_features, 2)

# Initialize the new last layer with random weights
nn.init.kaiming_normal_(model.fc.weight, mode='fan_out', nonlinearity='relu')
nn.init.constant_(model.fc.bias, 0)"""

## Training

In [None]:
#model_types = ["resnet18", "resnet50"]
model_types = ["mobilenet_v2"]
predict_coordinates=False
wandb.login(key=WANDB_TOKEN) if WANDB_TOKEN else wandb.login()

for model_type in model_types:
    if predict_coordinates:
        project_name = "predicting-coordinates"
        num_classes = 2
        sweep_goal = "minimize"
        sweep_metric_name = "Validation Distance (km)"
    else:
        num_classes = len(country_to_index)
        project_name = "predicting-country"
        sweep_goal = "maximize"
        sweep_metric_name = "Validation Accuracy Top 1"
    
    sweep_config = {
        "name": f"dspro2-basemodel-{model_type}-datasize-{NUMBER_OF_FILES}-input_imagesize-{image_size[0]}x{image_size[1]}",
        "method": "grid",
        "metric": {"goal": sweep_goal, "name": sweep_metric_name},
        "parameters": {
            "learning_rate": {"values": [1e-2, 1e-3, 1e-4]},
            "optimizer": {"values": ["adamW"]},
            "weight_decay": {"values": [1e-3]},
            "epochs": {"values": [100]},
            "dataset_size": {"values": [NUMBER_OF_FILES]},
            "dataset_identifier": {"values": [hash_filenames]},
            "seed": {"values": [42]},
            "model_name": {"values": [model_type]},
            "input_image_size": {"values": [image_size]},
            "predict_coordinates": {"values": [predict_coordinates]},
            "mapped_data": {"values": [USE_MAPPED]},
            "different_countries": {"values": [len(country_to_index)]},
            "data_augmentation": {"values": [data_augmentation]}
        },
    }
    
    sweep_id = wandb.sweep(sweep=sweep_config, project=f"dspro2-basemodel-{project_name}")
    trainer = GeoModelTrainer(datasize=NUMBER_OF_FILES, train_dataloader=train_dataloader, val_dataloader=val_dataloader, 
                              num_classes=num_classes, predict_coordinates=predict_coordinates, country_to_index=country_to_index if not predict_coordinates else None)

    wandb.agent(sweep_id, function=trainer.train)