In [1]:
#!pip install -r ../../requirements.txt wandb --upgrade

In [2]:
import random
import json
import os
import sys

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
from torchvision.models import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image

from resnet_pytorch import ResNet
import wandb
import uuid

# load .env file
from dotenv import load_dotenv
load_dotenv()


from custom_image_dataset import CustomImageDataset
from custom_image_name_dataset import CustomImageNameDataset
from image_data_handler import ImageDataHandler


sys.path.insert(0, '../')
from data_loader import get_data_to_load, split_json_and_image_files, load_json_files, load_image_files, load_json_file, load_image_file

### Loading data

In [3]:
# set number of files to load
NUMBER_OF_FILES = 3000 # 100000
# Set to False to use non-mapped data (singleplayer distribution), has more data
USE_MAPPED = True

# get list with local data and file paths
list_files = get_data_to_load(loading_file='../3_data_preparation/04_data_cleaning/updated_data_list' if USE_MAPPED else '../3_data_preparation/04_data_cleaning/updated_data_list_non_mapped', 
                              file_location='../3_data_preparation/01_enriching/.data', image_file_location='../1_data_collection/.data', allow_new_file_creation=True, 
                              from_remote_only=True, download_link='env', limit=NUMBER_OF_FILES, shuffle_seed=42, allow_file_location_env=True, allow_json_file_location_env=True, 
                              allow_image_file_location_env=True)

json_files, image_files = split_json_and_image_files(list_files)
paired_files = list(zip(json_files, image_files))

Getting files list from remote
Got files list from remote
Parsed files list from remote
All remote files: 257130
All local files: 385096
Filtering out unpaired files
Filtered out 0 unpaired files
Relevant files: 257130
Limited files: 6000


In [4]:
def filter_corrupted_pairs(paired_files):
    non_corrupted_pairs = []
    
    for json_path, image_path in paired_files:
        try:
            with Image.open(image_path) as img:
                img.verify()  # verify that it's a readable image
            non_corrupted_pairs.append((json_path, image_path))
        except (IOError, OSError):
            print(f"Corrupted image found and skipped: {image_path}")

    return non_corrupted_pairs

# Filter the paired_files list to remove any corrupted entries
filtered_paired_files = filter_corrupted_pairs(paired_files)
print(f"Total non-corrupted pairs: {len(filtered_paired_files)}")

def split_json_and_image_files(paired_files):
    json_files = [json_file for json_file, _ in paired_files if json_file.endswith('.json')]
    image_files = [image_file for _, image_file in paired_files if image_file.endswith('.png')]  # Assuming all images are .png
    return json_files, image_files

json_files, image_files = split_json_and_image_files(filtered_paired_files)
paired_files = filtered_paired_files

Total non-corrupted pairs: 3000


In [5]:
len(json_files), len(image_files), len(paired_files)

(3000, 3000, 3000)

## Processing and loading data

In [6]:
# Define transformations
transform = transforms.Compose([
        transforms.Resize((50, 50)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

# Creating Dataloasders with the classes
data_handler = ImageDataHandler(image_files, json_files, transform)
train_dataloader = data_handler.train_loader
val_dataloader = data_handler.val_loader
test_dataloader = data_handler.test_loader

# Load the country_to_index mapping and print the count of different countries
with open("country_to_index.json", 'r') as file:
  country_to_index = json.load(file)
print(f"Count of different countries: {len(country_to_index)}")

3000


In [7]:
print("Number of train batches:", len(train_dataloader.dataset), "")

PRINT_FIRST = True

# Print first batch as an example, to see the structure
for images, coordinates, country_indices in train_dataloader:
    if PRINT_FIRST:
      print("Images batch shape:", images.shape)
      print("Coordinates batch shape:", coordinates.shape)
      print(coordinates[0])
      print("Country indices:", country_indices.shape)
      print(country_indices[0])
      PRINT_FIRST = False
    #break

Number of train batches: 2100 
Images batch shape: torch.Size([128, 3, 50, 50])
Coordinates batch shape: torch.Size([128, 2])
tensor([7.3782, 3.9229])
Country indices: torch.Size([128])
tensor(20)


## Model

In [8]:
# Load the pre-trained ResNet50 model with updated approach
model = resnet50(weights=ResNet50_Weights.DEFAULT)

# Change the output features of the last layer to 2 for binary classification
model.fc = nn.Linear(model.fc.in_features, 2)

# Initialize the new last layer with random weights
nn.init.kaiming_normal_(model.fc.weight, mode='fan_out', nonlinearity='relu')
nn.init.constant_(model.fc.bias, 0)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /Users/lukasstoeckli/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:14<00:00, 6.95MB/s]


Parameter containing:
tensor([0., 0.], requires_grad=True)

## Training

In [22]:
# set necessary seeds to make notebook reproducible 
def set_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

In [23]:
def coordinates_to_cartesian(lon, lat, R=6371):
    # Convert degrees to radians
    lon_rad = np.radians(lon)
    lat_rad = np.radians(lat)

    # Cartesian coordinates using numpy
    x = R * np.cos(lat_rad) * np.cos(lon_rad)
    y = R * np.cos(lat_rad) * np.sin(lon_rad)
    z = R * np.sin(lat_rad)
    return np.stack([x, y, z], axis=-1)  # ensure the output is a numpy array with the correct shape

def spherical_distance(cartesian1, cartesian2, R=6371.0):
    cartesian1 = cartesian1.to(cartesian2.device)
    dot_product = (cartesian1 * cartesian2).sum(dim=1)
    
    norms1 = cartesian1.norm(p=2, dim=1)
    norms2 = cartesian2.norm(p=2, dim=1)

    cos_theta = dot_product / (norms1 * norms2)
    cos_theta = torch.clamp(cos_theta, -1.0, 1.0)
    
    theta = torch.acos(cos_theta)
    # curved distance -> "Bogenmass"
    distance = R * theta
    return distance

def mean_spherical_distance(preds, targets):
    distances = spherical_distance(preds, targets)
    return distances.mean()

In [24]:
class GeoModelTrainer:
    def __init__(self, num_classes=2, use_coordinates=True):
        self.num_classes = num_classes
        self.use_coordinates = use_coordinates
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model_type = None
        self.model = None
        
    def initialize_model(self, model_type):
        self.model_type = model_type
        if self.model_type == 'resnet18':
            model = resnet18(weights=ResNet18_Weights.DEFAULT)
        elif self.model_type == 'resnet34':
            model = resnet34(weights=ResNet34_Weights.DEFAULT)
        elif self.model_type == 'resnet50':
            model = resnet50(weights=ResNet50_Weights.DEFAULT)
        elif self.model_type == 'resnet101':
            model = resnet101(weights=ResNet101_Weights.DEFAULT)
        elif self.model_type == 'resnet152':
            model = resnet152(weights=ResNet152_Weights.DEFAULT)
        else:
            raise ValueError("Unsupported model type. Supported types are: resnet18, resnet34, resnet50, resnet101, resnet152.")
        
        # Modify the final layer based on the number of classes
        model.fc = nn.Linear(model.fc.in_features, self.num_classes)
        nn.init.kaiming_normal_(model.fc.weight, mode='fan_out', nonlinearity='relu')
        nn.init.constant_(model.fc.bias, 0)
        return model

    def train(self):
        with wandb.init(reinit=True) as run:
            config = run.config
            set_seed(config.seed)
            
            # Set seeds, configure optimizers, losses, etc.
            best_val_distance = float('inf')
            patience_counter = 0
            patience = 30

            # Rename run name and initialize parameters in model name
            model_name = f"model_{config.model_name}_lr_{config.learning_rate}_opt_{config.optimizer}_weightDecay_{config.weight_decay}"
            run_name = model_name + f"_{uuid.uuid4()}"
            wandb.run.name = run_name

            # Initialize model, optimizer and criterion
            self.model = self.initialize_model(model_type=config.model_name).to(self.device)
            optimizer_grouped_parameters = [
                {"params": [p for n, p in self.model.named_parameters() if not n.startswith('fc')], "lr": config.learning_rate * 0.1},
                {"params": self.model.fc.parameters(), "lr": config.learning_rate}
            ]
            optimizer = optim.AdamW(optimizer_grouped_parameters, weight_decay=config.weight_decay)
            criterion = nn.MSELoss()

            for epoch in range(config.epochs):
                train_loss, train_distance = self.run_epoch(config, criterion, optimizer, is_train=True)
                val_loss, val_distance = self.run_epoch(config, criterion, optimizer, is_train=False)
                
                # Early stopping and logging
                if val_distance < best_val_distance:
                    best_val_distance = val_distance
                    torch.save(model.state_dict(), f"models/datasize_{NUMBER_OF_FILES}/best_model_checkpoint{model_name}.pth")
                    patience_counter = 0
                else:
                    patience_counter += 1
                    if patience_counter >= patience:
                        print(f"Stopping early after {patience} epochs without improvement")
                        break

                # Log metrics to wandb
                wandb.log({
                    "Train Loss": train_loss,
                    "Train Distance": train_distance,
                    "Validation Loss": val_loss,
                    "Validation Distance": val_distance
                })

    def run_epoch(self, config, criterion, optimizer, is_train=True):
        if is_train:
            self.model.train()
        else:
            self.model.eval()
        
        total_loss = 0.0
        total_distance = 0.0
        data_loader = train_dataloader if is_train else val_dataloader
        
        for images, coordinates, country_indices in data_loader:
            with torch.set_grad_enabled(is_train):
                images = images.to(self.device)
                coordinates = coordinates.to(self.device) if self.use_coordinates else country_indices.to(self.device)
                optimizer.zero_grad()
                outputs = self.model(images)
                loss = criterion(outputs, coordinates)
                
                if is_train:
                    loss.backward()
                    optimizer.step()
                
                total_loss += loss.item() * images.size(0)
                total_distance += mean_spherical_distance(outputs, coordinates).item() * images.size(0)
        
        avg_loss = total_loss / len(data_loader.dataset)
        avg_distance = total_distance / len(data_loader.dataset)
        return avg_loss, avg_distance

In [None]:
#model_types = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]
model_types = ["resnet50", "resnet101", "resnet152"]


for model_type in model_types:
    wandb.login()

    sweep_config = {
        "name": f"dspro2-basemodel-{model_type}-datasize-{NUMBER_OF_FILES}",
        "method": "grid",
        "metric": {"goal": "minimize", "name": "Validation Distance"},
        "parameters": {
            "learning_rate": {"values": [1e-2, 1e-3, 1e-4, 1e-5, 1e-6]},
            "optimizer": {"values": ["adamW"]},
            "weight_decay": {"values": [1e-3]}, #1e-2, 
            "epochs": {"values": [500]},
            "dataset_size": {"values": [NUMBER_OF_FILES]},
            "seed": {"values": [42]},
            "model_name": {"values": [model_type]}
        },
    }
    
    sweep_id = wandb.sweep(sweep=sweep_config, project=f"dspro2-basemodel")
    trainer = GeoModelTrainer(num_classes=2, use_coordinates=True)
    wandb.agent(sweep_id, function=trainer.train)

[34m[1mwandb[0m: Currently logged in as: [33mluki-st[0m ([33mnlp_ls[0m). Use [1m`wandb login --relogin`[0m to force relogin


Create sweep with ID: 631malxm
Sweep URL: https://wandb.ai/nlp_ls/dspro2-basemodel/sweeps/631malxm


[34m[1mwandb[0m: Agent Starting Run: 0xisu2de with config:
[34m[1mwandb[0m: 	dataset_size: 100000
[34m[1mwandb[0m: 	epochs: 500
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	model_name: resnet50
[34m[1mwandb[0m: 	optimizer: adamW
[34m[1mwandb[0m: 	seed: 42
[34m[1mwandb[0m: 	weight_decay: 0.001
