# Model Evaluation Notebook (AWS SageMaker S3)
This notebook evaluates the **Distance Estimation** and **Inside Prediction** models.
It is configured to run on an AWS SageMaker Notebook Instance and load data directly from the S3 bucket: `spatial-agent-data-learner-lab`.

In [None]:
# Imports
import os
import json
import re
import io
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate
from torchvision import transforms, models
import torchvision.transforms.functional as TF
from PIL import Image, ImageFile
from tqdm import tqdm
import pycocotools.mask as mask_utils
import boto3

ImageFile.LOAD_TRUNCATED_IMAGES = True
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

# AWS S3 Configuration
S3_BUCKET = 'spatial-agent-data-learner-lab'
S3_CLIENT = boto3.client('s3')

# --- Paths Configuration (Relative to S3 Bucket Root) ---
VAL_JSON_KEY = 'val.json'
VAL_IMAGE_PREFIX = 'data/val/images'
VAL_DEPTH_PREFIX = 'data/val/depths'

# Checkpoint Keys in S3
DIST_CKPT_S3_KEY_EPOCH5 = 'distance_est/ckpt/epoch_5_iter_6831.pth'
DIST_CKPT_S3_KEY_3M = 'distance_est/ckpt/3m_epoch6.pth'
INSIDE_CKPT_S3_KEY = 'inside_pred/ckpt/epoch_4.pth'

## 1. Helper Functions

In [None]:
def load_s3_json(bucket, key):
    """Load JSON file directly from S3."""
    print(f"Loading JSON from s3://{bucket}/{key}...")
    response = S3_CLIENT.get_object(Bucket=bucket, Key=key)
    content = response['Body'].read().decode('utf-8')
    return json.loads(content)

def download_s3_checkpoint(bucket, key, local_path):
    """Download model checkpoint from S3 to local path if not exists."""
    if not os.path.exists(local_path):
        print(f"Downloading checkpoint s3://{bucket}/{key} to {local_path}...")
        os.makedirs(os.path.dirname(local_path), exist_ok=True)
        S3_CLIENT.download_file(bucket, key, local_path)
    else:
        print(f"Checkpoint found at {local_path}, skipping download.")
    return local_path

def load_s3_image(bucket, key):
    """Load image directly from S3 as PIL Image."""
    response = S3_CLIENT.get_object(Bucket=bucket, Key=key)
    image_data = response['Body'].read()
    return Image.open(io.BytesIO(image_data))

## 2. Model Definitions

In [None]:
# Distance Model
class ResNetDistanceRegressor(nn.Module):
    def __init__(self, input_channels=5, backbone='resnet50', pretrained=False):
        super().__init__()
        self.resnet = getattr(models, backbone)(weights=None)
        old_conv = self.resnet.conv1
        self.resnet.conv1 = nn.Conv2d(input_channels, old_conv.out_channels,
                                      kernel_size=old_conv.kernel_size,
                                      stride=old_conv.stride,
                                      padding=old_conv.padding,
                                      bias=old_conv.bias is not None)
        num_feats = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Linear(num_feats, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.resnet(x).squeeze(1)

# Inside Model
class ResNet50Binary(nn.Module):
    def __init__(self, in_channels=5):
        super().__init__()
        self.resnet = models.resnet50(weights=None)
        self.resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, 1)
        
    def forward(self, x):
        x = self.resnet(x)
        return x.squeeze(1)

## 3. Data Parsing Functions

In [None]:
def replace_masks_with_region(original_question: str) -> str:
    original_question = original_question.replace('<image>\n', '')
    token_pattern = re.compile(r'<mask>|[a-zA-Z]+|[^\s\w]')
    tokens = token_pattern.findall(original_question)
    mask_count = 0
    modified_tokens = []
    for i, token in enumerate(tokens):
        if token == '<mask>':
            replacement = f"[Region {mask_count}]"
            mask_count += 1
            modified_tokens.append(replacement)
        else:
            modified_tokens.append(token)
    return ' '.join(modified_tokens)

def parse_distance_data(data):
    print("Parsing Distance Data...")
    data = [q for q in data if q['category'] == 'distance']
    refined_data = []
    
    for item in tqdm(data, desc="Processing Distance Data"):
        try:
            gpt_response = next(conv['value'] for conv in item['conversations'] if conv['from'] == 'gpt')
            region_indices = sorted(set(int(x) for x in re.findall(r'\[Region (\d+)\]', gpt_response)))
            filtered_rle = [item['rle'][idx] for idx in region_indices if 0 <= idx < len(item['rle'])]
            
            if len(filtered_rle) == 2:
                refined_item = {
                    'id': item['id'],
                    'image': item['image'],
                    'rle': filtered_rle,
                    'normalized_answer': item['normalized_answer']
                }
                refined_data.append(refined_item)
        except Exception:
            continue
    print(f"Found {len(refined_data)} distance samples.")
    return refined_data

def parse_inside_data(data):
    print("Parsing Inside Data...")
    processed_data = []
    potential_items = [item for item in data if len(item['conversations']) > 1 and 'inside' in item['conversations'][1]['value']]
    
    for item in tqdm(potential_items, desc="Processing Inside Data"):
        try:
            image = item.get('image')
            rle = item.get('rle')
            question = item['conversations'][0]['value']
            response = item['conversations'][1]['value']
            
            rephrase_question = replace_masks_with_region(question)
            response_split = response.split('.')
            sentence_with_inside = [res for res in response_split if 'inside' in res]
            
            if len(sentence_with_inside) != 1:
                continue
            
            inside_sentence = sentence_with_inside[0]
            if 'inside the buffer region' not in inside_sentence:
                continue
                
            before, after = inside_sentence.split('inside the buffer region')
            all_ids = re.findall(r'\[Region (\d+)\]', rephrase_question)
            before_ids = re.findall(r'\[Region (\d+)\]', before)
            after_ids = re.findall(r'\[Region (\d+)\]', after)
            
            if len(after_ids) != 1 or len(before_ids) == 0:
                continue

            inside_ids = [int(id) for id in before_ids]
            buffer_id = int(after_ids[0])
            outside_ids = [int(id) for id in all_ids if int(id) not in inside_ids and int(id) != buffer_id]
            
            for id in inside_ids:
                processed_data.append({
                    'image': image,
                    'inside': 1,
                    'buffer_rle': rle[buffer_id],
                    'obj_rle': rle[id]
                })
            
            for id in outside_ids:
                processed_data.append({
                    'image': image,
                    'inside': 0,
                    'buffer_rle': rle[buffer_id],
                    'obj_rle': rle[id]
                })
        except Exception:
            continue
    print(f"Found {len(processed_data)} inside/outside pairs.")
    return processed_data

## 4. S3-Aware Dataset Classes with Error Skipping

In [None]:
class S3DistanceDataset(Dataset):
    def __init__(self, samples, bucket, image_prefix, depth_prefix, transform=None, resize=(360, 640)):
        self.samples = samples
        self.bucket = bucket
        self.image_prefix = image_prefix
        self.depth_prefix = depth_prefix
        self.resize = resize
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def decode_mask(self, rle):
        if isinstance(rle['counts'], str):
            rle['counts'] = rle['counts'].encode('utf-8')
        return mask_utils.decode(rle).astype(np.float32)

    def __getitem__(self, idx):
        item = self.samples[idx]
        components = []

        # Load RGB from S3
        img_key = f"{self.image_prefix}/{item['image']}"
        try:
            rgb = load_s3_image(self.bucket, img_key).convert('RGB')
            rgb = TF.resize(rgb, self.resize)
            rgb = np.array(rgb).astype(np.float32) / 255.0
            components.append(rgb)
        except Exception as e:
            # Return None to signal skipping this item
            # print(f"Skipping item {idx}, error loading RGB: {e}")
            return None

        # Process Masks (Already in memory)
        try:
            mask_a = self.decode_mask(item['rle'][0])
            mask_b = self.decode_mask(item['rle'][1])
            mask_a = Image.fromarray(mask_a)
            mask_b = Image.fromarray(mask_b)
            
            mask_a = TF.resize(mask_a, self.resize, interpolation=transforms.InterpolationMode.NEAREST)
            mask_b = TF.resize(mask_b, self.resize, interpolation=transforms.InterpolationMode.NEAREST)
            
            mask_a = np.array(mask_a).astype(np.float32)
            mask_b = np.array(mask_b).astype(np.float32)
            
            components.append(mask_a[..., None])
            components.append(mask_b[..., None])
        except Exception as e:
            # Return None to signal skipping
            print(f"Skipping item {idx}, error masks: {e}")
            return None

        # Input Tensor: 5 Channels
        input_tensor = np.concatenate(components, axis=-1)  # H x W x C
        input_tensor = torch.tensor(input_tensor).permute(2, 0, 1)  # C x H x W

        distance = torch.tensor(item['normalized_answer'], dtype=torch.float32)

        if self.transform:
            input_tensor = self.transform(input_tensor)

        return input_tensor, distance

class S3InsideDataset(Dataset):
    def __init__(self, samples, bucket, image_prefix, resize=(360, 640)):
        self.samples = samples
        self.bucket = bucket
        self.image_prefix = image_prefix
        self.resize = resize
        self.to_tensor = transforms.ToTensor()
        self.resize_tf = transforms.Resize(resize, interpolation=transforms.InterpolationMode.BILINEAR)

    def __len__(self):
        return len(self.samples)
    
    def decode_rle(self, rle):
        if isinstance(rle['counts'], str):
            rle['counts'] = rle['counts'].encode('utf-8')
        return mask_utils.decode(rle).astype(np.float32)
    
    def __getitem__(self, idx):
        item = self.samples[idx]
        img_key = f"{self.image_prefix}/{item['image']}"
        
        try:
            img = load_s3_image(self.bucket, img_key).convert('RGB')
            img = self.resize_tf(img)
            img = self.to_tensor(img)
            
            buffer_mask = self.decode_rle(item['buffer_rle'])
            obj_mask = self.decode_rle(item['obj_rle'])

            buffer_mask = Image.fromarray((buffer_mask * 255).astype(np.uint8))
            buffer_mask = self.resize_tf(buffer_mask)
            buffer_mask = self.to_tensor(buffer_mask)

            obj_mask = Image.fromarray((obj_mask * 255).astype(np.uint8))
            obj_mask = self.resize_tf(obj_mask)
            obj_mask = self.to_tensor(obj_mask)
            
            x = torch.cat([img, buffer_mask, obj_mask], dim=0)
            y = torch.tensor(item['inside'], dtype=torch.float32)
            
            return x, y
        except Exception as e:
            # Return None to signal skipping
            # print(f"Skipping item {idx}, error: {e}")
            return None

# Custom Collate to skip None
def collate_skip_none(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if len(batch) == 0:
        return None
    return default_collate(batch)

## 5. Main Execution
1. Load `val.json` from S3.
2. Parse data for each task.
3. Download checkpoints locally.
4. Run evaluation (skipping missing files).

In [None]:
# 1. Load Data
print("Loading Validation Set from S3...")
val_data = load_s3_json(S3_BUCKET, VAL_JSON_KEY)

dist_samples = parse_distance_data(val_data)
inside_samples = parse_inside_data(val_data)

# 2. Prepare Datasets
dist_dataset = S3DistanceDataset(dist_samples, S3_BUCKET, VAL_IMAGE_PREFIX, VAL_DEPTH_PREFIX)
dist_loader = DataLoader(dist_dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=collate_skip_none)

inside_dataset = S3InsideDataset(inside_samples, S3_BUCKET, VAL_IMAGE_PREFIX)
inside_loader = DataLoader(inside_dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=collate_skip_none)

print("Datasets ready.")

In [None]:
def run_eval_distance(ckpt_s3_key, label):
    local_ckpt = f"/tmp/{os.path.basename(ckpt_s3_key)}"
    download_s3_checkpoint(S3_BUCKET, ckpt_s3_key, local_ckpt)
    
    model = ResNetDistanceRegressor(input_channels=5, backbone='resnet50')
    model.load_state_dict(torch.load(local_ckpt, map_location=DEVICE))
    model.to(DEVICE).eval()
    
    total_abs_error = 0
    total_sq_error = 0
    count = 0
    
    with torch.no_grad():
        for batch in tqdm(dist_loader, desc=f"Eval {label}"):
            if batch is None:
                continue # Skip empty batch
                
            inputs, targets = batch
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            preds = model(inputs)
            
            total_abs_error += torch.abs(preds - targets).sum().item()
            total_sq_error += ((preds - targets) ** 2).sum().item()
            count += targets.size(0)
            
    print(f"\nResults [{label}]:")
    if count > 0:
        print(f"Evaluated on {count} samples.")
        print(f"MAE: {total_abs_error/count:.4f} m")
        print(f"RMSE: {np.sqrt(total_sq_error/count):.4f} m")
    else:
        print("No valid samples found.")

def run_eval_inside():
    local_ckpt = f"/tmp/{os.path.basename(INSIDE_CKPT_S3_KEY)}"
    download_s3_checkpoint(S3_BUCKET, INSIDE_CKPT_S3_KEY, local_ckpt)
    
    model = ResNet50Binary(in_channels=5)
    model.load_state_dict(torch.load(local_ckpt, map_location=DEVICE))
    model.to(DEVICE).eval()
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(inside_loader, desc="Eval Inside"):
            if batch is None:
                continue # Skip empty batch
                
            inputs, labels = batch
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            preds = (torch.sigmoid(model(inputs)) > 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            
    print(f"\nInside Eval Results:")
    if total > 0:
        print(f"Evaluated on {total} samples.")
        print(f"Inside Acc: {correct/total*100:.2f}%")
    else:
        print("No valid samples found.")

In [None]:
# Execute
run_eval_distance(DIST_CKPT_S3_KEY_EPOCH5, "Distance (Epoch 5)")
run_eval_distance(DIST_CKPT_S3_KEY_3M, "Distance (3m Epoch 6)")
run_eval_inside()