In [9]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
import sqlite3



In [None]:
class ShipResNet(nn.Module):
    def __init__(self, base_model_fn=models.resnet18, pretrained=True):
        super().__init__()

        # Preprocessing layer: project (2 channels VV+VH) -> 3 channels
        self.input_proj = nn.Conv2d(in_channels=2, out_channels=3, kernel_size=1)

        # Load base model and strip off the final fully connected layer
        base_model = base_model_fn(pretrained=pretrained)
        self.feature_extractor = nn.Sequential(*list(base_model.children())[:-1])

        # Regression head: output 8 values
        self.head = nn.Linear(in_features=512, out_features=8)

    def forward(self, x):
        x = self.input_proj(x)  # (B, 2, H, W) -> (B, 3, H, W)
        x = self.feature_extractor(x)  # (B, C, 1, 1)
        x = x.view(x.size(0), -1)  # Flatten
        return self.head(x)

In [4]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset

class ShipDataset(Dataset):
    def __init__(self, df, clip_percentiles=(2, 98)):
        self.df = df.reset_index(drop=True)
        self.clip_percentiles = clip_percentiles

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Load image
        img = np.load(row['image_path'])  # shape (2, H, W)
        if img.ndim != 3:
            raise ValueError(f"Unexpected shape {img.shape} in {row['image_path']}")

        # Normalize per-channel
        img = img.astype(np.float32)
        for c in range(img.shape[0]):
            band = img[c]
            band = np.nan_to_num(band, nan=0.0)
            vmin, vmax = np.percentile(band, self.clip_percentiles)
            band = np.clip(band, vmin, vmax)
            band = (band - vmin) / (vmax - vmin + 1e-5)
            img[c] = band

        # Convert to tensor
        image_tensor = torch.from_numpy(img)

        # Prepare target values
        heading_deg = row['heading']
        heading_rad = np.deg2rad(heading_deg) if pd.notnull(heading_deg) else 0.0
        heading_x = np.cos(heading_rad)
        heading_y = np.sin(heading_rad)

        target = torch.tensor([
            row['sog'],
            row['length'],
            row['width'] if pd.notnull(row['width']) else 0.0,
            row['draft'] if pd.notnull(row['draft']) else 0.0,
            heading_x,
            heading_y,
        ], dtype=torch.float32)

        return image_tensor, target


In [6]:
import torch
from torch.utils.data import DataLoader
from torch import nn, optim
from tqdm import tqdm

def train_model(
    model, dataset, num_epochs=10, batch_size=32, lr=1e-4, val_split=0.1, device="cuda"
):
    torch.manual_seed(0)

    # Split dataset
    n = len(dataset)
    val_size = int(val_split * n)
    train_size = n - val_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for imgs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            imgs = imgs.to(device)
            targets = targets.to(device)

            preds = model(imgs)
            loss = criterion(preds, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * imgs.size(0)

        avg_loss = total_loss / train_size
        print(f"Train Loss: {avg_loss:.4f}")

        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for imgs, targets in val_loader:
                imgs = imgs.to(device)
                targets = targets.to(device)
                preds = model(imgs)
                val_loss += criterion(preds, targets).item() * imgs.size(0)

        avg_val = val_loss / val_size
        print(f"Val Loss: {avg_val:.4f}")


In [10]:

import torch

# Load dataframe
import pandas as pd
conn = sqlite3.connect("../../data/ais.db")
df = pd.read_sql_query("SELECT * FROM ais", conn)

# Optional: shuffle or subsample for a quick first test
# df = df.sample(frac=0.1).reset_index(drop=True)

# Create dataset
dataset = ShipDataset(df)

# Init model
model = ShipResNet(num_outputs=8)

# Kick off training
train_model(
    model=model,
    dataset=dataset,
    num_epochs=10,
    batch_size=32,
    lr=1e-4,
    val_split=0.1,
    device="cuda" if torch.cuda.is_available() else "cpu"
)
