In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
import sqlite3
import pandas as pd
import random
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch import optim
from tqdm import tqdm

import util

# class definitions

In [None]:
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()

        #input to spread 2 channels out to 3
        self.input = nn.Conv2d(in_channels=2, out_channels=3, kernel_size=1)

        #base vision model we're wrapping around with output layer removed
        base_model = models.resnet34(pretrained=True)
        self.core = nn.Sequential(*list(base_model.children())[:-1])

        #custom output layer, with dimension dynamically calculated
        dummy = torch.randn(1, 3, 600, 600)
        dummy_out = self.core(dummy)
        flattened_dim = dummy_out.view(1, -1).shape[1]
        self.output = nn.Linear(in_features=flattened_dim, out_features=9)

    def forward(self, x):
        x = self.input(x)
        x = self.core(x)
        x = x.view(x.size(0), -1)
        x = self.output(x)
        return x

In [None]:
class ShipDataset(Dataset):
    def __init__(self, df):
        self.df = df.copy()

        #bulk preprocessing
        radians = np.deg2rad(self.df['heading'])
        self.df['heading_x'] = np.cos(radians)
        self.df['heading_y'] = np.sin(radians)

        radians = np.deg2rad(self.df['cog'])
        self.df['cog_x'] = np.cos(radians)
        self.df['cog_y'] = np.sin(radians)

        self.df = self.df.fillna(0)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        row = self.df.iloc[idx]

        #prepare image
        img = np.load(row['image_path'])  
        img = self.normalize(img)
        image_tensor = torch.from_numpy(img)
        image_tensor = self.pad_to_600(image_tensor)

        #prepare target
        target = torch.tensor([
            row['sog'],
            row['heading_x'],
            row['heading_y'],
            row['cog_x'],
            row['cog_y'],
            row['length'],
            row['width'],
            row['draft'],
            row['cargo'],
        ], dtype=torch.float32)

        return image_tensor, target

    @staticmethod
    def normalize(arr):
        clip_percentiles = (2, 98)
        arr = arr.astype(np.float32)
        for c in range(arr.shape[0]):
            band = arr[c]
            band = np.nan_to_num(band, nan=0.0)
            vmin, vmax = np.percentile(band, clip_percentiles)
            band = np.clip(band, vmin, vmax)
            band = (band - vmin) / (vmax - vmin + 1e-5)
            arr[c] = band
        return arr
    
    @staticmethod
    def pad_to_600(tensor):
        _, h, w = tensor.shape
        pad_h = 600 - h
        pad_w = 600 - w

        pad_top = pad_h // 2
        pad_bottom = pad_h - pad_top
        pad_left = pad_w // 2
        pad_right = pad_w - pad_left

        return F.pad(tensor, (pad_left, pad_right, pad_top, pad_bottom))

In [None]:
def train_model(
    model, 
    dataset, 
    num_epochs=10, 
    batch_size=32, 
    lr=1e-4, 
    val_split=0.1, 
    device="cpu"
                ):

    torch.manual_seed(33)

    # Split dataset
    n = len(dataset)
    val_size = int(val_split * n)
    train_size = n - val_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for imgs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            imgs = imgs.to(device)
            targets = targets.to(device)

            preds = model(imgs)
            loss = criterion(preds, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * imgs.size(0)

        avg_loss = total_loss / train_size
        print(f"Train Loss: {avg_loss:.4f}")

        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for imgs, targets in val_loader:
                imgs = imgs.to(device)
                targets = targets.to(device)
                preds = model(imgs)
                val_loss += criterion(preds, targets).item() * imgs.size(0)

        avg_val = val_loss / val_size
        print(f"Val Loss: {avg_val:.4f}")


# train loop

In [None]:
conn = sqlite3.connect("../../data/ais.db")
df = pd.read_sql_query("SELECT * FROM ais", conn)

dataset = ShipDataset(df)
model = Classifier()

In [None]:
train_model(
    model=model,
    dataset=dataset,
    num_epochs=10,
    batch_size=32,
    lr=1e-4,
    val_split=0.1,
    device="cuda" if torch.cuda.is_available() else "cpu"
)
