# Imports

In [1]:
import csv
import numpy as np
from collections import defaultdict
from PIL import Image

# Decoding of run-length encoding

In [2]:
def rle_decode(mask_rle, shape): 
    '''     
    mask_rle: run-length as string formated (start length)     
    shape: (height,width) of array to return      
    Returns numpy array, 1 - mask, 0 - background      
    ''' 
    s = mask_rle.split() 
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] 
    starts -= 1 
    ends = starts + lengths 
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8) 
    for lo, hi in zip(starts, ends): 
         img[lo:hi] = 1 
    return img.reshape(shape).T # Needed to align to RLE direction

In [3]:
# get masks and store in hashmap
def get_masks(file):
    segm = csv.reader(open(file, 'r'))
    masks = {}

    next(segm)
       
    while True:
        try:
            curr = next(segm);
            if curr[0] not in masks:
                masks[curr[0]] = []
            masks[curr[0]].append(curr[1])
        except StopIteration:
            break
        
    return masks

In [4]:
def masks_decode(masks):
    for i in range(len(masks)):
        for j in range(len(masks[i])):
            masks[i][j] = rle_decode(masks[i][j], (768,768))

In [5]:
# def generate_segmentation_mask(img):
#     segmentation_mask = np.zeros((768,768), dtype=np.uint8)
#     og_image = Image.open(f'airbus-ship-detection/train_v2/{img}')
#     adjusted_img = og_image.rotate(90).transpose(Image.FLIP_TOP_BOTTOM)
#     img_arr = np.asarray(adjusted_img).copy()
        
#     img_arr[mask == 1] = [255,255,255]

#     segmentation_image = Image.fromarray(img_arr)
#     segmentation_image.save(f'sample_imgs/{img}_segmentation_result.png')

# Data Loader

In [6]:
from torchvision.io import read_image
from torchvision.ops.boxes import masks_to_boxes
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F
import torch
import os

class ShipDataset(torch.utils.data.Dataset):
    def __init__(self, root, size, masks, transforms=None, train=True):
        self.root = root
        self.transforms = transforms
        self.imgs = list(masks.keys())[:size]
        self.masks = list(masks.values())[:size]
        if train: masks_decode(self.masks)

    def __getitem__(self, idx):
        #load image
        img_name = os.path.join(self.root, 'train_v2', self.imgs[idx])
        image = read_image(img_name)
        
        num_objs = len(self.masks[idx])
        # get bounding boxes coordinates for each mask
        boxes = torch.zeros((0, 4), dtype=torch.float32)
        area = torch.zeros((0,), dtype=torch.float32)
        
        if sum(sum(self.masks[idx][0])) != 0:
            masks = torch.stack([torch.tensor(m, dtype=torch.uint8) for m in self.masks[idx]])
            boxes = masks_to_boxes(masks)
            area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        
        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)
        
        image = tv_tensors.Image(image)
        
        target = {}
        target["boxes"] = tv_tensors.BoundingBoxes(boxes, format='xyxy', canvas_size=F.get_size(image))
        target["masks"] = tv_tensors.Mask(self.masks[idx])
        target["labels"] = labels
        target["image_id"] = self.imgs[idx][:-4]
        target["area"] = area
        
        if self.transforms is not None:
            image, target = self.transforms(image, target)
        
        return image, target
    
    def __len__(self):
        return len(self.imgs)

In [7]:
from torchvision.transforms import v2 as T

def get_transforms(train):
    transforms = []
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    transforms.append(T.ToDtype(torch.float, scale=True))
    transforms.append(T.ToPureTensor())
    return T.Compose(transforms)

In [8]:
import utils

masks_encoded = get_masks('airbus-ship-detection/train_ship_segmentations_v2.csv')

data_size = 1000

dataset = ShipDataset('airbus-ship-detection', data_size, masks_encoded, get_transforms(train=True))   
dataset_test = ShipDataset('airbus-ship-detection', data_size, masks_encoded, get_transforms(train=False), train=False)

indices = torch.randperm(len(dataset)).tolist()
data_train = torch.utils.data.Subset(dataset, indices[:-50])
data_test = torch.utils.data.Subset(dataset_test, indices[-50:])
   
data_loader_train = torch.utils.data.DataLoader(
    data_train, 
    batch_size=2, 
    shuffle=True, 
    num_workers=0, 
    collate_fn=utils.collate_fn
)

data_loader_test = torch.utils.data.DataLoader(
    data_test, 
    batch_size=1, 
    shuffle=False, 
    num_workers=0, 
    collate_fn=utils.collate_fn
)

# Models

In [9]:
from torchvision.models.detection import FasterRCNN, MaskRCNN
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

backbone = resnet_fpn_backbone('resnet101', pretrained=True)

model = MaskRCNN(backbone, num_classes=2)



In [10]:
# # For Training
# images, targets = next(iter(data_loader))
# images = list(image for image in images)
# targets = [{k: v for k, v in t.items()} for t in targets]
# output = model(images,targets)   # Returns losses and detections
# print(output)
# # loss = loss_function(output, targets)

# #For inference
# model.eval()
# x = [torch.rand(3, 786, 786),torch.rand(3, 786, 786)]
# predictions = model(x)           # Returns predictions
# print(predictions[0])

In [11]:
# selecting device
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')

In [None]:
from engine import train_one_epoch, evaluate

model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

num_epoch = 10

for epoch in range(num_epoch):
    train_one_epoch(model, optimizer, data_loader_train, device, epoch, print_freq=10)
    lr_scheduler.step()
    evaluate(model, data_loader_test, device=device)