In [1]:
# MUST be first cell: set multiprocessing method for Windows
import torch.multiprocessing as mp
try:
    mp.set_start_method("spawn", force=True)
except RuntimeError:
    pass  # Already set

In [None]:
import os
import sys
from pathlib import Path

# FOR LOCAL USE THIS LINES
current = Path.cwd()
src_path = current / "src" if (current / "src").exists() else current.parent

# FOR COLAB USE THIS LINE INSTEAD
#!git clone https://github.com/MatteoCamillo-code/GeoLoc-CVCS.git
#src_path = Path("/content/GeoLoc-CVCS/src").resolve()

sys.path.insert(0, str(src_path))

from utils.paths import find_project_root

# Set working directory and sys.path properly
project_root = find_project_root(src_path)
data_dir = project_root / "data"
os.chdir(project_root)
sys.path.insert(0, str(project_root / "src"))
print("CWD:", Path.cwd())

CWD: F:\InfTech\Prodotti\Python\GeoLocGit\GeoLoc-CVCS


In [14]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
from torch.optim.lr_scheduler import StepLR

from dataset.osv_dataset import OSV_mini
from configs.baseline import TrainConfig

from utils.seed import seed_everything
from training.runner import fit


In [4]:
cfg = TrainConfig()
seed_everything(cfg.seed)

device = cfg.device if torch.cuda.is_available() else "cpu"
print("Device:", device)


Device: cuda


In [5]:
import kagglehub

path = kagglehub.dataset_download("josht000/osv-mini-129k")
path = path + "/osv5m"
print("Path to dataset files:", path)

image_root = path + "/train_images"


  from .autonotebook import tqdm as notebook_tqdm


Path to dataset files: C:\Users\camil\.cache\kagglehub\datasets\josht000\osv-mini-129k\versions\1/osv5m


In [6]:
train_val_path = data_dir / "metadata/s2-geo-cells/train_val_split_geocells.csv"
cell_centers_path = data_dir / "metadata/s2-geo-cells/cell_center_dataset.csv"

train_val_meta = pd.read_csv(train_val_path)
cell_centers_df = pd.read_csv(cell_centers_path)

print("Train/val CSV:", train_val_path)
print("Cell centers CSV:", cell_centers_path)


Train/val CSV: F:\InfTech\Prodotti\Python\GeoLocGit\GeoLoc-CVCS\data\metadata\s2-geo-cells\train_val_split_geocells.csv
Cell centers CSV: F:\InfTech\Prodotti\Python\GeoLocGit\GeoLoc-CVCS\data\metadata\s2-geo-cells\cell_center_dataset.csv


In [None]:
IMG_SIZE = 224

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0), ratio=(3/4, 4/3)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ColorJitter(0.1, 0.1, 0.1, 0.05),
    transforms.RandomApply([transforms.RandomRotation(10)], p=0.2),
    transforms.ToTensor(),
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])

In [8]:
train_dataset = OSV_mini(
    image_root=image_root,
    csv_path=train_val_path,
    transform=train_transform,
    split="train",
    scene="total",
    label_maps=None
)

val_dataset = OSV_mini(
    image_root=image_root,
    csv_path=train_val_path,
    transform=val_transform,
    split="val",
    scene="total",
    label_maps=train_dataset.label_maps
)

print("Train size:", len(train_dataset))
print("Val size:", len(val_dataset))
print("Label maps:", {k: len(v) for k,v in train_dataset.label_maps.items()})


Train size: 100863
Val size: 17803
Label maps: {'label_config_1': 4741, 'label_config_2': 2508, 'label_config_3': 1336}


In [9]:
from dataset.osv_dataset import seed_worker, fast_collate

BATCH_SIZE = cfg.batch_size
NUM_WORKERS = cfg.num_workers
PREFETCH_FACTOR = 4

# Create a generator for reproducibility with workers
g = torch.Generator()
g.manual_seed(cfg.seed)

train_loader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    prefetch_factor=PREFETCH_FACTOR if NUM_WORKERS > 0 else None,
    persistent_workers=True if NUM_WORKERS > 0 else False,
    collate_fn=fast_collate,
    worker_init_fn=seed_worker,
    generator=g
)

val_loader = DataLoader(
    val_dataset,
    shuffle=False,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    prefetch_factor=PREFETCH_FACTOR if NUM_WORKERS > 0 else None,
    persistent_workers=True if NUM_WORKERS > 0 else False,
    collate_fn=fast_collate,
    worker_init_fn=seed_worker
)

In [10]:
PARTITION_IDX = 1  # 0,1,2 corresponding to label_config_1/2/3

weights = ResNet50_Weights.IMAGENET1K_V2
model = resnet50(weights=weights)

# number of classes depends on partition
num_classes = [
    train_val_meta["label_config_1"].nunique(),
    train_val_meta["label_config_2"].nunique(),
    train_val_meta["label_config_3"].nunique(),
]
n_out = num_classes[PARTITION_IDX]

in_features = model.fc.in_features
model.fc = nn.Linear(in_features, n_out)

model = model.to(device)

# Optional: comment out if it causes issues on Windows/your PyTorch version
# model = torch.compile(model, backend="aot_eager")

print("Output classes:", n_out)


Output classes: 2510


In [11]:
# Freeze all parameters in the model initially
for param in model.parameters():
    param.requires_grad = False

# Unfreeze the final fully connected layer (model.fc)
# This layer was replaced with a new nn.Linear layer in cell nYgB9PL73PAb
for param in model.fc.parameters():
    param.requires_grad = True

In [None]:
criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=cfg.lr,
    momentum=cfg.momentum,
    weight_decay=cfg.weight_decay,
    nesterov=True
)
scheduler = StepLR(optimizer, step_size=cfg.scheduler_step_size, gamma=cfg.scheduler_gamma)
scaler = torch.amp.GradScaler(device=cfg.device, enabled=cfg.amp)
torch.backends.cudnn.benchmark = True

In [None]:
# Ensure your config has:
# output_dir="outputs", model_name="first_try.pt"
# runner.py uses root_path(cfg.output_dir, "checkpoints", cfg.model_name)

history = fit(
    cfg=cfg,
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    criterion=criterion,
    scaler=scaler,
    label_idx=PARTITION_IDX,
    use_tqdm=cfg.use_tqdm,
    scheduler=None
)

history
