In [3]:
import random
import json
import os
import sys

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image

from resnet_pytorch import ResNet
import wandb
import uuid

sys.path.insert(0, '../')
from data_loader import get_data_to_load, split_json_and_image_files, load_json_files, load_image_files, load_json_file, load_image_file

### Loading data

In [4]:
# set number of files to load
NUMBER_OF_FILES = 0 #10000

USE_MAPPED = True

# get list with local data and file paths
list_files = get_data_to_load(loading_file='../3_data_preparation/04_data_cleaning/updated_data_list' if USE_MAPPED else '../3_data_preparation/04_data_cleaning/updated_data_list_non_mapped', file_location='../3_data_preparation/01_enriching/.data', image_file_location='../1_data_collection/.data', allow_new_file_creation=False, from_remote_only=True, download_link='default', limit=NUMBER_OF_FILES, shuffle_seed=42, allow_file_location_env=True, allow_json_file_location_env=True, allow_image_file_location_env=True, allow_download_link_env=True)

json_files, image_files = split_json_and_image_files(list_files)
paired_files = list(zip(json_files, image_files))

Getting files list from remote
Got files list from remote
Parsed files list from remote
All remote files: 274796
All local files: 257130
Filtering out unpaired files
Filtered out 17666 unpaired files
Relevant files: 257130
Downloading 68320 files
Downloaded 1000 files
Downloaded 2000 files
Downloaded 3000 files
Downloaded 4000 files
Downloaded 5000 files
Downloaded 6000 files
Downloaded 7000 files
Downloaded 8000 files
Downloaded 9000 files
Downloaded 10000 files
Downloaded 11000 files
Downloaded 12000 files
Downloaded 13000 files
Downloaded 14000 files
Downloaded 15000 files
Downloaded 16000 files
Downloaded 17000 files
Downloaded 18000 files
Downloaded 19000 files
Downloaded 20000 files
Downloaded 21000 files
Downloaded 22000 files
Downloaded 23000 files
Downloaded 24000 files
Downloaded 25000 files
Downloaded 26000 files
Downloaded 27000 files
Downloaded 28000 files
Downloaded 29000 files
Downloaded 30000 files
Downloaded 31000 files
Downloaded 32000 files
Downloaded 33000 files
Dow

## Example for processing

In [None]:
class CustomImageNameDataset(Dataset):
    def __init__(self, image_paths, json_paths, transform=None):
        self.image_paths = image_paths
        self.json_paths = json_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        return self.image_paths[idx], self.json_paths[idx]

# Define transformations
transform = transforms.Compose([
        transforms.Resize((50, 50)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

In [None]:
assert len(image_files) == len(json_files), "Mismatch in number of images and labels"

file_name_dataset = CustomImageNameDataset(image_files, json_files, transform=transform)
file_name_loader = DataLoader(file_name_dataset, batch_size=64, shuffle=True, num_workers=0)

In [None]:
file_name_loader.dataset.image_paths[0]

In [None]:
for temp_image_files, temp_label_files in file_name_loader:
    images = load_image_files(temp_image_files)
    labels = load_json_files(temp_label_files)
    countries = [item['country_name'] for item in labels]
    coordinates = [item['coordinates'] for item in labels]
    transformed_images = []
    for image in images:
      transformed_images.append(transform(image))
    break  # After the first batch, exit the loop
print("Images:", len(images))
print("Labels:", len(labels))

## Processing and loading data

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, images, coordinates, countries):
        self.images = images
        self.coordinates = coordinates
        self.countries = countries
        self.country_to_index = {country: idx for idx, country in enumerate(set(countries))}

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        country_index = self.country_to_index[self.countries[idx]]
        coordinates = torch.tensor(self.coordinates[idx], dtype=torch.float32)

        return image, coordinates, country_index

class ImageDataHandler:
    def __init__(self, image_paths, json_paths, batch_size=64, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
        self.batch_size = batch_size
        
        print(len(image_paths))
      
        file_name_dataset = CustomImageNameDataset(image_paths, json_paths, transform=transform)
        file_name_loader = DataLoader(file_name_dataset, batch_size=batch_size, shuffle=False)
        
        self.images = []
        self.countries = []
        self.coordinates = []

        for batch_image_files, batch_label_files in file_name_loader:
            images = load_image_files(batch_image_files)
            labels = load_json_files(batch_label_files)
            self.countries.extend([item['country_name'] for item in labels])
            self.coordinates.extend([item['coordinates'] for item in labels])
            for image in images:
              self.images.append(transform(image))
        
        # Initialize datasets and loaders
        self.train_loader, self.val_loader, self.test_loader = self.create_loaders(train_ratio, val_ratio, test_ratio)

    def create_loaders(self, train_ratio, val_ratio, test_ratio):
        assert train_ratio + val_ratio + test_ratio - 1 <= 0.001, "Ratios should sum to 1"
        
        combined = list(zip(self.images, self.coordinates, self.countries))
        random.shuffle(combined)
        total_count = len(combined)
        train_end = int(train_ratio * total_count)
        val_end = train_end + int(val_ratio * total_count)

        train_data = combined[:train_end]
        val_data = combined[train_end:val_end]
        test_data = combined[val_end:]
        
        # Create train, val- and test datasets
        train_dataset = CustomImageDataset(*zip(*train_data))
        val_dataset = CustomImageDataset(*zip(*val_data))
        test_dataset = CustomImageDataset(*zip(*test_data))

        # Create train, val- and test dataloaders
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)

        return train_loader, val_loader, test_loader

In [None]:
# Creating Dataloasders with the classes
data_handler = ImageDataHandler(image_files, json_files)
train_dataloader = data_handler.train_loader
val_dataloader = data_handler.val_loader
test_dataloader = data_handler.test_loader

In [None]:
print("Number of train batches:", len(train_dataloader.dataset), "")

PRINT_FIRST = True

# Print forst batch as an example, to see the structure
# 7000 images need 59 sec for processing as information
for images, coordinates, country_indices in train_dataloader:
    if PRINT_FIRST:
      print("Images batch shape:", images.shape)
      print("Coordinates batch shape:", coordinates.shape)
      print(coordinates[0][0])
      print("Country indices:", country_indices.shape)
      PRINT_FIRST = False
    #break

## Model

In [None]:
# Load the pretrained model
model = ResNet.from_pretrained('resnet18', num_classes=2)

In [None]:
print(model)

## Training

In [None]:
# set necessary seeds to make notebook reproducible 
def set_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

In [None]:
def haversine_distance(lon1, lat1, lon2, lat2):
    lon1, lat1, lon2, lat2 = map(np.radians, [lon1, lat1, lon2, lat2])
    
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    
    # Haversine formula
    a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2
    c = 2 * np.arcsin(np.sqrt(a))
    r = 6371
    return c * r

def mean_haversine_distance(preds, targets):
    total_distance = 0
    total = preds.shape[0]
    
    for i in range(total):
        pred_lon, pred_lat = preds[i, 0], preds[i, 1]
        true_lon, true_lat = targets[i, 0], targets[i, 1]
        distance = haversine_distance(pred_lon, pred_lat, true_lon, true_lat)
        total_distance += distance
    
    return total_distance / total

In [None]:
def train():
  with wandb.init(reinit=True) as run:
    config = run.config
    set_seed(config.seed)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    best_val_loss = float('inf')

    # Initializing early stopping
    patience = 20
    patience_counter = 0

    # Rename run name and initialize parameters in model name
    model_name = f"lr_{config.learning_rate}_opt_{config.optimizer}_weightDecay_{config.weight_decay}"
    run_name = model_name + f"_{uuid.uuid4()}"
    wandb.run.name = run_name
    

    criterion = nn.MSELoss()
    model = ResNet.from_pretrained('resnet18', num_classes=2).to(device)
    # SGD optimizer with different learning rates
    optimizer_grouped_parameters = [
        {"params": [p for n, p in model.named_parameters() if not n.startswith('fc')], "lr": config.learning_rate * 0.1},
        {"params": model.fc.parameters(), "lr": config.learning_rate}
    ]

    # SGD optimizer
    optimizer = optim.AdamW(optimizer_grouped_parameters, weight_decay=config.weight_decay)

    for epoch in range(config.epochs):
        train_loss = 0.0
        train_mhd = 0.0 # Mean Haversine Distance for training
        model.train()

        for images, coordinates, country_indices in train_dataloader:
            #print("Images shape:", images.shape)
            #print("Coordinates shape:", coordinates.shape)
            images, coordinates = images.to(device), coordinates.to(device)
            optimizer.zero_grad()
            output = model(images)
            
            if output.shape != coordinates.shape:
                raise ValueError("Mismatch in output and target shapes")


            loss = criterion(output, coordinates)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            preds = output.cpu().detach().numpy()
            targets = coordinates.cpu().numpy()
            train_mhd += mean_haversine_distance(preds, targets) * images.size(0)

        train_loss /= len(train_dataloader.dataset)
        train_mhd /= len(train_dataloader.dataset)

        val_loss = 0.0
        val_mhd = 0.0 # Mean Haversine Distance for validation
        model.eval()

        with torch.no_grad():
            for images, coordinates, country_indices in val_dataloader:
                images, coordinates = images.to(device), coordinates.to(device)
                output = model(images)
                loss = criterion(output, coordinates)

                val_loss += loss.item() * images.size(0)
                preds = output.cpu().detach().numpy()
                targets = coordinates.cpu().numpy()
                val_mhd += mean_haversine_distance(preds, targets) * images.size(0)

        val_loss /= len(val_dataloader.dataset)
        val_mhd /= len(val_dataloader.dataset)

        # Print metrics and log them to wandb
        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Train MHD: {train_mhd:.4f}, Val Loss: {val_loss:.4f}, Val MHD: {val_mhd:.4f}")
        wandb.log({
            "Train Loss (MSELoss)": train_loss, 
            "Train MHD (Mean Haversine Distance)": train_mhd, 
            "Val Loss (MSELoss)": val_loss, 
            "Val MHD (Mean Haversine Distance)": val_mhd
        })

        # Saving model and early stopping
        if val_loss < best_val_loss:
          best_val_loss = val_loss
          torch.save(model.state_dict(), f"models/resnet18_best_model_checkpoint{model_name}.pth")
          patience_counter = 0 
        else:
          patience_counter += 1
          if patience_counter >= patience:
              print(f"Stopping early after {patience} epochs without improvement")
              break

In [None]:
wandb.login()

sweep_config = {
    "name": f"dspro2-basemodel-resnet18",
    "method": "grid",
    "metric": {"goal": "maximize", "name": "eval_accuracy"},
    "parameters": {
        "learning_rate": {"values": [1e-2, 1e-3, 1e-4, 1e-5, 1e-6]},
        "optimizer": {"values": ["adamW"]},
        "weight_decay": {"values": [1e-2, 1e-3]},
        "epochs": {"values": [500]},
        "seed": {"values": [42]}
    },
}

sweep_id = wandb.sweep(sweep=sweep_config, project=f"dspro2-basemodel-resnet18")
wandb.agent(sweep_id, function=train)