In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from torch.utils.data import Dataset
from PIL import Image

In [None]:
# dataset
class ShipDetectionDataset(Dataset):
    def __init__(self, images_dir, labels_dir, transform=None):
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(images_dir) if f.endswith('.jpg') or f.endswith('.png')]
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.images_dir, img_name)
        label_path = os.path.join(self.labels_dir, img_name.replace('.jpg', '.txt').replace('.png', '.txt'))
        
        # Load image
        img = Image.open(img_path).convert("RGB")
        width, height = img.size
        
        # Load labels
        boxes = []
        labels = []
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f.readlines():
                    parts = line.strip().split()
                    class_id = int(parts[0])
                    cx, cy, w, h = map(float, parts[1:])
                    
                    # Convert normalized center format to x_min, y_min, x_max, y_max
                    x_center = cx * width
                    y_center = cy * height
                    w = w * width
                    h = h * height
                    
                    x_min = x_center - w / 2
                    y_min = y_center - h / 2
                    x_max = x_center + w / 2
                    y_max = y_center + h / 2
                    
                    boxes.append([x_min, y_min, x_max, y_max])
                    labels.append(class_id)
        
        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64)
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        
        if self.transform:
            img = self.transform(img)
        
        return img, target

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ShipDetectorCNN(nn.Module):
    def __init__(self):
        super(ShipDetectorCNN, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 128x128
            
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 64x64
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 32x32
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 16x16
        )
        
        self.fc = nn.Sequential(
            nn.Linear(128 * 16 * 16, 512),
            nn.ReLU(),
            nn.Linear(512, 5)  # 4 for bbox (x_min, y_min, x_max, y_max) + 1 for ship probability
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)
        return x


In [None]:
bbox_loss_fn = nn.MSELoss()
class_loss_fn = nn.BCEWithLogitsLoss()

def combined_loss(prediction, target):
    pred_bbox = prediction[:, :4]
    pred_class = prediction[:, 4]
    
    target_bbox = target[:, :4]
    target_class = target[:, 4]
    
    bbox_loss = bbox_loss_fn(pred_bbox, target_bbox)
    class_loss = class_loss_fn(pred_class, target_class)
    
    return bbox_loss + class_loss


In [None]:
# load images
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor(),
])

dataset = ShipDetectionDataset(
    images_dir='./ship_dataset_v0/', 
    labels_dir='./ship_dataset_v0/', 
    transform=transform
)

img, target = dataset[0]

print(img.shape)      # (3, 256, 256)
print(target)         # dict with "boxes" and "labels"

In [None]:
# train 
model = ShipDetectorCNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(num_epochs):
    for images, targets in train_loader:
        images = images.to(device)
        
        # Prepare targets: 
        # assume targets["boxes"] has shape (batch, num_boxes, 4)
        # and targets["labels"] has shape (batch, num_boxes)
        
        # Since you probably have only one ship per image, just pick the first box
        true_boxes = targets["boxes"][:, 0, :]
        true_labels = targets["labels"][:, 0].float()
        
        # Concatenate to match (batch_size, 5)
        true_target = torch.cat([true_boxes, true_labels.unsqueeze(1)], dim=1).to(device)
        
        preds = model(images)
        
        loss = combined_loss(preds, true_target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"Epoch {epoch}: Loss {loss.item():.4f}")
