In [1]:
import os
import torch
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

**Custom Dataset class for loading images and annotations**

In [2]:
class ShipDataset(Dataset):
    def __init__(self, images_dir, annotations_dir, transform=None):
        self.images_dir = images_dir
        self.annotations_dir = annotations_dir
        self.transform = transform
        self.images = sorted(os.listdir(images_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.images_dir, self.images[idx])
        image = Image.open(img_path).convert("RGB")
        
        # Load annotations
        annotation_path = os.path.join(self.annotations_dir, self.images[idx].replace(".jpg", ".txt"))
        boxes, labels = [], []
        
        with open(annotation_path, "r") as file:
            for line in file:
                class_id, x_center, y_center, width, height = map(float, line.split())
  
                img_width, img_height = image.size
                x_center *= img_width
                y_center *= img_height
                width *= img_width
                height *= img_height
                x_min = int(x_center - width / 2)
                y_min = int(y_center - height / 2)
                x_max = int(x_center + width / 2)
                y_max = int(y_center + height / 2)
                
                boxes.append([x_min, y_min, x_max, y_max])
                labels.append(int(class_id) + 1)
        
        target = {
            "boxes": torch.tensor(boxes, dtype=torch.float32),
            "labels": torch.tensor(labels, dtype=torch.int64),
        }
        
        if self.transform:
            image = self.transform(image)
        
        return image, target


In [3]:
# Define directories
images_dir = "D:/04_Personal_Files/Python/Ship_Detection_Model/Ship_Segmentation/DATASET/train/images"
annotations_dir = "D:/04_Personal_Files/Python/Ship_Detection_Model/Ship_Segmentation/DATASET/train/labels"

In [4]:
# Transform
transform = transforms.Compose([transforms.ToTensor()])

In [5]:
# Create Dataset and DataLoader
dataset = ShipDataset(images_dir, annotations_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

**Data Preprocessing**

In [6]:
from torchvision.transforms import functional as F
import random

In [7]:
class AugmentedShipDataset(ShipDataset):
    def __getitem__(self, idx):
        image, target = super().__getitem__(idx)
        
        # Random horizontal flip
        if random.random() > 0.5:
            image = F.hflip(image)
            target["boxes"][:, [0, 2]] = image.size[0] - target["boxes"][:, [2, 0]]
        
        return image, target

**Train Model**

In [8]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

In [9]:
# Load pre-trained Faster R-CNN model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 4  



In [10]:
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

In [11]:
# Set up optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

**Training Loop**

In [12]:
num_epochs = 1
model.train()

for epoch in range(num_epochs):
    epoch_loss = 0
    for images, targets in dataloader:
        images = list(img.to(device) for img in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        # Forward pass
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        epoch_loss += losses.item()
        
        # Backward pass
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}, Loss: {epoch_loss}")


Epoch 1, Loss: 9.873405039310455


**Validation and Accuracy Checking**

In [13]:
from torchvision.ops import box_iou

In [14]:
model.eval()
validation_dataloader = DataLoader(dataset, batch_size=4, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

correct_boxes = 0
total_boxes = 0

for images, targets in validation_dataloader:
    images = list(img.to(device) for img in images)
    with torch.no_grad():
        outputs = model(images)

    for target, output in zip(targets, outputs):
        target_boxes = target["boxes"].cpu()
        pred_boxes = output["boxes"].cpu()
        
        if len(pred_boxes) > 0:
            iou = box_iou(pred_boxes, target_boxes).diag()
            correct_boxes += (iou > 0.5).sum().item()
        
        total_boxes += len(target_boxes)

accuracy = correct_boxes / total_boxes
print(f"Validation Accuracy: {accuracy:.2f}")

Validation Accuracy: 0.46


**Save the Model**

In [17]:
torch.save(model.state_dict(), "Models/faster_rcnn_model.pth")
print("Model saved successfully.")

Model saved successfully.
