In [None]:
import boto3

def get_r2_image_count():
    # Set up R2 session using boto3
    session = boto3.session.Session()
    s3 = session.client(
        service_name='s3',
        aws_access_key_id=R2_ACCESS_KEY,
        aws_secret_access_key=R2_SECRET_KEY,
        endpoint_url=R2_ENDPOINT
    )

    prefix = "datasets/gmaps/world_sampling/world/world/images/"
    paginator = s3.get_paginator('list_objects_v2')
    page_iterator = paginator.paginate(
        Bucket=R2_BUCKET,
        Prefix=prefix,
    )

    image_count = 0
    for page in page_iterator:
        if "Contents" in page:
            image_count += len(page["Contents"])
    return image_count

# Example usage to get the length:
prefix = "datasets/gmaps/world_sampling/world/world/images/"
len_images = get_r2_image_count()
print(f"Number of images in {R2_BUCKET}/{prefix} = {len_images}")


In [None]:
import io
import tarfile
import tempfile
import os

# Download gmaps_shard_00001.tar.gz to memory
session = boto3.session.Session()
s3 = session.client(
    service_name='s3',
    aws_access_key_id=R2_ACCESS_KEY,
    aws_secret_access_key=R2_SECRET_KEY,
    endpoint_url=R2_ENDPOINT
)

prefix = "datasets/gmaps/world_sampling/world/world/images/"
filename = "gmaps_shard_00003.tar.gz"
key = prefix + filename

# Download to memory (BytesIO)
tar_gz_data = io.BytesIO()
s3.download_fileobj(R2_BUCKET, key, tar_gz_data)
tar_gz_data.seek(0)  # Reset pointer to beginning

# List files inside tar.gz
file_list = []
with tarfile.open(fileobj=tar_gz_data, mode='r:gz') as tar:
    for member in tar.getmembers():
        if member.isfile():
            file_list.append(member.name)

# Store file list in variable
files_inside = file_list

print(f"Number of files inside {filename}: {len(files_inside)}")
print(f"\nFiles inside {filename}:")
for i, file_path in enumerate(files_inside, 1):
    print(f"{i}. {file_path}")


In [None]:
import json

# Reset tar_gz_data pointer to beginning
tar_gz_data.seek(0)

# Find and read the first JSON file (excluding __metadata__.json)
first_json_data = None
first_json_filename = None

with tarfile.open(fileobj=tar_gz_data, mode='r:gz') as tar:
    # Find first JSON file that is not __metadata__.json
    for member in tar.getmembers():
        if member.isfile() and member.name.endswith('.json') and member.name != '__metadata__.json':
            first_json_filename = member.name
            # Extract and read the JSON file
            json_file = tar.extractfile(member)
            first_json_data = json.load(json_file)
            break

if first_json_data:
    print(f"First JSON file: {first_json_filename}")
    print(f"\nJSON content:")
    print(json.dumps(first_json_data, indent=2))
else:
    print("No JSON file found (excluding __metadata__.json)")

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import json
import io
import tarfile
from typing import List, Tuple, Optional, Dict, Any
import threading

# Get list of all tar.gz files
def get_tar_gz_list():
    session = boto3.session.Session()
    s3 = session.client(
        service_name='s3',
        aws_access_key_id=R2_ACCESS_KEY,
        aws_secret_access_key=R2_SECRET_KEY,
        endpoint_url=R2_ENDPOINT
    )
    
    prefix = "datasets/gmaps/world_sampling/world/world/images/"
    paginator = s3.get_paginator('list_objects_v2')
    page_iterator = paginator.paginate(Bucket=R2_BUCKET, Prefix=prefix)
    
    tar_gz_files = []
    for page in page_iterator:
        if "Contents" in page:
            for obj in page["Contents"]:
                if obj['Key'].endswith('.tar.gz'):
                    tar_gz_files.append(obj['Key'])
    
    return sorted(tar_gz_files)

# Get all tar.gz files
tar_gz_list = get_tar_gz_list()
print(f"Found {len(tar_gz_list)} tar.gz files")

class TarGZDataset(Dataset):
    def __init__(self, tar_gz_keys: List[str], r2_bucket: str, r2_endpoint: str, 
                 r2_access_key: str, r2_secret_key: str):
        self.tar_gz_keys = tar_gz_keys
        self.r2_bucket = r2_bucket
        self.r2_endpoint = r2_endpoint
        self.r2_access_key = r2_access_key
        self.r2_secret_key = r2_secret_key
        
        # Thread-local storage for worker-specific data
        self.worker_data = threading.local()
        
        # Pre-compute total samples across all tar.gz files
        # This is an approximation - actual count will be computed per worker
        self._total_samples = None
    
    def _get_worker_tar_gz(self, worker_id: int) -> Optional[str]:
        """Get the tar.gz file assigned to this worker"""
        if worker_id < len(self.tar_gz_keys):
            return self.tar_gz_keys[worker_id]
        return None
    
    def _load_tar_gz_data(self, tar_gz_key: str):
        """Load and cache tar.gz data for this worker"""
        if not hasattr(self.worker_data, 'tar_gz_key') or self.worker_data.tar_gz_key != tar_gz_key:
            # Get worker ID for logging
            worker_info = torch.utils.data.get_worker_info()
            worker_id = worker_info.id if worker_info else 0
            print(f"Worker {worker_id} loading tar.gz: {tar_gz_key}")
            
            # Download tar.gz to memory
            session = boto3.session.Session()
            s3 = session.client(
                service_name='s3',
                aws_access_key_id=self.r2_access_key,
                aws_secret_access_key=self.r2_secret_key,
                endpoint_url=self.r2_endpoint
            )
            
            tar_gz_data = io.BytesIO()
            s3.download_fileobj(self.r2_bucket, tar_gz_key, tar_gz_data)
            tar_gz_data.seek(0)
            
            # Extract all JSON and image pairs
            samples = []
            with tarfile.open(fileobj=tar_gz_data, mode='r:gz') as tar:
                # Build a mapping of JSON files to their data
                json_data_map = {}
                image_data_map = {}
                
                for member in tar.getmembers():
                    if member.isfile():
                        if member.name.endswith('.json') and member.name != '__metadata__.json':
                            json_file = tar.extractfile(member)
                            json_data = json.load(json_file)
                            json_data_map[member.name] = json_data
                        elif member.name.endswith('.jpg'):
                            image_file = tar.extractfile(member)
                            image_data = image_file.read()
                            image_data_map[member.name] = image_data
                
                # Match JSON files with their corresponding images
                for json_name, json_data in json_data_map.items():
                    if 'lat' in json_data and 'lon' in json_data:
                        # Find corresponding image (remove .json extension and add .jpg)
                        image_name = json_name.replace('.json', '.jpg')
                        if image_name in image_data_map:
                            samples.append({
                                'image': image_data_map[image_name],
                                'lat': json_data['lat'],
                                'lon': json_data['lon'],
                                'json_data': json_data
                            })
            
            print(f"Worker {worker_id} loaded {len(samples)} samples from {tar_gz_key}")
            
            # Cache in thread-local storage
            self.worker_data.tar_gz_key = tar_gz_key
            self.worker_data.samples = samples
            self.worker_data.current_idx = 0
    
    def __len__(self):
        # Return 1 - each worker will return all samples from its tar.gz in one batch
        return 1
    
    def __getitem__(self, idx):
        # Get worker ID
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            worker_id = 0
        else:
            worker_id = worker_info.id
        
        # Get tar.gz for this worker
        tar_gz_key = self._get_worker_tar_gz(worker_id)
        if tar_gz_key is None:
            raise IndexError(f"Worker {worker_id} has no assigned tar.gz")
        
        # Load tar.gz data if not already loaded
        self._load_tar_gz_data(tar_gz_key)
        
        # Get all samples from worker's cached data
        if not hasattr(self.worker_data, 'samples') or len(self.worker_data.samples) == 0:
            raise IndexError(f"No samples in tar.gz {tar_gz_key}")
        
        # Process all samples from this tar.gz
        transform = transforms.ToTensor()
        all_images = []
        all_lats = []
        all_lons = []
        all_json_data = []
        
        for sample in self.worker_data.samples:
            # Load image from bytes and convert to tensor
            image = Image.open(io.BytesIO(sample['image'])).convert('RGB')
            image_tensor = transform(image)
            
            all_images.append(image_tensor)
            all_lats.append(torch.tensor(sample['lat'], dtype=torch.float32))
            all_lons.append(torch.tensor(sample['lon'], dtype=torch.float32))
            all_json_data.append(sample['json_data'])
        
        # Return all samples as a batch
        return {
            'image': all_images,  # List of tensors
            'lat': all_lats,      # List of tensors
            'lon': all_lons,      # List of tensors
            'json_data': all_json_data  # List of dicts
        }

def custom_collate_fn(batch):
    """Custom collate function to handle dicts in json_data"""
    from torch.utils.data._utils.collate import default_collate
    
    # Each item in batch is a dict with lists (all samples from one tar.gz)
    # We need to flatten across all workers
    all_images = []
    all_lats = []
    all_lons = []
    all_json_data = []
    
    # Flatten the batch (each worker contributes all its samples)
    for item in batch:
        all_images.extend(item['image'])  # item['image'] is already a list
        all_lats.extend(item['lat'])     # item['lat'] is already a list
        all_lons.extend(item['lon'])     # item['lon'] is already a list
        all_json_data.extend(item['json_data'])  # item['json_data'] is already a list
    
    # Collate all samples together
    result = {
        'image': default_collate(all_images),
        'lat': default_collate(all_lats),
        'lon': default_collate(all_lons),
        'json_data': all_json_data  # Keep as list (can't be collated)
    }
    
    return result

def worker_init_fn(worker_id):
    """Initialize worker - each worker will process a different tar.gz"""
    pass  # The Dataset handles worker-specific logic

# Create dataset
dataset = TarGZDataset(
    tar_gz_keys=tar_gz_list,
    r2_bucket=R2_BUCKET,
    r2_endpoint=R2_ENDPOINT,
    r2_access_key=R2_ACCESS_KEY,
    r2_secret_key=R2_SECRET_KEY
)

# Create dataloader with multiple workers
# Each worker will process a different tar.gz
num_workers = min(4, len(tar_gz_list))  # Use up to 4 workers or number of tar.gz files
dataloader = DataLoader(
    dataset,
    batch_size=1,
    num_workers=num_workers,
    worker_init_fn=worker_init_fn,
    collate_fn=custom_collate_fn,
    shuffle=False
)

print(f"Created DataLoader with {num_workers} workers")
print(f"Each worker will process a different tar.gz file")
print(f"Each batch contains all images from the tar.gz files processed by workers")
print(f"\nExample usage:")
print(f"  for batch in dataloader:")
print(f"      image = batch['image']      # Tensor [B, C, H, W] - all images from all workers")
print(f"      lat = batch['lat']          # Tensor [B] - latitudes for all images")
print(f"      lon = batch['lon']          # Tensor [B] - longitudes for all images")
print(f"      json_data = batch['json_data']  # List of dicts (one per image)")
print(f"      print(f'Batch size: {{image.shape[0]}}')  # Number of images in this batch")


In [None]:
for batch in dataloader:
    print(batch["image"].shape)
    break