In [1]:
%reload_ext autoreload
%autoreload 2

import os
import cv2
import sys
import glob
import json
import time
import click
import random
import logging
import numpy as np
import pandas as pd
import ruamel.yaml as yaml
import rasterio
from rasterio.windows import Window
from rasterio.transform import from_bounds
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
from tqdm import tqdm
from shapely.geometry import shape, mapping, box, Polygon
from shapely.ops import unary_union

import albumentations as A
from albumentations.pytorch import ToTensorV2

from ship_detector.scripts.train_vit import ViTShipClassifier
from ship_detector.scripts.inference_pipeline import PipelineConfig
from ship_detector.scripts.inference_pipeline import ImageDataset

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
vit_checkpoint = 'outputs/checkpoints/vit-epoch=03-val_acc=0.967.ckpt'
vit_config = 'configs/vit.yaml'
patch_size = 224
overlap = 32
batch_size = 128
confidence_threshold = 0.5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
us3_s3 = False
output_format = ['mask', 'geojson']

In [6]:
with open(vit_config, 'r') as f:
    vit_config = yaml.YAML(typ='rt').load(f)


In [7]:
model = ViTShipClassifier(vit_config)

2025-09-01 22:37:01,355 - timm.models._builder - INFO - Loading pretrained weights from Hugging Face hub (timm/vit_base_patch16_224.augreg2_in21k_ft_in1k)
2025-09-01 22:37:01,494 - timm.models._hub - INFO - [timm/vit_base_patch16_224.augreg2_in21k_ft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-09-01 22:37:01,602 - timm.models._builder - INFO - Missing keys (head.weight, head.bias) discovered while loading pretrained weights. This is expected if model is being adapted.


In [8]:
checkpoint = torch.load(vit_checkpoint, map_location='cpu', weights_only=False)
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [9]:
model = model.to(device)
model.eval()

ViTShipClassifier(
  (model): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (norm): Identity()
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
       

In [10]:
vit_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

In [11]:
img_paths = glob.glob('data/airbus-ship-detection/test_v2/*.jpg')

In [12]:
# randomly pick an image path
image_path = random.choice(img_paths)
image_path

'data/airbus-ship-detection/test_v2\\eddb5b066.jpg'

In [13]:
def load_image( image_path: str) -> np.ndarray:
    ext = Path(image_path).suffix.lower()
    if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']:
        image = Image.open(image_path).convert('RGB')
        image = np.array(image)
    else:
        image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image

def tile_image(config: Dict[str, Any], image: np.ndarray) -> Tuple[List[Dict], Dict]:
    patches = []
    h, w = image.shape[:2]
    
    metadata = {
        'width': w,
        'height': h,
        'channels': image.shape[2] if len(image.shape) > 2 else 1,
        'dtype': str(image.dtype)
    }
    
    stride = config.patch_size - config.overlap
    patch_id = 0
    
    for row in range(0, h - config.patch_size + 1, stride):
        for col in range(0, w - config.patch_size + 1, stride):
            patch_data = image[row:row+config.patch_size, col:col+config.patch_size]
            
            if patch_data.std() < 1.0:
                continue
            
            patches.append({
                'id': patch_id,
                'image': patch_data.copy(),
                'row': row,
                'col': col,
                'coords': (row, col, row+config.patch_size, col+config.patch_size)
            })
            patch_id += 1
    
    if w % stride != 0 and w > config.patch_size:
        col = w - config.patch_size
        for row in range(0, h - config.patch_size + 1, stride):
            patch_data = image[row:row+config.patch_size, col:col+config.patch_size]
            
            if patch_data.std() >= 1.0:
                patches.append({
                    'id': patch_id,
                    'image': patch_data.copy(),
                    'row': row,
                    'col': col,
                    'coords': (row, col, row+config.patch_size, col+config.patch_size)
                })
                patch_id += 1
    
    if h % stride != 0 and h > config.patch_size:
        row = h - config.patch_size
        for col in range(0, w - config.patch_size + 1, stride):
            patch_data = image[row:row+config.patch_size, col:col+config.patch_size]
            
            if patch_data.std() >= 1.0:
                patches.append({
                    'id': patch_id,
                    'image': patch_data.copy(),
                    'row': row,
                    'col': col,
                    'coords': (row, col, row+config.patch_size, col+config.patch_size)
                })
                patch_id += 1
    
    if w % stride != 0 and h % stride != 0 and w > config.patch_size and h > config.patch_size:
        row = h - config.patch_size
        col = w - config.patch_size
        patch_data = image[row:row+config.patch_size, col:col+config.patch_size]
        
        if patch_data.std() >= 1.0:
            patches.append({
                'id': patch_id,
                'image': patch_data.copy(),
                'row': row,
                'col': col,
                'coords': (row, col, row+config.patch_size, col+config.patch_size)
            })
    return patches, metadata

def classify_patches(patches: List[Dict], config: Dict[str, Any], model: ViTShipClassifier, transforms: transforms.Compose) -> np.ndarray:
    if not patches:
        return np.array([])
    
    dataset = ImageDataset(patches, transform=transforms)
    dataloader = DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
    )
    
    probabilities = []
    
    with torch.no_grad():
        for batch, indices in tqdm(dataloader, desc='Classifying patches'):
            batch = batch.to(device)
            outputs = model(batch).squeeze()
            
            if outputs.dim() == 0:
                outputs = outputs.unsqueeze(0)
            
            probs = torch.sigmoid(outputs)
            probabilities.extend(probs.cpu().numpy())
            
    return np.array(probabilities)

In [None]:
def process_image(model: ViTShipClassifier, transforms: transforms.Compose, config: Dict[str, Any], image_path: str, output_dir: str) -> Dict:
    start_time = time.time()
    image_name = Path(image_path).stem
    
    image = load_image(image_path)
    patches, metadata = tile_image(config, image)
    
    if not patches:
        return {
            'image': image_name,
            'status': 'no_valid_patches',
            'processing_time': time.time() - start_time
        }
        
    probabilities = classify_patches(patches, config, model, transforms)
    
    ship_indices = np.where(probabilities >= config.confidence_threshold)[0]
    
    if config.save_patch_predictions:
        patch_results = pd.DataFrame({
            'patch_id': [p['id'] for p in patches],
            'row': [p['row'] for p in patches],
            'col': [p['col'] for p in patches],
            'probability': probabilities,
            'has_ship': probabilities >= config.confidence_threshold
        })
        patch_results.to_csv(
            os.path.join(output_dir, f"{image_name}_patches.csv"),
            index=False
        )
    return {
        'image': image_name,
        'status': 'processed',
        'processing_time': time.time() - start_time
    }