In [None]:
!pip install -r ../../requirements.txt

In [None]:
import random
import json
import os
import sys

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
from torchvision.models import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image

from resnet_pytorch import ResNet
import wandb
import uuid

# load .env file
from dotenv import load_dotenv
load_dotenv()

sys.path.insert(0, '../')
from data_loader import get_data_to_load, split_json_and_image_files, load_json_files, load_image_files, load_json_file, load_image_file

### Loading data

In [None]:
def count_json_files(directory):
    json_file_count = 0

    for filename in os.listdir(directory):
        if filename.endswith('.json'):
            json_file_count += 1

    return json_file_count

# Path to the directory
folder_path = "/home/jovyan/dspro2/dspro2/.data"

# Count JSON files in the specified directory
count = count_json_files(folder_path)
print(f"Total JSON files in the directory: {count}")


In [None]:
# set number of files to load
NUMBER_OF_FILES = 10000
# Set to False to use non-mapped data (singleplayer distribution), has more data
USE_MAPPED = False

# get list with local data and file paths
list_files = get_data_to_load(loading_file='../3_data_preparation/04_data_cleaning/updated_data_list' if USE_MAPPED else '../3_data_preparation/04_data_cleaning/updated_data_list_non_mapped', 
                              file_location='../3_data_preparation/01_enriching/.data', image_file_location='../1_data_collection/.data', allow_new_file_creation=False, 
                              from_remote_only=True, download_link='default', limit=NUMBER_OF_FILES, shuffle_seed=42, allow_file_location_env=True, allow_json_file_location_env=True, 
                              allow_image_file_location_env=True, allow_download_link_env=True)

json_files, image_files = split_json_and_image_files(list_files)
paired_files = list(zip(json_files, image_files))

In [None]:
coordinates_count = 0
country_names_count = 0
for path in json_files:
    with open(path, 'r') as file:
        data = json.load(file)
        if not data.get("coordinates"):
            #print(data)
            coordinates_count += 1
        if not data.get("country_name"):
            #print(data)
            country_names_count += 1
print(f"Missing Files with no coordinates : {coordinates_count}")
print(f"Missing Files with no country name : {country_names_count}")

In [None]:
def filter_corrupted_pairs(paired_files):
    non_corrupted_pairs = []
    
    for json_path, image_path in paired_files:
        try:
            with Image.open(image_path) as img:
                img.verify()  # verify that it's a readable image
            non_corrupted_pairs.append((json_path, image_path))
        except (IOError, OSError):
            print(f"Corrupted image found and skipped: {image_path}")

    return non_corrupted_pairs

# Filter the paired_files list to remove any corrupted entries
filtered_paired_files = filter_corrupted_pairs(paired_files)
print(f"Total non-corrupted pairs: {len(filtered_paired_files)}")

In [None]:
def split_json_and_image_files(paired_files):
    json_files = [json_file for json_file, _ in paired_files if json_file.endswith('.json')]
    image_files = [image_file for _, image_file in paired_files if image_file.endswith('.png')]  # Assuming all images are .png
    return json_files, image_files

In [None]:
json_files, image_files = split_json_and_image_files(filtered_paired_files)
paired_files = filtered_paired_files

In [None]:
len(json_files), len(image_files), len(paired_files)

## Example for processing

In [None]:
class CustomImageNameDataset(Dataset):
    def __init__(self, image_paths, json_paths, transform=None):
        self.image_paths = image_paths
        self.json_paths = json_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        return self.image_paths[idx], self.json_paths[idx]

# Define transformations
transform = transforms.Compose([
        transforms.Resize((50, 50)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

In [None]:
assert len(image_files) == len(json_files), "Mismatch in number of images and labels"

file_name_dataset = CustomImageNameDataset(image_files, json_files, transform=transform)
file_name_loader = DataLoader(file_name_dataset, batch_size=64, shuffle=True, num_workers=0)

In [None]:
file_name_loader.dataset.image_paths[0]

In [None]:
for temp_image_files, temp_label_files in file_name_loader:
    images = load_image_files(temp_image_files)
    labels = load_json_files(temp_label_files)
    countries = [item['country_name'] for item in labels]
    coordinates = [item['coordinates'] for item in labels]
    transformed_images = []
    for image in images:
      transformed_images.append(transform(image))
    break  # After the first batch, exit the loop
print("Images:", len(images))
print("Labels:", len(labels))

## Processing and loading data

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, images, coordinates, countries):
        self.images = images
        self.coordinates = coordinates
        self.countries = countries
        self.country_to_index = {country: idx for idx, country in enumerate(set(countries))}

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        country_index = self.country_to_index[self.countries[idx]]
        coordinates = torch.tensor(self.coordinates[idx], dtype=torch.float32)

        return image, coordinates, country_index

class ImageDataHandler:
    def __init__(self, image_paths, json_paths, batch_size=128, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
        self.batch_size = batch_size
        
        print(len(image_paths))
      
        file_name_dataset = CustomImageNameDataset(image_paths, json_paths, transform=transform)
        file_name_loader = DataLoader(file_name_dataset, batch_size=batch_size, shuffle=False)
        
        self.images = []
        self.countries = []
        self.coordinates = []

        for batch_image_files, batch_label_files in file_name_loader:
            images = load_image_files(batch_image_files)
            labels = load_json_files(batch_label_files)
            for item in labels:
                if not 'country_name' in item:
                    print(item)
            #self.countries.append(item['country_name'])
            self.countries.extend([item.get('country_name', 'Unknown') for item in labels])
            #self.coordinates.extend([item['coordinates'] for item in labels])
            self.coordinates.extend([item.get('coordinates', 'Unknown') for item in labels])
            for image in images:
              self.images.append(transform(image))
        
        # Initialize datasets and loaders
        self.train_loader, self.val_loader, self.test_loader = self.create_loaders(train_ratio, val_ratio, test_ratio)

    def create_loaders(self, train_ratio, val_ratio, test_ratio):
        assert train_ratio + val_ratio + test_ratio - 1 <= 0.001, "Ratios should sum to 1"
        
        combined = list(zip(self.images, self.coordinates, self.countries))
        random.shuffle(combined)
        total_count = len(combined)
        train_end = int(train_ratio * total_count)
        val_end = train_end + int(val_ratio * total_count)

        train_data = combined[:train_end]
        val_data = combined[train_end:val_end]
        test_data = combined[val_end:]
        
        # Create train, val- and test datasets
        train_dataset = CustomImageDataset(*zip(*train_data))
        val_dataset = CustomImageDataset(*zip(*val_data))
        test_dataset = CustomImageDataset(*zip(*test_data))

        # Create train, val- and test dataloaders
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=True)

        return train_loader, val_loader, test_loader

In [None]:
# Creating Dataloasders with the classes
data_handler = ImageDataHandler(image_files, json_files)
train_dataloader = data_handler.train_loader
val_dataloader = data_handler.val_loader
test_dataloader = data_handler.test_loader

In [None]:
train_dataloader

In [None]:
print("Number of train batches:", len(train_dataloader.dataset), "")

PRINT_FIRST = True

# Print first batch as an example, to see the structure
for images, coordinates, country_indices in train_dataloader:
    if PRINT_FIRST:
      print("Images batch shape:", images.shape)
      print("Coordinates batch shape:", coordinates.shape)
      print(coordinates[0])
      print("Country indices:", country_indices.shape)
      PRINT_FIRST = False
    #break

## Model

In [None]:
# Load the pre-trained ResNet50 model with updated approach
model = resnet50(weights=ResNet50_Weights.DEFAULT)

# Change the output features of the last layer to 2 for binary classification
model.fc = nn.Linear(model.fc.in_features, 2)

# Initialize the new last layer with random weights
nn.init.kaiming_normal_(model.fc.weight, mode='fan_out', nonlinearity='relu')
nn.init.constant_(model.fc.bias, 0)

In [None]:
def initialize_resnet(model_type='resnet50'):
    if model_type == 'resnet18':
        model = resnet18(weights=ResNet18_Weights.DEFAULT)
    elif model_type == 'resnet34':
        model = resnet34(weights=ResNet34_Weights.DEFAULT)
    elif model_type == 'resnet50':
        model = resnet50(weights=ResNet50_Weights.DEFAULT)
    elif model_type == 'resnet101':
        model = resnet101(weights=ResNet101_Weights.DEFAULT)
    elif model_type == 'resnet152':
        model = resnet152(weights=ResNet152_Weights.DEFAULT)
    else:
        raise ValueError("Unsupported model type. Supported types are: resnet18, resnet34, resnet50, resnet101, resnet152.")

    # Print the current last layer
    print(f"Original last layer of {model_type}:", model.fc)
    
    # Re-initialize the last layer with random weights
    model.fc = nn.Linear(model.fc.in_features, model.fc.out_features)
    nn.init.kaiming_normal_(model.fc.weight, mode='fan_out', nonlinearity='relu')
    nn.init.constant_(model.fc.bias, 0)

    return model

In [None]:
resnet18_model = initialize_resnet('resnet18')
resnet18_model.fc

In [None]:
resnet50_model = initialize_resnet('resnet50')
resnet50_model.fc

## Training

In [None]:
# set necessary seeds to make notebook reproducible 
def set_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

In [None]:
def coordinates_to_cartesian(lon, lat, R=6371):
    # Convert degrees to radians
    lon_rad = np.radians(lon)
    lat_rad = np.radians(lat)

    # Cartesian coordinates using numpy
    x = R * np.cos(lat_rad) * np.cos(lon_rad)
    y = R * np.cos(lat_rad) * np.sin(lon_rad)
    z = R * np.sin(lat_rad)
    return np.stack([x, y, z], axis=-1)  # ensure the output is a numpy array with the correct shape

def spherical_distance(cartesian1, cartesian2, R=6371.0):
    cartesian1 = cartesian1.to(cartesian2.device)
    dot_product = (cartesian1 * cartesian2).sum(dim=1)
    
    norms1 = cartesian1.norm(p=2, dim=1)
    norms2 = cartesian2.norm(p=2, dim=1)

    cos_theta = dot_product / (norms1 * norms2)
    cos_theta = torch.clamp(cos_theta, -1.0, 1.0)
    
    theta = torch.acos(cos_theta)
    # curved distance -> "Bogenmass"
    distance = R * theta
    return distance

def mean_spherical_distance(preds, targets):
    distances = spherical_distance(preds, targets)
    return distances.mean()

In [None]:
class GeoModelTrainer:
    def __init__(self, model_type='resnet50', num_classes=2, use_coordinates=True):
        self.model_type = model_type
        self.num_classes = num_classes
        self.use_coordinates = use_coordinates
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = self.initialize_model().to(self.device)
        
    def initialize_model(self):
        if self.model_type == 'resnet18':
            model = resnet18(weights=ResNet18_Weights.DEFAULT)
        elif self.model_type == 'resnet34':
            model = resnet34(weights=ResNet34_Weights.DEFAULT)
        elif self.model_type == 'resnet50':
            model = resnet50(weights=ResNet50_Weights.DEFAULT)
        elif self.model_type == 'resnet101':
            model = resnet101(weights=ResNet101_Weights.DEFAULT)
        elif self.model_type == 'resnet152':
            model = resnet152(weights=ResNet152_Weights.DEFAULT)
        else:
            raise ValueError("Unsupported model type. Supported types are: resnet18, resnet34, resnet50, resnet101, resnet152.")
        
        # Modify the final layer based on the number of classes
        model.fc = nn.Linear(model.fc.in_features, self.num_classes)
        return model

    def train(self):
        with wandb.init(reinit=True) as run:
            config = run.config
            set_seed(config.seed)
            
            # Set seeds, configure optimizers, losses, etc.
            best_val_loss = float('inf')
            patience_counter = 0
            patience = 100

            # Rename run name and initialize parameters in model name
            model_name = f"model_{config.model_name}_lr_{config.learning_rate}_opt_{config.optimizer}_weightDecay_{config.weight_decay}"
            run_name = model_name + f"_{uuid.uuid4()}"
            wandb.run.name = run_name
            
            # Configure the optimizer
            optimizer_grouped_parameters = [
                {"params": [p for n, p in self.model.named_parameters() if not n.startswith('fc')], "lr": config.learning_rate * 0.1},
                {"params": self.model.fc.parameters(), "lr": config.learning_rate}
            ]
            optimizer = optim.AdamW(optimizer_grouped_parameters, weight_decay=config.weight_decay)
            criterion = nn.MSELoss()

            for epoch in range(config.epochs):
                train_loss, train_distance = self.run_epoch(config, criterion, optimizer, is_train=True)
                val_loss, val_distance = self.run_epoch(config, criterion, optimizer, is_train=False)
                
                # Early stopping and logging
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    torch.save(model.state_dict(), f"models/resnet18_best_model_checkpoint{model_name}.pth")
                    patience_counter = 0
                else:
                    patience_counter += 1
                    if patience_counter >= patience:
                        print(f"Stopping early after {patience} epochs without improvement")
                        break

                # Log metrics to wandb
                wandb.log({
                    "Train Loss": train_loss,
                    "Train Distance": train_distance,
                    "Validation Loss": val_loss,
                    "Validation Distance": val_distance
                })

    def run_epoch(self, config, criterion, optimizer, is_train=True):
        if is_train:
            self.model.train()
        else:
            self.model.eval()
        
        total_loss = 0.0
        total_distance = 0.0
        data_loader = train_dataloader if is_train else val_dataloader
        
        for images, coordinates, country_indices in data_loader:
            with torch.set_grad_enabled(is_train):
                images = images.to(self.device)
                coordinates = coordinates.to(self.device) if self.use_coordinates else country_indices.to(self.device)
                optimizer.zero_grad()
                outputs = self.model(images)
                loss = criterion(outputs, coordinates)
                
                if is_train:
                    loss.backward()
                    optimizer.step()
                
                total_loss += loss.item() * images.size(0)
                total_distance += mean_spherical_distance(outputs, coordinates).item() * images.size(0)
        
        avg_loss = total_loss / len(data_loader.dataset)
        avg_distance = total_distance / len(data_loader.dataset)
        return avg_loss, avg_distance

In [None]:
model_types = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]

for model_type in model_types:
    wandb.login()

    sweep_config = {
        "name": f"dspro2-basemodel-{model_type}",
        "method": "grid",
        "metric": {"goal": "minimize", "name": "Validation Distance"},
        "parameters": {
            "learning_rate": {"values": [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6]},
            "optimizer": {"values": ["adamW"]},
            "weight_decay": {"values": [1e-3]}, #1e-2, 
            "epochs": {"values": [500]},
            "dataset_size": {"values": [NUMBER_OF_FILES]},
            "seed": {"values": [42]},
            "model_name": {"values": [model_type]}
        },
    }
    
    sweep_id = wandb.sweep(sweep=sweep_config, project=f"dspro2-basemodel")
    trainer = GeoModelTrainer(model_type=model_type, num_classes=2, use_coordinates=True)
    wandb.agent(sweep_id, function=trainer.train)