In [4]:
# important for gpuhub
#!pip install -r ../../requirements.txt wandb --upgrade

In [5]:
import wandb
import json
import sys

import torch.nn as nn
from torchvision.models import resnet50
from torchvision.models import ResNet50_Weights
from torchvision import transforms
from PIL import Image

# load .env file
from dotenv import load_dotenv
load_dotenv()

from geo_model_trainer import GeoModelTrainer
from image_data_handler import ImageDataHandler

sys.path.insert(0, '../')
from data_loader import get_data_to_load, split_json_and_image_files, load_json_files, load_image_files, load_json_file, load_image_file

### Loading data

In [6]:
# set number of files to load
NUMBER_OF_FILES = 3000 # 100000
# Set to False to use non-mapped data (singleplayer distribution), has more data
USE_MAPPED = True

# get list with local data and file paths
list_files = get_data_to_load(loading_file='../3_data_preparation/04_data_cleaning/updated_data_list' if USE_MAPPED else '../3_data_preparation/04_data_cleaning/updated_data_list_non_mapped', 
                              file_location='../3_data_preparation/01_enriching/.data', image_file_location='../1_data_collection/.data', allow_new_file_creation=False, 
                              from_remote_only=True, download_link='default', limit=NUMBER_OF_FILES, shuffle_seed=42, allow_file_location_env=True, allow_json_file_location_env=True, 
                              allow_image_file_location_env=True, allow_download_link_env=True)

json_files, image_files = split_json_and_image_files(list_files)
paired_files = list(zip(json_files, image_files))

Getting files list from remote
Got files list from remote
Parsed files list from remote
All remote files: 257130
All local files: 385096
Filtering out unpaired files
Filtered out 0 unpaired files
Relevant files: 257130
Limited files: 6000


In [7]:
def filter_corrupted_pairs(paired_files):
    non_corrupted_pairs = []
    
    for json_path, image_path in paired_files:
      with Image.open(image_path) as img:
          img.verify()  # verify that it's a readable image
      non_corrupted_pairs.append((json_path, image_path))

    return non_corrupted_pairs

# Filter the paired_files list to remove any corrupted entries
filtered_paired_files = filter_corrupted_pairs(paired_files)
print(f"Total non-corrupted pairs: {len(filtered_paired_files)}")

def split_json_and_image_files(paired_files):
    json_files = [json_file for json_file, _ in paired_files if json_file.endswith('.json')]
    image_files = [image_file for _, image_file in paired_files if image_file.endswith('.png')]  # Assuming all images are .png
    return json_files, image_files

json_files, image_files = split_json_and_image_files(filtered_paired_files)
paired_files = filtered_paired_files

Total non-corrupted pairs: 3000


In [8]:
len(json_files), len(image_files), len(paired_files)

(3000, 3000, 3000)

## Processing and loading data

In [9]:
# Define transformations
transform = transforms.Compose([
        transforms.Resize((50, 50)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

# Creating Dataloasders with the classes
data_handler = ImageDataHandler(image_files, json_files, transform, NUMBER_OF_FILES)
train_dataloader = data_handler.train_loader
val_dataloader = data_handler.val_loader
test_dataloader = data_handler.test_loader

# Load the country_to_index mapping and print the count of different countries
with open(f"models/datasize_{NUMBER_OF_FILES}_country_to_index.json", "r") as file:
  country_to_index = json.load(file)
print(f"Count of different countries: {len(country_to_index)}")

Count of different countries: 41


In [10]:
print("Number of train batches:", len(train_dataloader.dataset), "")

PRINT_FIRST = True

# Print first batch as an example, to see the structure
for images, coordinates, country_indices in train_dataloader:
    if PRINT_FIRST:
      print("Images batch shape:", images.shape)
      print("Coordinates batch shape:", coordinates.shape)
      print(coordinates[0])
      print("Country indices:", country_indices.shape)
      print(country_indices[0])
      PRINT_FIRST = False
    #break

Number of train batches: 2100 
Images batch shape: torch.Size([128, 3, 50, 50])
Coordinates batch shape: torch.Size([128, 2])
tensor([52.8881,  9.4488])
Country indices: torch.Size([128])
tensor(12)


## Model

In [11]:
# Load the pre-trained ResNet50 model with updated approach
model = resnet50(weights=ResNet50_Weights.DEFAULT)

# Change the output features of the last layer to 2 for binary classification
model.fc = nn.Linear(model.fc.in_features, 2)

# Initialize the new last layer with random weights
nn.init.kaiming_normal_(model.fc.weight, mode='fan_out', nonlinearity='relu')
nn.init.constant_(model.fc.bias, 0)

Parameter containing:
tensor([0., 0.], requires_grad=True)

## Training

In [16]:
#model_types = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]
model_type = "resnet18"
predict_coordinates=False
wandb.login()

if predict_coordinates:
    project_name = "predicting-coordinates"
    num_classes = 2
    sweep_goal = "minimize"
    sweep_metric_name = "Validation Distance (km)"
else:
    with open(f"models/datasize_{NUMBER_OF_FILES}_country_to_index.json", "r") as file:
        country_to_index = json.load(file)
    num_classes = len(country_to_index)
    project_name = "predicting-country"
    sweep_goal = "maximize"
    sweep_metric_name = "Validation Accuracy"

sweep_config = {
    "name": f"dspro2-basemodel-{model_type}-datasize-{NUMBER_OF_FILES}",
    "method": "grid",
    "metric": {"goal": sweep_goal, "name": sweep_metric_name},
    "parameters": {
        "learning_rate": {"values": [1e-2, 1e-3, 1e-4, 1e-5, 1e-6]},
        "optimizer": {"values": ["adamW"]},
        "weight_decay": {"values": [1e-3]},
        "epochs": {"values": [500]},
        "dataset_size": {"values": [NUMBER_OF_FILES]},
        "seed": {"values": [42]},
        "model_name": {"values": [model_type]},
        "predict_coordinates": {"values": [predict_coordinates]},
    },
}

sweep_id = wandb.sweep(sweep=sweep_config, project=f"dspro2-basemodel-{project_name}")
trainer = GeoModelTrainer(datasize=NUMBER_OF_FILES, train_dataloader=train_dataloader, val_dataloader=val_dataloader, 
                          num_classes=num_classes, predict_coordinates=predict_coordinates)
wandb.agent(sweep_id, function=trainer.train)

Error in callback <bound method _WandbInit._resume_backend of <wandb.sdk.wandb_init._WandbInit object at 0x1771a0710>> (for pre_run_cell), with arguments args (<ExecutionInfo object at 177930800, raw_cell="#model_types = ["resnet18", "resnet34", "resnet50".." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/Users/lukasstoeckli/GitLabProjects/DSPRO2/dspro2/dspro2/4_modeling/training_resnet_models.ipynb#X15sZmlsZQ%3D%3D>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe



Create sweep with ID: yxkufxom
Sweep URL: https://wandb.ai/nlp_ls/dspro2-basemodel-predicting-country/sweeps/yxkufxom


[34m[1mwandb[0m: Agent Starting Run: vhma30an with config:
[34m[1mwandb[0m: 	dataset_size: 3000
[34m[1mwandb[0m: 	epochs: 500
[34m[1mwandb[0m: 	learning_rate: 0.0001
[34m[1mwandb[0m: 	model_name: resnet18
[34m[1mwandb[0m: 	optimizer: adamW
[34m[1mwandb[0m: 	predict_coordinates: False
[34m[1mwandb[0m: 	seed: 42
[34m[1mwandb[0m: 	weight_decay: 0.001
Traceback (most recent call last):
  File "/Users/lukasstoeckli/GitLabProjects/DSPRO2/dspro2/.venv/lib/python3.12/site-packages/wandb/sdk/wandb_init.py", line 1177, in init
    wi.setup(kwargs)
  File "/Users/lukasstoeckli/GitLabProjects/DSPRO2/dspro2/.venv/lib/python3.12/site-packages/wandb/sdk/wandb_init.py", line 220, in setup
    with telemetry.context(obj=self._init_telemetry_obj) as tel:
  File "/Users/lukasstoeckli/GitLabProjects/DSPRO2/dspro2/.venv/lib/python3.12/site-packages/wandb/sdk/lib/telemetry.py", line 42, in __exit__
    self._run._telemetry_callback(self._obj)
  File "/Users/lukasstoeckli/GitLabPro

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x1771a0710>> (for post_run_cell), with arguments args (<ExecutionResult object at 1779303b0, execution_count=16 error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 177930800, raw_cell="#model_types = ["resnet18", "resnet34", "resnet50".." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/Users/lukasstoeckli/GitLabProjects/DSPRO2/dspro2/dspro2/4_modeling/training_resnet_models.ipynb#X15sZmlsZQ%3D%3D> result=None>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe