In [1]:
# important for gpuhub
#!pip install -r ../../requirements.txt wandb --upgrade

In [2]:
import random
import json
import os
import sys

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
from torchvision.models import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image

from resnet_pytorch import ResNet
import wandb
import uuid

# load .env file
from dotenv import load_dotenv
load_dotenv()


from geo_model_trainer import GeoModelTrainer
from image_data_handler import ImageDataHandler


sys.path.insert(0, '../')
from data_loader import get_data_to_load, split_json_and_image_files, load_json_files, load_image_files, load_json_file, load_image_file

### Loading data

In [5]:
# set number of files to load
NUMBER_OF_FILES = 3000 # 100000
# Set to False to use non-mapped data (singleplayer distribution), has more data
USE_MAPPED = True

# get list with local data and file paths
list_files = get_data_to_load(loading_file='../3_data_preparation/04_data_cleaning/updated_data_list' if USE_MAPPED else '../3_data_preparation/04_data_cleaning/updated_data_list_non_mapped', 
                              file_location='../3_data_preparation/01_enriching/.data', image_file_location='../1_data_collection/.data', allow_new_file_creation=True, 
                              from_remote_only=True, download_link='env', limit=NUMBER_OF_FILES, shuffle_seed=42, allow_file_location_env=True, allow_json_file_location_env=True, 
                              allow_image_file_location_env=True)

json_files, image_files = split_json_and_image_files(list_files)
paired_files = list(zip(json_files, image_files))

Getting files list from remote


MaxRetryError: HTTPConnectionPool(host='49.12.197.1', port=80): Max retries exceeded with url: / (Caused by ProtocolError('Connection broken: IncompleteRead(134990 bytes read)', IncompleteRead(134990 bytes read)))

In [None]:
def filter_corrupted_pairs(paired_files):
    non_corrupted_pairs = []
    
    for json_path, image_path in paired_files:
        try:
            with Image.open(image_path) as img:
                img.verify()  # verify that it's a readable image
            non_corrupted_pairs.append((json_path, image_path))
        except (IOError, OSError):
            print(f"Corrupted image found and skipped: {image_path}")

    return non_corrupted_pairs

# Filter the paired_files list to remove any corrupted entries
filtered_paired_files = filter_corrupted_pairs(paired_files)
print(f"Total non-corrupted pairs: {len(filtered_paired_files)}")

def split_json_and_image_files(paired_files):
    json_files = [json_file for json_file, _ in paired_files if json_file.endswith('.json')]
    image_files = [image_file for _, image_file in paired_files if image_file.endswith('.png')]  # Assuming all images are .png
    return json_files, image_files

json_files, image_files = split_json_and_image_files(filtered_paired_files)
paired_files = filtered_paired_files

In [None]:
len(json_files), len(image_files), len(paired_files)

## Processing and loading data

In [None]:
# Define transformations
transform = transforms.Compose([
        transforms.Resize((50, 50)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

# Creating Dataloasders with the classes
data_handler = ImageDataHandler(image_files, json_files, transform, NUMBER_OF_FILES)
train_dataloader = data_handler.train_loader
val_dataloader = data_handler.val_loader
test_dataloader = data_handler.test_loader

# Load the country_to_index mapping and print the count of different countries
with open(f"models/datasize_{NUMBER_OF_FILES}_country_to_index.json", "r") as file:
  country_to_index = json.load(file)
print(f"Count of different countries: {len(country_to_index)}")

In [None]:
print("Number of train batches:", len(train_dataloader.dataset), "")

PRINT_FIRST = True

# Print first batch as an example, to see the structure
for images, coordinates, country_indices in train_dataloader:
    if PRINT_FIRST:
      print("Images batch shape:", images.shape)
      print("Coordinates batch shape:", coordinates.shape)
      print(coordinates[0])
      print("Country indices:", country_indices.shape)
      print(country_indices[0])
      PRINT_FIRST = False
    #break

## Model

In [None]:
# Load the pre-trained ResNet50 model with updated approach
model = resnet50(weights=ResNet50_Weights.DEFAULT)

# Change the output features of the last layer to 2 for binary classification
model.fc = nn.Linear(model.fc.in_features, 2)

# Initialize the new last layer with random weights
nn.init.kaiming_normal_(model.fc.weight, mode='fan_out', nonlinearity='relu')
nn.init.constant_(model.fc.bias, 0)

## Training

In [None]:
#model_types = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]
model_type = "resnet50"
predict_coordinates=True
wandb.login()

if not predict_coordinates:
  with open(f"models/datasize_{NUMBER_OF_FILES}_country_to_index.json", "r") as file:
    country_to_index = json.load(file)
  num_classes = len(country_to_index)
  project_name = "predicting-country"
  print(f"Number of classes: {num_classes}")
else:
  project_name = "-predicting-coordinates"
  num_classes = 2

sweep_config = {
    "name": f"dspro2-basemodel-{model_type}-datasize-{NUMBER_OF_FILES}",
    "method": "grid",
    "metric": {"goal": "minimize", "name": "Validation Distance (km)"},
    "parameters": {
        "learning_rate": {"values": [1e-4, 1e-5, 1e-6]},
        "optimizer": {"values": ["adamW"]},
        "weight_decay": {"values": [1e-3]}, #1e-2, 
        "epochs": {"values": [500]},
        "dataset_size": {"values": [NUMBER_OF_FILES]},
        "seed": {"values": [42]},
        "model_name": {"values": [model_type]},
        "predict_coordinates": {"values": [predict_coordinates]},
    },
}

sweep_id = wandb.sweep(sweep=sweep_config, project=f"dspro2-basemodel-{project_name}")
trainer = GeoModelTrainer(datasize=NUMBER_OF_FILES, train_dataloader=train_dataloader, val_dataloader=val_dataloader, 
                          num_classes=num_classes, use_coordinates=predict_coordinates)
wandb.agent(sweep_id, function=trainer.train)